This is an automated email from the ASF dual-hosted git repository. dsoumis pushed a commit to branch 11.0.x in repository https://gitbox.apache.org/repos/asf/tomcat.git
commit 48343a64f24258be7e273d49849ee866463be67c Author: Dimitris Soumis <[email protected]> AuthorDate: Wed Apr 8 14:52:53 2026 +0300 Add more tests and minor fixes for FilterValve, ProxyErrorReportValve and SemaphoreValve --- .../apache/catalina/valves/TestFilterValve.java | 99 ++++---- .../catalina/valves/TestProxyErrorReportValve.java | 121 +++++----- .../apache/catalina/valves/TestSemaphoreValve.java | 254 ++++++++++++++++++--- 3 files changed, 333 insertions(+), 141 deletions(-) diff --git a/test/org/apache/catalina/valves/TestFilterValve.java b/test/org/apache/catalina/valves/TestFilterValve.java index cfc3a5abb7..dd2d918c5c 100644 --- a/test/org/apache/catalina/valves/TestFilterValve.java +++ b/test/org/apache/catalina/valves/TestFilterValve.java @@ -17,16 +17,16 @@ package org.apache.catalina.valves; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.Collections; +import java.util.List; import jakarta.servlet.Filter; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletResponse; -import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; import jakarta.servlet.http.HttpServletResponse; import org.junit.Assert; @@ -47,8 +47,8 @@ public class TestFilterValve extends TomcatBaseTest { Context ctx = getProgrammaticRootContext(); - Tomcat.addServlet(ctx, "ok", new OkServlet()); - ctx.addServletMappingDecoded("/", "ok"); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); FilterValve valve = new FilterValve(); valve.setFilterClass(PassthroughFilter.class.getName()); @@ -57,11 +57,10 @@ public class TestFilterValve extends TomcatBaseTest { tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_OK, rc); - Assert.assertEquals("OK", res.toString()); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); } @@ -71,8 +70,8 @@ public class TestFilterValve extends TomcatBaseTest { Context ctx = getProgrammaticRootContext(); - Tomcat.addServlet(ctx, "ok", new OkServlet()); - ctx.addServletMappingDecoded("/", "ok"); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); FilterValve valve = new FilterValve(); valve.setFilterClass(BlockingFilter.class.getName()); @@ -81,14 +80,33 @@ public class TestFilterValve extends TomcatBaseTest { tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_FORBIDDEN, rc); } - @Test + public void testFilterWrappingRequestThrows() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + FilterValve valve = new FilterValve(); + valve.setFilterClass(WrappingFilter.class.getName()); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + int rc = getUrl("http://localhost:" + getPort(), new ByteChunk(), null); + + Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); + } + + + @Test(expected = LifecycleException.class) public void testNullFilterClassThrowsOnStart() throws Exception { Tomcat tomcat = getTomcatInstance(); @@ -98,18 +116,11 @@ public class TestFilterValve extends TomcatBaseTest { // Do NOT set filterClassName ctx.getPipeline().addValve(valve); - boolean threw = false; - try { - tomcat.start(); - } catch (LifecycleException e) { - threw = true; - } - - Assert.assertTrue("Should throw LifecycleException for null filter class", threw); + tomcat.start(); } - @Test + @Test(expected = LifecycleException.class) public void testInvalidFilterClassThrowsOnStart() throws Exception { Tomcat tomcat = getTomcatInstance(); @@ -119,26 +130,19 @@ public class TestFilterValve extends TomcatBaseTest { valve.setFilterClass("com.nonexistent.FakeFilter"); ctx.getPipeline().addValve(valve); - boolean threw = false; - try { - tomcat.start(); - } catch (LifecycleException e) { - threw = true; - } - - Assert.assertTrue("Should throw LifecycleException for invalid filter class", threw); + tomcat.start(); } @Test - public void testGetFilterNameReturnsNull() throws Exception { + public void testGetFilterNameReturnsNull() { FilterValve valve = new FilterValve(); Assert.assertNull(valve.getFilterName()); } @Test - public void testInitParams() throws Exception { + public void testInitParams() { FilterValve valve = new FilterValve(); valve.addInitParam("key1", "value1"); @@ -148,7 +152,7 @@ public class TestFilterValve extends TomcatBaseTest { Assert.assertEquals("value2", valve.getInitParameter("key2")); Assert.assertNull(valve.getInitParameter("nonexistent")); - java.util.List<String> names = Collections.list(valve.getInitParameterNames()); + List<String> names = Collections.list(valve.getInitParameterNames()); Assert.assertEquals(2, names.size()); Assert.assertTrue(names.contains("key1")); Assert.assertTrue(names.contains("key2")); @@ -156,7 +160,7 @@ public class TestFilterValve extends TomcatBaseTest { @Test - public void testInitParamsEmpty() throws Exception { + public void testInitParamsEmpty() { FilterValve valve = new FilterValve(); Assert.assertNull(valve.getInitParameter("anything")); @@ -165,7 +169,7 @@ public class TestFilterValve extends TomcatBaseTest { @Test - public void testGetSetFilterClassName() throws Exception { + public void testGetSetFilterClassName() { FilterValve valve = new FilterValve(); Assert.assertNull(valve.getFilterClassName()); @@ -173,11 +177,16 @@ public class TestFilterValve extends TomcatBaseTest { valve.setFilterClassName("com.example.MyFilter"); Assert.assertEquals("com.example.MyFilter", valve.getFilterClassName()); - // setFilterClass is an alias valve.setFilterClass("com.example.OtherFilter"); Assert.assertEquals("com.example.OtherFilter", valve.getFilterClassName()); } + @Test(expected = IllegalStateException.class) + public void testGetServletContextThrowsBeforeStart() { + FilterValve valve = new FilterValve(); + valve.getServletContext(); + } + /** * A Filter that passes the request through to the next element in the chain. @@ -186,35 +195,35 @@ public class TestFilterValve extends TomcatBaseTest { @Override public void doFilter(ServletRequest request, ServletResponse response, - FilterChain chain) throws IOException, ServletException { + FilterChain chain) throws IOException, ServletException { chain.doFilter(request, response); } } /** - * A Filter that blocks the request by sending a 403 response without - * calling chain.doFilter(). + * A Filter that blocks the request by sending a 403 response without calling chain.doFilter(). */ public static final class BlockingFilter implements Filter { @Override public void doFilter(ServletRequest request, ServletResponse response, - FilterChain chain) throws IOException, ServletException { + FilterChain chain) throws IOException, ServletException { ((HttpServletResponse) response).sendError(HttpServletResponse.SC_FORBIDDEN); } } - - private static final class OkServlet extends HttpServlet { - - private static final long serialVersionUID = 1L; + /** + * A Filter that wraps the request before calling chain.doFilter(), which FilterValve explicitly forbids. + */ + public static final class WrappingFilter implements Filter { @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { - resp.setContentType("text/plain"); - resp.getWriter().print("OK"); + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + HttpServletRequestWrapper wrapped = new HttpServletRequestWrapper((HttpServletRequest) request); + chain.doFilter(wrapped, response); } } + } diff --git a/test/org/apache/catalina/valves/TestProxyErrorReportValve.java b/test/org/apache/catalina/valves/TestProxyErrorReportValve.java index 98a92fe84f..8829fa2d63 100644 --- a/test/org/apache/catalina/valves/TestProxyErrorReportValve.java +++ b/test/org/apache/catalina/valves/TestProxyErrorReportValve.java @@ -17,12 +17,8 @@ package org.apache.catalina.valves; import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.io.Serial; -import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -35,7 +31,6 @@ import org.apache.catalina.core.StandardHost; import org.apache.catalina.startup.Tomcat; import org.apache.catalina.startup.TomcatBaseTest; import org.apache.tomcat.util.buf.ByteChunk; -import org.apache.tomcat.util.descriptor.web.ErrorPage; public class TestProxyErrorReportValve extends TomcatBaseTest { @@ -46,7 +41,8 @@ public class TestProxyErrorReportValve extends TomcatBaseTest { @Test public void testRedirectMode() throws Exception { Tomcat tomcat = getTomcatInstance(); - ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + StandardHost host = (StandardHost) tomcat.getHost(); + host.setErrorReportValveClass(PROXY_VALVE); Context ctx = getProgrammaticRootContext(); @@ -54,28 +50,49 @@ public class TestProxyErrorReportValve extends TomcatBaseTest { HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Server broke")); ctx.addServletMappingDecoded("/", "error"); - // Register an error page that the valve will redirect to + // Register an error page at the Host's error report valve level + // so findErrorPage() returns a URL for the redirect + Tomcat.addServlet(ctx, "errorPage", new ErrorPageServlet()); + ctx.addServletMappingDecoded("/error-page", "errorPage"); + + tomcat.start(); + + ProxyErrorReportValve valve = (ProxyErrorReportValve) host.getPipeline().getFirst(); + valve.setProperty("errorCode." + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + "http://localhost:" + getPort() + "/error-page"); + + int rc = getUrl("http://localhost:" + getPort(), new ByteChunk(), false); + + Assert.assertEquals(HttpServletResponse.SC_FOUND, rc); + } + + @Test + public void testProxyMode() throws Exception { + Tomcat tomcat = getTomcatInstance(); + StandardHost host = (StandardHost) tomcat.getHost(); + host.setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "error", new SendErrorServlet( + HttpServletResponse.SC_NOT_FOUND, "Not found")); + ctx.addServletMappingDecoded("/", "error"); + Tomcat.addServlet(ctx, "errorPage", new ErrorPageServlet()); ctx.addServletMappingDecoded("/error-page", "errorPage"); - ErrorPage errorPage = new ErrorPage(); - errorPage.setErrorCode(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); - errorPage.setLocation("/error-page"); - ctx.addErrorPage(errorPage); tomcat.start(); + ProxyErrorReportValve valve = (ProxyErrorReportValve) host.getPipeline().getFirst(); + valve.setUseRedirect(false); + valve.setProperty("errorCode." + HttpServletResponse.SC_NOT_FOUND, + "http://localhost:" + getPort() + "/error-page"); + ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); - Map<String, List<String>> resHead = new HashMap<>(); - // Don't follow redirects - int rc = getUrl("http://localhost:" + getPort(), res, resHead); - - // ProxyErrorReportValve uses error pages from context — but since - // it calls findErrorPage() which uses Host-level error pages, - // the context error page might not be found and it falls back to - // the superclass. The test verifies the valve is loaded correctly. - Assert.assertTrue("Status should indicate an error", - rc >= 400 || rc == 302); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_NOT_FOUND, rc); + Assert.assertTrue(res.toString().contains("ERROR_PAGE_OK")); } @@ -90,20 +107,18 @@ public class TestProxyErrorReportValve extends TomcatBaseTest { HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "No page configured")); ctx.addServletMappingDecoded("/", "error"); - // No error page configured — should fall back to ErrorReportValve's report() tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); String body = res.toString(); Assert.assertNotNull(body); - // The default ErrorReportValve produces HTML Assert.assertTrue("Should contain HTML error report", - body.contains("<html>") || body.contains("<h1>")); + body.contains("html") && + body.contains(String.valueOf(HttpServletResponse.SC_INTERNAL_SERVER_ERROR))); } @@ -114,17 +129,16 @@ public class TestProxyErrorReportValve extends TomcatBaseTest { Context ctx = getProgrammaticRootContext(); - Tomcat.addServlet(ctx, "ok", new OkServlet()); - ctx.addServletMappingDecoded("/", "ok"); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_OK, rc); - Assert.assertEquals("OK", res.toString()); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); } @@ -142,28 +156,24 @@ public class TestProxyErrorReportValve extends TomcatBaseTest { tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_NOT_FOUND, rc); String body = res.toString(); Assert.assertNotNull(body); - // Falls back to parent ErrorReportValve HTML Assert.assertTrue("Should contain error report", - body.contains("404") || body.contains("Not Found")); + body.contains(String.valueOf(HttpServletResponse.SC_NOT_FOUND))); } @Test - public void testGetSetProperties() throws Exception { + public void testGetSetProperties() { ProxyErrorReportValve valve = new ProxyErrorReportValve(); - // Defaults Assert.assertTrue(valve.getUseRedirect()); Assert.assertFalse(valve.getUsePropertiesFile()); - // Setters valve.setUseRedirect(false); Assert.assertFalse(valve.getUseRedirect()); @@ -174,19 +184,19 @@ public class TestProxyErrorReportValve extends TomcatBaseTest { @Test public void testMessageInErrorReport() throws Exception { + final String customErrorMessage = "Custom error message"; Tomcat tomcat = getTomcatInstance(); ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); Context ctx = getProgrammaticRootContext(); Tomcat.addServlet(ctx, "error", new SendErrorServlet( - HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Custom error message")); + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, customErrorMessage)); ctx.addServletMappingDecoded("/", "error"); tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); @@ -194,8 +204,7 @@ public class TestProxyErrorReportValve extends TomcatBaseTest { String body = res.toString(); Assert.assertNotNull(body); // Falls back to super.report() which includes the message - Assert.assertTrue("Should contain the custom error message", - body.contains("Custom error message")); + Assert.assertTrue(body.contains(customErrorMessage)); } @@ -212,21 +221,21 @@ public class TestProxyErrorReportValve extends TomcatBaseTest { tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); String body = res.toString(); Assert.assertNotNull(body); - Assert.assertTrue("Should contain exception info", - body.contains("RuntimeException")); + Assert.assertTrue(body.contains("RuntimeException")); } private static final class SendErrorServlet extends HttpServlet { + @Serial private static final long serialVersionUID = 1L; + private final int statusCode; private final String message; @@ -237,33 +246,19 @@ public class TestProxyErrorReportValve extends TomcatBaseTest { @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { + throws IOException { resp.sendError(statusCode, message); } } - - private static final class OkServlet extends HttpServlet { - - private static final long serialVersionUID = 1L; - - @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { - resp.setContentType("text/plain"); - resp.getWriter().print("OK"); - } - } - - private static final class ErrorPageServlet extends HttpServlet { + @Serial private static final long serialVersionUID = 1L; @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { - resp.setContentType("text/plain"); + throws IOException { resp.getWriter().print("ERROR_PAGE_OK"); } } @@ -271,11 +266,11 @@ public class TestProxyErrorReportValve extends TomcatBaseTest { private static final class ExceptionServlet extends HttpServlet { + @Serial private static final long serialVersionUID = 1L; @Override - public void service(jakarta.servlet.ServletRequest request, - jakarta.servlet.ServletResponse response) throws IOException { + protected void doGet(HttpServletRequest req, HttpServletResponse resp) { throw new RuntimeException("Test exception"); } } diff --git a/test/org/apache/catalina/valves/TestSemaphoreValve.java b/test/org/apache/catalina/valves/TestSemaphoreValve.java index eefe4d4acb..f5a619e6d4 100644 --- a/test/org/apache/catalina/valves/TestSemaphoreValve.java +++ b/test/org/apache/catalina/valves/TestSemaphoreValve.java @@ -17,12 +17,14 @@ package org.apache.catalina.valves; import java.io.IOException; -import java.nio.charset.StandardCharsets; +import java.io.Serial; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; -import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -31,6 +33,8 @@ import org.junit.Assert; import org.junit.Test; import org.apache.catalina.Context; +import org.apache.catalina.connector.Request; +import org.apache.catalina.connector.Response; import org.apache.catalina.startup.Tomcat; import org.apache.catalina.startup.TomcatBaseTest; import org.apache.tomcat.util.buf.ByteChunk; @@ -44,8 +48,8 @@ public class TestSemaphoreValve extends TomcatBaseTest { Context ctx = getProgrammaticRootContext(); - Tomcat.addServlet(ctx, "ok", new OkServlet()); - ctx.addServletMappingDecoded("/", "ok"); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); SemaphoreValve valve = new SemaphoreValve(); valve.setConcurrency(10); @@ -54,11 +58,33 @@ public class TestSemaphoreValve extends TomcatBaseTest { tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_OK, rc); - Assert.assertEquals("OK", res.toString()); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); + } + + @Test + public void testInterruptedConcurrency() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(10); + valve.setInterruptible(true); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_OK, rc); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); } @@ -76,7 +102,7 @@ public class TestSemaphoreValve extends TomcatBaseTest { SemaphoreValve valve = new SemaphoreValve(); valve.setConcurrency(1); valve.setBlock(false); - valve.setHighConcurrencyStatus(503); + valve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE); ctx.getPipeline().addValve(valve); tomcat.start(); @@ -85,9 +111,7 @@ public class TestSemaphoreValve extends TomcatBaseTest { AtomicInteger firstRc = new AtomicInteger(); Thread firstThread = new Thread(() -> { try { - ByteChunk r = new ByteChunk(); - r.setCharset(StandardCharsets.UTF_8); - firstRc.set(getUrl("http://localhost:" + getPort(), r, null)); + firstRc.set(getUrl("http://localhost:" + getPort(), new ByteChunk(), null)); } catch (IOException e) { // Ignore } @@ -99,9 +123,7 @@ public class TestSemaphoreValve extends TomcatBaseTest { insideServlet.await(10, TimeUnit.SECONDS)); // Second request — should be denied because concurrency=1 and block=false - ByteChunk res2 = new ByteChunk(); - res2.setCharset(StandardCharsets.UTF_8); - int rc2 = getUrl("http://localhost:" + getPort(), res2, null); + int rc2 = getUrl("http://localhost:" + getPort(), new ByteChunk(), null); Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, rc2); @@ -135,8 +157,7 @@ public class TestSemaphoreValve extends TomcatBaseTest { // First request holds the permit Thread firstThread = new Thread(() -> { try { - ByteChunk r = new ByteChunk(); - getUrl("http://localhost:" + getPort(), r, null); + getUrl("http://localhost:" + getPort(), new ByteChunk(), null); } catch (IOException e) { // Ignore } @@ -147,10 +168,9 @@ public class TestSemaphoreValve extends TomcatBaseTest { insideServlet.await(10, TimeUnit.SECONDS)); // Second request — denied but no error status is sent - ByteChunk res2 = new ByteChunk(); - int rc2 = getUrl("http://localhost:" + getPort(), res2, null); + int rc2 = getUrl("http://localhost:" + getPort(), new ByteChunk(), null); - // With no highConcurrencyStatus, response is 200 with no body + // With no highConcurrencyStatus, response is 200 without body Assert.assertEquals(HttpServletResponse.SC_OK, rc2); canReturn.countDown(); @@ -159,7 +179,7 @@ public class TestSemaphoreValve extends TomcatBaseTest { @Test - public void testGetSetProperties() throws Exception { + public void testGetSetProperties() { SemaphoreValve valve = new SemaphoreValve(); // Defaults @@ -193,8 +213,8 @@ public class TestSemaphoreValve extends TomcatBaseTest { Context ctx = getProgrammaticRootContext(); - Tomcat.addServlet(ctx, "ok", new OkServlet()); - ctx.addServletMappingDecoded("/", "ok"); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); SemaphoreValve valve = new SemaphoreValve(); valve.setConcurrency(5); @@ -203,30 +223,178 @@ public class TestSemaphoreValve extends TomcatBaseTest { tomcat.start(); + Assert.assertNotNull(valve.semaphore); + Assert.assertTrue(valve.semaphore.isFair()); + Assert.assertEquals(5, valve.semaphore.availablePermits()); + ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_OK, rc); - Assert.assertEquals("OK", res.toString()); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); } + @Test + public void testBlockingWaitsForPermit() throws Exception { + Tomcat tomcat = getTomcatInstance(); - private static final class OkServlet extends HttpServlet { + Context ctx = getProgrammaticRootContext(); - private static final long serialVersionUID = 1L; + CountDownLatch insideServlet = new CountDownLatch(1); + CountDownLatch canReturn = new CountDownLatch(1); + Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, canReturn)); + ctx.addServletMappingDecoded("/", "slow"); - @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { - resp.setContentType("text/plain"); - resp.getWriter().print("OK"); - } + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(1); + valve.setBlock(true); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + AtomicReference<Throwable> firstError = new AtomicReference<>(); + Thread firstThread = new Thread(() -> { + try { + getUrl("http://localhost:" + getPort(), new ByteChunk(), null); + } catch (IOException e) { + firstError.set(e); + } + }); + firstThread.start(); + + Assert.assertTrue("First request should reach servlet", + insideServlet.await(10, TimeUnit.SECONDS)); + + AtomicInteger secondRc = new AtomicInteger(); + AtomicReference<Throwable> secondError = new AtomicReference<>(); + Thread secondThread = new Thread(() -> { + try { + secondRc.set(getUrl("http://localhost:" + getPort(), new ByteChunk(), null)); + } catch (IOException e) { + secondError.set(e); + } + }); + secondThread.start(); + + // Give the second request time to arrive and block on the semaphore + Thread.sleep(500); + + Assert.assertTrue("Second request should be blocked waiting for permit", secondThread.isAlive()); + + canReturn.countDown(); + firstThread.join(10000); + Assert.assertNull(firstError.get()); + + secondThread.join(10000); + Assert.assertFalse(secondThread.isAlive()); + Assert.assertNull(secondError.get()); + Assert.assertEquals(HttpServletResponse.SC_OK, secondRc.get()); + } + + @Test + public void testControlConcurrencyBypass() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + CountDownLatch insideServlet = new CountDownLatch(1); + CountDownLatch canReturn = new CountDownLatch(1); + Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, canReturn)); + ctx.addServletMappingDecoded("/slow", "slow"); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/bypass", "hello"); + + SemaphoreValve valve = new SemaphoreValve() { + @Override + public boolean controlConcurrency(Request request, Response response) { + return !request.getDecodedRequestURI().equals("/bypass"); + } + }; + valve.setConcurrency(1); + valve.setBlock(false); + valve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + Thread firstThread = new Thread(() -> { + try { + getUrl("http://localhost:" + getPort() + "/slow", new ByteChunk(), null); + } catch (IOException e) { + // Ignored + } + }); + firstThread.start(); + + Assert.assertTrue("First request should reach servlet", + insideServlet.await(10, TimeUnit.SECONDS)); + + // Request to /bypass should succeed despite all permits being held, + // because controlConcurrency() returns false for this path + int bypassRc = getUrl("http://localhost:" + getPort() + "/bypass", new ByteChunk(), null); + Assert.assertEquals(HttpServletResponse.SC_OK, bypassRc); + + int deniedRc = getUrl("http://localhost:" + getPort() + "/slow", new ByteChunk(), null); + Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, deniedRc); + + canReturn.countDown(); + firstThread.join(10000); } + @Test + public void testInterruptibleDenied() throws Exception { + SemaphoreValve semaphoreValve = new SemaphoreValve(); + semaphoreValve.setConcurrency(1); + semaphoreValve.setBlock(true); + semaphoreValve.setInterruptible(true); + semaphoreValve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE); + + semaphoreValve.semaphore = new Semaphore(1, false); + + AtomicBoolean nextInvoked = new AtomicBoolean(false); + semaphoreValve.setNext(new ValveBase() { + @Override + public void invoke(Request request, Response response) { + nextInvoked.set(true); + } + }); + + MockResponse response = new MockResponse(); + + semaphoreValve.semaphore.acquire(); + + // On a new thread, valve will block on semaphore.acquire() because the permit is already held. + CountDownLatch invokeStarted = new CountDownLatch(1); + Thread blocked = new Thread(() -> { + invokeStarted.countDown(); + try { + semaphoreValve.invoke(null, response); + } catch (Throwable t) { + // Ignored + } + }); + blocked.start(); + + Assert.assertTrue(invokeStarted.await(10, TimeUnit.SECONDS)); + Thread.sleep(200); + + blocked.interrupt(); + blocked.join(10000); + Assert.assertFalse(blocked.isAlive()); + + Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, response.getStatus()); + + Assert.assertFalse("Next valve should not be invoked when permit denied", nextInvoked.get()); + + Assert.assertEquals(0, semaphoreValve.semaphore.availablePermits()); + + semaphoreValve.semaphore.release(); + } private static final class SlowServlet extends HttpServlet { + @Serial private static final long serialVersionUID = 1L; private final CountDownLatch insideServlet; private final CountDownLatch canReturn; @@ -238,10 +406,10 @@ public class TestSemaphoreValve extends TomcatBaseTest { @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { + throws IOException { insideServlet.countDown(); try { - canReturn.await(30, TimeUnit.SECONDS); + Assert.assertTrue(canReturn.await(30, TimeUnit.SECONDS)); } catch (InterruptedException e) { // Ignore } @@ -249,4 +417,24 @@ public class TestSemaphoreValve extends TomcatBaseTest { resp.getWriter().print("OK"); } } + + public static class MockResponse extends Response { + + public MockResponse() { + super(null); + } + + private int status = HttpServletResponse.SC_OK; + + @Override + public void sendError(int status) throws IOException { + this.status = status; + } + + @Override + public int getStatus() { + return status; + } + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
