我的管道如下所示
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new HttpServerCodec());
pipeline.addLast(new HttpObjectAggregator(65536));
pipeline.addLast(new WebSocketServerProtocolHandler(WEBSOCKET_PATH, null, true));
我想在握手响应中添加Set-Cookie
HTTP标头。这是RFC6455
来自服务器的握手如下所示:
Connection:upgrade
Sec-Websocket-Accept:T1UGQ4HhT3dvLNq5Yi+i/gfASi8=
Upgrade:websocket
Set-Cookie: ccc=22; path=/; HttpOnly
无序的标题字段集合位于两者中的前导行之后 案例。这些头字段的含义在第4节中指定 这份文件。还可以存在附加的头字段 作为cookies [RFC6265]。
答案 0 :(得分:0)
我找不到好方法。最后我通过反射调用私有方法来做到这一点。
Netty 4.1.2.Final
首先找到WebSocketServerProtocolHandshakeHandler
类的源代码。这个类是非公开的,所以制作这个类的副本并根据它进行修改。
class CustomWebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;
private final int maxFramePayloadSize;
private final boolean allowMaskMismatch;
static final MethodHandle setHandshakerMethod = getSetHandshakerMethod();
static final MethodHandle forbiddenHttpRequestResponderMethod = getForbiddenHttpRequestResponderMethod();
static MethodHandle getSetHandshakerMethod(){
try {
Method method = WebSocketServerProtocolHandler.class.getDeclaredMethod("setHandshaker"
, Channel.class
, WebSocketServerHandshaker.class
);
method.setAccessible(true);
return MethodHandles.lookup().unreflect(method);
} catch (Throwable e) {
// Should never happen
e.printStackTrace();
System.exit(5);
return null;
}
}
static MethodHandle getForbiddenHttpRequestResponderMethod(){
try {
Method method = WebSocketServerProtocolHandler.class.getDeclaredMethod("forbiddenHttpRequestResponder");
method.setAccessible(true);
return MethodHandles.lookup().unreflect(method);
} catch (Throwable e) {
// Should never happen
e.printStackTrace();
System.exit(6);
return null;
}
}
public CustomWebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
maxFramePayloadSize = maxFrameSize;
this.allowMaskMismatch = allowMaskMismatch;
}
@Override
public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
FullHttpRequest req = (FullHttpRequest) msg;
if (!websocketPath.equals(req.uri())) {
ctx.fireChannelRead(msg);
return;
}
try {
if (req.method() != GET) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
return;
}
final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
allowExtensions, maxFramePayloadSize, allowMaskMismatch);
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
Channel channel = ctx.channel();
final ChannelFuture handshakeFuture = handshaker.handshake(channel, req, getResponseHeaders(req), channel.newPromise());
handshakeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
ctx.fireExceptionCaught(future.cause());
} else {
ctx.fireUserEventTriggered(
WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
}
}
});
try {
setHandshakerMethod.invokeExact(ctx.channel(), handshaker);
ChannelHandler handler = (ChannelHandler)forbiddenHttpRequestResponderMethod.invokeExact();
ctx.pipeline().replace(this, "WS403Responder", handler);
} catch (Throwable e) {
// Should never happen
e.printStackTrace();
System.exit(7);
}
}
} finally {
req.release();
}
}
private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
ChannelFuture f = ctx.channel().writeAndFlush(res);
if (!isKeepAlive(req) || res.status().code() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
String protocol = "ws";
if (cp.get(SslHandler.class) != null) {
// SSL in use so use Secure WebSockets
protocol = "wss";
}
return protocol + "://" + req.headers().get(HttpHeaderNames.HOST) + path;
}
private static HttpHeaders getResponseHeaders(FullHttpRequest req){
final String cookieName = "cid";
final DefaultHttpHeaders httpHeaders = new DefaultHttpHeaders();
String connectionID = null;
String cookieString = req.headers().get(HttpHeaderNames.COOKIE);
if( cookieString != null && cookieString.length() > 0 )
{
Set<Cookie> cookies = ServerCookieDecoder.LAX.decode(cookieString);
for (Cookie cookie : cookies) {
if( cookieName.equalsIgnoreCase(cookie.name())){
connectionID = cookie.value();
break;
}
}
}
if( connectionID == null || connectionID.length() < 16 || connectionID.length() > 50 ){
connectionID = UUID.randomUUID().toString().replaceAll("-", "");
}
DefaultCookie cookie = new DefaultCookie("cid", connectionID);
cookie.setPath("/");
cookie.setHttpOnly(true);
cookie.setSecure(false);
httpHeaders.add(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.LAX.encode(cookie));
return httpHeaders;
}
}
然后添加一个继承自WebSocketServerProtocolHandler
class CustomWebSocketServerProtocolHandler extends WebSocketServerProtocolHandler {
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;
private final int maxFramePayloadLength;
private final boolean allowMaskMismatch;
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions) {
this(websocketPath, subprotocols, allowExtensions, 65536, false);
// TODO Auto-generated constructor stub
}
public CustomWebSocketServerProtocolHandler(String websocketPath,
String subprotocols, boolean allowExtensions, int maxFrameSize,
boolean allowMaskMismatch) {
super(websocketPath, subprotocols, allowExtensions, maxFrameSize,
allowMaskMismatch);
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
maxFramePayloadLength = maxFrameSize;
this.allowMaskMismatch = allowMaskMismatch;
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
ChannelPipeline cp = ctx.pipeline();
if (cp.get(CustomWebSocketServerProtocolHandshakeHandler.class) == null) {
// Add the WebSocketHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), CustomWebSocketServerProtocolHandshakeHandler.class.getName(),
new CustomWebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols,
allowExtensions, maxFramePayloadLength, allowMaskMismatch));
}
if (cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.
ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator());
}
}
}
将它们放入管道
pipeline.addLast(new HttpServerCodec());
pipeline.addLast(new HttpObjectAggregator(65536));
pipeline.addLast(new WebSocketServerCompressionHandler());
pipeline.addLast(new CustomWebSocketServerProtocolHandler(WEBSOCKET_PATH, "*", true));