This is an automated email from the ASF dual-hosted git repository.

markt pushed a commit to branch 9.0.x
in repository https://gitbox.apache.org/repos/asf/tomcat.git

commit 52d6650e062d880704898d7d8c1b2b7a3efe8068
Author: Mark Thomas <ma...@apache.org>
AuthorDate: Thu Jan 18 11:32:43 2024 +0000

    Refactor WebSocket close for suspend/resume
    
    Ensure that WebSocket connection closure completes if the connection is
    closed when the server side has used the proprietary suspend/resume
    feature to suspend the connection.
---
 java/org/apache/tomcat/websocket/Constants.java    |  6 ++
 java/org/apache/tomcat/websocket/WsSession.java    | 67 +++++++++++++--
 .../tomcat/websocket/WsWebSocketContainer.java     |  9 +-
 .../tomcat/websocket/server/WsServerContainer.java |  2 +-
 .../websocket/TestWsSessionSuspendResume.java      | 99 ++++++++++++++++++++++
 webapps/docs/changelog.xml                         |  5 ++
 webapps/docs/web-socket-howto.xml                  |  7 ++
 7 files changed, 187 insertions(+), 8 deletions(-)

diff --git a/java/org/apache/tomcat/websocket/Constants.java 
b/java/org/apache/tomcat/websocket/Constants.java
index d0a96e706d..d03e21abc8 100644
--- a/java/org/apache/tomcat/websocket/Constants.java
+++ b/java/org/apache/tomcat/websocket/Constants.java
@@ -19,6 +19,7 @@ package org.apache.tomcat.websocket;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.concurrent.TimeUnit;
 
 import javax.websocket.Extension;
 
@@ -107,6 +108,11 @@ public class Constants {
     // Milliseconds so this is 20 seconds
     public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000;
 
+    // Configuration for session close timeout
+    public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = 
"org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT";
+    // Default is 30 seconds - setting is in milliseconds
+    public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = 
TimeUnit.SECONDS.toMillis(30);
+
     // Configuration for read idle timeout on WebSocket session
     public static final String READ_IDLE_TIMEOUT_MS = 
"org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS";
 
diff --git a/java/org/apache/tomcat/websocket/WsSession.java 
b/java/org/apache/tomcat/websocket/WsSession.java
index 907f93abc9..e71b719e27 100644
--- a/java/org/apache/tomcat/websocket/WsSession.java
+++ b/java/org/apache/tomcat/websocket/WsSession.java
@@ -27,6 +27,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -114,6 +115,7 @@ public class WsSession implements Session {
     private volatile long lastActiveRead = System.currentTimeMillis();
     private volatile long lastActiveWrite = System.currentTimeMillis();
     private Map<FutureToSendHandler, FutureToSendHandler> futures = new 
ConcurrentHashMap<>();
+    private volatile Long sessionCloseTimeoutExpiry;
 
 
     /**
@@ -676,7 +678,14 @@ public class WsSession implements Session {
              */
             state.set(State.CLOSED);
             // ... and close the network connection.
-            wsRemoteEndpoint.close();
+            closeConnection();
+        } else {
+            /*
+             * Set close timeout. If the client fails to send a close message 
response within the timeout, the session
+             * and the connection will be closed when the timeout expires.
+             */
+            sessionCloseTimeoutExpiry =
+                    Long.valueOf(System.nanoTime() + 
TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout()));
         }
 
         // Fail any uncompleted messages.
@@ -715,7 +724,7 @@ public class WsSession implements Session {
             state.set(State.CLOSED);
 
             // Close the network connection.
-            wsRemoteEndpoint.close();
+            closeConnection();
         } else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) {
             /*
              * The local endpoint sent a close message the the same time as 
the remote endpoint. The local close is
@@ -727,12 +736,55 @@ public class WsSession implements Session {
              * The local endpoint sent the first close message. The remote 
endpoint has now responded with its own close
              * message so mark the session as fully closed and close the 
network connection.
              */
-            wsRemoteEndpoint.close();
+            closeConnection();
         }
         // CLOSING and CLOSED are NO-OPs
     }
 
 
+    private void closeConnection() {
+        /*
+         * Close the network connection.
+         */
+        wsRemoteEndpoint.close();
+        /*
+         * Don't unregister the session until the connection is fully closed 
since webSocketContainer is responsible for
+         * tracking the session close timeout.
+         */
+        webSocketContainer.unregisterSession(getSessionMapKey(), this);
+    }
+
+
+    /*
+     * Returns the session close timeout in milliseconds
+     */
+    protected long getSessionCloseTimeout() {
+        long result = 0;
+        Object obj = 
userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY);
+        if (obj instanceof Long) {
+            result = ((Long) obj).intValue();
+        }
+        if (result <= 0) {
+            result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT;
+        }
+        return result;
+    }
+
+
+    protected void checkCloseTimeout() {
+        // Skip the check if no session close timeout has been set.
+        if (sessionCloseTimeoutExpiry != null) {
+            // Check if the timeout has expired.
+            if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) 
{
+                // Check if the session has been closed in another thread 
while the timeout was being processed.
+                if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) {
+                    closeConnection();
+                }
+            }
+        }
+    }
+
+
     private void fireEndpointOnClose(CloseReason closeReason) {
 
         // Fire the onClose event
@@ -805,7 +857,7 @@ public class WsSession implements Session {
             if (log.isDebugEnabled()) {
                 log.debug(sm.getString("wsSession.sendCloseFail", id), e);
             }
-            wsRemoteEndpoint.close();
+            closeConnection();
             // Failure to send a close message is not unexpected in the case of
             // an abnormal closure (usually triggered by a failure to 
read/write
             // from/to the client. In this case do not trigger the endpoint's
@@ -813,8 +865,6 @@ public class WsSession implements Session {
             if (closeCode != CloseCodes.CLOSED_ABNORMALLY) {
                 localEndpoint.onError(this, e);
             }
-        } finally {
-            webSocketContainer.unregisterSession(getSessionMapKey(), this);
         }
     }
 
@@ -947,6 +997,11 @@ public class WsSession implements Session {
     @Override
     public Principal getUserPrincipal() {
         checkState();
+        return getUserPrincipalInternal();
+    }
+
+
+    public Principal getUserPrincipalInternal() {
         return userPrincipal;
     }
 
diff --git a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java 
b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
index 73625ca732..56dfa3e8f2 100644
--- a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
+++ b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
@@ -609,7 +609,12 @@ public class WsWebSocketContainer implements 
WebSocketContainer, BackgroundProce
         synchronized (endPointSessionMapLock) {
             Set<WsSession> sessions = endpointSessionMap.get(key);
             if (sessions != null) {
-                result.addAll(sessions);
+                // Some sessions may be in the process of closing
+                for (WsSession session : sessions) {
+                    if (session.isOpen()) {
+                        result.add(session);
+                    }
+                }
             }
         }
         return result;
@@ -1052,8 +1057,10 @@ public class WsWebSocketContainer implements 
WebSocketContainer, BackgroundProce
         if (backgroundProcessCount >= processPeriod) {
             backgroundProcessCount = 0;
 
+            // Check all registered sessions.
             for (WsSession wsSession : sessions.keySet()) {
                 wsSession.checkExpiration();
+                wsSession.checkCloseTimeout();
             }
         }
 
diff --git a/java/org/apache/tomcat/websocket/server/WsServerContainer.java 
b/java/org/apache/tomcat/websocket/server/WsServerContainer.java
index 06afbee06d..afde6d658d 100644
--- a/java/org/apache/tomcat/websocket/server/WsServerContainer.java
+++ b/java/org/apache/tomcat/websocket/server/WsServerContainer.java
@@ -432,7 +432,7 @@ public class WsServerContainer extends WsWebSocketContainer 
implements ServerCon
      */
     @Override
     protected void unregisterSession(Object key, WsSession wsSession) {
-        if (wsSession.getUserPrincipal() != null && 
wsSession.getHttpSessionId() != null) {
+        if (wsSession.getUserPrincipalInternal() != null && 
wsSession.getHttpSessionId() != null) {
             unregisterAuthenticatedSession(wsSession, 
wsSession.getHttpSessionId());
         }
         super.unregisterSession(key, wsSession);
diff --git a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java 
b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java
index 170222ad38..2cfad10103 100644
--- a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java
+++ b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java
@@ -23,6 +23,8 @@ import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 
+import javax.servlet.ServletContextEvent;
+import javax.servlet.ServletContextListener;
 import javax.websocket.ClientEndpointConfig;
 import javax.websocket.CloseReason;
 import javax.websocket.ContainerProvider;
@@ -39,7 +41,9 @@ import org.apache.catalina.Context;
 import org.apache.catalina.servlets.DefaultServlet;
 import org.apache.catalina.startup.Tomcat;
 import 
org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
+import org.apache.tomcat.websocket.server.Constants;
 import org.apache.tomcat.websocket.server.TesterEndpointConfig;
+import org.apache.tomcat.websocket.server.WsServerContainer;
 
 public class TestWsSessionSuspendResume extends WebSocketBaseTest {
 
@@ -141,4 +145,99 @@ public class TestWsSessionSuspendResume extends 
WebSocketBaseTest {
             }
         }
     }
+
+
+    @Test
+    public void testSuspendThenClose() throws Exception {
+        Tomcat tomcat = getTomcatInstance();
+
+        Context ctx = getProgrammaticRootContext();
+        ctx.addApplicationListener(SuspendCloseConfig.class.getName());
+        ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName());
+
+        Tomcat.addServlet(ctx, "default", new DefaultServlet());
+        ctx.addServletMappingDecoded("/", "default");
+
+        tomcat.start();
+
+        WebSocketContainer wsContainer = 
ContainerProvider.getWebSocketContainer();
+
+        ClientEndpointConfig clientEndpointConfig = 
ClientEndpointConfig.Builder.create().build();
+        Session wsSession = 
wsContainer.connectToServer(TesterProgrammaticEndpoint.class, 
clientEndpointConfig,
+                new URI("ws://localhost:" + getPort() + 
SuspendResumeConfig.PATH));
+
+        wsSession.getBasicRemote().sendText("start test");
+
+        // Wait for the client response to be received by the server
+        int count = 0;
+        while (count < 50 && 
!SuspendCloseEndpoint.isServerSessionFullyClosed()) {
+            Thread.sleep(100);
+            count ++;
+        }
+        Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed());
+    }
+
+
+    public static final class SuspendCloseConfig extends TesterEndpointConfig {
+        private static final String PATH = "/echo";
+
+        @Override
+        protected Class<?> getEndpointClass() {
+            return SuspendCloseEndpoint.class;
+        }
+
+        @Override
+        protected ServerEndpointConfig getServerEndpointConfig() {
+            return ServerEndpointConfig.Builder.create(getEndpointClass(), 
PATH).build();
+        }
+    }
+
+
+    public static final class SuspendCloseEndpoint extends Endpoint {
+
+        // Yes, a static variable is a hack.
+        private static WsSession serverSession;
+
+        @Override
+        public void onOpen(Session session, EndpointConfig epc) {
+            serverSession = (WsSession) session;
+            // Set a short session close timeout (milliseconds)
+            serverSession.getUserProperties().put(
+                    
org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, 
Long.valueOf(2000));
+            // Any message will trigger the suspend then close
+            serverSession.addMessageHandler(String.class, message -> {
+                try {
+                    serverSession.getBasicRemote().sendText("server session 
open");
+                    serverSession.getBasicRemote().sendText("suspending server 
session");
+                    serverSession.suspend();
+                    serverSession.getBasicRemote().sendText("closing server 
session");
+                    serverSession.close();
+                } catch (IOException ioe) {
+                    ioe.printStackTrace();
+                    // Attempt to make the failure more obvious
+                    throw new RuntimeException(ioe);
+                }
+            });
+        }
+
+        @Override
+        public void onError(Session session, Throwable t) {
+            t.printStackTrace();
+        }
+
+        public static boolean isServerSessionFullyClosed() {
+            return serverSession.isClosed();
+        }
+    }
+
+
+    public static class WebSocketFastServerTimeout implements 
ServletContextListener {
+
+        @Override
+        public void contextInitialized(ServletContextEvent sce) {
+            WsServerContainer container = (WsServerContainer) 
sce.getServletContext().getAttribute(
+                    Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
+            container.setProcessPeriod(0);
+        }
+    }
 }
\ No newline at end of file
diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml
index 52ade80bb5..27781eb925 100644
--- a/webapps/docs/changelog.xml
+++ b/webapps/docs/changelog.xml
@@ -205,6 +205,11 @@
         Review usage of debug logging and downgrade trace or data dumping
         operations from debug level to trace. (remm)
       </fix>
+      <fix>
+        Ensure that WebSocket connection closure completes if the connection is
+        closed when the server side has used the proprietary suspend/resume
+        feature to suspend the connection. (markt)
+      </fix>
     </changelog>
   </subsection>
   <subsection name="Web applications">
diff --git a/webapps/docs/web-socket-howto.xml 
b/webapps/docs/web-socket-howto.xml
index 46266eb033..7eda457be2 100644
--- a/webapps/docs/web-socket-howto.xml
+++ b/webapps/docs/web-socket-howto.xml
@@ -63,6 +63,13 @@
    the timeout to use in milliseconds. For an infinite timeout, use
    <code>-1</code>.</p>
 
+<p>The session close timeout defaults to 30000 milliseconds (30 seconds). This
+   may be changed by setting the property
+   <code>org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT</code> in the user
+   properties collection attached to the WebSocket session. The value assigned
+   to this property should be a <code>Long</code> and represents the timeout to
+   use in milliseconds. Values less than or equal to zero will be ignored.</p>
+
 <p>In addition to the <code>Session.setMaxIdleTimeout(long)</code> method which
    is part of the Java WebSocket API, Tomcat provides greater control of the
    timing out the session due to lack of activity. Setting the property


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org
For additional commands, e-mail: dev-h...@tomcat.apache.org

Reply via email to