This is an automated email from the ASF dual-hosted git repository.

dsoumis pushed a commit to branch 11.0.x
in repository https://gitbox.apache.org/repos/asf/tomcat.git

commit 48343a64f24258be7e273d49849ee866463be67c
Author: Dimitris Soumis <[email protected]>
AuthorDate: Wed Apr 8 14:52:53 2026 +0300

    Add more tests and minor fixes for FilterValve, ProxyErrorReportValve and 
SemaphoreValve
---
 .../apache/catalina/valves/TestFilterValve.java    |  99 ++++----
 .../catalina/valves/TestProxyErrorReportValve.java | 121 +++++-----
 .../apache/catalina/valves/TestSemaphoreValve.java | 254 ++++++++++++++++++---
 3 files changed, 333 insertions(+), 141 deletions(-)

diff --git a/test/org/apache/catalina/valves/TestFilterValve.java 
b/test/org/apache/catalina/valves/TestFilterValve.java
index cfc3a5abb7..dd2d918c5c 100644
--- a/test/org/apache/catalina/valves/TestFilterValve.java
+++ b/test/org/apache/catalina/valves/TestFilterValve.java
@@ -17,16 +17,16 @@
 package org.apache.catalina.valves;
 
 import java.io.IOException;
-import java.nio.charset.StandardCharsets;
 import java.util.Collections;
+import java.util.List;
 
 import jakarta.servlet.Filter;
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.ServletException;
 import jakarta.servlet.ServletRequest;
 import jakarta.servlet.ServletResponse;
-import jakarta.servlet.http.HttpServlet;
 import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletRequestWrapper;
 import jakarta.servlet.http.HttpServletResponse;
 
 import org.junit.Assert;
@@ -47,8 +47,8 @@ public class TestFilterValve extends TomcatBaseTest {
 
         Context ctx = getProgrammaticRootContext();
 
-        Tomcat.addServlet(ctx, "ok", new OkServlet());
-        ctx.addServletMappingDecoded("/", "ok");
+        Tomcat.addServlet(ctx, "hello", new HelloWorldServlet());
+        ctx.addServletMappingDecoded("/", "hello");
 
         FilterValve valve = new FilterValve();
         valve.setFilterClass(PassthroughFilter.class.getName());
@@ -57,11 +57,10 @@ public class TestFilterValve extends TomcatBaseTest {
         tomcat.start();
 
         ByteChunk res = new ByteChunk();
-        res.setCharset(StandardCharsets.UTF_8);
         int rc = getUrl("http://localhost:"; + getPort(), res, null);
 
         Assert.assertEquals(HttpServletResponse.SC_OK, rc);
-        Assert.assertEquals("OK", res.toString());
+        Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString());
     }
 
 
@@ -71,8 +70,8 @@ public class TestFilterValve extends TomcatBaseTest {
 
         Context ctx = getProgrammaticRootContext();
 
-        Tomcat.addServlet(ctx, "ok", new OkServlet());
-        ctx.addServletMappingDecoded("/", "ok");
+        Tomcat.addServlet(ctx, "hello", new HelloWorldServlet());
+        ctx.addServletMappingDecoded("/", "hello");
 
         FilterValve valve = new FilterValve();
         valve.setFilterClass(BlockingFilter.class.getName());
@@ -81,14 +80,33 @@ public class TestFilterValve extends TomcatBaseTest {
         tomcat.start();
 
         ByteChunk res = new ByteChunk();
-        res.setCharset(StandardCharsets.UTF_8);
         int rc = getUrl("http://localhost:"; + getPort(), res, null);
 
         Assert.assertEquals(HttpServletResponse.SC_FORBIDDEN, rc);
     }
 
-
     @Test
+    public void testFilterWrappingRequestThrows() throws Exception {
+        Tomcat tomcat = getTomcatInstance();
+
+        Context ctx = getProgrammaticRootContext();
+
+        Tomcat.addServlet(ctx, "hello", new HelloWorldServlet());
+        ctx.addServletMappingDecoded("/", "hello");
+
+        FilterValve valve = new FilterValve();
+        valve.setFilterClass(WrappingFilter.class.getName());
+        ctx.getPipeline().addValve(valve);
+
+        tomcat.start();
+
+        int rc = getUrl("http://localhost:"; + getPort(), new ByteChunk(), 
null);
+
+        Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc);
+    }
+
+
+    @Test(expected = LifecycleException.class)
     public void testNullFilterClassThrowsOnStart() throws Exception {
         Tomcat tomcat = getTomcatInstance();
 
@@ -98,18 +116,11 @@ public class TestFilterValve extends TomcatBaseTest {
         // Do NOT set filterClassName
         ctx.getPipeline().addValve(valve);
 
-        boolean threw = false;
-        try {
-            tomcat.start();
-        } catch (LifecycleException e) {
-            threw = true;
-        }
-
-        Assert.assertTrue("Should throw LifecycleException for null filter 
class", threw);
+        tomcat.start();
     }
 
 
-    @Test
+    @Test(expected = LifecycleException.class)
     public void testInvalidFilterClassThrowsOnStart() throws Exception {
         Tomcat tomcat = getTomcatInstance();
 
@@ -119,26 +130,19 @@ public class TestFilterValve extends TomcatBaseTest {
         valve.setFilterClass("com.nonexistent.FakeFilter");
         ctx.getPipeline().addValve(valve);
 
-        boolean threw = false;
-        try {
-            tomcat.start();
-        } catch (LifecycleException e) {
-            threw = true;
-        }
-
-        Assert.assertTrue("Should throw LifecycleException for invalid filter 
class", threw);
+        tomcat.start();
     }
 
 
     @Test
-    public void testGetFilterNameReturnsNull() throws Exception {
+    public void testGetFilterNameReturnsNull() {
         FilterValve valve = new FilterValve();
         Assert.assertNull(valve.getFilterName());
     }
 
 
     @Test
-    public void testInitParams() throws Exception {
+    public void testInitParams() {
         FilterValve valve = new FilterValve();
 
         valve.addInitParam("key1", "value1");
@@ -148,7 +152,7 @@ public class TestFilterValve extends TomcatBaseTest {
         Assert.assertEquals("value2", valve.getInitParameter("key2"));
         Assert.assertNull(valve.getInitParameter("nonexistent"));
 
-        java.util.List<String> names = 
Collections.list(valve.getInitParameterNames());
+        List<String> names = Collections.list(valve.getInitParameterNames());
         Assert.assertEquals(2, names.size());
         Assert.assertTrue(names.contains("key1"));
         Assert.assertTrue(names.contains("key2"));
@@ -156,7 +160,7 @@ public class TestFilterValve extends TomcatBaseTest {
 
 
     @Test
-    public void testInitParamsEmpty() throws Exception {
+    public void testInitParamsEmpty() {
         FilterValve valve = new FilterValve();
 
         Assert.assertNull(valve.getInitParameter("anything"));
@@ -165,7 +169,7 @@ public class TestFilterValve extends TomcatBaseTest {
 
 
     @Test
-    public void testGetSetFilterClassName() throws Exception {
+    public void testGetSetFilterClassName() {
         FilterValve valve = new FilterValve();
 
         Assert.assertNull(valve.getFilterClassName());
@@ -173,11 +177,16 @@ public class TestFilterValve extends TomcatBaseTest {
         valve.setFilterClassName("com.example.MyFilter");
         Assert.assertEquals("com.example.MyFilter", 
valve.getFilterClassName());
 
-        // setFilterClass is an alias
         valve.setFilterClass("com.example.OtherFilter");
         Assert.assertEquals("com.example.OtherFilter", 
valve.getFilterClassName());
     }
 
+    @Test(expected = IllegalStateException.class)
+    public void testGetServletContextThrowsBeforeStart() {
+        FilterValve valve = new FilterValve();
+        valve.getServletContext();
+    }
+
 
     /**
      * A Filter that passes the request through to the next element in the 
chain.
@@ -186,35 +195,35 @@ public class TestFilterValve extends TomcatBaseTest {
 
         @Override
         public void doFilter(ServletRequest request, ServletResponse response,
-                FilterChain chain) throws IOException, ServletException {
+                             FilterChain chain) throws IOException, 
ServletException {
             chain.doFilter(request, response);
         }
     }
 
 
     /**
-     * A Filter that blocks the request by sending a 403 response without
-     * calling chain.doFilter().
+     * A Filter that blocks the request by sending a 403 response without 
calling chain.doFilter().
      */
     public static final class BlockingFilter implements Filter {
 
         @Override
         public void doFilter(ServletRequest request, ServletResponse response,
-                FilterChain chain) throws IOException, ServletException {
+                             FilterChain chain) throws IOException, 
ServletException {
             ((HttpServletResponse) 
response).sendError(HttpServletResponse.SC_FORBIDDEN);
         }
     }
 
-
-    private static final class OkServlet extends HttpServlet {
-
-        private static final long serialVersionUID = 1L;
+    /**
+     * A Filter that wraps the request before calling chain.doFilter(), which 
FilterValve explicitly forbids.
+     */
+    public static final class WrappingFilter implements Filter {
 
         @Override
-        protected void doGet(HttpServletRequest req, HttpServletResponse resp)
-                throws ServletException, IOException {
-            resp.setContentType("text/plain");
-            resp.getWriter().print("OK");
+        public void doFilter(ServletRequest request, ServletResponse response,
+                             FilterChain chain) throws IOException, 
ServletException {
+            HttpServletRequestWrapper wrapped = new 
HttpServletRequestWrapper((HttpServletRequest) request);
+            chain.doFilter(wrapped, response);
         }
     }
+
 }
diff --git a/test/org/apache/catalina/valves/TestProxyErrorReportValve.java 
b/test/org/apache/catalina/valves/TestProxyErrorReportValve.java
index 98a92fe84f..8829fa2d63 100644
--- a/test/org/apache/catalina/valves/TestProxyErrorReportValve.java
+++ b/test/org/apache/catalina/valves/TestProxyErrorReportValve.java
@@ -17,12 +17,8 @@
 package org.apache.catalina.valves;
 
 import java.io.IOException;
-import java.nio.charset.StandardCharsets;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.io.Serial;
 
-import jakarta.servlet.ServletException;
 import jakarta.servlet.http.HttpServlet;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
@@ -35,7 +31,6 @@ import org.apache.catalina.core.StandardHost;
 import org.apache.catalina.startup.Tomcat;
 import org.apache.catalina.startup.TomcatBaseTest;
 import org.apache.tomcat.util.buf.ByteChunk;
-import org.apache.tomcat.util.descriptor.web.ErrorPage;
 
 public class TestProxyErrorReportValve extends TomcatBaseTest {
 
@@ -46,7 +41,8 @@ public class TestProxyErrorReportValve extends TomcatBaseTest 
{
     @Test
     public void testRedirectMode() throws Exception {
         Tomcat tomcat = getTomcatInstance();
-        ((StandardHost) 
tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE);
+        StandardHost host = (StandardHost) tomcat.getHost();
+        host.setErrorReportValveClass(PROXY_VALVE);
 
         Context ctx = getProgrammaticRootContext();
 
@@ -54,28 +50,49 @@ public class TestProxyErrorReportValve extends 
TomcatBaseTest {
                 HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Server broke"));
         ctx.addServletMappingDecoded("/", "error");
 
-        // Register an error page that the valve will redirect to
+        // Register an error page at the Host's error report valve level
+        // so findErrorPage() returns a URL for the redirect
+        Tomcat.addServlet(ctx, "errorPage", new ErrorPageServlet());
+        ctx.addServletMappingDecoded("/error-page", "errorPage");
+
+        tomcat.start();
+
+        ProxyErrorReportValve valve = (ProxyErrorReportValve) 
host.getPipeline().getFirst();
+        valve.setProperty("errorCode." + 
HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
+                "http://localhost:"; + getPort() + "/error-page");
+
+        int rc = getUrl("http://localhost:"; + getPort(), new ByteChunk(), 
false);
+
+        Assert.assertEquals(HttpServletResponse.SC_FOUND, rc);
+    }
+
+    @Test
+    public void testProxyMode() throws Exception {
+        Tomcat tomcat = getTomcatInstance();
+        StandardHost host = (StandardHost) tomcat.getHost();
+        host.setErrorReportValveClass(PROXY_VALVE);
+
+        Context ctx = getProgrammaticRootContext();
+
+        Tomcat.addServlet(ctx, "error", new SendErrorServlet(
+                HttpServletResponse.SC_NOT_FOUND, "Not found"));
+        ctx.addServletMappingDecoded("/", "error");
+
         Tomcat.addServlet(ctx, "errorPage", new ErrorPageServlet());
         ctx.addServletMappingDecoded("/error-page", "errorPage");
-        ErrorPage errorPage = new ErrorPage();
-        errorPage.setErrorCode(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
-        errorPage.setLocation("/error-page");
-        ctx.addErrorPage(errorPage);
 
         tomcat.start();
 
+        ProxyErrorReportValve valve = (ProxyErrorReportValve) 
host.getPipeline().getFirst();
+        valve.setUseRedirect(false);
+        valve.setProperty("errorCode." + HttpServletResponse.SC_NOT_FOUND,
+                "http://localhost:"; + getPort() + "/error-page");
+
         ByteChunk res = new ByteChunk();
-        res.setCharset(StandardCharsets.UTF_8);
-        Map<String, List<String>> resHead = new HashMap<>();
-        // Don't follow redirects
-        int rc = getUrl("http://localhost:"; + getPort(), res, resHead);
-
-        // ProxyErrorReportValve uses error pages from context — but since
-        // it calls findErrorPage() which uses Host-level error pages,
-        // the context error page might not be found and it falls back to
-        // the superclass. The test verifies the valve is loaded correctly.
-        Assert.assertTrue("Status should indicate an error",
-                rc >= 400 || rc == 302);
+        int rc = getUrl("http://localhost:"; + getPort(), res, null);
+
+        Assert.assertEquals(HttpServletResponse.SC_NOT_FOUND, rc);
+        Assert.assertTrue(res.toString().contains("ERROR_PAGE_OK"));
     }
 
 
@@ -90,20 +107,18 @@ public class TestProxyErrorReportValve extends 
TomcatBaseTest {
                 HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "No page 
configured"));
         ctx.addServletMappingDecoded("/", "error");
 
-        // No error page configured — should fall back to ErrorReportValve's 
report()
         tomcat.start();
 
         ByteChunk res = new ByteChunk();
-        res.setCharset(StandardCharsets.UTF_8);
         int rc = getUrl("http://localhost:"; + getPort(), res, null);
 
         Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc);
 
         String body = res.toString();
         Assert.assertNotNull(body);
-        // The default ErrorReportValve produces HTML
         Assert.assertTrue("Should contain HTML error report",
-                body.contains("<html>") || body.contains("<h1>"));
+                body.contains("html") &&
+                    
body.contains(String.valueOf(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)));
     }
 
 
@@ -114,17 +129,16 @@ public class TestProxyErrorReportValve extends 
TomcatBaseTest {
 
         Context ctx = getProgrammaticRootContext();
 
-        Tomcat.addServlet(ctx, "ok", new OkServlet());
-        ctx.addServletMappingDecoded("/", "ok");
+        Tomcat.addServlet(ctx, "hello", new HelloWorldServlet());
+        ctx.addServletMappingDecoded("/", "hello");
 
         tomcat.start();
 
         ByteChunk res = new ByteChunk();
-        res.setCharset(StandardCharsets.UTF_8);
         int rc = getUrl("http://localhost:"; + getPort(), res, null);
 
         Assert.assertEquals(HttpServletResponse.SC_OK, rc);
-        Assert.assertEquals("OK", res.toString());
+        Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString());
     }
 
 
@@ -142,28 +156,24 @@ public class TestProxyErrorReportValve extends 
TomcatBaseTest {
         tomcat.start();
 
         ByteChunk res = new ByteChunk();
-        res.setCharset(StandardCharsets.UTF_8);
         int rc = getUrl("http://localhost:"; + getPort(), res, null);
 
         Assert.assertEquals(HttpServletResponse.SC_NOT_FOUND, rc);
 
         String body = res.toString();
         Assert.assertNotNull(body);
-        // Falls back to parent ErrorReportValve HTML
         Assert.assertTrue("Should contain error report",
-                body.contains("404") || body.contains("Not Found"));
+                
body.contains(String.valueOf(HttpServletResponse.SC_NOT_FOUND)));
     }
 
 
     @Test
-    public void testGetSetProperties() throws Exception {
+    public void testGetSetProperties() {
         ProxyErrorReportValve valve = new ProxyErrorReportValve();
 
-        // Defaults
         Assert.assertTrue(valve.getUseRedirect());
         Assert.assertFalse(valve.getUsePropertiesFile());
 
-        // Setters
         valve.setUseRedirect(false);
         Assert.assertFalse(valve.getUseRedirect());
 
@@ -174,19 +184,19 @@ public class TestProxyErrorReportValve extends 
TomcatBaseTest {
 
     @Test
     public void testMessageInErrorReport() throws Exception {
+        final String customErrorMessage = "Custom error message";
         Tomcat tomcat = getTomcatInstance();
         ((StandardHost) 
tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE);
 
         Context ctx = getProgrammaticRootContext();
 
         Tomcat.addServlet(ctx, "error", new SendErrorServlet(
-                HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Custom error 
message"));
+                HttpServletResponse.SC_INTERNAL_SERVER_ERROR, 
customErrorMessage));
         ctx.addServletMappingDecoded("/", "error");
 
         tomcat.start();
 
         ByteChunk res = new ByteChunk();
-        res.setCharset(StandardCharsets.UTF_8);
         int rc = getUrl("http://localhost:"; + getPort(), res, null);
 
         Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc);
@@ -194,8 +204,7 @@ public class TestProxyErrorReportValve extends 
TomcatBaseTest {
         String body = res.toString();
         Assert.assertNotNull(body);
         // Falls back to super.report() which includes the message
-        Assert.assertTrue("Should contain the custom error message",
-                body.contains("Custom error message"));
+        Assert.assertTrue(body.contains(customErrorMessage));
     }
 
 
@@ -212,21 +221,21 @@ public class TestProxyErrorReportValve extends 
TomcatBaseTest {
         tomcat.start();
 
         ByteChunk res = new ByteChunk();
-        res.setCharset(StandardCharsets.UTF_8);
         int rc = getUrl("http://localhost:"; + getPort(), res, null);
 
         Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc);
 
         String body = res.toString();
         Assert.assertNotNull(body);
-        Assert.assertTrue("Should contain exception info",
-                body.contains("RuntimeException"));
+        Assert.assertTrue(body.contains("RuntimeException"));
     }
 
 
     private static final class SendErrorServlet extends HttpServlet {
 
+        @Serial
         private static final long serialVersionUID = 1L;
+
         private final int statusCode;
         private final String message;
 
@@ -237,33 +246,19 @@ public class TestProxyErrorReportValve extends 
TomcatBaseTest {
 
         @Override
         protected void doGet(HttpServletRequest req, HttpServletResponse resp)
-                throws ServletException, IOException {
+                throws IOException {
             resp.sendError(statusCode, message);
         }
     }
 
-
-    private static final class OkServlet extends HttpServlet {
-
-        private static final long serialVersionUID = 1L;
-
-        @Override
-        protected void doGet(HttpServletRequest req, HttpServletResponse resp)
-                throws ServletException, IOException {
-            resp.setContentType("text/plain");
-            resp.getWriter().print("OK");
-        }
-    }
-
-
     private static final class ErrorPageServlet extends HttpServlet {
 
+        @Serial
         private static final long serialVersionUID = 1L;
 
         @Override
         protected void doGet(HttpServletRequest req, HttpServletResponse resp)
-                throws ServletException, IOException {
-            resp.setContentType("text/plain");
+                throws IOException {
             resp.getWriter().print("ERROR_PAGE_OK");
         }
     }
@@ -271,11 +266,11 @@ public class TestProxyErrorReportValve extends 
TomcatBaseTest {
 
     private static final class ExceptionServlet extends HttpServlet {
 
+        @Serial
         private static final long serialVersionUID = 1L;
 
         @Override
-        public void service(jakarta.servlet.ServletRequest request,
-                jakarta.servlet.ServletResponse response) throws IOException {
+        protected void doGet(HttpServletRequest req, HttpServletResponse resp) 
{
             throw new RuntimeException("Test exception");
         }
     }
diff --git a/test/org/apache/catalina/valves/TestSemaphoreValve.java 
b/test/org/apache/catalina/valves/TestSemaphoreValve.java
index eefe4d4acb..f5a619e6d4 100644
--- a/test/org/apache/catalina/valves/TestSemaphoreValve.java
+++ b/test/org/apache/catalina/valves/TestSemaphoreValve.java
@@ -17,12 +17,14 @@
 package org.apache.catalina.valves;
 
 import java.io.IOException;
-import java.nio.charset.StandardCharsets;
+import java.io.Serial;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
 
-import jakarta.servlet.ServletException;
 import jakarta.servlet.http.HttpServlet;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
@@ -31,6 +33,8 @@ import org.junit.Assert;
 import org.junit.Test;
 
 import org.apache.catalina.Context;
+import org.apache.catalina.connector.Request;
+import org.apache.catalina.connector.Response;
 import org.apache.catalina.startup.Tomcat;
 import org.apache.catalina.startup.TomcatBaseTest;
 import org.apache.tomcat.util.buf.ByteChunk;
@@ -44,8 +48,8 @@ public class TestSemaphoreValve extends TomcatBaseTest {
 
         Context ctx = getProgrammaticRootContext();
 
-        Tomcat.addServlet(ctx, "ok", new OkServlet());
-        ctx.addServletMappingDecoded("/", "ok");
+        Tomcat.addServlet(ctx, "hello", new HelloWorldServlet());
+        ctx.addServletMappingDecoded("/", "hello");
 
         SemaphoreValve valve = new SemaphoreValve();
         valve.setConcurrency(10);
@@ -54,11 +58,33 @@ public class TestSemaphoreValve extends TomcatBaseTest {
         tomcat.start();
 
         ByteChunk res = new ByteChunk();
-        res.setCharset(StandardCharsets.UTF_8);
         int rc = getUrl("http://localhost:"; + getPort(), res, null);
 
         Assert.assertEquals(HttpServletResponse.SC_OK, rc);
-        Assert.assertEquals("OK", res.toString());
+        Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString());
+    }
+
+    @Test
+    public void testInterruptedConcurrency() throws Exception {
+        Tomcat tomcat = getTomcatInstance();
+
+        Context ctx = getProgrammaticRootContext();
+
+        Tomcat.addServlet(ctx, "hello", new HelloWorldServlet());
+        ctx.addServletMappingDecoded("/", "hello");
+
+        SemaphoreValve valve = new SemaphoreValve();
+        valve.setConcurrency(10);
+        valve.setInterruptible(true);
+        ctx.getPipeline().addValve(valve);
+
+        tomcat.start();
+
+        ByteChunk res = new ByteChunk();
+        int rc = getUrl("http://localhost:"; + getPort(), res, null);
+
+        Assert.assertEquals(HttpServletResponse.SC_OK, rc);
+        Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString());
     }
 
 
@@ -76,7 +102,7 @@ public class TestSemaphoreValve extends TomcatBaseTest {
         SemaphoreValve valve = new SemaphoreValve();
         valve.setConcurrency(1);
         valve.setBlock(false);
-        valve.setHighConcurrencyStatus(503);
+        
valve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
         ctx.getPipeline().addValve(valve);
 
         tomcat.start();
@@ -85,9 +111,7 @@ public class TestSemaphoreValve extends TomcatBaseTest {
         AtomicInteger firstRc = new AtomicInteger();
         Thread firstThread = new Thread(() -> {
             try {
-                ByteChunk r = new ByteChunk();
-                r.setCharset(StandardCharsets.UTF_8);
-                firstRc.set(getUrl("http://localhost:"; + getPort(), r, null));
+                firstRc.set(getUrl("http://localhost:"; + getPort(), new 
ByteChunk(), null));
             } catch (IOException e) {
                 // Ignore
             }
@@ -99,9 +123,7 @@ public class TestSemaphoreValve extends TomcatBaseTest {
                 insideServlet.await(10, TimeUnit.SECONDS));
 
         // Second request — should be denied because concurrency=1 and 
block=false
-        ByteChunk res2 = new ByteChunk();
-        res2.setCharset(StandardCharsets.UTF_8);
-        int rc2 = getUrl("http://localhost:"; + getPort(), res2, null);
+        int rc2 = getUrl("http://localhost:"; + getPort(), new ByteChunk(), 
null);
 
         Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, rc2);
 
@@ -135,8 +157,7 @@ public class TestSemaphoreValve extends TomcatBaseTest {
         // First request holds the permit
         Thread firstThread = new Thread(() -> {
             try {
-                ByteChunk r = new ByteChunk();
-                getUrl("http://localhost:"; + getPort(), r, null);
+                getUrl("http://localhost:"; + getPort(), new ByteChunk(), null);
             } catch (IOException e) {
                 // Ignore
             }
@@ -147,10 +168,9 @@ public class TestSemaphoreValve extends TomcatBaseTest {
                 insideServlet.await(10, TimeUnit.SECONDS));
 
         // Second request — denied but no error status is sent
-        ByteChunk res2 = new ByteChunk();
-        int rc2 = getUrl("http://localhost:"; + getPort(), res2, null);
+        int rc2 = getUrl("http://localhost:"; + getPort(), new ByteChunk(), 
null);
 
-        // With no highConcurrencyStatus, response is 200 with no body
+        // With no highConcurrencyStatus, response is 200 without body
         Assert.assertEquals(HttpServletResponse.SC_OK, rc2);
 
         canReturn.countDown();
@@ -159,7 +179,7 @@ public class TestSemaphoreValve extends TomcatBaseTest {
 
 
     @Test
-    public void testGetSetProperties() throws Exception {
+    public void testGetSetProperties() {
         SemaphoreValve valve = new SemaphoreValve();
 
         // Defaults
@@ -193,8 +213,8 @@ public class TestSemaphoreValve extends TomcatBaseTest {
 
         Context ctx = getProgrammaticRootContext();
 
-        Tomcat.addServlet(ctx, "ok", new OkServlet());
-        ctx.addServletMappingDecoded("/", "ok");
+        Tomcat.addServlet(ctx, "hello", new HelloWorldServlet());
+        ctx.addServletMappingDecoded("/", "hello");
 
         SemaphoreValve valve = new SemaphoreValve();
         valve.setConcurrency(5);
@@ -203,30 +223,178 @@ public class TestSemaphoreValve extends TomcatBaseTest {
 
         tomcat.start();
 
+        Assert.assertNotNull(valve.semaphore);
+        Assert.assertTrue(valve.semaphore.isFair());
+        Assert.assertEquals(5, valve.semaphore.availablePermits());
+
         ByteChunk res = new ByteChunk();
-        res.setCharset(StandardCharsets.UTF_8);
         int rc = getUrl("http://localhost:"; + getPort(), res, null);
 
         Assert.assertEquals(HttpServletResponse.SC_OK, rc);
-        Assert.assertEquals("OK", res.toString());
+        Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString());
     }
 
+    @Test
+    public void testBlockingWaitsForPermit() throws Exception {
+        Tomcat tomcat = getTomcatInstance();
 
-    private static final class OkServlet extends HttpServlet {
+        Context ctx = getProgrammaticRootContext();
 
-        private static final long serialVersionUID = 1L;
+        CountDownLatch insideServlet = new CountDownLatch(1);
+        CountDownLatch canReturn = new CountDownLatch(1);
+        Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, 
canReturn));
+        ctx.addServletMappingDecoded("/", "slow");
 
-        @Override
-        protected void doGet(HttpServletRequest req, HttpServletResponse resp)
-                throws ServletException, IOException {
-            resp.setContentType("text/plain");
-            resp.getWriter().print("OK");
-        }
+        SemaphoreValve valve = new SemaphoreValve();
+        valve.setConcurrency(1);
+        valve.setBlock(true);
+        ctx.getPipeline().addValve(valve);
+
+        tomcat.start();
+
+        AtomicReference<Throwable> firstError = new AtomicReference<>();
+        Thread firstThread = new Thread(() -> {
+            try {
+                getUrl("http://localhost:"; + getPort(), new ByteChunk(), null);
+            } catch (IOException e) {
+                firstError.set(e);
+            }
+        });
+        firstThread.start();
+
+        Assert.assertTrue("First request should reach servlet",
+                insideServlet.await(10, TimeUnit.SECONDS));
+
+        AtomicInteger secondRc = new AtomicInteger();
+        AtomicReference<Throwable> secondError = new AtomicReference<>();
+        Thread secondThread = new Thread(() -> {
+            try {
+                secondRc.set(getUrl("http://localhost:"; + getPort(), new 
ByteChunk(), null));
+            } catch (IOException e) {
+                secondError.set(e);
+            }
+        });
+        secondThread.start();
+
+        // Give the second request time to arrive and block on the semaphore
+        Thread.sleep(500);
+
+        Assert.assertTrue("Second request should be blocked waiting for 
permit", secondThread.isAlive());
+
+        canReturn.countDown();
+        firstThread.join(10000);
+        Assert.assertNull(firstError.get());
+
+        secondThread.join(10000);
+        Assert.assertFalse(secondThread.isAlive());
+        Assert.assertNull(secondError.get());
+        Assert.assertEquals(HttpServletResponse.SC_OK, secondRc.get());
+    }
+
+    @Test
+    public void testControlConcurrencyBypass() throws Exception {
+        Tomcat tomcat = getTomcatInstance();
+
+        Context ctx = getProgrammaticRootContext();
+
+        CountDownLatch insideServlet = new CountDownLatch(1);
+        CountDownLatch canReturn = new CountDownLatch(1);
+        Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, 
canReturn));
+        ctx.addServletMappingDecoded("/slow", "slow");
+
+        Tomcat.addServlet(ctx, "hello", new HelloWorldServlet());
+        ctx.addServletMappingDecoded("/bypass", "hello");
+
+        SemaphoreValve valve = new SemaphoreValve() {
+            @Override
+            public boolean controlConcurrency(Request request, Response 
response) {
+                return !request.getDecodedRequestURI().equals("/bypass");
+            }
+        };
+        valve.setConcurrency(1);
+        valve.setBlock(false);
+        
valve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
+        ctx.getPipeline().addValve(valve);
+
+        tomcat.start();
+
+        Thread firstThread = new Thread(() -> {
+            try {
+                getUrl("http://localhost:"; + getPort() + "/slow", new 
ByteChunk(), null);
+            } catch (IOException e) {
+                // Ignored
+            }
+        });
+        firstThread.start();
+
+        Assert.assertTrue("First request should reach servlet",
+                insideServlet.await(10, TimeUnit.SECONDS));
+
+        // Request to /bypass should succeed despite all permits being held,
+        // because controlConcurrency() returns false for this path
+        int bypassRc = getUrl("http://localhost:"; + getPort() + "/bypass", new 
ByteChunk(), null);
+        Assert.assertEquals(HttpServletResponse.SC_OK, bypassRc);
+
+        int deniedRc = getUrl("http://localhost:"; + getPort() + "/slow", new 
ByteChunk(), null);
+        Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, 
deniedRc);
+
+        canReturn.countDown();
+        firstThread.join(10000);
     }
 
+    @Test
+    public void testInterruptibleDenied() throws Exception {
+        SemaphoreValve semaphoreValve = new SemaphoreValve();
+        semaphoreValve.setConcurrency(1);
+        semaphoreValve.setBlock(true);
+        semaphoreValve.setInterruptible(true);
+        
semaphoreValve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
+
+        semaphoreValve.semaphore = new Semaphore(1, false);
+
+        AtomicBoolean nextInvoked = new AtomicBoolean(false);
+        semaphoreValve.setNext(new ValveBase() {
+            @Override
+            public void invoke(Request request, Response response) {
+                nextInvoked.set(true);
+            }
+        });
+
+        MockResponse response = new MockResponse();
+
+        semaphoreValve.semaphore.acquire();
+
+        // On a new thread, valve will block on semaphore.acquire() because 
the permit is already held.
+        CountDownLatch invokeStarted = new CountDownLatch(1);
+        Thread blocked = new Thread(() -> {
+            invokeStarted.countDown();
+            try {
+                semaphoreValve.invoke(null, response);
+            } catch (Throwable t) {
+                // Ignored
+            }
+        });
+        blocked.start();
+
+        Assert.assertTrue(invokeStarted.await(10, TimeUnit.SECONDS));
+        Thread.sleep(200);
+
+        blocked.interrupt();
+        blocked.join(10000);
+        Assert.assertFalse(blocked.isAlive());
+
+        Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, 
response.getStatus());
+
+        Assert.assertFalse("Next valve should not be invoked when permit 
denied", nextInvoked.get());
+
+        Assert.assertEquals(0, semaphoreValve.semaphore.availablePermits());
+
+        semaphoreValve.semaphore.release();
+    }
 
     private static final class SlowServlet extends HttpServlet {
 
+        @Serial
         private static final long serialVersionUID = 1L;
         private final CountDownLatch insideServlet;
         private final CountDownLatch canReturn;
@@ -238,10 +406,10 @@ public class TestSemaphoreValve extends TomcatBaseTest {
 
         @Override
         protected void doGet(HttpServletRequest req, HttpServletResponse resp)
-                throws ServletException, IOException {
+                throws IOException {
             insideServlet.countDown();
             try {
-                canReturn.await(30, TimeUnit.SECONDS);
+                Assert.assertTrue(canReturn.await(30, TimeUnit.SECONDS));
             } catch (InterruptedException e) {
                 // Ignore
             }
@@ -249,4 +417,24 @@ public class TestSemaphoreValve extends TomcatBaseTest {
             resp.getWriter().print("OK");
         }
     }
+
+    public static class MockResponse extends Response {
+
+        public MockResponse() {
+            super(null);
+        }
+
+        private int status = HttpServletResponse.SC_OK;
+
+        @Override
+        public void sendError(int status) throws IOException {
+            this.status = status;
+        }
+
+        @Override
+        public int getStatus() {
+            return status;
+        }
+    }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to