This is an automated email from the ASF dual-hosted git repository. markt pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tomcat.git
commit b0e3b1bd78de270d53e319d7cb79eb282aa53cb9 Author: Mark Thomas <ma...@apache.org> AuthorDate: Thu Jan 18 11:32:43 2024 +0000 Refactor WebSocket close for suspend/resume Ensure that WebSocket connection closure completes if the connection is closed when the server side has used the proprietary suspend/resume feature to suspend the connection. --- java/org/apache/tomcat/websocket/Constants.java | 6 ++ java/org/apache/tomcat/websocket/WsSession.java | 67 +++++++++++++-- .../tomcat/websocket/WsWebSocketContainer.java | 9 +- .../tomcat/websocket/server/WsServerContainer.java | 2 +- .../websocket/TestWsSessionSuspendResume.java | 99 ++++++++++++++++++++++ webapps/docs/changelog.xml | 5 ++ webapps/docs/web-socket-howto.xml | 7 ++ 7 files changed, 187 insertions(+), 8 deletions(-) diff --git a/java/org/apache/tomcat/websocket/Constants.java b/java/org/apache/tomcat/websocket/Constants.java index 7ec14131bd..f619c59642 100644 --- a/java/org/apache/tomcat/websocket/Constants.java +++ b/java/org/apache/tomcat/websocket/Constants.java @@ -19,6 +19,7 @@ package org.apache.tomcat.websocket; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.TimeUnit; import jakarta.websocket.Extension; @@ -94,6 +95,11 @@ public class Constants { // Milliseconds so this is 20 seconds public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000; + // Configuration for session close timeout + public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT"; + // Default is 30 seconds - setting is in milliseconds + public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30); + // Configuration for read idle timeout on WebSocket session public static final String READ_IDLE_TIMEOUT_MS = "org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS"; diff --git a/java/org/apache/tomcat/websocket/WsSession.java b/java/org/apache/tomcat/websocket/WsSession.java index f85c9c21ac..be16756bf4 100644 --- a/java/org/apache/tomcat/websocket/WsSession.java +++ b/java/org/apache/tomcat/websocket/WsSession.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -115,6 +116,7 @@ public class WsSession implements Session { private volatile long lastActiveRead = System.currentTimeMillis(); private volatile long lastActiveWrite = System.currentTimeMillis(); private Map<FutureToSendHandler, FutureToSendHandler> futures = new ConcurrentHashMap<>(); + private volatile Long sessionCloseTimeoutExpiry; /** @@ -593,7 +595,14 @@ public class WsSession implements Session { */ state.set(State.CLOSED); // ... and close the network connection. - wsRemoteEndpoint.close(); + closeConnection(); + } else { + /* + * Set close timeout. If the client fails to send a close message response within the timeout, the session + * and the connection will be closed when the timeout expires. + */ + sessionCloseTimeoutExpiry = + Long.valueOf(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout())); } // Fail any uncompleted messages. @@ -632,7 +641,7 @@ public class WsSession implements Session { state.set(State.CLOSED); // Close the network connection. - wsRemoteEndpoint.close(); + closeConnection(); } else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) { /* * The local endpoint sent a close message the the same time as the remote endpoint. The local close is @@ -644,12 +653,55 @@ public class WsSession implements Session { * The local endpoint sent the first close message. The remote endpoint has now responded with its own close * message so mark the session as fully closed and close the network connection. */ - wsRemoteEndpoint.close(); + closeConnection(); } // CLOSING and CLOSED are NO-OPs } + private void closeConnection() { + /* + * Close the network connection. + */ + wsRemoteEndpoint.close(); + /* + * Don't unregister the session until the connection is fully closed since webSocketContainer is responsible for + * tracking the session close timeout. + */ + webSocketContainer.unregisterSession(getSessionMapKey(), this); + } + + + /* + * Returns the session close timeout in milliseconds + */ + protected long getSessionCloseTimeout() { + long result = 0; + Object obj = userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY); + if (obj instanceof Long) { + result = ((Long) obj).intValue(); + } + if (result <= 0) { + result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT; + } + return result; + } + + + protected void checkCloseTimeout() { + // Skip the check if no session close timeout has been set. + if (sessionCloseTimeoutExpiry != null) { + // Check if the timeout has expired. + if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) { + // Check if the session has been closed in another thread while the timeout was being processed. + if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) { + closeConnection(); + } + } + } + } + + private void fireEndpointOnClose(CloseReason closeReason) { // Fire the onClose event @@ -722,7 +774,7 @@ public class WsSession implements Session { if (log.isDebugEnabled()) { log.debug(sm.getString("wsSession.sendCloseFail", id), e); } - wsRemoteEndpoint.close(); + closeConnection(); // Failure to send a close message is not unexpected in the case of // an abnormal closure (usually triggered by a failure to read/write // from/to the client. In this case do not trigger the endpoint's @@ -730,8 +782,6 @@ public class WsSession implements Session { if (closeCode != CloseCodes.CLOSED_ABNORMALLY) { localEndpoint.onError(this, e); } - } finally { - webSocketContainer.unregisterSession(getSessionMapKey(), this); } } @@ -864,6 +914,11 @@ public class WsSession implements Session { @Override public Principal getUserPrincipal() { checkState(); + return getUserPrincipalInternal(); + } + + + public Principal getUserPrincipalInternal() { return userPrincipal; } diff --git a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java index 034e30d2a0..e6376ce4b2 100644 --- a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java +++ b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java @@ -604,7 +604,12 @@ public class WsWebSocketContainer implements WebSocketContainer, BackgroundProce synchronized (endPointSessionMapLock) { Set<WsSession> sessions = endpointSessionMap.get(key); if (sessions != null) { - result.addAll(sessions); + // Some sessions may be in the process of closing + for (WsSession session : sessions) { + if (session.isOpen()) { + result.add(session); + } + } } } return result; @@ -1019,8 +1024,10 @@ public class WsWebSocketContainer implements WebSocketContainer, BackgroundProce if (backgroundProcessCount >= processPeriod) { backgroundProcessCount = 0; + // Check all registered sessions. for (WsSession wsSession : sessions.keySet()) { wsSession.checkExpiration(); + wsSession.checkCloseTimeout(); } } diff --git a/java/org/apache/tomcat/websocket/server/WsServerContainer.java b/java/org/apache/tomcat/websocket/server/WsServerContainer.java index 8fb4eb967c..b3b37ca456 100644 --- a/java/org/apache/tomcat/websocket/server/WsServerContainer.java +++ b/java/org/apache/tomcat/websocket/server/WsServerContainer.java @@ -349,7 +349,7 @@ public class WsServerContainer extends WsWebSocketContainer implements ServerCon */ @Override protected void unregisterSession(Object key, WsSession wsSession) { - if (wsSession.getUserPrincipal() != null && wsSession.getHttpSessionId() != null) { + if (wsSession.getUserPrincipalInternal() != null && wsSession.getHttpSessionId() != null) { unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId()); } super.unregisterSession(key, wsSession); diff --git a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java index cb54821662..f624f5c87c 100644 --- a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java +++ b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java @@ -23,6 +23,8 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import jakarta.servlet.ServletContextEvent; +import jakarta.servlet.ServletContextListener; import jakarta.websocket.ClientEndpointConfig; import jakarta.websocket.CloseReason; import jakarta.websocket.ContainerProvider; @@ -39,7 +41,9 @@ import org.apache.catalina.Context; import org.apache.catalina.servlets.DefaultServlet; import org.apache.catalina.startup.Tomcat; import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint; +import org.apache.tomcat.websocket.server.Constants; import org.apache.tomcat.websocket.server.TesterEndpointConfig; +import org.apache.tomcat.websocket.server.WsServerContainer; public class TestWsSessionSuspendResume extends WebSocketBaseTest { @@ -141,4 +145,99 @@ public class TestWsSessionSuspendResume extends WebSocketBaseTest { } } } + + + @Test + public void testSuspendThenClose() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + ctx.addApplicationListener(SuspendCloseConfig.class.getName()); + ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName()); + + Tomcat.addServlet(ctx, "default", new DefaultServlet()); + ctx.addServletMappingDecoded("/", "default"); + + tomcat.start(); + + WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer(); + + ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build(); + Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig, + new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH)); + + wsSession.getBasicRemote().sendText("start test"); + + // Wait for the client response to be received by the server + int count = 0; + while (count < 50 && !SuspendCloseEndpoint.isServerSessionFullyClosed()) { + Thread.sleep(100); + count ++; + } + Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed()); + } + + + public static final class SuspendCloseConfig extends TesterEndpointConfig { + private static final String PATH = "/echo"; + + @Override + protected Class<?> getEndpointClass() { + return SuspendCloseEndpoint.class; + } + + @Override + protected ServerEndpointConfig getServerEndpointConfig() { + return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build(); + } + } + + + public static final class SuspendCloseEndpoint extends Endpoint { + + // Yes, a static variable is a hack. + private static WsSession serverSession; + + @Override + public void onOpen(Session session, EndpointConfig epc) { + serverSession = (WsSession) session; + // Set a short session close timeout (milliseconds) + serverSession.getUserProperties().put( + org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, Long.valueOf(2000)); + // Any message will trigger the suspend then close + serverSession.addMessageHandler(String.class, message -> { + try { + serverSession.getBasicRemote().sendText("server session open"); + serverSession.getBasicRemote().sendText("suspending server session"); + serverSession.suspend(); + serverSession.getBasicRemote().sendText("closing server session"); + serverSession.close(); + } catch (IOException ioe) { + ioe.printStackTrace(); + // Attempt to make the failure more obvious + throw new RuntimeException(ioe); + } + }); + } + + @Override + public void onError(Session session, Throwable t) { + t.printStackTrace(); + } + + public static boolean isServerSessionFullyClosed() { + return serverSession.isClosed(); + } + } + + + public static class WebSocketFastServerTimeout implements ServletContextListener { + + @Override + public void contextInitialized(ServletContextEvent sce) { + WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + container.setProcessPeriod(0); + } + } } \ No newline at end of file diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml index 377e08e7a3..227c4cb3dc 100644 --- a/webapps/docs/changelog.xml +++ b/webapps/docs/changelog.xml @@ -220,6 +220,11 @@ Review usage of debug logging and downgrade trace or data dumping operations from debug level to trace. (remm) </fix> + <fix> + Ensure that WebSocket connection closure completes if the connection is + closed when the server side has used the proprietary suspend/resume + feature to suspend the connection. (markt) + </fix> </changelog> </subsection> <subsection name="Web applications"> diff --git a/webapps/docs/web-socket-howto.xml b/webapps/docs/web-socket-howto.xml index 49d155bd25..60231694cb 100644 --- a/webapps/docs/web-socket-howto.xml +++ b/webapps/docs/web-socket-howto.xml @@ -64,6 +64,13 @@ the timeout to use in milliseconds. For an infinite timeout, use <code>-1</code>.</p> +<p>The session close timeout defaults to 30000 milliseconds (30 seconds). This + may be changed by setting the property + <code>org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT</code> in the user + properties collection attached to the WebSocket session. The value assigned + to this property should be a <code>Long</code> and represents the timeout to + use in milliseconds. Values less than or equal to zero will be ignored.</p> + <p>In addition to the <code>Session.setMaxIdleTimeout(long)</code> method which is part of the Jakarta WebSocket API, Tomcat provides greater control of the timing out the session due to lack of activity. Setting the property --------------------------------------------------------------------- To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org For additional commands, e-mail: dev-h...@tomcat.apache.org