This is an automated email from the ASF dual-hosted git repository.

snoopdave pushed a commit to branch one-time-salt
in repository https://gitbox.apache.org/repos/asf/roller.git

commit e2dbf4ddd14ad01e7a9bae0ca1e5e7e6dd9d0d73
Author: David M. Johnson <[email protected]>
AuthorDate: Sat Oct 5 16:43:46 2024 -0400

    Implement one-time-salt use and add comprehensive tests
---
 .../weblogger/ui/core/filters/LoadSaltFilter.java  |  15 ++-
 .../ui/core/filters/ValidateSaltFilter.java        |  24 ++--
 .../ui/core/filters/LoadSaltFilterTest.java        |  88 +++++++++++++
 .../ui/core/filters/ValidateSaltFilterTest.java    | 144 +++++++++++++++++++++
 4 files changed, 255 insertions(+), 16 deletions(-)

diff --git 
a/app/src/main/java/org/apache/roller/weblogger/ui/core/filters/LoadSaltFilter.java
 
b/app/src/main/java/org/apache/roller/weblogger/ui/core/filters/LoadSaltFilter.java
index e6416ce96..b2e63915d 100644
--- 
a/app/src/main/java/org/apache/roller/weblogger/ui/core/filters/LoadSaltFilter.java
+++ 
b/app/src/main/java/org/apache/roller/weblogger/ui/core/filters/LoadSaltFilter.java
@@ -37,13 +37,14 @@ public class LoadSaltFilter implements Filter {
         throws IOException, ServletException {
 
         HttpServletRequest httpReq = (HttpServletRequest) request;
-        RollerSession rses = RollerSession.getRollerSession(httpReq);
-        String userId = rses != null && rses.getAuthenticatedUser() != null ? 
rses.getAuthenticatedUser().getId() : "";
-
-        SaltCache saltCache = SaltCache.getInstance();
-        String salt = RandomStringUtils.random(20, 0, 0, true, true, null, new 
SecureRandom());
-        saltCache.put(salt, userId);
-        httpReq.setAttribute("salt", salt);
+        RollerSession rollerSession = RollerSession.getRollerSession(httpReq);
+        if (rollerSession != null) {
+            String userId = rollerSession.getAuthenticatedUser() != null ? 
rollerSession.getAuthenticatedUser().getId() : "";
+            SaltCache saltCache = SaltCache.getInstance();
+            String salt = RandomStringUtils.random(20, 0, 0, true, true, null, 
new SecureRandom());
+            saltCache.put(salt, userId);
+            httpReq.setAttribute("salt", salt);
+        }
 
         chain.doFilter(request, response);
     }
diff --git 
a/app/src/main/java/org/apache/roller/weblogger/ui/core/filters/ValidateSaltFilter.java
 
b/app/src/main/java/org/apache/roller/weblogger/ui/core/filters/ValidateSaltFilter.java
index 275ccd328..3ab6b80fd 100644
--- 
a/app/src/main/java/org/apache/roller/weblogger/ui/core/filters/ValidateSaltFilter.java
+++ 
b/app/src/main/java/org/apache/roller/weblogger/ui/core/filters/ValidateSaltFilter.java
@@ -52,16 +52,24 @@ public class ValidateSaltFilter implements Filter {
         HttpServletRequest httpReq = (HttpServletRequest) request;
 
         if ("POST".equals(httpReq.getMethod()) && 
!isIgnoredURL(httpReq.getServletPath())) {
-            RollerSession rses = RollerSession.getRollerSession(httpReq);
-            String userId = rses != null && rses.getAuthenticatedUser() != 
null ? rses.getAuthenticatedUser().getId() : "";
+            RollerSession rollerSession = 
RollerSession.getRollerSession(httpReq);
+            if (rollerSession != null) {
+                String userId = rollerSession.getAuthenticatedUser() != null ? 
rollerSession.getAuthenticatedUser().getId() : "";
 
-            String salt = httpReq.getParameter("salt");
-            SaltCache saltCache = SaltCache.getInstance();
-            if (salt == null || !Objects.equals(saltCache.get(salt), userId)) {
+                String salt = httpReq.getParameter("salt");
+                SaltCache saltCache = SaltCache.getInstance();
+                if (salt == null || !Objects.equals(saltCache.get(salt), 
userId)) {
+                    if (log.isDebugEnabled()) {
+                        log.debug("Valid salt value not found on POST to URL : 
" + httpReq.getServletPath());
+                    }
+                    throw new ServletException("Security Violation");
+                }
+
+                // Remove salt from cache after successful validation
+                saltCache.remove(salt);
                 if (log.isDebugEnabled()) {
-                    log.debug("Valid salt value not found on POST to URL : " + 
httpReq.getServletPath());
+                    log.debug("Salt used and invalidated: " + salt);
                 }
-                throw new ServletException("Security Violation");
             }
         }
 
@@ -70,8 +78,6 @@ public class ValidateSaltFilter implements Filter {
 
     @Override
     public void init(FilterConfig filterConfig) throws ServletException {
-
-        // Construct our list of ignored urls
         String urls = WebloggerConfig.getProperty("salt.ignored.urls");
         ignored = Set.of(StringUtils.stripAll(StringUtils.split(urls, ",")));
     }
diff --git 
a/app/src/test/java/org/apache/roller/weblogger/ui/core/filters/LoadSaltFilterTest.java
 
b/app/src/test/java/org/apache/roller/weblogger/ui/core/filters/LoadSaltFilterTest.java
new file mode 100644
index 000000000..5ace927a2
--- /dev/null
+++ 
b/app/src/test/java/org/apache/roller/weblogger/ui/core/filters/LoadSaltFilterTest.java
@@ -0,0 +1,88 @@
+package org.apache.roller.weblogger.ui.core.filters;
+
+import org.apache.roller.weblogger.pojos.User;
+import org.apache.roller.weblogger.ui.core.RollerSession;
+import org.apache.roller.weblogger.ui.rendering.util.cache.SaltCache;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
+import org.mockito.MockedStatic;
+import org.mockito.MockitoAnnotations;
+
+import javax.servlet.FilterChain;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+import static org.mockito.Mockito.*;
+
+public class LoadSaltFilterTest {
+
+    private LoadSaltFilter filter;
+
+    @Mock
+    private HttpServletRequest request;
+
+    @Mock
+    private HttpServletResponse response;
+
+    @Mock
+    private FilterChain chain;
+
+    @Mock
+    private RollerSession rollerSession;
+
+    @Mock
+    private SaltCache saltCache;
+
+    @BeforeEach
+    public void setUp() {
+        MockitoAnnotations.initMocks(this);
+        filter = new LoadSaltFilter();
+    }
+
+    @Test
+    public void testDoFilterGeneratesSalt() throws Exception {
+        try (MockedStatic<RollerSession> mockedRollerSession = 
mockStatic(RollerSession.class);
+             MockedStatic<SaltCache> mockedSaltCache = 
mockStatic(SaltCache.class)) {
+
+            mockedRollerSession.when(() -> 
RollerSession.getRollerSession(request)).thenReturn(rollerSession);
+            mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);
+
+            when(rollerSession.getAuthenticatedUser()).thenReturn(new 
TestUser("userId"));
+
+            filter.doFilter(request, response, chain);
+
+            verify(request).setAttribute(eq("salt"), anyString());
+            verify(saltCache).put(anyString(), eq("userId"));
+            verify(chain).doFilter(request, response);
+        }
+    }
+
+    @Test
+    public void testDoFilterWithNullRollerSession() throws Exception {
+        try (MockedStatic<RollerSession> mockedRollerSession = 
mockStatic(RollerSession.class);
+             MockedStatic<SaltCache> mockedSaltCache = 
mockStatic(SaltCache.class)) {
+
+            mockedRollerSession.when(() -> 
RollerSession.getRollerSession(request)).thenReturn(null);
+            mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);
+
+            filter.doFilter(request, response, chain);
+
+            verify(request, never()).setAttribute(eq("salt"), anyString());
+            verify(saltCache, never()).put(anyString(), anyString());
+            verify(chain).doFilter(request, response);
+        }
+    }
+
+    private static class TestUser extends User {
+        private final String id;
+
+        TestUser(String id) {
+            this.id = id;
+        }
+
+        public String getId() {
+            return id;
+        }
+    }
+}
diff --git 
a/app/src/test/java/org/apache/roller/weblogger/ui/core/filters/ValidateSaltFilterTest.java
 
b/app/src/test/java/org/apache/roller/weblogger/ui/core/filters/ValidateSaltFilterTest.java
new file mode 100644
index 000000000..e62c33a08
--- /dev/null
+++ 
b/app/src/test/java/org/apache/roller/weblogger/ui/core/filters/ValidateSaltFilterTest.java
@@ -0,0 +1,144 @@
+package org.apache.roller.weblogger.ui.core.filters;
+
+import org.apache.roller.weblogger.pojos.User;
+import org.apache.roller.weblogger.ui.core.RollerSession;
+import org.apache.roller.weblogger.ui.rendering.util.cache.SaltCache;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
+import org.mockito.MockedStatic;
+import org.mockito.MockitoAnnotations;
+
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.Mockito.*;
+
+public class ValidateSaltFilterTest {
+
+    private ValidateSaltFilter filter;
+
+    @Mock
+    private HttpServletRequest request;
+
+    @Mock
+    private HttpServletResponse response;
+
+    @Mock
+    private FilterChain chain;
+
+    @Mock
+    private RollerSession rollerSession;
+
+    @Mock
+    private SaltCache saltCache;
+
+    @BeforeEach
+    public void setUp() {
+        MockitoAnnotations.openMocks(this);
+        filter = new ValidateSaltFilter();
+    }
+
+    @Test
+    public void testDoFilterWithGetMethod() throws Exception {
+        when(request.getMethod()).thenReturn("GET");
+
+        filter.doFilter(request, response, chain);
+
+        verify(chain).doFilter(request, response);
+    }
+
+    @Test
+    public void testDoFilterWithPostMethodAndValidSalt() throws Exception {
+        try (MockedStatic<RollerSession> mockedRollerSession = 
mockStatic(RollerSession.class);
+             MockedStatic<SaltCache> mockedSaltCache = 
mockStatic(SaltCache.class)) {
+
+            mockedRollerSession.when(() -> 
RollerSession.getRollerSession(request)).thenReturn(rollerSession);
+            mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);
+
+            when(request.getMethod()).thenReturn("POST");
+            when(request.getServletPath()).thenReturn("/someurl");
+            when(request.getParameter("salt")).thenReturn("validSalt");
+            when(saltCache.get("validSalt")).thenReturn("userId");
+            when(rollerSession.getAuthenticatedUser()).thenReturn(new 
TestUser("userId"));
+
+            filter.doFilter(request, response, chain);
+
+            verify(chain).doFilter(request, response);
+            verify(saltCache).remove("validSalt");
+        }
+    }
+
+    @Test
+    public void testDoFilterWithPostMethodAndInvalidSalt() throws Exception {
+        try (MockedStatic<RollerSession> mockedRollerSession = 
mockStatic(RollerSession.class);
+             MockedStatic<SaltCache> mockedSaltCache = 
mockStatic(SaltCache.class)) {
+
+            mockedRollerSession.when(() -> 
RollerSession.getRollerSession(request)).thenReturn(rollerSession);
+            mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);
+
+            when(request.getMethod()).thenReturn("POST");
+            when(request.getServletPath()).thenReturn("/someurl");
+            when(request.getParameter("salt")).thenReturn("invalidSalt");
+            when(saltCache.get("invalidSalt")).thenReturn(null);
+
+            assertThrows(ServletException.class, () -> {
+                filter.doFilter(request, response, chain);
+            });
+        }
+    }
+
+    @Test
+    public void testDoFilterWithPostMethodAndMismatchedUserId() throws 
Exception {
+        try (MockedStatic<RollerSession> mockedRollerSession = 
mockStatic(RollerSession.class);
+             MockedStatic<SaltCache> mockedSaltCache = 
mockStatic(SaltCache.class)) {
+
+            mockedRollerSession.when(() -> 
RollerSession.getRollerSession(request)).thenReturn(rollerSession);
+            mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);
+
+            when(request.getMethod()).thenReturn("POST");
+            when(request.getServletPath()).thenReturn("/someurl");
+            when(request.getParameter("salt")).thenReturn("validSalt");
+            when(saltCache.get("validSalt")).thenReturn("differentUserId");
+            when(rollerSession.getAuthenticatedUser()).thenReturn(new 
TestUser("userId"));
+
+            assertThrows(ServletException.class, () -> {
+                filter.doFilter(request, response, chain);
+            });
+        }
+    }
+
+    @Test
+    public void testDoFilterWithPostMethodAndNullRollerSession() throws 
Exception {
+        try (MockedStatic<RollerSession> mockedRollerSession = 
mockStatic(RollerSession.class);
+             MockedStatic<SaltCache> mockedSaltCache = 
mockStatic(SaltCache.class)) {
+
+            mockedRollerSession.when(() -> 
RollerSession.getRollerSession(request)).thenReturn(null);
+            mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);
+
+            when(request.getMethod()).thenReturn("POST");
+            when(request.getServletPath()).thenReturn("/someurl");
+            when(request.getParameter("salt")).thenReturn("validSalt");
+            when(saltCache.get("validSalt")).thenReturn("");
+
+            filter.doFilter(request, response, chain);
+
+            verify(saltCache, never()).remove("validSalt");
+        }
+    }
+    private static class TestUser extends User {
+        private final String id;
+
+        TestUser(String id) {
+            this.id = id;
+        }
+
+        @Override
+        public String getId() {
+            return id;
+        }
+    }
+}

Reply via email to