This is an automated email from the ASF dual-hosted git repository. exceptionfactory pushed a commit to branch support/nifi-1.x in repository https://gitbox.apache.org/repos/asf/nifi.git
The following commit(s) were added to refs/heads/support/nifi-1.x by this push: new c091eade00 NIFI-11767 Refactored Groovy tests in nifi-web-error and nifi-web-security to Java c091eade00 is described below commit c091eade00fdffe27dfe9463c3b2e1b379b7f83c Author: dan-s1 <dsti...@gmail.com> AuthorDate: Mon Jul 3 12:46:35 2023 -0500 NIFI-11767 Refactored Groovy tests in nifi-web-error and nifi-web-security to Java This closes #7457 Signed-off-by: David Handermann <exceptionfact...@apache.org> (cherry picked from commit d24318cdb8003925f9a5411bf20094a02b0c084a) --- .../nifi/web/filter/CatchAllFilterTest.groovy | 133 ------- .../apache/nifi/web/filter/CatchAllFilterTest.java | 106 ++++++ .../nifi-web/nifi-web-security/pom.xml | 6 - .../web/security/ProxiedEntitiesUtilsTest.groovy | 393 --------------------- .../requests/ContentLengthFilterTest.groovy | 277 --------------- .../web/security/ProxiedEntitiesUtilsTest.java | 242 +++++++++++++ .../security/requests/ContentLengthFilterTest.java | 232 ++++++++++++ 7 files changed, 580 insertions(+), 809 deletions(-) diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-error/src/test/groovy/org/apache/nifi/web/filter/CatchAllFilterTest.groovy b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-error/src/test/groovy/org/apache/nifi/web/filter/CatchAllFilterTest.groovy deleted file mode 100644 index 43be5a9a70..0000000000 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-error/src/test/groovy/org/apache/nifi/web/filter/CatchAllFilterTest.groovy +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.web.filter - - -import org.junit.jupiter.api.BeforeAll -import org.junit.jupiter.api.Test -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -import javax.servlet.FilterChain -import javax.servlet.FilterConfig -import javax.servlet.RequestDispatcher -import javax.servlet.ServletContext -import javax.servlet.ServletRequest -import javax.servlet.ServletResponse -import javax.servlet.http.HttpServletRequest -import javax.servlet.http.HttpServletResponse - -class CatchAllFilterTest { - private static final Logger logger = LoggerFactory.getLogger(CatchAllFilterTest.class) - - @BeforeAll - static void setUpOnce() throws Exception { - logger.metaClass.methodMissing = { String name, args -> - logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}") - } - } - - private static String getValue(String parameterName, Map<String, String> params = [:]) { - params.containsKey(parameterName) ? params[parameterName] : "" - } - - @Test - void testInitShouldCallSuper() { - // Arrange - def EXPECTED_ALLOWED_CONTEXT_PATHS = ["/path1", "/path2"].join(", ") - def parameters = [allowedContextPaths: EXPECTED_ALLOWED_CONTEXT_PATHS] - FilterConfig mockFilterConfig = [ - getInitParameter : { String parameterName -> - return getValue(parameterName, parameters) - }, - getServletContext: { -> - [getInitParameter: { String parameterName -> - return getValue(parameterName, parameters) - }] as ServletContext - }] as FilterConfig - - CatchAllFilter caf = new CatchAllFilter() - - // Act - caf.init(mockFilterConfig) - logger.info("Allowed context paths: ${caf.getAllowedContextPaths()}") - - // Assert - assert caf.getAllowedContextPaths() == EXPECTED_ALLOWED_CONTEXT_PATHS - } - - @Test - void testShouldDoFilter() { - // Arrange - final String EXPECTED_ALLOWED_CONTEXT_PATHS = ["/path1", "/path2"].join(", ") - final String EXPECTED_FORWARD_PATH = "index.jsp" - final Map PARAMETERS = [ - allowedContextPaths: EXPECTED_ALLOWED_CONTEXT_PATHS, - forwardPath : EXPECTED_FORWARD_PATH - ] - - final String EXPECTED_CONTEXT_PATH = "" - - // Mock collaborators - FilterConfig mockFilterConfig = [ - getInitParameter : { String parameterName -> - return getValue(parameterName, PARAMETERS) - }, - getServletContext: { -> - [getInitParameter: { String parameterName -> - return getValue(parameterName, PARAMETERS) - }] as ServletContext - }] as FilterConfig - - // Local map to store request attributes - def requestAttributes = [:] - - // Local string to store resulting path - String forwardedRequestTo = "" - - final Map HEADERS = [ - "X-ProxyContextPath" : "", - "X-Forwarded-Context": "", - "X-Forwarded-Prefix" : ""] - - HttpServletRequest mockRequest = [ - getContextPath : { -> EXPECTED_CONTEXT_PATH }, - getHeader : { String headerName -> getValue(headerName, HEADERS) }, - setAttribute : { String attr, String value -> - requestAttributes[attr] = value - logger.mock("Set request attribute ${attr} to ${value}") - }, - getRequestDispatcher: { String path -> - [forward: { ServletRequest request, ServletResponse response -> - forwardedRequestTo = path - logger.mock("Forwarded request to ${path}") - }] as RequestDispatcher - }] as HttpServletRequest - HttpServletResponse mockResponse = [:] as HttpServletResponse - FilterChain mockFilterChain = [:] as FilterChain - - CatchAllFilter caf = new CatchAllFilter() - caf.init(mockFilterConfig) - logger.info("Allowed context paths: ${caf.getAllowedContextPaths()}") - - // Act - caf.doFilter(mockRequest, mockResponse, mockFilterChain) - - // Assert - assert forwardedRequestTo == EXPECTED_FORWARD_PATH - } -} diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-error/src/test/java/org/apache/nifi/web/filter/CatchAllFilterTest.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-error/src/test/java/org/apache/nifi/web/filter/CatchAllFilterTest.java new file mode 100644 index 0000000000..fa7851afbc --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-error/src/test/java/org/apache/nifi/web/filter/CatchAllFilterTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.web.filter; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.RequestDispatcher; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class CatchAllFilterTest { + @Mock + private ServletContext servletContext; + + @Mock + private FilterConfig filterConfig; + + @BeforeEach + public void setUp() { + when(filterConfig.getServletContext()).thenReturn(servletContext); + } + + @Test + public void testInitShouldCallSuper() throws ServletException { + String expectedAllowedContextPaths = getExpectedAllowedContextPaths(); + final Map<String, String> parameters = Collections.singletonMap("allowedContextPaths", getExpectedAllowedContextPaths()); + when(servletContext.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters)); + when(filterConfig.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters)); + + CatchAllFilter catchAllFilter = new CatchAllFilter(); + catchAllFilter.init(filterConfig); + + assertEquals(expectedAllowedContextPaths, catchAllFilter.getAllowedContextPaths()); + } + + @Test + public void testShouldDoFilter(@Mock HttpServletRequest request, @Mock RequestDispatcher requestDispatcher, + @Mock HttpServletResponse response, @Mock FilterChain filterChain ) throws ServletException, IOException { + final String expectedAllowedContextPaths = getExpectedAllowedContextPaths(); + final String expectedForwardPath = "index.jsp"; + final Map<String, String> parameters = new HashMap<>(); + parameters.put("allowedContextPaths", expectedAllowedContextPaths); + parameters.put("forwardPath", expectedForwardPath); + final Map<String, Object> requestAttributes = new HashMap<>(); + final String[] forwardedRequestTo = new String[1]; + final Map<String, String> headers = new HashMap<>(); + headers.put("X-ProxyContextPath", ""); + headers.put("X-Forwarded-Context", ""); + headers.put("X-Forwarded-Prefix", ""); + + when(servletContext.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters)); + when(filterConfig.getInitParameter(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), parameters)); + when(request.getHeader(anyString())).thenAnswer(invocation -> getValue(invocation.getArgument(0), headers)); + doAnswer(invocation -> requestAttributes.put(invocation.getArgument(0), invocation.getArgument(1))).when(request).setAttribute(anyString(), any()); + when(request.getRequestDispatcher(anyString())).thenAnswer(outerInvocation -> { + doAnswer(innerInvocation -> forwardedRequestTo[0] = outerInvocation.getArgument(0)).when(requestDispatcher).forward(any(), any()); + return requestDispatcher;}); + + CatchAllFilter catchAllFilter = new CatchAllFilter(); + catchAllFilter.init(filterConfig); + catchAllFilter.doFilter(request, response, filterChain); + + assertEquals(expectedForwardPath, forwardedRequestTo[0]); + } + + private String getExpectedAllowedContextPaths() { + return String.join(",", "/path1", "/path2"); + } + + private static String getValue(String parameterName, Map<String, String> params) { + return params.getOrDefault(parameterName, ""); + } +} diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/pom.xml b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/pom.xml index 835a727748..9b5ba864f5 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/pom.xml +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/pom.xml @@ -293,11 +293,5 @@ <artifactId>jetty-servlet</artifactId> <scope>test</scope> </dependency> - <dependency> - <groupId>org.codehaus.groovy</groupId> - <artifactId>groovy-json</artifactId> - <version>${nifi.groovy.version}</version> - <scope>test</scope> - </dependency> </dependencies> </project> diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/groovy/org/apache/nifi/web/security/ProxiedEntitiesUtilsTest.groovy b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/groovy/org/apache/nifi/web/security/ProxiedEntitiesUtilsTest.groovy deleted file mode 100644 index 9cdb4c791f..0000000000 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/groovy/org/apache/nifi/web/security/ProxiedEntitiesUtilsTest.groovy +++ /dev/null @@ -1,393 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.web.security - -import org.apache.nifi.authorization.user.NiFiUser -import org.junit.jupiter.api.BeforeAll -import org.junit.jupiter.api.Test -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -import java.nio.charset.StandardCharsets - -import static org.junit.jupiter.api.Assertions.assertEquals -import static org.junit.jupiter.api.Assertions.assertFalse -import static org.junit.jupiter.api.Assertions.assertNotEquals - -class ProxiedEntitiesUtilsTest { - private static final Logger logger = LoggerFactory.getLogger(ProxiedEntitiesUtils.class) - - private static final String SAFE_USER_NAME_JOHN = "jdoe" - private static final String SAFE_USER_DN_JOHN = "CN=${SAFE_USER_NAME_JOHN}, OU=Apache NiFi" - - private static final String SAFE_USER_NAME_PROXY_1 = "proxy1.nifi.apache.org" - private static final String SAFE_USER_DN_PROXY_1 = "CN=${SAFE_USER_NAME_PROXY_1}, OU=Apache NiFi" - - private static final String SAFE_USER_NAME_PROXY_2 = "proxy2.nifi.apache.org" - private static final String SAFE_USER_DN_PROXY_2 = "CN=${SAFE_USER_NAME_PROXY_2}, OU=Apache NiFi" - - private static - final String MALICIOUS_USER_NAME_JOHN = "${SAFE_USER_NAME_JOHN}, OU=Apache NiFi><CN=${SAFE_USER_NAME_PROXY_1}" - private static final String MALICIOUS_USER_DN_JOHN = "CN=${MALICIOUS_USER_NAME_JOHN}, OU=Apache NiFi" - - private static - final String MALICIOUS_USER_NAME_JOHN_ESCAPED = sanitizeDn(MALICIOUS_USER_NAME_JOHN) - - private static final String UNICODE_DN_1 = "CN=Алйс, OU=Apache NiFi" - private static final String UNICODE_DN_1_ENCODED = "<" + base64Encode(UNICODE_DN_1) + ">" - - private static final String UNICODE_DN_2 = "CN=Боб, OU=Apache NiFi" - private static final String UNICODE_DN_2_ENCODED = "<" + base64Encode(UNICODE_DN_2) + ">" - - @BeforeAll - static void setUpOnce() throws Exception { - logger.metaClass.methodMissing = { String name, args -> - logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}") - } - } - - private static String sanitizeDn(String dn = "") { - dn.replaceAll(/>/, '\\\\>').replaceAll('<', '\\\\<') - } - - private static String base64Encode(String dn = "") { - return Base64.getEncoder().encodeToString(dn.getBytes(StandardCharsets.UTF_8)) - } - - private static String printUnicodeString(final String raw) { - StringBuilder sb = new StringBuilder() - for (int i = 0; i < raw.size(); i++) { - int codePoint = Character.codePointAt(raw, i) - int charCount = Character.charCount(codePoint) - if (charCount > 1) { - i += charCount - 1 // 2. - if (i >= raw.length()) { - throw new IllegalArgumentException("Code point indicated more characters than available") - } - } - sb.append(String.format("\\u%04x ", codePoint)) - } - return sb.toString().trim() - } - - @Test - void testSanitizeDnShouldHandleFuzzing() throws Exception { - // Arrange - final String DESIRED_NAME = SAFE_USER_NAME_JOHN - logger.info(" Desired name: ${DESIRED_NAME} | ${printUnicodeString(DESIRED_NAME)}") - - // Contains various attempted >< escapes, trailing NULL, and BACKSPACE + 'n' - final List MALICIOUS_NAMES = [MALICIOUS_USER_NAME_JOHN, - SAFE_USER_NAME_JOHN + ">", - SAFE_USER_NAME_JOHN + "><>", - SAFE_USER_NAME_JOHN + "\\>", - SAFE_USER_NAME_JOHN + "\u003e", - SAFE_USER_NAME_JOHN + "\u005c\u005c\u003e", - SAFE_USER_NAME_JOHN + "\u0000", - SAFE_USER_NAME_JOHN + "\u0008n"] - - // Act - MALICIOUS_NAMES.each { String name -> - logger.info(" Raw name: ${name} | ${printUnicodeString(name)}") - String sanitizedName = ProxiedEntitiesUtils.sanitizeDn(name) - logger.info("Sanitized name: ${sanitizedName} | ${printUnicodeString(sanitizedName)}") - - // Assert - assertNotEquals(DESIRED_NAME, sanitizedName) - } - } - - @Test - void testShouldFormatProxyDn() throws Exception { - // Arrange - final String DN = SAFE_USER_DN_JOHN - logger.info(" Provided proxy DN: ${DN}") - - final String EXPECTED_PROXY_DN = "<${DN}>" - logger.info(" Expected proxy DN: ${EXPECTED_PROXY_DN}") - - // Act - String forjohnedProxyDn = ProxiedEntitiesUtils.formatProxyDn(DN) - logger.info("Forjohned proxy DN: ${forjohnedProxyDn}") - - // Assert - assertEquals(EXPECTED_PROXY_DN, forjohnedProxyDn) - } - - @Test - void testFormatProxyDnShouldHandleMaliciousInput() throws Exception { - // Arrange - final String DN = MALICIOUS_USER_DN_JOHN - logger.info(" Provided proxy DN: ${DN}") - - final String SANITIZED_DN = sanitizeDn(DN) - final String EXPECTED_PROXY_DN = "<${SANITIZED_DN}>" - logger.info(" Expected proxy DN: ${EXPECTED_PROXY_DN}") - - // Act - String forjohnedProxyDn = ProxiedEntitiesUtils.formatProxyDn(DN) - logger.info("Forjohned proxy DN: ${forjohnedProxyDn}") - - // Assert - assertEquals(EXPECTED_PROXY_DN, forjohnedProxyDn) - } - - @Test - void testGetProxiedEntitiesChain() throws Exception { - // Arrange - String[] input = [SAFE_USER_NAME_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2] - final String expectedOutput = "<${SAFE_USER_NAME_JOHN}><${SAFE_USER_DN_PROXY_1}><${SAFE_USER_DN_PROXY_2}>" - - // Act - def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input) - - // Assert - assertEquals(expectedOutput, output) - } - - @Test - void testGetProxiedEntitiesChainShouldHandleMaliciousInput() throws Exception { - // Arrange - String[] input = [MALICIOUS_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2] - final String expectedOutput = "<${sanitizeDn(MALICIOUS_USER_DN_JOHN)}><${SAFE_USER_DN_PROXY_1}><${SAFE_USER_DN_PROXY_2}>" - - // Act - def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input) - - // Assert - assertEquals(expectedOutput, output) - } - - @Test - void testGetProxiedEntitiesChainShouldEncodeUnicode() throws Exception { - // Arrange - String[] input = [SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2] - final String expectedOutput = "<${SAFE_USER_NAME_JOHN}><${UNICODE_DN_1_ENCODED}><${UNICODE_DN_2_ENCODED}>" - - // Act - def output = ProxiedEntitiesUtils.getProxiedEntitiesChain(input) - - // Assert - assertEquals(expectedOutput, output) - } - - @Test - void testFormatProxyDnShouldEncodeNonAsciiCharacters() throws Exception { - // Arrange - logger.info(" Provided DN: ${UNICODE_DN_1}") - final String expectedFormattedDn = "<${UNICODE_DN_1_ENCODED}>" - logger.info(" Expected DN: expected") - - // Act - String formattedDn = ProxiedEntitiesUtils.formatProxyDn(UNICODE_DN_1) - logger.info("Formatted DN: ${formattedDn}") - - // Assert - assertEquals(expectedFormattedDn, formattedDn) - } - - @Test - void testShouldBuildProxyChain() throws Exception { - // Arrange - def mockProxy1 = [getIdentity: { -> SAFE_USER_NAME_PROXY_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser - def mockJohn = [getIdentity: { -> SAFE_USER_NAME_JOHN }, getChain: { -> mockProxy1 }, isAnonymous: { -> false}] as NiFiUser - - // Act - String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockJohn) - logger.info("Proxied entities chain: ${proxiedEntitiesChain}") - - // Assert - assertEquals("<${SAFE_USER_NAME_JOHN}><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain) - } - - @Test - void testBuildProxyChainFromNullUserShouldBeAnonymous() throws Exception { - // Arrange - - // Act - String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(null) - logger.info("Proxied entities chain: ${proxiedEntitiesChain}") - - // Assert - assertEquals("<>", proxiedEntitiesChain) - } - - @Test - void testBuildProxyChainFromAnonymousUserShouldBeAnonymous() throws Exception { - // Arrange - def mockProxy1 = [getIdentity: { -> SAFE_USER_NAME_PROXY_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser - def mockAnonymous = [getIdentity: { -> "anonymous" }, getChain: { -> mockProxy1 }, isAnonymous: { -> true}] as NiFiUser - - // Act - String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockAnonymous) - logger.info("Proxied entities chain: ${proxiedEntitiesChain}") - - // Assert - assertEquals("<><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain) - } - - @Test - void testBuildProxyChainShouldHandleUnicode() throws Exception { - // Arrange - def mockProxy1 = [getIdentity: { -> UNICODE_DN_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser - def mockJohn = [getIdentity: { -> SAFE_USER_NAME_JOHN }, getChain: { -> mockProxy1 }, isAnonymous: { -> false}] as NiFiUser - - // Act - String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockJohn) - logger.info("Proxied entities chain: ${proxiedEntitiesChain}") - - // Assert - assertEquals("<${SAFE_USER_NAME_JOHN}><${UNICODE_DN_1_ENCODED}>" as String, proxiedEntitiesChain) - } - - @Test - void testBuildProxyChainShouldHandleMaliciousUser() throws Exception { - // Arrange - def mockProxy1 = [getIdentity: { -> SAFE_USER_NAME_PROXY_1 }, getChain: { -> null }, isAnonymous: { -> false}] as NiFiUser - def mockJohn = [getIdentity: { -> MALICIOUS_USER_NAME_JOHN }, getChain: { -> mockProxy1 }, isAnonymous: { -> false}] as NiFiUser - - // Act - String proxiedEntitiesChain = ProxiedEntitiesUtils.buildProxiedEntitiesChainString(mockJohn) - logger.info("Proxied entities chain: ${proxiedEntitiesChain}") - - // Assert - assertEquals("<${MALICIOUS_USER_NAME_JOHN_ESCAPED}><${SAFE_USER_NAME_PROXY_1}>" as String, proxiedEntitiesChain) - } - - @Test - void testShouldTokenizeProxiedEntitiesChainWithUserNames() throws Exception { - // Arrange - final List NAMES = [SAFE_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2] - final String RAW_PROXY_CHAIN = "<${NAMES.join("><")}>" - logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}") - - // Act - def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN) - logger.info("Tokenized proxy chain: ${tokenizedNames}") - - // Assert - assertEquals(NAMES, tokenizedNames) - } - - @Test - void testShouldTokenizeAnonymous() throws Exception { - // Arrange - final List NAMES = [""] - final String RAW_PROXY_CHAIN = "<>" - logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}") - - // Act - def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN) - logger.info("Tokenized proxy chain: ${tokenizedNames}") - - // Assert - assertEquals(NAMES, tokenizedNames) - } - - @Test - void testShouldTokenizeDoubleAnonymous() throws Exception { - // Arrange - final List NAMES = ["", ""] - final String RAW_PROXY_CHAIN = "<><>" - logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}") - - // Act - def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN) - logger.info("Tokenized proxy chain: ${tokenizedNames}") - - // Assert - assertEquals(NAMES, tokenizedNames) - } - - @Test - void testShouldTokenizeNestedAnonymous() throws Exception { - // Arrange - final List NAMES = [SAFE_USER_DN_PROXY_1, "", SAFE_USER_DN_PROXY_2] - final String RAW_PROXY_CHAIN = "<${SAFE_USER_DN_PROXY_1}><><${SAFE_USER_DN_PROXY_2}>" - logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}") - - // Act - def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN) - logger.info("Tokenized proxy chain: ${tokenizedNames}") - - // Assert - assertEquals(NAMES, tokenizedNames) - } - - @Test - void testShouldTokenizeProxiedEntitiesChainWithDNs() throws Exception { - // Arrange - final List DNS = [SAFE_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2] - final String RAW_PROXY_CHAIN = "<${DNS.join("><")}>" - logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}") - - // Act - def tokenizedDns = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN) - logger.info("Tokenized proxy chain: ${tokenizedDns.collect { "\"${it}\"" }}") - - // Assert - assertEquals(DNS, tokenizedDns) - } - - @Test - void testShouldTokenizeProxiedEntitiesChainWithAnonymousUser() throws Exception { - // Arrange - final List NAMES = ["", SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2] - final String RAW_PROXY_CHAIN = "<${NAMES.join("><")}>" - logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}") - - // Act - def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN) - logger.info("Tokenized proxy chain: ${tokenizedNames}") - - // Assert - assertEquals(NAMES, tokenizedNames) - } - - @Test - void testTokenizeProxiedEntitiesChainShouldHandleMaliciousUser() throws Exception { - // Arrange - final List NAMES = [MALICIOUS_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2] - final String RAW_PROXY_CHAIN = "<${NAMES.collect { sanitizeDn(it) }.join("><")}>" - logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}") - - // Act - def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN) - logger.info("Tokenized proxy chain: ${tokenizedNames.collect { "\"${it}\"" }}") - - // Assert - assertEquals(NAMES, tokenizedNames) - assertEquals(NAMES.size(), tokenizedNames.size()) - assertFalse(tokenizedNames.contains(SAFE_USER_NAME_JOHN)) - } - - @Test - void testTokenizeProxiedEntitiesChainShouldDecodeNonAsciiValues() throws Exception { - // Arrange - final String RAW_PROXY_CHAIN = "<${SAFE_USER_NAME_JOHN}><${UNICODE_DN_1_ENCODED}><${UNICODE_DN_2_ENCODED}>" - final List TOKENIZED_NAMES = [SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2] - logger.info(" Provided proxy chain: ${RAW_PROXY_CHAIN}") - - // Act - def tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(RAW_PROXY_CHAIN) - logger.info("Tokenized proxy chain: ${tokenizedNames.collect { "\"${it}\"" }}") - - // Assert - assertEquals(TOKENIZED_NAMES, tokenizedNames) - assertEquals(TOKENIZED_NAMES.size(), tokenizedNames.size()) - } -} diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/groovy/org/apache/nifi/web/security/requests/ContentLengthFilterTest.groovy b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/groovy/org/apache/nifi/web/security/requests/ContentLengthFilterTest.groovy deleted file mode 100644 index 82eb941811..0000000000 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/groovy/org/apache/nifi/web/security/requests/ContentLengthFilterTest.groovy +++ /dev/null @@ -1,277 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.web.security.requests - -import org.apache.commons.lang3.StringUtils -import org.apache.nifi.stream.io.StreamUtils -import org.eclipse.jetty.server.LocalConnector -import org.eclipse.jetty.server.Server -import org.eclipse.jetty.servlet.FilterHolder -import org.eclipse.jetty.servlet.ServletContextHandler -import org.eclipse.jetty.servlet.ServletHolder -import org.junit.jupiter.api.AfterEach -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -import javax.servlet.DispatcherType -import javax.servlet.ServletException -import javax.servlet.ServletInputStream -import javax.servlet.http.HttpServlet -import javax.servlet.http.HttpServletRequest -import javax.servlet.http.HttpServletResponse -import java.util.concurrent.TimeUnit - -import static org.junit.jupiter.api.Assertions.assertFalse -import static org.junit.jupiter.api.Assertions.assertTrue - -class ContentLengthFilterTest { - private static final Logger logger = LoggerFactory.getLogger(ContentLengthFilterTest.class) - - private static final int MAX_CONTENT_LENGTH = 1000 - private static final int SERVER_IDLE_TIMEOUT = 2500 // only one request needed + value large enough for slow systems - private static final String POST_REQUEST = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s" - private static final String FORM_REQUEST = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\nContent-Type: application/x-www-form-urlencoded\r\nAccept-Charset: UTF-8\r\n\r\n%s" - public static final int FORM_CONTENT_SIZE = 128 - - // These variables hold data for content small enough to be allowed - private static final int SMALL_CLAIM_SIZE_BYTES = 150 - private static final String SMALL_PAYLOAD = "1" * SMALL_CLAIM_SIZE_BYTES - - // These variables hold data for content too large to be allowed - private static final int LARGE_CLAIM_SIZE_BYTES = 2000 - private static final String LARGE_PAYLOAD = "1" * LARGE_CLAIM_SIZE_BYTES - - private Server serverUnderTest - private LocalConnector localConnector - private ServletContextHandler contextUnderTest - - @BeforeEach - void setUp() { - createSimpleReadServer() - } - - @AfterEach - void tearDown() { - stopServer() - } - - void stopServer() throws Exception { - if (serverUnderTest && serverUnderTest.isRunning()) { - serverUnderTest.stop() - } - } - - private void configureAndStartServer(HttpServlet servlet, int maxFormContentSize) throws Exception { - serverUnderTest = new Server() - localConnector = new LocalConnector(serverUnderTest) - localConnector.setIdleTimeout(SERVER_IDLE_TIMEOUT) - serverUnderTest.addConnector(localConnector) - - contextUnderTest = new ServletContextHandler(serverUnderTest, "/") - if (maxFormContentSize > 0) { - contextUnderTest.setMaxFormContentSize(maxFormContentSize) - } - contextUnderTest.addServlet(new ServletHolder(servlet), "/*") - - // This only adds the ContentLengthFilter if a valid maxFormContentSize is not provided - if (maxFormContentSize < 0) { - FilterHolder holder = contextUnderTest.addFilter(ContentLengthFilter.class, "/*", EnumSet.of(DispatcherType.REQUEST) as EnumSet<DispatcherType>) - holder.setInitParameter(ContentLengthFilter.MAX_LENGTH_INIT_PARAM, String.valueOf(MAX_CONTENT_LENGTH)) - } - serverUnderTest.start() - } - - /** - * Initializes a server which consumes any provided request input stream and returns HTTP 200. It has no - * {@code maxFormContentSize}, so the {@link ContentLengthFilter} is applied. The response contains a header and the - * response body indicating the total number of request content bytes read. - * - * @throws Exception if there is a problem setting up the server - */ - private void createSimpleReadServer() throws Exception { - HttpServlet mockServlet = [ - doPost: { HttpServletRequest req, HttpServletResponse resp -> - byte[] byteBuffer = new byte[2048] - int bytesRead = StreamUtils.fillBuffer(req.getInputStream(), byteBuffer, false) - resp.setHeader("Bytes-Read", bytesRead as String) - resp.setStatus(HttpServletResponse.SC_OK) - resp.getWriter().write("Read ${bytesRead} bytes of request input") - } - ] as HttpServlet - configureAndStartServer(mockServlet, -1) - } - - private static void logResponse(String response, String s = "Response: ") { - String responseId = String.valueOf(System.currentTimeMillis() % 100) - final String delimiterLine = "\n-----" + responseId + "-----\n" - String formattedResponse = s + delimiterLine + response + delimiterLine - logger.info(formattedResponse) - } - - @Test - void testRequestsWithMissingContentLengthHeader() throws Exception { - // This shows that the ContentLengthFilter allows a request that does not have a content-length header. - String response = localConnector.getResponse("POST / HTTP/1.0\r\n\r\n") - assertFalse(StringUtils.containsIgnoreCase(response, "411 Length Required")) - } - - /** - * This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends more than - * the max. - */ - @Test - void testShouldRejectRequestWithLongContentLengthHeader() throws Exception { - // Arrange - final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD) - logger.info("Making request with CL: ${LARGE_CLAIM_SIZE_BYTES} and actual length: ${LARGE_PAYLOAD.length()}") - - // Act - String response = localConnector.getResponse(requestBody) - logResponse(response) - - // Assert - assertTrue(response.contains("413 Payload Too Large")) - } - - /** - * This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends less than - * the claim. - */ - @Test - void testShouldRejectRequestWithLongContentLengthHeaderAndSmallPayload() throws Exception { - // Arrange - String incompletePayload = "1" * (SMALL_CLAIM_SIZE_BYTES / 2) - final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, incompletePayload) - logger.info("Making request with CL: ${LARGE_CLAIM_SIZE_BYTES} and actual length: ${incompletePayload.length()}") - - // Act - String response = localConnector.getResponse(requestBody) - logResponse(response) - - // Assert - assertTrue(response.contains("413 Payload Too Large")) - } - - /** - * This shows that the ContentLengthFilter <em>allows</em> a request when the client claims less - * than the max + sends more than the max, but restricts the request body to the stated content - * length size. - */ - @Test - void testShouldRejectRequestWithSmallContentLengthHeaderAndLargePayload() throws Exception { - // Arrange - final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, LARGE_PAYLOAD) - logger.info("Making request with CL: ${SMALL_CLAIM_SIZE_BYTES} and actual length: ${LARGE_PAYLOAD.length()}") - - // Act - String response = localConnector.getResponse(requestBody) - logResponse(response) - - // Assert - assertTrue(response.contains("200")) - assertTrue(response.contains("Bytes-Read: " + SMALL_CLAIM_SIZE_BYTES)) - assertTrue(response.contains("Read " + SMALL_CLAIM_SIZE_BYTES + " bytes")) - } - - /** - * This shows that the server times out when the client claims less than the max + sends less than the max + sends - * less than it claims to send. - */ - @Test - void testShouldTimeoutRequestWithSmallContentLengthHeaderAndSmallerPayload() throws Exception { - // Arrange - String smallerPayload = SMALL_PAYLOAD[0..(SMALL_PAYLOAD.length() / 2)] - final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, smallerPayload) - logger.info("Making request with CL: ${SMALL_CLAIM_SIZE_BYTES} and actual length: ${smallerPayload.length()}") - - // Act - String response = localConnector.getResponse(requestBody, 500, TimeUnit.MILLISECONDS) - logResponse(response) - - // Assert - assertTrue(response.contains("500 Server Error")) - assertTrue(response.contains("Timeout")) - } - - @Test - void testFilterShouldAllowSiteToSiteTransfer() throws Exception { - // Arrange - final String SITE_TO_SITE_POST_REQUEST = "POST /nifi-api/data-transfer/input-ports HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s" - - final String siteToSiteRequest = String.format(SITE_TO_SITE_POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD) - logResponse(siteToSiteRequest, "Request: ") - - // Act - String response = localConnector.getResponse(siteToSiteRequest) - logResponse(response) - - // Assert - assertTrue(response.contains("200 OK")) - } - - @Test - void testJettyMaxFormSize() throws Exception { - // This shows that the jetty server option for 'maxFormContentSize' is insufficient for our needs because it - // catches requests like this: - - // Configure the server but do not apply the CLF because the FORM_CONTENT_SIZE > 0 - configureAndStartServer(new HttpServlet() { - @Override - protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - try { - req.getParameterMap() - ServletInputStream input = req.getInputStream() - int count = 0 - while (!input.isFinished()) { - input.read() - count += 1 - } - final int FORM_LIMIT_BYTES = FORM_CONTENT_SIZE + "a=\n".length() - if (count > FORM_LIMIT_BYTES) { - logger.warn("Bytes read ({}) is larger than the limit ({})", count, FORM_LIMIT_BYTES) - resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Should not reach this code.") - } else { - logger.warn("Bytes read ({}) is less than or equal to the limit ({})", count, FORM_LIMIT_BYTES) - resp.sendError(HttpServletResponse.SC_EXPECTATION_FAILED, "Read Too Many Bytes") - } - } catch (final Exception e) { - // This is the jetty context returning a 400 from the maxFormContentSize setting: - if (StringUtils.containsIgnoreCase(e.getCause().toString(), "Form is larger than max length " + FORM_CONTENT_SIZE)) { - logger.warn("Exception thrown by input stream: ", e) - resp.sendError(HttpServletResponse.SC_REQUEST_ENTITY_TOO_LARGE, "Payload Too Large") - } else { - logger.warn("Exception thrown by input stream: ", e) - resp.sendError(HttpServletResponse.SC_FORBIDDEN, "Should not reach this code, either.") - } - } - } - }, FORM_CONTENT_SIZE) - - // Test to catch a form submission that exceeds the FORM_CONTENT_SIZE limit - String form = "a=" + "1" * FORM_CONTENT_SIZE - String response = localConnector.getResponse(String.format(FORM_REQUEST, form.length(), form)) - logResponse(response) - assertTrue(response.contains("413 Payload Too Large")) - - - // But it does not catch requests like this: - response = localConnector.getResponse(String.format(POST_REQUEST, form.length(), form + form)) - assertTrue(response.contains("417 Read Too Many Bytes")) - } -} \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/ProxiedEntitiesUtilsTest.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/ProxiedEntitiesUtilsTest.java new file mode 100644 index 0000000000..e245ea861c --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/ProxiedEntitiesUtilsTest.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.web.security; + +import org.apache.nifi.authorization.user.NiFiUser; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class ProxiedEntitiesUtilsTest { + private static final String SAFE_USER_NAME_JOHN = "jdoe"; + private static final String SAFE_USER_DN_JOHN = "CN=" + SAFE_USER_NAME_JOHN + ", OU=Apache NiFi"; + private static final String SAFE_USER_NAME_PROXY_1 = "proxy1.nifi.apache.org"; + private static final String SAFE_USER_DN_PROXY_1 = "CN=" + SAFE_USER_NAME_PROXY_1 + ", OU=Apache NiFi"; + private static final String SAFE_USER_NAME_PROXY_2 = "proxy2.nifi.apache.org"; + private static final String SAFE_USER_DN_PROXY_2 = "CN=" + SAFE_USER_NAME_PROXY_2 + ", OU=Apache NiFi"; + private static final String MALICIOUS_USER_NAME_JOHN = SAFE_USER_NAME_JOHN + ", OU=Apache NiFi><CN=" + SAFE_USER_NAME_PROXY_1; + private static final String MALICIOUS_USER_DN_JOHN = "CN=" + MALICIOUS_USER_NAME_JOHN + ", OU=Apache NiFi"; + private static final String MALICIOUS_USER_NAME_JOHN_ESCAPED = sanitizeDn(MALICIOUS_USER_NAME_JOHN); + private static final String UNICODE_DN_1 = "CN=Алйс, OU=Apache NiFi"; + private static final String UNICODE_DN_1_ENCODED = "<" + base64Encode(UNICODE_DN_1) + ">"; + private static final String UNICODE_DN_2 = "CN=Боб, OU=Apache NiFi"; + private static final String UNICODE_DN_2_ENCODED = "<" + base64Encode(UNICODE_DN_2) + ">"; + private static final String ANONYMOUS_USER = ""; + private static final String ANONYMOUS_PROXIED_ENTITY_CHAIN = "<>"; + + private static String sanitizeDn(String dn) { + return dn.replaceAll(">", "\\\\>").replaceAll("<", "\\\\<"); + } + + private static String base64Encode(String dn) { + return Base64.getEncoder().encodeToString(dn.getBytes(StandardCharsets.UTF_8)); + } + + @ParameterizedTest + @MethodSource("getMaliciousNames" ) + public void testSanitizeDnShouldHandleFuzzing(String maliciousName) { + assertNotEquals(formatDn(SAFE_USER_NAME_JOHN), ProxiedEntitiesUtils.formatProxyDn(maliciousName)); + } + + // Contains various attempted >< escapes, trailing NULL, and BACKSPACE + 'n' + private static List<String> getMaliciousNames() { + return Arrays.asList(MALICIOUS_USER_NAME_JOHN, + SAFE_USER_NAME_JOHN + ">", + SAFE_USER_NAME_JOHN + "><>", + SAFE_USER_NAME_JOHN + "\\>", + SAFE_USER_NAME_JOHN + "\u003e", + SAFE_USER_NAME_JOHN + "\u005c\u005c\u003e", + SAFE_USER_NAME_JOHN + "\u0000", + SAFE_USER_NAME_JOHN + "\u0008n"); + } + + @Test + public void testShouldFormatProxyDn() { + assertEquals(formatDn(SAFE_USER_DN_JOHN), ProxiedEntitiesUtils.formatProxyDn(SAFE_USER_DN_JOHN)); + } + + @Test + public void testFormatProxyDnShouldHandleMaliciousInput() { + assertEquals(formatSanitizedDn(MALICIOUS_USER_DN_JOHN), ProxiedEntitiesUtils.formatProxyDn(MALICIOUS_USER_DN_JOHN)); + } + + @Test + public void testGetProxiedEntitiesChain() { + String[] input = new String [] {SAFE_USER_NAME_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2}; + assertEquals(formatDns(input), ProxiedEntitiesUtils.getProxiedEntitiesChain(input)); + } + + @Test + public void testGetProxiedEntitiesChainShouldHandleMaliciousInput() { + final String expectedOutput = formatSanitizedDn(MALICIOUS_USER_DN_JOHN) + formatDns(SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2); + assertEquals(expectedOutput, ProxiedEntitiesUtils.getProxiedEntitiesChain(MALICIOUS_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2)); + } + + @Test + public void testGetProxiedEntitiesChainShouldEncodeUnicode() { + assertEquals(formatDns(SAFE_USER_NAME_JOHN, UNICODE_DN_1_ENCODED, UNICODE_DN_2_ENCODED), + ProxiedEntitiesUtils.getProxiedEntitiesChain(SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2)); + } + + @Test + public void testFormatProxyDnShouldEncodeNonAsciiCharacters() { + assertEquals(formatDn(UNICODE_DN_1_ENCODED), ProxiedEntitiesUtils.formatProxyDn(UNICODE_DN_1)); + } + + @Test + public void testShouldBuildProxyChain(@Mock NiFiUser proxy1, @Mock NiFiUser john) { + when(proxy1.getIdentity()).thenReturn(SAFE_USER_NAME_PROXY_1); + when(proxy1.getChain()).thenReturn(null); + when(proxy1.isAnonymous()).thenReturn(false); + when(john.getIdentity()).thenReturn(SAFE_USER_NAME_JOHN); + when(john.getChain()).thenReturn(proxy1); + when(john.isAnonymous()).thenReturn(false); + + assertEquals(formatDns(SAFE_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(john)); + } + + @Test + public void testBuildProxyChainFromNullUserShouldBeAnonymous() { + assertEquals(ANONYMOUS_PROXIED_ENTITY_CHAIN, ProxiedEntitiesUtils.buildProxiedEntitiesChainString(null)); + } + + @Test + public void testBuildProxyChainFromAnonymousUserShouldBeAnonymous(@Mock NiFiUser proxy1, @Mock NiFiUser anonymous) { + when(proxy1.getIdentity()).thenReturn(SAFE_USER_NAME_PROXY_1); + when(proxy1.getChain()).thenReturn(null); + when(proxy1.isAnonymous()).thenReturn(false); + when(anonymous.getChain()).thenReturn(proxy1); + when(anonymous.isAnonymous()).thenReturn(true); + + assertEquals(formatDns(ANONYMOUS_USER, SAFE_USER_NAME_PROXY_1), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(anonymous)); + } + + @Test + public void testBuildProxyChainShouldHandleUnicode(@Mock NiFiUser proxy1, @Mock NiFiUser john) { + when(proxy1.getIdentity()).thenReturn(UNICODE_DN_1); + when(proxy1.getChain()).thenReturn(null); + when(proxy1.isAnonymous()).thenReturn(false); + when(john.getIdentity()).thenReturn(SAFE_USER_NAME_JOHN); + when(john.getChain()).thenReturn(proxy1); + when(john.isAnonymous()).thenReturn(false); + + assertEquals(formatDns(SAFE_USER_NAME_JOHN, UNICODE_DN_1_ENCODED), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(john)); + } + + @Test + public void testBuildProxyChainShouldHandleMaliciousUser(@Mock NiFiUser proxy1, @Mock NiFiUser john) { + when(proxy1.getIdentity()).thenReturn(SAFE_USER_NAME_PROXY_1); + when(proxy1.getChain()).thenReturn(null); + when(proxy1.isAnonymous()).thenReturn(false); + when(john.getIdentity()).thenReturn(MALICIOUS_USER_NAME_JOHN); + when(john.getChain()).thenReturn(proxy1); + when(john.isAnonymous()).thenReturn(false); + + assertEquals(formatDns(MALICIOUS_USER_NAME_JOHN_ESCAPED, SAFE_USER_NAME_PROXY_1), ProxiedEntitiesUtils.buildProxiedEntitiesChainString(john)); + } + + @Test + public void testShouldTokenizeProxiedEntitiesChainWithUserNames() { + final List<String> names = Arrays.asList(SAFE_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2); + final String rawProxyChain = formatDns(names.toArray(new String[0])); + + assertEquals(names, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain)); + } + + @Test + public void testShouldTokenizeAnonymous() { + assertEquals(Collections.singletonList(ANONYMOUS_USER), ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(ANONYMOUS_PROXIED_ENTITY_CHAIN)); + } + + @Test + public void testShouldTokenizeDoubleAnonymous() { + assertEquals(Arrays.asList(ANONYMOUS_USER, ANONYMOUS_USER), ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(ANONYMOUS_PROXIED_ENTITY_CHAIN.repeat(2))); + } + + @Test + public void testShouldTokenizeNestedAnonymous() { + final List<String> names = Arrays.asList(SAFE_USER_DN_PROXY_1, ANONYMOUS_USER, SAFE_USER_DN_PROXY_2); + final String rawProxyChain = formatDns(names.toArray(new String [0])); + + assertEquals(names, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain)); + } + + @Test + public void testShouldTokenizeProxiedEntitiesChainWithDNs() { + final List<String> dns = Arrays.asList(SAFE_USER_DN_JOHN, SAFE_USER_DN_PROXY_1, SAFE_USER_DN_PROXY_2); + final String rawProxyChain = formatDns(dns.toArray(new String[0])); + + assertEquals(dns, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain)); + } + + @Test + public void testShouldTokenizeProxiedEntitiesChainWithAnonymousUser() { + final List<String> names = Arrays.asList(ANONYMOUS_USER, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2); + final String rawProxyChain = formatDns(names.toArray(new String[0])); + + assertEquals(names, ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain)); + } + + @Test + public void testTokenizeProxiedEntitiesChainShouldHandleMaliciousUser() { + final List<String> names = Arrays.asList(MALICIOUS_USER_NAME_JOHN, SAFE_USER_NAME_PROXY_1, SAFE_USER_NAME_PROXY_2); + final String rawProxyChain = names.stream() + .map(this::formatSanitizedDn) + .collect(Collectors.joining()); + List<String> tokenizedNames = ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(rawProxyChain); + + assertEquals(names, tokenizedNames); + assertFalse(tokenizedNames.contains(SAFE_USER_NAME_JOHN)); + } + + @Test + public void testTokenizeProxiedEntitiesChainShouldDecodeNonAsciiValues() { + List<String> tokenizedNames = + ProxiedEntitiesUtils.tokenizeProxiedEntitiesChain(formatDns(SAFE_USER_NAME_JOHN, UNICODE_DN_1_ENCODED, UNICODE_DN_2_ENCODED)); + + assertEquals(Arrays.asList(SAFE_USER_NAME_JOHN, UNICODE_DN_1, UNICODE_DN_2), tokenizedNames); + } + + private String formatSanitizedDn(String dn) { + return formatDn((sanitizeDn(dn))); + } + + private String formatDn(String dn) { + return formatDns(dn); + } + + private String formatDns(String...dns) { + return Arrays.stream(dns) + .collect(Collectors.joining("><", "<", ">")); + } +} diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/requests/ContentLengthFilterTest.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/requests/ContentLengthFilterTest.java new file mode 100644 index 0000000000..f34268d8e9 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-web/nifi-web-security/src/test/java/org/apache/nifi/web/security/requests/ContentLengthFilterTest.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.web.security.requests; + +import org.apache.commons.lang3.StringUtils; +import org.apache.nifi.stream.io.StreamUtils; +import org.eclipse.jetty.server.LocalConnector; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.FilterHolder; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +import javax.servlet.DispatcherType; +import javax.servlet.ServletInputStream; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.EnumSet; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@ExtendWith(MockitoExtension.class) +class ContentLengthFilterTest { + private static final String POST_REQUEST = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s"; + public static final int FORM_CONTENT_SIZE = 128; + + // These variables hold data for content small enough to be allowed + private static final int SMALL_CLAIM_SIZE_BYTES = 150; + private static final String SMALL_PAYLOAD = "1".repeat(SMALL_CLAIM_SIZE_BYTES); + + // These variables hold data for content too large to be allowed + private static final int LARGE_CLAIM_SIZE_BYTES = 2000; + private static final String LARGE_PAYLOAD = "1".repeat(LARGE_CLAIM_SIZE_BYTES); + + private Server serverUnderTest; + private LocalConnector localConnector; + + @BeforeEach + public void setUp() throws Exception { + createSimpleReadServer(); + } + + @AfterEach + public void tearDown() throws Exception { + stopServer(); + } + + @Test + public void testRequestsWithMissingContentLengthHeader() throws Exception { + // This shows that the ContentLengthFilter allows a request that does not have a content-length header. + String response = localConnector.getResponse("POST / HTTP/1.0\r\n\r\n"); + assertFalse(StringUtils.containsIgnoreCase(response, "411 Length Required")); + } + + /** + * This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends more than + * the max. + */ + @Test + public void testShouldRejectRequestWithLongContentLengthHeader() throws Exception { + final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD); + String response = localConnector.getResponse(requestBody); + + assertTrue(response.contains("413 Payload Too Large")); + } + + /** + * This shows that the ContentLengthFilter rejects a request when the client claims more than the max + sends less than + * the claim. + */ + @Test + public void testShouldRejectRequestWithLongContentLengthHeaderAndSmallPayload() throws Exception { + String incompletePayload = "1".repeat(SMALL_CLAIM_SIZE_BYTES / 2); + final String requestBody = String.format(POST_REQUEST, LARGE_CLAIM_SIZE_BYTES, incompletePayload); + String response = localConnector.getResponse(requestBody); + + assertTrue(response.contains("413 Payload Too Large")); + } + + /** + * This shows that the ContentLengthFilter <em>allows</em> a request when the client claims less + * than the max + sends more than the max, but restricts the request body to the stated content + * length size. + */ + @Test + public void testShouldRejectRequestWithSmallContentLengthHeaderAndLargePayload() throws Exception { + final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, LARGE_PAYLOAD); + String response = localConnector.getResponse(requestBody); + + assertTrue(response.contains("200")); + assertTrue(response.contains("Bytes-Read: " + SMALL_CLAIM_SIZE_BYTES)); + assertTrue(response.contains("Read " + SMALL_CLAIM_SIZE_BYTES + " bytes")); + } + + /** + * This shows that the server times out when the client claims less than the max + sends less than the max + sends + * less than it claims to send. + */ + @Test + public void testShouldTimeoutRequestWithSmallContentLengthHeaderAndSmallerPayload() throws Exception { + String smallerPayload = SMALL_PAYLOAD.substring(0, SMALL_PAYLOAD.length() / 2); + final String requestBody = String.format(POST_REQUEST, SMALL_CLAIM_SIZE_BYTES, smallerPayload); + String response = localConnector.getResponse(requestBody, 500, TimeUnit.MILLISECONDS); + + assertTrue(response.contains("500 Server Error")); + assertTrue(response.contains("Timeout")); + } + + @Test + public void testFilterShouldAllowSiteToSiteTransfer() throws Exception { + final String siteToSitePostRequest = "POST /nifi-api/data-transfer/input-ports HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\n\r\n%s"; + final String siteToSiteRequest = String.format(siteToSitePostRequest, LARGE_CLAIM_SIZE_BYTES, LARGE_PAYLOAD); + String response = localConnector.getResponse(siteToSiteRequest); + + assertTrue(response.contains("200 OK")); + } + + @Test + void testJettyMaxFormSize() throws Exception { + // This shows that the jetty server option for 'maxFormContentSize' is insufficient for our needs because it + // catches requests like this: + + // Configure the server but do not apply the CLF because the FORM_CONTENT_SIZE > 0 + configureAndStartServer(new HttpServlet() { + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException { + try { + req.getParameterMap(); + ServletInputStream input = req.getInputStream(); + int count = 0; + while (!input.isFinished()) { + input.read(); + count += 1; + } + final int formLimitBytes = FORM_CONTENT_SIZE + "a=\n".length(); + if (count > formLimitBytes) { + resp.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Should not reach this code."); + } else { + resp.sendError(HttpServletResponse.SC_EXPECTATION_FAILED, "Read Too Many Bytes"); + } + } catch (final Exception e) { + // This is the jetty context returning a 400 from the maxFormContentSize setting: + if (StringUtils.containsIgnoreCase(e.getCause().toString(), "Form is larger than max length " + FORM_CONTENT_SIZE)) { + resp.sendError(HttpServletResponse.SC_REQUEST_ENTITY_TOO_LARGE, "Payload Too Large"); + } else { + resp.sendError(HttpServletResponse.SC_FORBIDDEN, "Should not reach this code, either."); + } + } + } + }, FORM_CONTENT_SIZE); + + // Test to catch a form submission that exceeds the FORM_CONTENT_SIZE limit + String form = "a=" + "1".repeat(FORM_CONTENT_SIZE); + final String formRequest = "POST / HTTP/1.1\r\nContent-Length: %d\r\nHost: h\r\nContent-Type: application/x-www-form-urlencoded\r\nAccept-Charset: UTF-8\r\n\r\n%s"; + String response = localConnector.getResponse(String.format(formRequest, form.length(), form)); + + assertTrue(response.contains("413 Payload Too Large")); + + // But it does not catch requests like this: + response = localConnector.getResponse(String.format(POST_REQUEST, form.length(), form + form)); + assertTrue(response.contains("417 Read Too Many Bytes")); + } + + /** + * Initializes a server which consumes any provided request input stream and returns HTTP 200. It has no + * {@code maxFormContentSize}, so the {@link ContentLengthFilter} is applied. The response contains a header and the + * response body indicating the total number of request content bytes read. + * + * @throws Exception if there is a problem setting up the server + */ + private void createSimpleReadServer() throws Exception { + HttpServlet mockServlet = new HttpServlet() { + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException { + byte[] byteBuffer = new byte[2048]; + int bytesRead = StreamUtils.fillBuffer(req.getInputStream(), byteBuffer, false); + resp.setHeader("Bytes-Read", Integer.toString(bytesRead)); + resp.setStatus(HttpServletResponse.SC_OK); + resp.getWriter().write("Read " + bytesRead + " bytes of request input"); + } + }; + configureAndStartServer(mockServlet, -1); + } + + private void configureAndStartServer(HttpServlet servlet, int maxFormContentSize) throws Exception { + serverUnderTest = new Server(); + localConnector = new LocalConnector(serverUnderTest); + localConnector.setIdleTimeout(2500); // only one request needed + value large enough for slow systems + serverUnderTest.addConnector(localConnector); + + ServletContextHandler contextUnderTest = new ServletContextHandler(serverUnderTest, "/"); + if (maxFormContentSize > 0) { + contextUnderTest.setMaxFormContentSize(maxFormContentSize); + } + contextUnderTest.addServlet(new ServletHolder(servlet), "/*"); + + // This only adds the ContentLengthFilter if a valid maxFormContentSize is not provided + if (maxFormContentSize < 0) { + FilterHolder holder = contextUnderTest.addFilter(ContentLengthFilter.class, "/*", EnumSet.of(DispatcherType.REQUEST)); + holder.setInitParameter(ContentLengthFilter.MAX_LENGTH_INIT_PARAM, String.valueOf(1000)); + } + serverUnderTest.start(); + } + + void stopServer() throws Exception { + if (serverUnderTest != null && serverUnderTest.isRunning()) { + serverUnderTest.stop(); + } + } +}