ATLAS-584 Integrate CSRF prevention filter (kevalbhatt18 via shwethags)
Project: http://git-wip-us.apache.org/repos/asf/incubator-atlas/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-atlas/commit/dda382f4 Tree: http://git-wip-us.apache.org/repos/asf/incubator-atlas/tree/dda382f4 Diff: http://git-wip-us.apache.org/repos/asf/incubator-atlas/diff/dda382f4 Branch: refs/heads/master Commit: dda382f491f1bb26bdeb31620092a11ce9fdb710 Parents: 33fdad7 Author: Shwetha GS <[email protected]> Authored: Thu Jul 7 10:31:02 2016 +0530 Committer: Shwetha GS <[email protected]> Committed: Thu Jul 7 10:31:02 2016 +0530 ---------------------------------------------------------------------- .../public/js/collection/BaseCollection.js | 4 +- dashboardv2/public/js/main.js | 16 +- dashboardv2/public/js/models/BaseModel.js | 3 +- .../public/js/utils/CommonViewFunction.js | 55 ++++- .../business_catalog/BusinessCatalogHeader.js | 11 +- dashboardv2/public/js/views/site/Header.js | 11 +- distro/src/conf/atlas-application.properties | 8 +- release-log.txt | 1 + .../web/filters/AtlasCSRFPreventionFilter.java | 247 +++++++++++++++++++ .../atlas/web/resources/AdminResource.java | 11 + webapp/src/main/resources/spring-security.xml | 4 + .../filters/AtlasCSRFPreventionFilterTest.java | 149 +++++++++++ 12 files changed, 492 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/dashboardv2/public/js/collection/BaseCollection.js ---------------------------------------------------------------------- diff --git a/dashboardv2/public/js/collection/BaseCollection.js b/dashboardv2/public/js/collection/BaseCollection.js index 0c148ac..e4ac1ae 100644 --- a/dashboardv2/public/js/collection/BaseCollection.js +++ b/dashboardv2/public/js/collection/BaseCollection.js @@ -19,8 +19,9 @@ define(['require', 'utils/Globals', 'utils/Utils', + 'utils/CommonViewFunction', 'backbone.paginator' -], function(require, Globals, Utils) { +], function(require, Globals, Utils, CommonViewFunction) { 'use strict'; var BaseCollection = Backbone.PageableCollection.extend( @@ -138,6 +139,7 @@ define(['require', return retCols; }, nonCrudOperation: function(url, requestMethod, options) { + options['beforeSend'] = CommonViewFunction.addRestCsrfCustomHeader; return Backbone.sync.call(this, null, this, _.extend({ url: url, type: requestMethod http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/dashboardv2/public/js/main.js ---------------------------------------------------------------------- diff --git a/dashboardv2/public/js/main.js b/dashboardv2/public/js/main.js index ceed10c..552d906 100644 --- a/dashboardv2/public/js/main.js +++ b/dashboardv2/public/js/main.js @@ -140,11 +140,23 @@ require.config({ require(['App', 'router/Router', + 'utils/CommonViewFunction', + 'utils/Globals', 'utils/Overrides', 'bootstrap', 'd3', 'select2' -], function(App, Router) { +], function(App, Router, CommonViewFunction, Globals) { App.appRouter = new Router(); - App.start(); + CommonViewFunction.userDataFetch({ + url: Globals.baseURL + "/api/atlas/admin/session", + callback: function(response) { + if (response && response.userName) { + Globals.userLogedIn.status = true; + Globals.userLogedIn.response = response; + } + App.start(); + } + }); + }); http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/dashboardv2/public/js/models/BaseModel.js ---------------------------------------------------------------------- diff --git a/dashboardv2/public/js/models/BaseModel.js b/dashboardv2/public/js/models/BaseModel.js index da96d04..27e0332 100644 --- a/dashboardv2/public/js/models/BaseModel.js +++ b/dashboardv2/public/js/models/BaseModel.js @@ -16,7 +16,7 @@ * limitations under the License. */ -define(['require', 'utils/Utils', 'backbone'], function(require, Utils, Backbone) { +define(['require', 'utils/Utils', 'backbone','utils/CommonViewFunction'], function(require, Utils, Backbone,CommonViewFunction) { 'use strict'; var BaseModel = Backbone.Model.extend( @@ -60,6 +60,7 @@ define(['require', 'utils/Utils', 'backbone'], function(require, Utils, Backbone * @return {[type]} [description] */ nonCrudOperation: function(url, requestMethod, options) { + options['beforeSend'] = CommonViewFunction.addRestCsrfCustomHeader; return Backbone.sync.call(this, null, this, _.extend({ url: url, type: requestMethod http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/dashboardv2/public/js/utils/CommonViewFunction.js ---------------------------------------------------------------------- diff --git a/dashboardv2/public/js/utils/CommonViewFunction.js b/dashboardv2/public/js/utils/CommonViewFunction.js index ed6a34f..6d88dc8 100644 --- a/dashboardv2/public/js/utils/CommonViewFunction.js +++ b/dashboardv2/public/js/utils/CommonViewFunction.js @@ -419,7 +419,6 @@ define(['require', 'utils/Utils', 'modules/Modal', 'utils/Messages', 'utils/Glob return "api/atlas/v1/entities/" + options.guid + "/tags/" + name; }; VCatalog.save(null, { - beforeSend: function() {}, success: function(data) { Utils.notifySuccess({ content: "Term " + name + Messages.addTermToEntitySuccessMessage @@ -435,7 +434,7 @@ define(['require', 'utils/Utils', 'modules/Modal', 'utils/Messages', 'utils/Glob if (data && data.responseText) { var data = JSON.parse(data.responseText); Utils.notifyError({ - content: data.message + content: data.message || data.msgDesc }); if (options.callback) { options.callback(); @@ -446,13 +445,63 @@ define(['require', 'utils/Utils', 'modules/Modal', 'utils/Messages', 'utils/Glob }); }) } + CommonViewFunction.addRestCsrfCustomHeader = function(xhr, settings) { + // if (settings.url == null || !settings.url.startsWith('/webhdfs/')) { + if (settings.url == null) { + return; + } + var method = settings.type; + if (CommonViewFunction.restCsrfCustomHeader != null && !CommonViewFunction.restCsrfMethodsToIgnore[method]) { + // The value of the header is unimportant. Only its presence matters. + xhr.setRequestHeader(CommonViewFunction.restCsrfCustomHeader, '""'); + } + } + CommonViewFunction.restCsrfCustomHeader = null; + CommonViewFunction.restCsrfMethodsToIgnore = null; CommonViewFunction.userDataFetch = function(options) { + var csrfEnabled = false, + header = null, + methods = []; + + function getTrimmedStringArrayValue(string) { + var str = string, + array = []; + if (str) { + var splitStr = str.split(','); + for (var i = 0; i < splitStr.length; i++) { + array.push(splitStr[i].trim()); + } + } + return array; + } if (options.url) { $.ajax({ url: options.url, success: function(response) { + if (response) { + if (response['atlas.rest-csrf.enabled']) { + var str = "" + response['atlas.rest-csrf.enabled']; + csrfEnabled = (str.toLowerCase() == 'true'); + } + if (response['atlas.rest-csrf.custom-header']) { + header = response['atlas.rest-csrf.custom-header'].trim(); + } + if (response['atlas.rest-csrf.methods-to-ignore']) { + methods = getTrimmedStringArrayValue(response['atlas.rest-csrf.methods-to-ignore']); + } + if (csrfEnabled) { + CommonViewFunction.restCsrfCustomHeader = header; + CommonViewFunction.restCsrfMethodsToIgnore = {}; + methods.map(function(method) { CommonViewFunction.restCsrfMethodsToIgnore[method] = true; }); + Backbone.$.ajaxSetup({ + beforeSend: CommonViewFunction.addRestCsrfCustomHeader + }); + } + } + }, + complete: function(response) { if (options.callback) { - options.callback(response); + options.callback(response.responseJSON); } } }); http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/dashboardv2/public/js/views/business_catalog/BusinessCatalogHeader.js ---------------------------------------------------------------------- diff --git a/dashboardv2/public/js/views/business_catalog/BusinessCatalogHeader.js b/dashboardv2/public/js/views/business_catalog/BusinessCatalogHeader.js index 8fa436c..6be1d2d 100644 --- a/dashboardv2/public/js/views/business_catalog/BusinessCatalogHeader.js +++ b/dashboardv2/public/js/views/business_catalog/BusinessCatalogHeader.js @@ -40,16 +40,7 @@ define(['require', render: function() { var that = this; $(this.el).html(this.template()); - if (!Globals.userLogedIn.status) { - CommonViewFunction.userDataFetch({ - url: Globals.baseURL + "/api/atlas/admin/session", - callback: function(response) { - that.$('.userName').html(response.userName); - Globals.userLogedIn.status = true; - Globals.userLogedIn.response = response; - } - }); - } else { + if (Globals.userLogedIn.status) { that.$('.userName').html(Globals.userLogedIn.response.userName); } var that = this; http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/dashboardv2/public/js/views/site/Header.js ---------------------------------------------------------------------- diff --git a/dashboardv2/public/js/views/site/Header.js b/dashboardv2/public/js/views/site/Header.js index 467cbf5..f53d3e8 100644 --- a/dashboardv2/public/js/views/site/Header.js +++ b/dashboardv2/public/js/views/site/Header.js @@ -30,16 +30,7 @@ define(['require', initialize: function(options) {}, onRender: function() { var that = this; - if (!Globals.userLogedIn.status) { - CommonViewFunction.userDataFetch({ - url: Globals.baseURL + "/api/atlas/admin/session", - callback: function(response) { - that.$('.userName').html(response.userName); - Globals.userLogedIn.status = true; - Globals.userLogedIn.response = response; - } - }); - } else { + if (Globals.userLogedIn.status) { that.$('.userName').html(Globals.userLogedIn.response.userName); } }, http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/distro/src/conf/atlas-application.properties ---------------------------------------------------------------------- diff --git a/distro/src/conf/atlas-application.properties b/distro/src/conf/atlas-application.properties index 79a4982..215d8d5 100755 --- a/distro/src/conf/atlas-application.properties +++ b/distro/src/conf/atlas-application.properties @@ -178,4 +178,10 @@ atlas.authorizer.impl=SIMPLE #atlas.graph.storage.cache.db-cache-time=120000 ######### Business Catalog ######### -atlas.taxonomy.default.name=Catalog \ No newline at end of file +atlas.taxonomy.default.name=Catalog + +######### CSRF Configs ######### +atlas.rest-csrf.enabled=true +atlas.rest-csrf.browser-useragents-regex=^Mozilla.*,^Opera.*,^Chrome.* +atlas.rest-csrf.methods-to-ignore=GET,OPTIONS,HEAD,TRACE +atlas.rest-csrf.custom-header=X-XSRF-HEADER http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/release-log.txt ---------------------------------------------------------------------- diff --git a/release-log.txt b/release-log.txt index 20a8a2c..78ae9a2 100644 --- a/release-log.txt +++ b/release-log.txt @@ -6,6 +6,7 @@ INCOMPATIBLE CHANGES: ALL CHANGES: +ATLAS-584 Integrate CSRF prevention filter (kevalbhatt18 via shwethags) ATLAS-963 UI: Entity details is not display String array attribute values correctly (kevalbhatt18 via shwethags) ATLAS-988 HiveHookIT.testInsertIntoTable is broken (svimal2106 via shwethags) ATLAS-655 Please delete old releases from mirroring system (shwethags) http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/webapp/src/main/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilter.java ---------------------------------------------------------------------- diff --git a/webapp/src/main/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilter.java b/webapp/src/main/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilter.java new file mode 100644 index 0000000..3cc83c5 --- /dev/null +++ b/webapp/src/main/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilter.java @@ -0,0 +1,247 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.atlas.web.filters; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.apache.log4j.Logger; +import org.apache.atlas.ApplicationProperties; +import org.apache.atlas.AtlasException; +import org.apache.commons.configuration.Configuration; +import org.codehaus.jackson.map.ObjectMapper; +import org.json.simple.JSONObject; + +public class AtlasCSRFPreventionFilter implements Filter { + private static final Logger LOG = Logger.getLogger(AtlasCSRFPreventionFilter.class); + private static Configuration configuration; + + static { + try { + configuration = ApplicationProperties.get(); + LOG.info("Configuration obtained :: "+configuration); + } catch (AtlasException e) { + LOG.error(e.getMessage(), e); + } + } + + public static final boolean isCSRF_ENABLED = configuration.getBoolean("atlas.rest-csrf.enabled", true); + public static final String BROWSER_USER_AGENT_PARAM = "atlas.rest-csrf.browser-useragents-regex"; + public static final String BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*,^Chrome"; + public static final String CUSTOM_METHODS_TO_IGNORE_PARAM = "atlas.rest-csrf.methods-to-ignore"; + public static final String METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE"; + public static final String CUSTOM_HEADER_PARAM = "atlas.rest-csrf.custom-header"; + public static final String HEADER_DEFAULT = "X-XSRF-HEADER"; + public static final String HEADER_USER_AGENT = "User-Agent"; + + private String headerName = HEADER_DEFAULT; + private Set<String> methodsToIgnore = null; + private Set<Pattern> browserUserAgents; + + public AtlasCSRFPreventionFilter() { + try { + if (isCSRF_ENABLED){ + init(null); + } + } catch (Exception e) { + LOG.error("Error while initializing Filter ", e); + } + } + + public void init(FilterConfig filterConfig) throws ServletException { + String customHeader = configuration.getString(CUSTOM_HEADER_PARAM, HEADER_DEFAULT); + if (customHeader != null) { + headerName = customHeader; + } + + String customMethodsToIgnore = configuration.getString(CUSTOM_METHODS_TO_IGNORE_PARAM, METHODS_TO_IGNORE_DEFAULT); + if (customMethodsToIgnore != null) { + parseMethodsToIgnore(customMethodsToIgnore); + } else { + parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT); + } + String agents = configuration.getString(BROWSER_USER_AGENT_PARAM, BROWSER_USER_AGENTS_DEFAULT); + if (agents == null) { + agents = BROWSER_USER_AGENTS_DEFAULT; + } + parseBrowserUserAgents(agents); + LOG.info("Adding cross-site request forgery (CSRF) protection"); + } + + void parseMethodsToIgnore(String mti) { + String[] methods = mti.split(","); + methodsToIgnore = new HashSet<String>(); + for (int i = 0; i < methods.length; i++) { + methodsToIgnore.add(methods[i]); + } + } + + void parseBrowserUserAgents(String userAgents) { + String[] agentsArray = userAgents.split(","); + browserUserAgents = new HashSet<Pattern>(); + for (String patternString : agentsArray) { + browserUserAgents.add(Pattern.compile(patternString)); + } + } + + protected boolean isBrowser(String userAgent) { + if (userAgent == null) { + return false; + } + if (browserUserAgents != null){ + for (Pattern pattern : browserUserAgents) { + Matcher matcher = pattern.matcher(userAgent); + if (matcher.matches()) { + return true; + } + } + } + return false; + } + + public interface HttpInteraction { + /** + * Returns the value of a header. + * + * @param header + * name of header + * @return value of header + */ + String getHeader(String header); + + /** + * Returns the method. + * + * @return method + */ + String getMethod(); + + /** + * Called by the filter after it decides that the request may proceed. + * + * @throws IOException + * if there is an I/O error + * @throws ServletException + * if the implementation relies on the servlet API and a + * servlet API call has failed + */ + void proceed() throws IOException, ServletException; + + /** + * Called by the filter after it decides that the request is a potential + * CSRF attack and therefore must be rejected. + * + * @param code + * status code to send + * @param message + * response message + * @throws IOException + * if there is an I/O error + */ + void sendError(int code, String message) throws IOException; + } + + public void handleHttpInteraction(HttpInteraction httpInteraction) + throws IOException, ServletException { + if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) + || methodsToIgnore.contains(httpInteraction.getMethod()) + || httpInteraction.getHeader(headerName) != null) { + httpInteraction.proceed(); + }else { + httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST,"Missing Required Header for CSRF Vulnerability Protection"); + } + } + + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { + if (isCSRF_ENABLED){ + final HttpServletRequest httpRequest = (HttpServletRequest)request; + final HttpServletResponse httpResponse = (HttpServletResponse)response; + handleHttpInteraction(new ServletFilterHttpInteraction(httpRequest, httpResponse, chain)); + }else{ + chain.doFilter(request, response); + } + } + + public void destroy() { + } + + private static final class ServletFilterHttpInteraction implements + HttpInteraction { + + private final FilterChain chain; + private final HttpServletRequest httpRequest; + private final HttpServletResponse httpResponse; + + /** + * Creates a new ServletFilterHttpInteraction. + * + * @param httpRequest + * request to process + * @param httpResponse + * response to process + * @param chain + * filter chain to forward to if HTTP interaction is allowed + */ + public ServletFilterHttpInteraction(HttpServletRequest httpRequest, + HttpServletResponse httpResponse, FilterChain chain) { + this.httpRequest = httpRequest; + this.httpResponse = httpResponse; + this.chain = chain; + } + + @Override + public String getHeader(String header) { + return httpRequest.getHeader(header); + } + + @Override + public String getMethod() { + return httpRequest.getMethod(); + } + + @Override + public void proceed() throws IOException, ServletException { + chain.doFilter(httpRequest, httpResponse); + } + + @Override + public void sendError(int code, String message) throws IOException { + JSONObject json = new JSONObject(); + ObjectMapper mapper = new ObjectMapper(); + json.put("msgDesc", message); + String jsonAsStr = mapper.writeValueAsString(json); + httpResponse.setContentType("application/json"); + httpResponse.setStatus(code); + httpResponse.setCharacterEncoding("UTF-8"); + httpResponse.getWriter().write(jsonAsStr); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/webapp/src/main/java/org/apache/atlas/web/resources/AdminResource.java ---------------------------------------------------------------------- diff --git a/webapp/src/main/java/org/apache/atlas/web/resources/AdminResource.java b/webapp/src/main/java/org/apache/atlas/web/resources/AdminResource.java index 3a46068..b7f6cf2 100755 --- a/webapp/src/main/java/org/apache/atlas/web/resources/AdminResource.java +++ b/webapp/src/main/java/org/apache/atlas/web/resources/AdminResource.java @@ -31,6 +31,7 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.apache.atlas.AtlasClient; +import org.apache.atlas.web.filters.AtlasCSRFPreventionFilter; import org.apache.atlas.web.service.ServiceState; import org.apache.atlas.web.util.Servlets; import org.apache.commons.configuration.ConfigurationException; @@ -51,6 +52,11 @@ import com.google.inject.Inject; @Singleton public class AdminResource { + private static final String isCSRF_ENABLED = "atlas.rest-csrf.enabled"; + private static final String BROWSER_USER_AGENT_PARAM = "atlas.rest-csrf.browser-useragents-regex"; + private static final String CUSTOM_METHODS_TO_IGNORE_PARAM = "atlas.rest-csrf.methods-to-ignore"; + private static final String CUSTOM_HEADER_PARAM = "atlas.rest-csrf.custom-header"; + private Response version; private ServiceState serviceState; @@ -147,6 +153,11 @@ public class AdminResource { } } + responseData.put(isCSRF_ENABLED, AtlasCSRFPreventionFilter.isCSRF_ENABLED); + responseData.put(BROWSER_USER_AGENT_PARAM, AtlasCSRFPreventionFilter.BROWSER_USER_AGENTS_DEFAULT); + responseData.put(CUSTOM_METHODS_TO_IGNORE_PARAM, AtlasCSRFPreventionFilter.METHODS_TO_IGNORE_DEFAULT); + responseData.put(CUSTOM_HEADER_PARAM, AtlasCSRFPreventionFilter.HEADER_DEFAULT); + responseData.put("userName", userName); responseData.put("groups", groups); Response response = Response.ok(responseData).build(); http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/webapp/src/main/resources/spring-security.xml ---------------------------------------------------------------------- diff --git a/webapp/src/main/resources/spring-security.xml b/webapp/src/main/resources/spring-security.xml index c21a644..ea9aa94 100644 --- a/webapp/src/main/resources/spring-security.xml +++ b/webapp/src/main/resources/spring-security.xml @@ -43,6 +43,7 @@ <intercept-url pattern="/**" access="isAuthenticated()" /> <security:custom-filter ref="krbAuthenticationFilter" after="SERVLET_API_SUPPORT_FILTER" /> + <security:custom-filter ref="CSRFPreventionFilter" after="REMEMBER_ME_FILTER" /> <form-login login-page="/login.jsp" @@ -59,6 +60,9 @@ <beans:bean id="krbAuthenticationFilter" class="org.apache.atlas.web.filters.AtlasAuthenticationFilter"> </beans:bean> + + <beans:bean id="CSRFPreventionFilter" class="org.apache.atlas.web.filters.AtlasCSRFPreventionFilter"> + </beans:bean> <beans:bean id="atlasAuthenticationSuccessHandler" class="org.apache.atlas.web.security.AtlasAuthenticationSuccessHandler" /> http://git-wip-us.apache.org/repos/asf/incubator-atlas/blob/dda382f4/webapp/src/test/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilterTest.java ---------------------------------------------------------------------- diff --git a/webapp/src/test/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilterTest.java b/webapp/src/test/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilterTest.java new file mode 100644 index 0000000..a742dd5 --- /dev/null +++ b/webapp/src/test/java/org/apache/atlas/web/filters/AtlasCSRFPreventionFilterTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.atlas.web.filters; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Test; +import org.mockito.Mockito; + +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.verify; + +public class AtlasCSRFPreventionFilterTest { + private static final String EXPECTED_MESSAGE = "Missing Required Header for CSRF Vulnerability Protection"; + private static final String X_CUSTOM_HEADER = "X-CUSTOM_HEADER"; + private String userAgent = "Mozilla"; + + @Test + public void testNoHeaderDefaultConfig_badRequest() throws ServletException, IOException { + // CSRF has not been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_DEFAULT)).thenReturn(null); + Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_USER_AGENT)).thenReturn(userAgent); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + AtlasCSRFPreventionFilter filter = new AtlasCSRFPreventionFilter(); + filter.doFilter(mockReq, mockRes, mockChain); + + verify(mockRes, atLeastOnce()).sendError(HttpServletResponse.SC_BAD_REQUEST, EXPECTED_MESSAGE); + Mockito.verifyZeroInteractions(mockChain); + } + + @Test + public void testHeaderPresentDefaultConfig_goodRequest() throws ServletException, IOException { + // CSRF HAS been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_DEFAULT)).thenReturn("valueUnimportant"); + Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_USER_AGENT)).thenReturn(userAgent); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + AtlasCSRFPreventionFilter filter = new AtlasCSRFPreventionFilter(); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verify(mockChain).doFilter(mockReq, mockRes); + } + + @Test + public void testHeaderPresentCustomHeaderConfig_goodRequest() throws ServletException, IOException { + // CSRF HAS been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(X_CUSTOM_HEADER)).thenReturn("valueUnimportant"); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + AtlasCSRFPreventionFilter filter = new AtlasCSRFPreventionFilter(); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verify(mockChain).doFilter(mockReq, mockRes); + } + + @Test + public void testMissingHeaderWithCustomHeaderConfig_badRequest() throws ServletException, IOException { + // CSRF has not been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(X_CUSTOM_HEADER)).thenReturn(null); + Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_USER_AGENT)).thenReturn(userAgent); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + AtlasCSRFPreventionFilter filter = new AtlasCSRFPreventionFilter(); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verifyZeroInteractions(mockChain); + } + + @Test + public void testMissingHeaderIgnoreGETMethodConfig_goodRequest() + throws ServletException, IOException { + // CSRF has not been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_DEFAULT)).thenReturn(null); + Mockito.when(mockReq.getMethod()).thenReturn("GET"); + Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_USER_AGENT)).thenReturn(userAgent); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + AtlasCSRFPreventionFilter filter = new AtlasCSRFPreventionFilter(); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verify(mockChain).doFilter(mockReq, mockRes); + } + + @Test + public void testMissingHeaderMultipleIgnoreMethodsConfig_badRequest() + throws ServletException, IOException { + // CSRF has not been sent + HttpServletRequest mockReq = Mockito.mock(HttpServletRequest.class); + Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_DEFAULT)) + .thenReturn(null); + Mockito.when(mockReq.getMethod()).thenReturn("PUT"); + Mockito.when(mockReq.getHeader(AtlasCSRFPreventionFilter.HEADER_USER_AGENT)).thenReturn(userAgent); + + // Objects to verify interactions based on request + HttpServletResponse mockRes = Mockito.mock(HttpServletResponse.class); + FilterChain mockChain = Mockito.mock(FilterChain.class); + + // Object under test + AtlasCSRFPreventionFilter filter = new AtlasCSRFPreventionFilter(); + filter.doFilter(mockReq, mockRes, mockChain); + + Mockito.verifyZeroInteractions(mockChain); + } +}
