Netty:WebSocket握手中的Set-Cookie

时间:2016-07-11 11:49:28

标签: netty

我的管道如下所示

ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new HttpServerCodec());
pipeline.addLast(new HttpObjectAggregator(65536));
pipeline.addLast(new WebSocketServerProtocolHandler(WEBSOCKET_PATH, null, true));

我想在握手响应中添加Set-Cookie HTTP标头。这是RFC6455

的一部分

来自服务器的握手如下所示:

    Connection:upgrade
    Sec-Websocket-Accept:T1UGQ4HhT3dvLNq5Yi+i/gfASi8=
    Upgrade:websocket
    Set-Cookie: ccc=22; path=/; HttpOnly
  

无序的标题字段集合位于两者中的前导行之后   案例。这些头字段的含义在第4节中指定   这份文件。还可以存在附加的头字段   作为cookies [RFC6265]。

1 个答案:

答案 0 :(得分:0)

我找不到好方法。最后我通过反射调用私有方法来做到这一点。

Netty 4.1.2.Final

首先找到WebSocketServerProtocolHandshakeHandler类的源代码。这个类是非公开的,所以制作这个类的副本并根据它进行修改。

class CustomWebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {

    private final String websocketPath;
    private final String subprotocols;
    private final boolean allowExtensions;
    private final int maxFramePayloadSize;
    private final boolean allowMaskMismatch;
    static final MethodHandle setHandshakerMethod = getSetHandshakerMethod();
    static final MethodHandle forbiddenHttpRequestResponderMethod = getForbiddenHttpRequestResponderMethod();

    static MethodHandle getSetHandshakerMethod(){
         try {
             Method method = WebSocketServerProtocolHandler.class.getDeclaredMethod("setHandshaker"
                    , Channel.class
                    , WebSocketServerHandshaker.class
                    );
             method.setAccessible(true);

             return MethodHandles.lookup().unreflect(method);
        } catch (Throwable e) {
            // Should never happen
            e.printStackTrace();
            System.exit(5);
            return null;
        }
    }

    static MethodHandle getForbiddenHttpRequestResponderMethod(){
        try {
            Method method =  WebSocketServerProtocolHandler.class.getDeclaredMethod("forbiddenHttpRequestResponder");
            method.setAccessible(true);

            return MethodHandles.lookup().unreflect(method);
        } catch (Throwable e) {
            // Should never happen
            e.printStackTrace();
            System.exit(6);
            return null;
        }
   }

    public CustomWebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
            boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
        this.websocketPath = websocketPath;
        this.subprotocols = subprotocols;
        this.allowExtensions = allowExtensions;
        maxFramePayloadSize = maxFrameSize;
        this.allowMaskMismatch = allowMaskMismatch;
    }

    @Override
    public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
        FullHttpRequest req = (FullHttpRequest) msg;
        if (!websocketPath.equals(req.uri())) {
            ctx.fireChannelRead(msg);
            return;
        }

        try {
            if (req.method() != GET) {
                sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
                return;
            }

            final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
                    getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
                            allowExtensions, maxFramePayloadSize, allowMaskMismatch);
            final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
            if (handshaker == null) {
                WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
            } else {



                Channel channel = ctx.channel();
                final ChannelFuture handshakeFuture = handshaker.handshake(channel, req, getResponseHeaders(req), channel.newPromise());

                handshakeFuture.addListener(new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                        if (!future.isSuccess()) {
                            ctx.fireExceptionCaught(future.cause());
                        } else {
                            ctx.fireUserEventTriggered(
                                    WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
                        }
                    }
                });


                try {
                    setHandshakerMethod.invokeExact(ctx.channel(), handshaker);

                    ChannelHandler handler = (ChannelHandler)forbiddenHttpRequestResponderMethod.invokeExact();
                    ctx.pipeline().replace(this, "WS403Responder", handler);

                } catch (Throwable e) {
                    // Should never happen
                    e.printStackTrace();
                    System.exit(7);
                } 
            }
        } finally {
            req.release();
        }
    }

    private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
        ChannelFuture f = ctx.channel().writeAndFlush(res);
        if (!isKeepAlive(req) || res.status().code() != 200) {
            f.addListener(ChannelFutureListener.CLOSE);
        }
    }

    private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
        String protocol = "ws";
        if (cp.get(SslHandler.class) != null) {
            // SSL in use so use Secure WebSockets
            protocol = "wss";
        }
        return protocol + "://" + req.headers().get(HttpHeaderNames.HOST) + path;
    }

    private static HttpHeaders getResponseHeaders(FullHttpRequest req){
        final String cookieName = "cid";        
        final DefaultHttpHeaders httpHeaders = new DefaultHttpHeaders(); 

        String connectionID = null;
        String cookieString = req.headers().get(HttpHeaderNames.COOKIE);
        if( cookieString != null && cookieString.length() > 0 )
        {
            Set<Cookie> cookies = ServerCookieDecoder.LAX.decode(cookieString);
            for (Cookie cookie : cookies) {
                if( cookieName.equalsIgnoreCase(cookie.name())){
                    connectionID = cookie.value();
                    break;
                }
            }   
        }
        if( connectionID == null || connectionID.length() < 16 || connectionID.length() > 50 ){
            connectionID = UUID.randomUUID().toString().replaceAll("-", "");
        }


        DefaultCookie cookie = new DefaultCookie("cid", connectionID);
        cookie.setPath("/");
        cookie.setHttpOnly(true);
        cookie.setSecure(false);

        httpHeaders.add(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.LAX.encode(cookie));
        return httpHeaders;
    }
}

然后添加一个继承自WebSocketServerProtocolHandler

的新类
class CustomWebSocketServerProtocolHandler extends WebSocketServerProtocolHandler  {

    private final String websocketPath;
    private final String subprotocols;
    private final boolean allowExtensions;
    private final int maxFramePayloadLength;
    private final boolean allowMaskMismatch;


    public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,
            boolean allowExtensions) {
        this(websocketPath, subprotocols, allowExtensions, 65536, false);
        // TODO Auto-generated constructor stub
    }

    public CustomWebSocketServerProtocolHandler(String websocketPath,
            String subprotocols, boolean allowExtensions, int maxFrameSize,
            boolean allowMaskMismatch) {
        super(websocketPath, subprotocols, allowExtensions, maxFrameSize,
                allowMaskMismatch);

        this.websocketPath = websocketPath;
        this.subprotocols = subprotocols;
        this.allowExtensions = allowExtensions;
        maxFramePayloadLength = maxFrameSize;
        this.allowMaskMismatch = allowMaskMismatch;
    }

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) {
        ChannelPipeline cp = ctx.pipeline();
        if (cp.get(CustomWebSocketServerProtocolHandshakeHandler.class) == null) {
            // Add the WebSocketHandshakeHandler before this one.
            ctx.pipeline().addBefore(ctx.name(), CustomWebSocketServerProtocolHandshakeHandler.class.getName(),
                        new CustomWebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols,
                                allowExtensions, maxFramePayloadLength, allowMaskMismatch));
        }
        if (cp.get(Utf8FrameValidator.class) == null) {
            // Add the UFT8 checking before this one.
            ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
                    new Utf8FrameValidator());
        }
    }


}

将它们放入管道

    pipeline.addLast(new HttpServerCodec());
    pipeline.addLast(new HttpObjectAggregator(65536));
    pipeline.addLast(new WebSocketServerCompressionHandler());
    pipeline.addLast(new CustomWebSocketServerProtocolHandler(WEBSOCKET_PATH, "*", true));