我使用Tomcat 8.5.4构建了一个发布/订阅websocket,但是我在发送多个异步消息时遇到了困难。这是我的端点实现剥离其原始的本质:
package com.example.websocket.server;
import javax.websocket.*;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Enumeration;
import java.util.concurrent.ConcurrentHashMap;
@ServerEndpoint("/repeater")
public class SocketEndpoint {
/**
* session ID(String) : Session
*/
private static final ConcurrentHashMap<String, Session> activeConnections = new ConcurrentHashMap<>();
@OnMessage
public void onData(InputStream data, Session session) {
try {
Byte newData = (byte)data.read();
System.out.println("received data: " + newData);
Enumeration<Session> connections = activeConnections.elements();
ByteBuffer notification = ByteBuffer.wrap(new byte[] { newData });
int notificationCount = 0;
while(connections.hasMoreElements() ) {
Session connection = connections.nextElement();
notificationCount++;
connection.getAsyncRemote().sendBinary(notification, new SendHandler() {
@Override public void onResult(SendResult sendResult) {
System.out.println("getAsyncRemote().sendBinary() OK=" + sendResult.isOK());
}
});
//try {
// connection.getBasicRemote().sendBinary(notification);
//} catch(IOException ioe) {
// System.err.println(ioe);
//}
}
System.out.println("sent " + notificationCount + " notifications");
}
catch (Exception e) {
System.out.println(e);
}
}
@OnOpen
public void myOnOpen(Session session) {
System.out.println("session opened: " + session.getId());
activeConnections.put(session.getId(), session);
}
@OnClose
public void myOnClose(CloseReason reason, Session session) {
System.out.println("Closing a session: " + reason.getCloseCode());
activeConnections.remove(session.getId());
}
}
我遇到的问题是,如果有多个活动连接,则只有一个客户端会收到通知,以响应其中一个客户端发送的1字节消息。 当我切换到使用BasicRemote时,所有连接的客户端都会按预期通知。
我更喜欢在BasicRemote上使用AsyncRemote,以便在继续在onData()中工作的同时在后台发送通知。 似乎通过AsyncRemote发送的消息正在成功发送,但客户端没有收到它们。我误解了AsyncRemote吗?
我还写了一个确认这种行为的JUnit测试:
ServerTests.java
package com.example.websocket.server;
import org.apache.catalina.Context;
import org.apache.catalina.authenticator.jaspic.AuthConfigFactoryImpl;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import javax.security.auth.message.config.AuthConfigFactory;
import javax.websocket.ContainerProvider;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerEndpoint;
import java.net.URI;
import java.nio.ByteBuffer;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertThat;
public class ServerTests {
private URI uri;
private Tomcat tomcat;
public URI getUri() {
return uri;
}
@Before
public void setUp() throws Exception {
AuthConfigFactory.setFactory(new AuthConfigFactoryImpl());
tomcat = new Tomcat();
tomcat.setBaseDir("./target/server");
final int port = 6000;
tomcat.setPort(port);
// No file system docBase required
Context ctx = tomcat.addContext("", null);
//add SocketEndpoint
ctx.addApplicationListener(SocketEndpointSetup.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMapping("/", "default");
WebSocketContainer wsContainer =
ContainerProvider.getWebSocketContainer();
tomcat.start();
this.uri = new URI("ws://localhost:" + port + SocketEndpoint.class.getAnnotation(ServerEndpoint.class).value());
}
@After
public void tearDown() throws Exception {
tomcat.stop();
tomcat.destroy();
tomcat = null;
}
@Test
public void testNotifications() throws Throwable {
//client for sending message
TesterMessageCountClient testClient1 = new TesterMessageCountClient(this.getUri());
testClient1.connect();
testClient1.listenForResponse();
//clients for receiving notification messages
TesterMessageCountClient testClient2 = new TesterMessageCountClient(this.getUri());
testClient2.connect();
testClient2.listenForResponse();
TesterMessageCountClient testClient3 = new TesterMessageCountClient(this.getUri());
testClient3.connect();
testClient3.listenForResponse();
final byte newData = 2;
testClient1.getSession().getBasicRemote().sendBinary(ByteBuffer.wrap(new byte[]{newData}));
ByteBuffer notification = testClient1.awaitResponse();
assertThat(notification.get(0), is(newData));
notification = testClient2.awaitResponse();
assertThat(notification.get(0), is(newData));
notification = testClient3.awaitResponse();
assertThat(notification.get(0), is(newData));
testClient1.disconnect();
testClient2.disconnect();
testClient3.disconnect();
}
}
SocketEndpointSetup.java
package com.example.websocket.server;
import org.apache.tomcat.websocket.server.Constants;
import org.apache.tomcat.websocket.server.WsContextListener;
import javax.servlet.ServletContextEvent;
import javax.websocket.DeploymentException;
import javax.websocket.server.ServerContainer;
public class SocketEndpointSetup extends WsContextListener {
@Override
public void contextInitialized(ServletContextEvent sce) {
super.contextInitialized(sce);
ServerContainer sc =
(ServerContainer) sce.getServletContext().getAttribute(
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
try {
sc.addEndpoint(SocketEndpoint.class);
} catch (DeploymentException e) {
throw new IllegalStateException(e);
}
}
}
TesterMessageCountClient.java
package com.example.websocket.server;
import org.junit.Assert;
import javax.websocket.*;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
public class TesterMessageCountClient {
public interface TesterEndpoint {
void setLatch(CountDownLatch latch);
}
@ClientEndpoint
public static class TesterAnnotatedEndpoint implements TesterEndpoint {
private CountDownLatch latch = null;
@Override
public void setLatch(CountDownLatch latch) {
this.latch = latch;
}
@OnClose
public void onClose() {
clearLatch();
}
@OnError
public void onError(@SuppressWarnings("unused") Throwable throwable) {
clearLatch();
}
private void clearLatch() {
if (latch != null) {
while (latch.getCount() > 0) {
latch.countDown();
}
}
}
@OnOpen
public void onOpen(Session session) {
session.getUserProperties().put("endpoint", this);
}
}
public abstract static class BasicHandler<T>
implements MessageHandler.Whole<T> {
private final CountDownLatch latch;
private final Queue<T> messages = new LinkedBlockingQueue<>();
public BasicHandler(CountDownLatch latch) {
this.latch = latch;
}
public CountDownLatch getLatch() {
return latch;
}
public Queue<T> getMessages() {
return messages;
}
}
public static class BasicBinary extends BasicHandler<ByteBuffer> {
public BasicBinary(CountDownLatch latch) {
super(latch);
}
@Override
public void onMessage(ByteBuffer message) {
getMessages().add(message);
if (getLatch() != null) {
getLatch().countDown();
}
}
}
private final URI uri;
private Session session;
private BasicBinary handler;
public TesterMessageCountClient(URI serverUri) {
this.uri = serverUri;
}
public Session getSession() {
return this.session;
}
public void connect() throws Exception {
System.out.println("connect to " + uri);
this.session = ContainerProvider.getWebSocketContainer()
.connectToServer(TesterMessageCountClient.TesterAnnotatedEndpoint.class, uri);
}
public void disconnect() throws IOException {
this.session.close();
}
/**
* Same as #setExpectedResponseCount(1);
*/
public void listenForResponse() {
setExpectedResponseCount(1);
}
/**
* Call this to start listening for websocket messages. The messages can then be retreived using
* #awaitResponses()
* @param expectedResponseCount instruct awaitResponses() to block until this number of messages have been received
*/
public void setExpectedResponseCount(int expectedResponseCount) {
if(this.handler != null) {
getSession().removeMessageHandler(this.handler);
}
CountDownLatch latch = new CountDownLatch(expectedResponseCount);
TesterMessageCountClient.TesterEndpoint tep =
(TesterMessageCountClient.TesterEndpoint) getSession().getUserProperties().get("endpoint");
tep.setLatch(latch);
this.handler = new TesterMessageCountClient.BasicBinary(latch);
getSession().addMessageHandler(this.handler);
}
public ByteBuffer awaitResponse() throws InterruptedException, IOException {
return ByteBuffer.wrap(awaitResponses().get(0).array());
}
public List<ByteBuffer> awaitResponses() throws InterruptedException {
if(this.handler == null) {
throw new IllegalStateException("call setExpectedResponseCount() first");
}
boolean latchResult = handler.getLatch().await(10, TimeUnit.SECONDS);
Assert.assertTrue(latchResult);
List<ByteBuffer> responses = new ArrayList<>();
Queue<ByteBuffer> messages = handler.getMessages();
responses.addAll(messages);
return responses;
}
}