package com.github.netty.springboot.server; import com.github.netty.protocol.servlet.util.HttpHeaderConstants; import com.github.netty.core.util.Wrapper; import com.github.netty.protocol.servlet.ServletChannelHandler; import com.github.netty.protocol.servlet.ServletHttpServletRequest; import com.github.netty.protocol.servlet.util.ServletUtil; import com.github.netty.protocol.servlet.websocket.NettyMessageToWebSocketRunnable; import com.github.netty.protocol.servlet.websocket.WebSocketServerContainer; import com.github.netty.protocol.servlet.websocket.WebSocketServerHandshaker13Extension; import com.github.netty.protocol.servlet.websocket.WebSocketSession; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; import io.netty.handler.codec.http.websocketx.WebSocketVersion; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.standard.AbstractStandardUpgradeStrategy; import org.springframework.web.socket.server.standard.ServerEndpointRegistration; import javax.servlet.ServletContext; import javax.servlet.http.HttpServletRequest; import javax.websocket.Endpoint; import javax.websocket.Extension; import javax.websocket.server.ServerEndpointConfig; import java.security.Principal; import java.util.*; /** * Websocket version number: the version number of draft 8 to draft 12 is 8, and the version number of draft 13 and later is the same as the draft number * @author wangzihao */ public class NettyRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { public static final String SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE = "javax.websocket.server.ServerContainer"; private int maxFramePayloadLength; public NettyRequestUpgradeStrategy() { this(64 * 1024); } public NettyRequestUpgradeStrategy(int maxFramePayloadLength) { this.maxFramePayloadLength = maxFramePayloadLength; } @Override public String[] getSupportedVersions() { return new String[]{WebSocketVersion.V13.toHttpHeaderValue()}; } @Override protected void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, List<Extension> selectedExtensions, Endpoint endpoint) throws HandshakeFailureException { HttpServletRequest servletRequest = getHttpServletRequest(request); ServletHttpServletRequest httpServletRequest = ServletUtil.unWrapper(servletRequest); if(httpServletRequest == null) { throw new HandshakeFailureException( "Servlet request failed to upgrade to WebSocket: " + servletRequest.getRequestURL()); } WebSocketServerContainer serverContainer = getContainer(servletRequest); Principal principal = request.getPrincipal(); Map<String, String> pathParams = new LinkedHashMap<>(3); ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(servletRequest.getRequestURI(), endpoint); endpointConfig.setSubprotocols(Arrays.asList(WebSocketServerHandshaker.SUB_PROTOCOL_WILDCARD,selectedProtocol)); if(selectedExtensions != null) { endpointConfig.setExtensions(selectedExtensions); } try { handshakeToWebsocket(httpServletRequest, selectedProtocol, maxFramePayloadLength, principal, selectedExtensions, pathParams, endpoint, endpointConfig, serverContainer); } catch (Exception e) { throw new HandshakeFailureException( "Servlet request failed to upgrade to WebSocket: " + servletRequest.getRequestURL(), e); } } @Override protected WebSocketServerContainer getContainer(HttpServletRequest request) { ServletContext servletContext = request.getServletContext(); Object websocketServerContainer = servletContext.getAttribute(SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); if (!(websocketServerContainer instanceof WebSocketServerContainer)) { websocketServerContainer = new WebSocketServerContainer(); servletContext.setAttribute(SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE, websocketServerContainer); } return (WebSocketServerContainer) websocketServerContainer; } /** * The WebSocket handshake * @param servletRequest servletRequest * @param subprotocols subprotocols * @param maxFramePayloadLength maxFramePayloadLength * @param userPrincipal userPrincipal * @param negotiatedExtensions negotiatedExtensions * @param pathParameters pathParameters * @param localEndpoint localEndpoint * @param endpointConfig endpointConfig * @param webSocketContainer webSocketContainer */ protected void handshakeToWebsocket(ServletHttpServletRequest servletRequest, String subprotocols, int maxFramePayloadLength, Principal userPrincipal, List<Extension> negotiatedExtensions, Map<String, String> pathParameters, Endpoint localEndpoint, ServerEndpointConfig endpointConfig, WebSocketServerContainer webSocketContainer){ FullHttpRequest nettyRequest = servletRequest.getNettyRequest(); ChannelHandlerContext channelContext = Wrapper.unwrap(servletRequest.getServletHttpExchange().getChannelHandlerContext()); String queryString = servletRequest.getQueryString(); String httpSessionId = servletRequest.getSession().getId(); String webSocketURL = getWebSocketLocation(servletRequest); Map<String,List<String>> requestParameterMap = getRequestParameterMap(servletRequest); WebSocketServerHandshaker13Extension wsHandshaker = new WebSocketServerHandshaker13Extension(webSocketURL,subprotocols,true,maxFramePayloadLength); ChannelFuture handshakelFuture = wsHandshaker.handshake(channelContext.channel(), nettyRequest); handshakelFuture.addListener((ChannelFutureListener) future -> { if(future.isSuccess()) { Channel channel = future.channel(); ServletChannelHandler.setMessageToRunnable(channel, new NettyMessageToWebSocketRunnable(ServletChannelHandler.getMessageToRunnable(channel))); WebSocketSession websocketSession = new WebSocketSession( channel, webSocketContainer, wsHandshaker, requestParameterMap, queryString, userPrincipal, httpSessionId, negotiatedExtensions, pathParameters, localEndpoint,endpointConfig); WebSocketSession.setSession(channel, websocketSession); localEndpoint.onOpen(websocketSession, endpointConfig); }else { logger.error("The Websocket handshake failed : "+ webSocketURL, future.cause()); } }); } protected Map<String,List<String>> getRequestParameterMap(HttpServletRequest request){ MultiValueMap<String,String> requestParameterMap = new LinkedMultiValueMap<>(); for(Map.Entry<String,String[]> entry : request.getParameterMap().entrySet()){ for(String value : entry.getValue()){ requestParameterMap.add(entry.getKey(),value); } } return requestParameterMap; } protected String getWebSocketLocation(HttpServletRequest req) { String host = req.getHeader(HttpHeaderConstants.HOST.toString()); if(host == null || host.isEmpty()){ host = req.getServerName(); } String scheme = req.isSecure()? "wss://" : "ws://"; return scheme + host + req.getRequestURI(); } public void setMaxFramePayloadLength(int maxFramePayloadLength) { this.maxFramePayloadLength = maxFramePayloadLength; } public int getMaxFramePayloadLength() { return maxFramePayloadLength; } }