This is an automated email from the ASF dual-hosted git repository. markt pushed a commit to branch 8.5.x in repository https://gitbox.apache.org/repos/asf/tomcat.git
commit 3631adb1342d8bbd8598802a12b63ad02c37d591 Author: Mark Thomas <ma...@apache.org> AuthorDate: Thu Jan 18 11:32:43 2024 +0000 Refactor WebSocket close for suspend/resume Ensure that WebSocket connection closure completes if the connection is closed when the server side has used the proprietary suspend/resume feature to suspend the connection. --- java/org/apache/tomcat/websocket/Constants.java | 6 ++ java/org/apache/tomcat/websocket/WsSession.java | 67 +++++++++++-- .../tomcat/websocket/WsWebSocketContainer.java | 9 +- .../tomcat/websocket/server/WsServerContainer.java | 2 +- .../websocket/TestWsSessionSuspendResume.java | 107 +++++++++++++++++++++ webapps/docs/changelog.xml | 5 + webapps/docs/web-socket-howto.xml | 7 ++ 7 files changed, 195 insertions(+), 8 deletions(-) diff --git a/java/org/apache/tomcat/websocket/Constants.java b/java/org/apache/tomcat/websocket/Constants.java index 23bb378c3f..a912b0a5ab 100644 --- a/java/org/apache/tomcat/websocket/Constants.java +++ b/java/org/apache/tomcat/websocket/Constants.java @@ -19,6 +19,7 @@ package org.apache.tomcat.websocket; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.TimeUnit; import javax.websocket.Extension; @@ -107,6 +108,11 @@ public class Constants { // Milliseconds so this is 20 seconds public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000; + // Configuration for session close timeout + public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT"; + // Default is 30 seconds - setting is in milliseconds + public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30); + // Configuration for read idle timeout on WebSocket session public static final String READ_IDLE_TIMEOUT_MS = "org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS"; diff --git a/java/org/apache/tomcat/websocket/WsSession.java b/java/org/apache/tomcat/websocket/WsSession.java index 907f93abc9..e71b719e27 100644 --- a/java/org/apache/tomcat/websocket/WsSession.java +++ b/java/org/apache/tomcat/websocket/WsSession.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -114,6 +115,7 @@ public class WsSession implements Session { private volatile long lastActiveRead = System.currentTimeMillis(); private volatile long lastActiveWrite = System.currentTimeMillis(); private Map<FutureToSendHandler, FutureToSendHandler> futures = new ConcurrentHashMap<>(); + private volatile Long sessionCloseTimeoutExpiry; /** @@ -676,7 +678,14 @@ public class WsSession implements Session { */ state.set(State.CLOSED); // ... and close the network connection. - wsRemoteEndpoint.close(); + closeConnection(); + } else { + /* + * Set close timeout. If the client fails to send a close message response within the timeout, the session + * and the connection will be closed when the timeout expires. + */ + sessionCloseTimeoutExpiry = + Long.valueOf(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout())); } // Fail any uncompleted messages. @@ -715,7 +724,7 @@ public class WsSession implements Session { state.set(State.CLOSED); // Close the network connection. - wsRemoteEndpoint.close(); + closeConnection(); } else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) { /* * The local endpoint sent a close message the the same time as the remote endpoint. The local close is @@ -727,12 +736,55 @@ public class WsSession implements Session { * The local endpoint sent the first close message. The remote endpoint has now responded with its own close * message so mark the session as fully closed and close the network connection. */ - wsRemoteEndpoint.close(); + closeConnection(); } // CLOSING and CLOSED are NO-OPs } + private void closeConnection() { + /* + * Close the network connection. + */ + wsRemoteEndpoint.close(); + /* + * Don't unregister the session until the connection is fully closed since webSocketContainer is responsible for + * tracking the session close timeout. + */ + webSocketContainer.unregisterSession(getSessionMapKey(), this); + } + + + /* + * Returns the session close timeout in milliseconds + */ + protected long getSessionCloseTimeout() { + long result = 0; + Object obj = userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY); + if (obj instanceof Long) { + result = ((Long) obj).intValue(); + } + if (result <= 0) { + result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT; + } + return result; + } + + + protected void checkCloseTimeout() { + // Skip the check if no session close timeout has been set. + if (sessionCloseTimeoutExpiry != null) { + // Check if the timeout has expired. + if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) { + // Check if the session has been closed in another thread while the timeout was being processed. + if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) { + closeConnection(); + } + } + } + } + + private void fireEndpointOnClose(CloseReason closeReason) { // Fire the onClose event @@ -805,7 +857,7 @@ public class WsSession implements Session { if (log.isDebugEnabled()) { log.debug(sm.getString("wsSession.sendCloseFail", id), e); } - wsRemoteEndpoint.close(); + closeConnection(); // Failure to send a close message is not unexpected in the case of // an abnormal closure (usually triggered by a failure to read/write // from/to the client. In this case do not trigger the endpoint's @@ -813,8 +865,6 @@ public class WsSession implements Session { if (closeCode != CloseCodes.CLOSED_ABNORMALLY) { localEndpoint.onError(this, e); } - } finally { - webSocketContainer.unregisterSession(getSessionMapKey(), this); } } @@ -947,6 +997,11 @@ public class WsSession implements Session { @Override public Principal getUserPrincipal() { checkState(); + return getUserPrincipalInternal(); + } + + + public Principal getUserPrincipalInternal() { return userPrincipal; } diff --git a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java index b5fc142698..4f075bd64f 100644 --- a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java +++ b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java @@ -614,7 +614,12 @@ public class WsWebSocketContainer implements WebSocketContainer, BackgroundProce synchronized (endPointSessionMapLock) { Set<WsSession> sessions = endpointSessionMap.get(key); if (sessions != null) { - result.addAll(sessions); + // Some sessions may be in the process of closing + for (WsSession session : sessions) { + if (session.isOpen()) { + result.add(session); + } + } } } return result; @@ -1061,8 +1066,10 @@ public class WsWebSocketContainer implements WebSocketContainer, BackgroundProce if (backgroundProcessCount >= processPeriod) { backgroundProcessCount = 0; + // Check all registered sessions. for (WsSession wsSession : sessions.keySet()) { wsSession.checkExpiration(); + wsSession.checkCloseTimeout(); } } diff --git a/java/org/apache/tomcat/websocket/server/WsServerContainer.java b/java/org/apache/tomcat/websocket/server/WsServerContainer.java index 1ef2a41c6a..e4146eba96 100644 --- a/java/org/apache/tomcat/websocket/server/WsServerContainer.java +++ b/java/org/apache/tomcat/websocket/server/WsServerContainer.java @@ -429,7 +429,7 @@ public class WsServerContainer extends WsWebSocketContainer implements ServerCon */ @Override protected void unregisterSession(Object key, WsSession wsSession) { - if (wsSession.getUserPrincipal() != null && wsSession.getHttpSessionId() != null) { + if (wsSession.getUserPrincipalInternal() != null && wsSession.getHttpSessionId() != null) { unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId()); } super.unregisterSession(key, wsSession); diff --git a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java index 313411bd23..10c3a8204e 100644 --- a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java +++ b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java @@ -23,6 +23,8 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import javax.servlet.ServletContextEvent; +import javax.servlet.ServletContextListener; import javax.websocket.ClientEndpointConfig; import javax.websocket.CloseReason; import javax.websocket.ContainerProvider; @@ -40,7 +42,9 @@ import org.apache.catalina.Context; import org.apache.catalina.servlets.DefaultServlet; import org.apache.catalina.startup.Tomcat; import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint; +import org.apache.tomcat.websocket.server.Constants; import org.apache.tomcat.websocket.server.TesterEndpointConfig; +import org.apache.tomcat.websocket.server.WsServerContainer; public class TestWsSessionSuspendResume extends WebSocketBaseTest { @@ -152,4 +156,107 @@ public class TestWsSessionSuspendResume extends WebSocketBaseTest { } } } + + + @Test + public void testSuspendThenClose() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + ctx.addApplicationListener(SuspendCloseConfig.class.getName()); + ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName()); + + Tomcat.addServlet(ctx, "default", new DefaultServlet()); + ctx.addServletMappingDecoded("/", "default"); + + tomcat.start(); + + WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer(); + + ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build(); + Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig, + new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH)); + + wsSession.getBasicRemote().sendText("start test"); + + // Wait for the client response to be received by the server + int count = 0; + while (count < 50 && !SuspendCloseEndpoint.isServerSessionFullyClosed()) { + Thread.sleep(100); + count ++; + } + Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed()); + } + + + public static final class SuspendCloseConfig extends TesterEndpointConfig { + private static final String PATH = "/echo"; + + @Override + protected Class<?> getEndpointClass() { + return SuspendCloseEndpoint.class; + } + + @Override + protected ServerEndpointConfig getServerEndpointConfig() { + return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build(); + } + } + + + public static final class SuspendCloseEndpoint extends Endpoint { + + // Yes, a static variable is a hack. + private static WsSession serverSession; + + @Override + public void onOpen(Session session, EndpointConfig epc) { + serverSession = (WsSession) session; + // Set a short session close timeout (milliseconds) + serverSession.getUserProperties().put( + org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, Long.valueOf(2000)); + // Any message will trigger the suspend then close + serverSession.addMessageHandler(String.class, new MessageHandler.Whole<String>() { + @Override + public void onMessage(String message) { + try { + serverSession.getBasicRemote().sendText("server session open"); + serverSession.getBasicRemote().sendText("suspending server session"); + serverSession.suspend(); + serverSession.getBasicRemote().sendText("closing server session"); + serverSession.close(); + } catch (IOException ioe) { + ioe.printStackTrace(); + // Attempt to make the failure more obvious + throw new RuntimeException(ioe); + } + } + }); + } + + @Override + public void onError(Session session, Throwable t) { + t.printStackTrace(); + } + + public static boolean isServerSessionFullyClosed() { + return serverSession.isClosed(); + } + } + + + public static class WebSocketFastServerTimeout implements ServletContextListener { + + @Override + public void contextInitialized(ServletContextEvent sce) { + WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + container.setProcessPeriod(0); + } + + @Override + public void contextDestroyed(ServletContextEvent sce) { + // NO-OP + } + } } \ No newline at end of file diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml index dd4d2b2d74..5f02398080 100644 --- a/webapps/docs/changelog.xml +++ b/webapps/docs/changelog.xml @@ -199,6 +199,11 @@ Review usage of debug logging and downgrade trace or data dumping operations from debug level to trace. (remm) </fix> + <fix> + Ensure that WebSocket connection closure completes if the connection is + closed when the server side has used the proprietary suspend/resume + feature to suspend the connection. (markt) + </fix> </changelog> </subsection> <subsection name="Web applications"> diff --git a/webapps/docs/web-socket-howto.xml b/webapps/docs/web-socket-howto.xml index 9405199010..1ffcb2ca61 100644 --- a/webapps/docs/web-socket-howto.xml +++ b/webapps/docs/web-socket-howto.xml @@ -63,6 +63,13 @@ the timeout to use in milliseconds. For an infinite timeout, use <code>-1</code>.</p> +<p>The session close timeout defaults to 30000 milliseconds (30 seconds). This + may be changed by setting the property + <code>org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT</code> in the user + properties collection attached to the WebSocket session. The value assigned + to this property should be a <code>Long</code> and represents the timeout to + use in milliseconds. Values less than or equal to zero will be ignored.</p> + <p>In addition to the <code>Session.setMaxIdleTimeout(long)</code> method which is part of the Java WebSocket API, Tomcat provides greater control of the timing out the session due to lack of activity. Setting the property --------------------------------------------------------------------- To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org For additional commands, e-mail: dev-h...@tomcat.apache.org