http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-server/src/main/java/org/apache/knox/gateway/websockets/ProxyWebSocketAdapter.java ---------------------------------------------------------------------- diff --cc gateway-server/src/main/java/org/apache/knox/gateway/websockets/ProxyWebSocketAdapter.java index 850157e,0000000..a678a72 mode 100644,000000..100644 --- a/gateway-server/src/main/java/org/apache/knox/gateway/websockets/ProxyWebSocketAdapter.java +++ b/gateway-server/src/main/java/org/apache/knox/gateway/websockets/ProxyWebSocketAdapter.java @@@ -1,276 -1,0 +1,289 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.knox.gateway.websockets; + +import java.io.IOException; +import java.net.URI; +import java.util.concurrent.ExecutorService; + ++import javax.websocket.ClientEndpointConfig; +import javax.websocket.CloseReason; +import javax.websocket.ContainerProvider; +import javax.websocket.DeploymentException; +import javax.websocket.WebSocketContainer; + +import org.apache.knox.gateway.i18n.messages.MessagesFactory; +import org.eclipse.jetty.io.RuntimeIOException; +import org.eclipse.jetty.util.component.LifeCycle; +import org.eclipse.jetty.websocket.api.BatchMode; +import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.StatusCode; +import org.eclipse.jetty.websocket.api.WebSocketAdapter; + +/** + * Handles outbound/inbound Websocket connections and sessions. + * + * @since 0.10 + */ +public class ProxyWebSocketAdapter extends WebSocketAdapter { + + private static final WebsocketLogMessages LOG = MessagesFactory + .get(WebsocketLogMessages.class); + + /* URI for the backend */ + private final URI backend; + + /* Session between the frontend (browser) and Knox */ + private Session frontendSession; + + /* Session between the backend (outbound) and Knox */ + private javax.websocket.Session backendSession; + + private WebSocketContainer container; + + private ExecutorService pool; + + /** ++ * Used to transmit headers from browser to backend server. ++ * @since 0.14 ++ */ ++ private ClientEndpointConfig clientConfig; ++ ++ /** + * Create an instance + */ + public ProxyWebSocketAdapter(final URI backend, final ExecutorService pool) { ++ this(backend, pool, null); ++ } ++ ++ public ProxyWebSocketAdapter(final URI backend, final ExecutorService pool, final ClientEndpointConfig clientConfig) { + super(); + this.backend = backend; + this.pool = pool; ++ this.clientConfig = clientConfig; + } + + @Override + public void onWebSocketConnect(final Session frontEndSession) { + + /* + * Let's connect to the backend, this is where the Backend-to-frontend + * plumbing takes place + */ + container = ContainerProvider.getWebSocketContainer(); - final ProxyInboundSocket backendSocket = new ProxyInboundSocket( - getMessageCallback()); ++ ++ final ProxyInboundClient backendSocket = new ProxyInboundClient(getMessageCallback()); + + /* build the configuration */ + + /* Attempt Connect */ + try { - backendSession = container.connectToServer(backendSocket, backend); ++ backendSession = container.connectToServer(backendSocket, clientConfig, backend); ++ + LOG.onConnectionOpen(backend.toString()); + + } catch (DeploymentException e) { + LOG.connectionFailed(e); + throw new RuntimeException(e); + } catch (IOException e) { + LOG.connectionFailed(e); + throw new RuntimeIOException(e); + } + + super.onWebSocketConnect(frontEndSession); + this.frontendSession = frontEndSession; + + } + + @Override + public void onWebSocketBinary(final byte[] payload, final int offset, + final int length) { + + if (isNotConnected()) { + return; + } + + throw new UnsupportedOperationException( + "Websocket support for binary messages is not supported at this time."); + } + + @Override + public void onWebSocketText(final String message) { + + if (isNotConnected()) { + return; + } + + LOG.logMessage("[From Frontend --->]" + message); + + /* Proxy message to backend */ + try { + backendSession.getBasicRemote().sendText(message); + + } catch (IOException e) { + LOG.connectionFailed(e); + } + + } + + @Override + public void onWebSocketClose(int statusCode, String reason) { + super.onWebSocketClose(statusCode, reason); + + /* do the cleaning business in seperate thread so we don't block */ + pool.execute(new Runnable() { + @Override + public void run() { + closeQuietly(); + } + }); + + LOG.onConnectionClose(backend.toString()); + + } + + @Override + public void onWebSocketError(final Throwable t) { + cleanupOnError(t); + } + + /** + * Cleanup sessions + */ + private void cleanupOnError(final Throwable t) { + + LOG.onError(t.toString()); + if (t.toString().contains("exceeds maximum size")) { + if(frontendSession != null && !frontendSession.isOpen()) { + frontendSession.close(StatusCode.MESSAGE_TOO_LARGE, t.getMessage()); + } + } + + else { + if(frontendSession != null && !frontendSession.isOpen()) { + frontendSession.close(StatusCode.SERVER_ERROR, t.getMessage()); + } + + /* do the cleaning business in seperate thread so we don't block */ + pool.execute(new Runnable() { + @Override + public void run() { + closeQuietly(); + } + }); + + } + } + + private MessageEventCallback getMessageCallback() { + + return new MessageEventCallback() { + + @Override + public void doCallback(String message) { + /* do nothing */ + + } + + @Override + public void onConnectionOpen(Object session) { + /* do nothing */ + + } + + @Override + public void onConnectionClose(final CloseReason reason) { + try { + frontendSession.close(reason.getCloseCode().getCode(), + reason.getReasonPhrase()); + } finally { + + /* do the cleaning business in seperate thread so we don't block */ + pool.execute(new Runnable() { + @Override + public void run() { + closeQuietly(); + } + }); + + } + + } + + @Override + public void onError(Throwable cause) { + cleanupOnError(cause); + } + + @Override + public void onMessageText(String message, Object session) { + final RemoteEndpoint remote = getRemote(); + + LOG.logMessage("[From Backend <---]" + message); + + /* Proxy message to frontend */ + try { + remote.sendString(message); + if (remote.getBatchMode() == BatchMode.ON) { + remote.flush(); + } + } catch (IOException e) { + LOG.connectionFailed(e); + throw new RuntimeIOException(e); + } + + } + + @Override + public void onMessageBinary(byte[] message, boolean last, + Object session) { + throw new UnsupportedOperationException( + "Websocket support for binary messages is not supported at this time."); + + } + + }; + + } + + private void closeQuietly() { + + try { + if(backendSession != null && !backendSession.isOpen()) { + backendSession.close(); + } + } catch (IOException e) { + LOG.connectionFailed(e); + } + + if (container instanceof LifeCycle) { + try { + ((LifeCycle) container).stop(); + } catch (Exception e) { + LOG.connectionFailed(e); + } + } + + if(frontendSession != null && !frontendSession.isOpen()) { + frontendSession.close(); + } + + } + +}
http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-server/src/test/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorHandlerTest.java ---------------------------------------------------------------------- diff --cc gateway-server/src/test/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorHandlerTest.java index b713491,0000000..b5558fd mode 100644,000000..100644 --- a/gateway-server/src/test/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorHandlerTest.java +++ b/gateway-server/src/test/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorHandlerTest.java @@@ -1,239 -1,0 +1,392 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.knox.gateway.topology.simple; + +import org.apache.knox.gateway.topology.validation.TopologyValidator; +import org.apache.knox.gateway.util.XmlUtils; ++import java.io.ByteArrayInputStream; ++import java.io.File; ++import java.io.FileOutputStream; ++import java.io.IOException; ++ ++import java.util.ArrayList; ++import java.util.Collections; ++import java.util.HashMap; ++import java.util.List; ++import java.util.Map; ++import java.util.Properties; ++ ++import javax.xml.xpath.XPath; ++import javax.xml.xpath.XPathConstants; ++import javax.xml.xpath.XPathFactory; ++ ++import org.apache.commons.io.FileUtils; +import org.easymock.EasyMock; +import org.junit.Test; +import org.w3c.dom.Document; +import org.w3c.dom.Node; +import org.w3c.dom.NodeList; +import org.xml.sax.SAXException; + - import javax.xml.xpath.XPath; - import javax.xml.xpath.XPathConstants; - import javax.xml.xpath.XPathFactory; - import java.io.*; - import java.util.*; - - import static org.junit.Assert.*; ++import static org.junit.Assert.assertEquals; ++import static org.junit.Assert.assertFalse; ++import static org.junit.Assert.assertNotNull; ++import static org.junit.Assert.assertTrue; ++import static org.junit.Assert.fail; + + +public class SimpleDescriptorHandlerTest { + + private static final String TEST_PROVIDER_CONFIG = + " <gateway>\n" + + " <provider>\n" + + " <role>authentication</role>\n" + + " <name>ShiroProvider</name>\n" + + " <enabled>true</enabled>\n" + + " <param>\n" + + " <!-- \n" + + " session timeout in minutes, this is really idle timeout,\n" + + " defaults to 30mins, if the property value is not defined,, \n" + + " current client authentication would expire if client idles contiuosly for more than this value\n" + + " -->\n" + + " <name>sessionTimeout</name>\n" + + " <value>30</value>\n" + + " </param>\n" + + " <param>\n" + + " <name>main.ldapRealm</name>\n" + + " <value>org.apache.knox.gateway.shirorealm.KnoxLdapRealm</value>\n" + + " </param>\n" + + " <param>\n" + + " <name>main.ldapContextFactory</name>\n" + + " <value>org.apache.knox.gateway.shirorealm.KnoxLdapContextFactory</value>\n" + + " </param>\n" + + " <param>\n" + + " <name>main.ldapRealm.contextFactory</name>\n" + + " <value>$ldapContextFactory</value>\n" + + " </param>\n" + + " <param>\n" + + " <name>main.ldapRealm.userDnTemplate</name>\n" + + " <value>uid={0},ou=people,dc=hadoop,dc=apache,dc=org</value>\n" + + " </param>\n" + + " <param>\n" + + " <name>main.ldapRealm.contextFactory.url</name>\n" + + " <value>ldap://localhost:33389</value>\n" + + " </param>\n" + + " <param>\n" + + " <name>main.ldapRealm.contextFactory.authenticationMechanism</name>\n" + + " <value>simple</value>\n" + + " </param>\n" + + " <param>\n" + + " <name>urls./**</name>\n" + + " <value>authcBasic</value>\n" + + " </param>\n" + + " </provider>\n" + + "\n" + + " <provider>\n" + + " <role>identity-assertion</role>\n" + + " <name>Default</name>\n" + + " <enabled>true</enabled>\n" + + " </provider>\n" + + "\n" + + " <!--\n" + + " Defines rules for mapping host names internal to a Hadoop cluster to externally accessible host names.\n" + + " For example, a hadoop service running in AWS may return a response that includes URLs containing the\n" + + " some AWS internal host name. If the client needs to make a subsequent request to the host identified\n" + + " in those URLs they need to be mapped to external host names that the client Knox can use to connect.\n" + + "\n" + + " If the external hostname and internal host names are same turn of this provider by setting the value of\n" + + " enabled parameter as false.\n" + + "\n" + + " The name parameter specifies the external host names in a comma separated list.\n" + + " The value parameter specifies corresponding internal host names in a comma separated list.\n" + + "\n" + + " Note that when you are using Sandbox, the external hostname needs to be localhost, as seen in out\n" + + " of box sandbox.xml. This is because Sandbox uses port mapping to allow clients to connect to the\n" + + " Hadoop services using localhost. In real clusters, external host names would almost never be localhost.\n" + + " -->\n" + + " <provider>\n" + + " <role>hostmap</role>\n" + + " <name>static</name>\n" + + " <enabled>true</enabled>\n" + + " <param><name>localhost</name><value>sandbox,sandbox.hortonworks.com</value></param>\n" + + " </provider>\n" + + " </gateway>\n"; + + + /** + * KNOX-1006 + * + * N.B. This test depends on the DummyServiceDiscovery extension being configured: + * org.apache.knox.gateway.topology.discovery.test.extension.DummyServiceDiscovery + */ + @Test + public void testSimpleDescriptorHandler() throws Exception { + + final String type = "DUMMY"; + final String address = "http://c6401.ambari.apache.org:8080"; + final String clusterName = "dummy"; + final Map<String, List<String>> serviceURLs = new HashMap<>(); + serviceURLs.put("NAMENODE", null); + serviceURLs.put("JOBTRACKER", null); + serviceURLs.put("WEBHDFS", null); + serviceURLs.put("WEBHCAT", null); + serviceURLs.put("OOZIE", null); + serviceURLs.put("WEBHBASE", null); + serviceURLs.put("HIVE", null); + serviceURLs.put("RESOURCEMANAGER", null); - serviceURLs.put("AMBARIUI", Arrays.asList("http://c6401.ambari.apache.org:8080")); ++ serviceURLs.put("AMBARIUI", Collections.singletonList("http://c6401.ambari.apache.org:8080")); + + // Write the externalized provider config to a temp file + File providerConfig = writeProviderConfig("ambari-cluster-policy.xml", TEST_PROVIDER_CONFIG); + + File topologyFile = null; + try { + File destDir = (new File(".")).getCanonicalFile(); + + // Mock out the simple descriptor + SimpleDescriptor testDescriptor = EasyMock.createNiceMock(SimpleDescriptor.class); + EasyMock.expect(testDescriptor.getName()).andReturn("mysimpledescriptor").anyTimes(); + EasyMock.expect(testDescriptor.getDiscoveryAddress()).andReturn(address).anyTimes(); + EasyMock.expect(testDescriptor.getDiscoveryType()).andReturn(type).anyTimes(); + EasyMock.expect(testDescriptor.getDiscoveryUser()).andReturn(null).anyTimes(); + EasyMock.expect(testDescriptor.getProviderConfig()).andReturn(providerConfig.getAbsolutePath()).anyTimes(); + EasyMock.expect(testDescriptor.getClusterName()).andReturn(clusterName).anyTimes(); + List<SimpleDescriptor.Service> serviceMocks = new ArrayList<>(); + for (String serviceName : serviceURLs.keySet()) { + SimpleDescriptor.Service svc = EasyMock.createNiceMock(SimpleDescriptor.Service.class); + EasyMock.expect(svc.getName()).andReturn(serviceName).anyTimes(); + EasyMock.expect(svc.getURLs()).andReturn(serviceURLs.get(serviceName)).anyTimes(); + EasyMock.replay(svc); + serviceMocks.add(svc); + } + EasyMock.expect(testDescriptor.getServices()).andReturn(serviceMocks).anyTimes(); + EasyMock.replay(testDescriptor); + + // Invoke the simple descriptor handler + Map<String, File> files = + SimpleDescriptorHandler.handle(testDescriptor, + providerConfig.getParentFile(), // simple desc co-located with provider config + destDir); + topologyFile = files.get("topology"); + + // Validate the resulting topology descriptor + assertTrue(topologyFile.exists()); + + // Validate the topology descriptor's correctness + TopologyValidator validator = new TopologyValidator( topologyFile.getAbsolutePath() ); + if( !validator.validateTopology() ){ + throw new SAXException( validator.getErrorString() ); + } + + XPathFactory xPathfactory = XPathFactory.newInstance(); + XPath xpath = xPathfactory.newXPath(); + + // Parse the topology descriptor + Document topologyXml = XmlUtils.readXml(topologyFile); + + // Validate the provider configuration + Document extProviderConf = XmlUtils.readXml(new ByteArrayInputStream(TEST_PROVIDER_CONFIG.getBytes())); + Node gatewayNode = (Node) xpath.compile("/topology/gateway").evaluate(topologyXml, XPathConstants.NODE); + assertTrue("Resulting provider config should be identical to the referenced content.", + extProviderConf.getDocumentElement().isEqualNode(gatewayNode)); + + // Validate the service declarations + Map<String, List<String>> topologyServiceURLs = new HashMap<>(); + NodeList serviceNodes = + (NodeList) xpath.compile("/topology/service").evaluate(topologyXml, XPathConstants.NODESET); + for (int serviceNodeIndex=0; serviceNodeIndex < serviceNodes.getLength(); serviceNodeIndex++) { + Node serviceNode = serviceNodes.item(serviceNodeIndex); + Node roleNode = (Node) xpath.compile("role/text()").evaluate(serviceNode, XPathConstants.NODE); + assertNotNull(roleNode); + String role = roleNode.getNodeValue(); + NodeList urlNodes = (NodeList) xpath.compile("url/text()").evaluate(serviceNode, XPathConstants.NODESET); + for(int urlNodeIndex = 0 ; urlNodeIndex < urlNodes.getLength(); urlNodeIndex++) { + Node urlNode = urlNodes.item(urlNodeIndex); + assertNotNull(urlNode); + String url = urlNode.getNodeValue(); + assertNotNull("Every declared service should have a URL.", url); + if (!topologyServiceURLs.containsKey(role)) { + topologyServiceURLs.put(role, new ArrayList<String>()); + } + topologyServiceURLs.get(role).add(url); + } + } + assertEquals("Unexpected number of service declarations.", serviceURLs.size(), topologyServiceURLs.size()); + + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } finally { + providerConfig.delete(); + if (topologyFile != null) { + topologyFile.delete(); + } + } + } + + - private File writeProviderConfig(String path, String content) throws IOException { - File f = new File(path); ++ /** ++ * KNOX-1006 ++ * ++ * Verify the behavior of the SimpleDescriptorHandler when service discovery fails to produce a valid URL for ++ * a service. ++ * ++ * N.B. This test depends on the PropertiesFileServiceDiscovery extension being configured: ++ * org.apache.hadoop.gateway.topology.discovery.test.extension.PropertiesFileServiceDiscovery ++ */ ++ @Test ++ public void testInvalidServiceURLFromDiscovery() throws Exception { ++ final String CLUSTER_NAME = "myproperties"; ++ ++ // Configure the PropertiesFile Service Discovery implementation for this test ++ final String DEFAULT_VALID_SERVICE_URL = "http://localhost:9999/thiswillwork"; ++ Properties serviceDiscoverySourceProps = new Properties(); ++ serviceDiscoverySourceProps.setProperty(CLUSTER_NAME + ".NAMENODE", ++ DEFAULT_VALID_SERVICE_URL.replace("http", "hdfs")); ++ serviceDiscoverySourceProps.setProperty(CLUSTER_NAME + ".JOBTRACKER", ++ DEFAULT_VALID_SERVICE_URL.replace("http", "rpc")); ++ serviceDiscoverySourceProps.setProperty(CLUSTER_NAME + ".WEBHDFS", DEFAULT_VALID_SERVICE_URL); ++ serviceDiscoverySourceProps.setProperty(CLUSTER_NAME + ".WEBHCAT", DEFAULT_VALID_SERVICE_URL); ++ serviceDiscoverySourceProps.setProperty(CLUSTER_NAME + ".OOZIE", DEFAULT_VALID_SERVICE_URL); ++ serviceDiscoverySourceProps.setProperty(CLUSTER_NAME + ".WEBHBASE", DEFAULT_VALID_SERVICE_URL); ++ serviceDiscoverySourceProps.setProperty(CLUSTER_NAME + ".HIVE", "{SCHEME}://localhost:10000/"); ++ serviceDiscoverySourceProps.setProperty(CLUSTER_NAME + ".RESOURCEMANAGER", DEFAULT_VALID_SERVICE_URL); ++ serviceDiscoverySourceProps.setProperty(CLUSTER_NAME + ".AMBARIUI", DEFAULT_VALID_SERVICE_URL); ++ File serviceDiscoverySource = File.createTempFile("service-discovery", ".properties"); ++ serviceDiscoverySourceProps.store(new FileOutputStream(serviceDiscoverySource), ++ "Test Service Discovery Source"); ++ ++ // Prepare a mock SimpleDescriptor ++ final String type = "PROPERTIES_FILE"; ++ final String address = serviceDiscoverySource.getAbsolutePath(); ++ final Map<String, List<String>> serviceURLs = new HashMap<>(); ++ serviceURLs.put("NAMENODE", null); ++ serviceURLs.put("JOBTRACKER", null); ++ serviceURLs.put("WEBHDFS", null); ++ serviceURLs.put("WEBHCAT", null); ++ serviceURLs.put("OOZIE", null); ++ serviceURLs.put("WEBHBASE", null); ++ serviceURLs.put("HIVE", null); ++ serviceURLs.put("RESOURCEMANAGER", null); ++ serviceURLs.put("AMBARIUI", Collections.singletonList("http://c6401.ambari.apache.org:8080")); + - Writer fw = new FileWriter(f); - fw.write(content); - fw.flush(); - fw.close(); ++ // Write the externalized provider config to a temp file ++ File providerConfig = writeProviderConfig("ambari-cluster-policy.xml", TEST_PROVIDER_CONFIG); ++ ++ File topologyFile = null; ++ try { ++ File destDir = (new File(".")).getCanonicalFile(); ++ ++ // Mock out the simple descriptor ++ SimpleDescriptor testDescriptor = EasyMock.createNiceMock(SimpleDescriptor.class); ++ EasyMock.expect(testDescriptor.getName()).andReturn("mysimpledescriptor").anyTimes(); ++ EasyMock.expect(testDescriptor.getDiscoveryAddress()).andReturn(address).anyTimes(); ++ EasyMock.expect(testDescriptor.getDiscoveryType()).andReturn(type).anyTimes(); ++ EasyMock.expect(testDescriptor.getDiscoveryUser()).andReturn(null).anyTimes(); ++ EasyMock.expect(testDescriptor.getProviderConfig()).andReturn(providerConfig.getAbsolutePath()).anyTimes(); ++ EasyMock.expect(testDescriptor.getClusterName()).andReturn(CLUSTER_NAME).anyTimes(); ++ List<SimpleDescriptor.Service> serviceMocks = new ArrayList<>(); ++ for (String serviceName : serviceURLs.keySet()) { ++ SimpleDescriptor.Service svc = EasyMock.createNiceMock(SimpleDescriptor.Service.class); ++ EasyMock.expect(svc.getName()).andReturn(serviceName).anyTimes(); ++ EasyMock.expect(svc.getURLs()).andReturn(serviceURLs.get(serviceName)).anyTimes(); ++ EasyMock.replay(svc); ++ serviceMocks.add(svc); ++ } ++ EasyMock.expect(testDescriptor.getServices()).andReturn(serviceMocks).anyTimes(); ++ EasyMock.replay(testDescriptor); ++ ++ // Invoke the simple descriptor handler ++ Map<String, File> files = ++ SimpleDescriptorHandler.handle(testDescriptor, ++ providerConfig.getParentFile(), // simple desc co-located with provider config ++ destDir); ++ ++ topologyFile = files.get("topology"); + ++ // Validate the resulting topology descriptor ++ assertTrue(topologyFile.exists()); ++ ++ // Validate the topology descriptor's correctness ++ TopologyValidator validator = new TopologyValidator( topologyFile.getAbsolutePath() ); ++ if( !validator.validateTopology() ){ ++ throw new SAXException( validator.getErrorString() ); ++ } ++ ++ XPathFactory xPathfactory = XPathFactory.newInstance(); ++ XPath xpath = xPathfactory.newXPath(); ++ ++ // Parse the topology descriptor ++ Document topologyXml = XmlUtils.readXml(topologyFile); ++ ++ // Validate the provider configuration ++ Document extProviderConf = XmlUtils.readXml(new ByteArrayInputStream(TEST_PROVIDER_CONFIG.getBytes())); ++ Node gatewayNode = (Node) xpath.compile("/topology/gateway").evaluate(topologyXml, XPathConstants.NODE); ++ assertTrue("Resulting provider config should be identical to the referenced content.", ++ extProviderConf.getDocumentElement().isEqualNode(gatewayNode)); ++ ++ // Validate the service declarations ++ List<String> topologyServices = new ArrayList<>(); ++ Map<String, List<String>> topologyServiceURLs = new HashMap<>(); ++ NodeList serviceNodes = ++ (NodeList) xpath.compile("/topology/service").evaluate(topologyXml, XPathConstants.NODESET); ++ for (int serviceNodeIndex=0; serviceNodeIndex < serviceNodes.getLength(); serviceNodeIndex++) { ++ Node serviceNode = serviceNodes.item(serviceNodeIndex); ++ Node roleNode = (Node) xpath.compile("role/text()").evaluate(serviceNode, XPathConstants.NODE); ++ assertNotNull(roleNode); ++ String role = roleNode.getNodeValue(); ++ topologyServices.add(role); ++ NodeList urlNodes = (NodeList) xpath.compile("url/text()").evaluate(serviceNode, XPathConstants.NODESET); ++ for(int urlNodeIndex = 0 ; urlNodeIndex < urlNodes.getLength(); urlNodeIndex++) { ++ Node urlNode = urlNodes.item(urlNodeIndex); ++ assertNotNull(urlNode); ++ String url = urlNode.getNodeValue(); ++ assertNotNull("Every declared service should have a URL.", url); ++ if (!topologyServiceURLs.containsKey(role)) { ++ topologyServiceURLs.put(role, new ArrayList<String>()); ++ } ++ topologyServiceURLs.get(role).add(url); ++ } ++ } ++ ++ // There should not be a service element for HIVE, since it had no valid URLs ++ assertEquals("Unexpected number of service declarations.", serviceURLs.size() - 1, topologyServices.size()); ++ assertFalse("The HIVE service should have been omitted from the generated topology.", topologyServices.contains("HIVE")); ++ ++ assertEquals("Unexpected number of service URLs.", serviceURLs.size() - 1, topologyServiceURLs.size()); ++ ++ } catch (Exception e) { ++ e.printStackTrace(); ++ fail(e.getMessage()); ++ } finally { ++ serviceDiscoverySource.delete(); ++ providerConfig.delete(); ++ if (topologyFile != null) { ++ topologyFile.delete(); ++ } ++ } ++ } ++ ++ ++ private File writeProviderConfig(String path, String content) throws IOException { ++ File f = new File(path); ++ FileUtils.write(f, content); + return f; + } + +} http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-service-definitions/src/main/resources/services/ambariui/2.2.0/service.xml ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-service-knoxsso/src/main/java/org/apache/knox/gateway/service/knoxsso/WebSSOResource.java ---------------------------------------------------------------------- diff --cc gateway-service-knoxsso/src/main/java/org/apache/knox/gateway/service/knoxsso/WebSSOResource.java index 8a9d028,0000000..a97cee2 mode 100644,000000..100644 --- a/gateway-service-knoxsso/src/main/java/org/apache/knox/gateway/service/knoxsso/WebSSOResource.java +++ b/gateway-service-knoxsso/src/main/java/org/apache/knox/gateway/service/knoxsso/WebSSOResource.java @@@ -1,322 -1,0 +1,322 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.knox.gateway.service.knoxsso; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.security.Principal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.annotation.PostConstruct; +import javax.servlet.ServletContext; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.Response; +import javax.ws.rs.WebApplicationException; + +import org.apache.knox.gateway.i18n.messages.MessagesFactory; +import org.apache.knox.gateway.services.GatewayServices; +import org.apache.knox.gateway.services.security.token.JWTokenAuthority; +import org.apache.knox.gateway.services.security.token.TokenServiceException; +import org.apache.knox.gateway.services.security.token.impl.JWT; +import org.apache.knox.gateway.util.RegExUtils; +import org.apache.knox.gateway.util.Urls; + +import static javax.ws.rs.core.MediaType.APPLICATION_JSON; +import static javax.ws.rs.core.MediaType.APPLICATION_XML; + +@Path( WebSSOResource.RESOURCE_PATH ) +public class WebSSOResource { + private static final String SSO_COOKIE_NAME = "knoxsso.cookie.name"; + private static final String SSO_COOKIE_SECURE_ONLY_INIT_PARAM = "knoxsso.cookie.secure.only"; + private static final String SSO_COOKIE_MAX_AGE_INIT_PARAM = "knoxsso.cookie.max.age"; + private static final String SSO_COOKIE_DOMAIN_SUFFIX_PARAM = "knoxsso.cookie.domain.suffix"; + private static final String SSO_COOKIE_TOKEN_TTL_PARAM = "knoxsso.token.ttl"; + private static final String SSO_COOKIE_TOKEN_AUDIENCES_PARAM = "knoxsso.token.audiences"; + private static final String SSO_COOKIE_TOKEN_WHITELIST_PARAM = "knoxsso.redirect.whitelist.regex"; + private static final String SSO_ENABLE_SESSION_PARAM = "knoxsso.enable.session"; + private static final String ORIGINAL_URL_REQUEST_PARAM = "originalUrl"; + private static final String ORIGINAL_URL_COOKIE_NAME = "original-url"; + private static final String DEFAULT_SSO_COOKIE_NAME = "hadoop-jwt"; + // default for the whitelist - open up for development - relative paths and localhost only + private static final String DEFAULT_WHITELIST = "^/.*$;^https?://(localhost|127.0.0.1|0:0:0:0:0:0:0:1|::1):\\d{0,9}/.*$"; + static final String RESOURCE_PATH = "/api/v1/websso"; + private static KnoxSSOMessages log = MessagesFactory.get( KnoxSSOMessages.class ); + private String cookieName = null; + private boolean secureOnly = true; + private int maxAge = -1; + private long tokenTTL = 30000l; + private String whitelist = null; + private String domainSuffix = null; + private List<String> targetAudiences = new ArrayList<>(); + private boolean enableSession = false; + + @Context + HttpServletRequest request; + + @Context + HttpServletResponse response; + + @Context + ServletContext context; + + @PostConstruct + public void init() { + + // configured cookieName + cookieName = context.getInitParameter(SSO_COOKIE_NAME); + if (cookieName == null) { + cookieName = DEFAULT_SSO_COOKIE_NAME; + } + + String secure = context.getInitParameter(SSO_COOKIE_SECURE_ONLY_INIT_PARAM); + if (secure != null) { + secureOnly = ("false".equals(secure) ? false : true); + if (!secureOnly) { + log.cookieSecureOnly(secureOnly); + } + } + + String age = context.getInitParameter(SSO_COOKIE_MAX_AGE_INIT_PARAM); + if (age != null) { + try { + log.setMaxAge(age); + maxAge = Integer.parseInt(age); + } + catch (NumberFormatException nfe) { + log.invalidMaxAgeEncountered(age); + } + } + + domainSuffix = context.getInitParameter(SSO_COOKIE_DOMAIN_SUFFIX_PARAM); + + whitelist = context.getInitParameter(SSO_COOKIE_TOKEN_WHITELIST_PARAM); + if (whitelist == null) { + // default to local/relative targets + whitelist = DEFAULT_WHITELIST; + } + + String audiences = context.getInitParameter(SSO_COOKIE_TOKEN_AUDIENCES_PARAM); + if (audiences != null) { + String[] auds = audiences.split(","); + for (int i = 0; i < auds.length; i++) { - targetAudiences.add(auds[i]); ++ targetAudiences.add(auds[i].trim()); + } + } + + String ttl = context.getInitParameter(SSO_COOKIE_TOKEN_TTL_PARAM); + if (ttl != null) { + try { + tokenTTL = Long.parseLong(ttl); + } + catch (NumberFormatException nfe) { + log.invalidTokenTTLEncountered(ttl); + } + } + + String enableSession = context.getInitParameter(SSO_ENABLE_SESSION_PARAM); + this.enableSession = ("true".equals(enableSession)); + } + + @GET + @Produces({APPLICATION_JSON, APPLICATION_XML}) + public Response doGet() { + return getAuthenticationToken(HttpServletResponse.SC_TEMPORARY_REDIRECT); + } + + @POST + @Produces({APPLICATION_JSON, APPLICATION_XML}) + public Response doPost() { + return getAuthenticationToken(HttpServletResponse.SC_SEE_OTHER); + } + + private Response getAuthenticationToken(int statusCode) { + GatewayServices services = (GatewayServices) request.getServletContext() + .getAttribute(GatewayServices.GATEWAY_SERVICES_ATTRIBUTE); + boolean removeOriginalUrlCookie = true; + String original = getCookieValue((HttpServletRequest) request, ORIGINAL_URL_COOKIE_NAME); + if (original == null) { + // in the case where there are no SAML redirects done before here + // we need to get it from the request parameters + removeOriginalUrlCookie = false; + original = getOriginalUrlFromQueryParams(); + if (original.isEmpty()) { + log.originalURLNotFound(); + throw new WebApplicationException("Original URL not found in the request.", Response.Status.BAD_REQUEST); + } + boolean validRedirect = RegExUtils.checkWhitelist(whitelist, original); + if (!validRedirect) { + log.whiteListMatchFail(original, whitelist); + throw new WebApplicationException("Original URL not valid according to the configured whitelist.", + Response.Status.BAD_REQUEST); + } + } + + JWTokenAuthority ts = services.getService(GatewayServices.TOKEN_SERVICE); + Principal p = ((HttpServletRequest)request).getUserPrincipal(); + + try { + JWT token = null; + if (targetAudiences.isEmpty()) { + token = ts.issueToken(p, "RS256", getExpiry()); + } else { + token = ts.issueToken(p, targetAudiences, "RS256", getExpiry()); + } + + // Coverity CID 1327959 + if( token != null ) { + addJWTHadoopCookie( original, token ); + } + + if (removeOriginalUrlCookie) { + removeOriginalUrlCookie(response); + } + + log.aboutToRedirectToOriginal(original); + response.setStatus(statusCode); + response.setHeader("Location", original); + try { + response.getOutputStream().close(); + } catch (IOException e) { + log.unableToCloseOutputStream(e.getMessage(), Arrays.toString(e.getStackTrace())); + } + } + catch (TokenServiceException e) { + log.unableToIssueToken(e); + } + URI location = null; + try { + location = new URI(original); + } + catch(URISyntaxException urise) { + // todo log return error response + } + + if (!enableSession) { + // invalidate the session to avoid autologin + // Coverity CID 1352857 + HttpSession session = request.getSession(false); + if( session != null ) { + session.invalidate(); + } + } + + return Response.seeOther(location).entity("{ \"redirectTo\" : " + original + " }").build(); + } + + private String getOriginalUrlFromQueryParams() { + String original = request.getParameter(ORIGINAL_URL_REQUEST_PARAM); + StringBuffer buf = new StringBuffer(original); + + // Add any other query params. + // Probably not ideal but will not break existing integrations by requiring + // some encoding. + Map<String, String[]> params = request.getParameterMap(); + for (Entry<String, String[]> entry : params.entrySet()) { + if (!ORIGINAL_URL_REQUEST_PARAM.equals(entry.getKey()) + && !original.contains(entry.getKey() + "=")) { + buf.append("&").append(entry.getKey()); + String[] values = entry.getValue(); + if (values.length > 0 && values[0] != null) { + buf.append("="); + } + for (int i = 0; i < values.length; i++) { + if (values[0] != null) { + buf.append(values[i]); + if (i < values.length-1) { + buf.append("&").append(entry.getKey()).append("="); + } + } + } + } + } + + return buf.toString(); + } + + private long getExpiry() { + long expiry = 0l; + if (tokenTTL == -1) { + expiry = -1; + } + else { + expiry = System.currentTimeMillis() + tokenTTL; + } + return expiry; + } + + private void addJWTHadoopCookie(String original, JWT token) { + log.addingJWTCookie(token.toString()); + Cookie c = new Cookie(cookieName, token.toString()); + c.setPath("/"); + try { + String domain = Urls.getDomainName(original, domainSuffix); + if (domain != null) { + c.setDomain(domain); + } + c.setHttpOnly(true); + if (secureOnly) { + c.setSecure(true); + } + if (maxAge != -1) { + c.setMaxAge(maxAge); + } + response.addCookie(c); + log.addedJWTCookie(); + } + catch(Exception e) { + log.unableAddCookieToResponse(e.getMessage(), Arrays.toString(e.getStackTrace())); + throw new WebApplicationException("Unable to add JWT cookie to response."); + } + } + + private void removeOriginalUrlCookie(HttpServletResponse response) { + Cookie c = new Cookie(ORIGINAL_URL_COOKIE_NAME, null); + c.setMaxAge(0); + c.setPath(RESOURCE_PATH); + response.addCookie(c); + } + + private String getCookieValue(HttpServletRequest request, String name) { + Cookie[] cookies = request.getCookies(); + String value = null; + if (cookies != null) { + for(Cookie cookie : cookies){ + if(name.equals(cookie.getName())){ + value = cookie.getValue(); + } + } + } + if (value == null) { + log.cookieNotFound(name); + } + return value; + } +} http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-service-knoxsso/src/test/java/org/apache/knox/gateway/service/knoxsso/WebSSOResourceTest.java ---------------------------------------------------------------------- diff --cc gateway-service-knoxsso/src/test/java/org/apache/knox/gateway/service/knoxsso/WebSSOResourceTest.java index 6f0a805,0000000..6b8411e mode 100644,000000..100644 --- a/gateway-service-knoxsso/src/test/java/org/apache/knox/gateway/service/knoxsso/WebSSOResourceTest.java +++ b/gateway-service-knoxsso/src/test/java/org/apache/knox/gateway/service/knoxsso/WebSSOResourceTest.java @@@ -1,352 -1,0 +1,410 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.knox.gateway.service.knoxsso; + +import org.apache.knox.gateway.util.RegExUtils; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.Principal; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.security.auth.Subject; +import javax.servlet.ServletContext; +import javax.servlet.ServletOutputStream; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + +import org.apache.knox.gateway.services.GatewayServices; +import org.apache.knox.gateway.services.security.token.JWTokenAuthority; +import org.apache.knox.gateway.services.security.token.TokenServiceException; +import org.apache.knox.gateway.services.security.token.impl.JWT; +import org.apache.knox.gateway.services.security.token.impl.JWTToken; +import org.apache.knox.gateway.util.RegExUtils; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import com.nimbusds.jose.JWSSigner; +import com.nimbusds.jose.JWSVerifier; +import com.nimbusds.jose.crypto.RSASSASigner; +import com.nimbusds.jose.crypto.RSASSAVerifier; + +/** + * Some tests for the Knox SSO service. + */ +public class WebSSOResourceTest { + + protected static RSAPublicKey publicKey; + protected static RSAPrivateKey privateKey; + + @BeforeClass + public static void setup() throws Exception, NoSuchAlgorithmException { + KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA"); + kpg.initialize(1024); + KeyPair KPair = kpg.generateKeyPair(); + + publicKey = (RSAPublicKey) KPair.getPublic(); + privateKey = (RSAPrivateKey) KPair.getPrivate(); + } + + @Test + public void testWhitelistMatching() throws Exception { + String whitelist = "^https?://.*example.com:8080/.*$;" + + "^https?://.*example.com/.*$;" + + "^https?://.*example2.com:\\d{0,9}/.*$;" + + "^https://.*example3.com:\\d{0,9}/.*$;" + + "^https?://localhost:\\d{0,9}/.*$;^/.*$"; + + // match on explicit hostname/domain and port + Assert.assertTrue("Failed to match whitelist", RegExUtils.checkWhitelist(whitelist, + "http://host.example.com:8080/")); + // match on non-required port + Assert.assertTrue("Failed to match whitelist", RegExUtils.checkWhitelist(whitelist, + "http://host.example.com/")); + // match on required but any port + Assert.assertTrue("Failed to match whitelist", RegExUtils.checkWhitelist(whitelist, + "http://host.example2.com:1234/")); + // fail on missing port + Assert.assertFalse("Matched whitelist inappropriately", RegExUtils.checkWhitelist(whitelist, + "http://host.example2.com/")); + // fail on invalid port + Assert.assertFalse("Matched whitelist inappropriately", RegExUtils.checkWhitelist(whitelist, + "http://host.example.com:8081/")); + // fail on alphanumeric port + Assert.assertFalse("Matched whitelist inappropriately", RegExUtils.checkWhitelist(whitelist, + "http://host.example.com:A080/")); + // fail on invalid hostname/domain + Assert.assertFalse("Matched whitelist inappropriately", RegExUtils.checkWhitelist(whitelist, + "http://host.example.net:8080/")); + // fail on required port + Assert.assertFalse("Matched whitelist inappropriately", RegExUtils.checkWhitelist(whitelist, + "http://host.example2.com/")); + // fail on required https + Assert.assertFalse("Matched whitelist inappropriately", RegExUtils.checkWhitelist(whitelist, + "http://host.example3.com/")); + // match on localhost and port + Assert.assertTrue("Failed to match whitelist", RegExUtils.checkWhitelist(whitelist, + "http://localhost:8080/")); + // match on local/relative path + Assert.assertTrue("Failed to match whitelist", RegExUtils.checkWhitelist(whitelist, + "/local/resource/")); + } + + @Test + public void testGetToken() throws Exception { + + ServletContext context = EasyMock.createNiceMock(ServletContext.class); + EasyMock.expect(context.getInitParameter("knoxsso.cookie.name")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.cookie.secure.only")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.cookie.max.age")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.cookie.domain.suffix")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.redirect.whitelist.regex")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.token.audiences")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.token.ttl")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.enable.session")).andReturn(null); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + EasyMock.expect(request.getParameter("originalUrl")).andReturn("http://localhost:9080/service"); + EasyMock.expect(request.getParameterMap()).andReturn(Collections.<String,String[]>emptyMap()); + EasyMock.expect(request.getServletContext()).andReturn(context).anyTimes(); + + Principal principal = EasyMock.createNiceMock(Principal.class); + EasyMock.expect(principal.getName()).andReturn("alice").anyTimes(); + EasyMock.expect(request.getUserPrincipal()).andReturn(principal).anyTimes(); + + GatewayServices services = EasyMock.createNiceMock(GatewayServices.class); + EasyMock.expect(context.getAttribute(GatewayServices.GATEWAY_SERVICES_ATTRIBUTE)).andReturn(services); + + JWTokenAuthority authority = new TestJWTokenAuthority(publicKey, privateKey); + EasyMock.expect(services.getService(GatewayServices.TOKEN_SERVICE)).andReturn(authority); + + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class); + CookieResponseWrapper responseWrapper = new CookieResponseWrapper(response, outputStream); + + EasyMock.replay(principal, services, context, request); + + WebSSOResource webSSOResponse = new WebSSOResource(); + webSSOResponse.request = request; + webSSOResponse.response = responseWrapper; + webSSOResponse.context = context; + webSSOResponse.init(); + + // Issue a token + webSSOResponse.doGet(); + + // Check the cookie + Cookie cookie = responseWrapper.getCookie("hadoop-jwt"); + assertNotNull(cookie); + + JWTToken parsedToken = new JWTToken(cookie.getValue()); + assertEquals("alice", parsedToken.getSubject()); + assertTrue(authority.verifyToken(parsedToken)); + } + + @Test + public void testAudiences() throws Exception { + + ServletContext context = EasyMock.createNiceMock(ServletContext.class); + EasyMock.expect(context.getInitParameter("knoxsso.cookie.name")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.cookie.secure.only")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.cookie.max.age")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.cookie.domain.suffix")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.redirect.whitelist.regex")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.token.audiences")).andReturn("recipient1,recipient2"); + EasyMock.expect(context.getInitParameter("knoxsso.token.ttl")).andReturn(null); + EasyMock.expect(context.getInitParameter("knoxsso.enable.session")).andReturn(null); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + EasyMock.expect(request.getParameter("originalUrl")).andReturn("http://localhost:9080/service"); + EasyMock.expect(request.getParameterMap()).andReturn(Collections.<String,String[]>emptyMap()); + EasyMock.expect(request.getServletContext()).andReturn(context).anyTimes(); + + Principal principal = EasyMock.createNiceMock(Principal.class); + EasyMock.expect(principal.getName()).andReturn("alice").anyTimes(); + EasyMock.expect(request.getUserPrincipal()).andReturn(principal).anyTimes(); ++ ++ GatewayServices services = EasyMock.createNiceMock(GatewayServices.class); ++ EasyMock.expect(context.getAttribute(GatewayServices.GATEWAY_SERVICES_ATTRIBUTE)).andReturn(services); ++ ++ JWTokenAuthority authority = new TestJWTokenAuthority(publicKey, privateKey); ++ EasyMock.expect(services.getService(GatewayServices.TOKEN_SERVICE)).andReturn(authority); ++ ++ HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); ++ ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class); ++ CookieResponseWrapper responseWrapper = new CookieResponseWrapper(response, outputStream); ++ ++ EasyMock.replay(principal, services, context, request); ++ ++ WebSSOResource webSSOResponse = new WebSSOResource(); ++ webSSOResponse.request = request; ++ webSSOResponse.response = responseWrapper; ++ webSSOResponse.context = context; ++ webSSOResponse.init(); ++ ++ // Issue a token ++ webSSOResponse.doGet(); ++ ++ // Check the cookie ++ Cookie cookie = responseWrapper.getCookie("hadoop-jwt"); ++ assertNotNull(cookie); ++ ++ JWTToken parsedToken = new JWTToken(cookie.getValue()); ++ assertEquals("alice", parsedToken.getSubject()); ++ assertTrue(authority.verifyToken(parsedToken)); ++ ++ // Verify the audiences ++ List<String> audiences = Arrays.asList(parsedToken.getAudienceClaims()); ++ assertEquals(2, audiences.size()); ++ assertTrue(audiences.contains("recipient1")); ++ assertTrue(audiences.contains("recipient2")); ++ } ++ ++ @Test ++ public void testAudiencesWhitespace() throws Exception { ++ ++ ServletContext context = EasyMock.createNiceMock(ServletContext.class); ++ EasyMock.expect(context.getInitParameter("knoxsso.cookie.name")).andReturn(null); ++ EasyMock.expect(context.getInitParameter("knoxsso.cookie.secure.only")).andReturn(null); ++ EasyMock.expect(context.getInitParameter("knoxsso.cookie.max.age")).andReturn(null); ++ EasyMock.expect(context.getInitParameter("knoxsso.cookie.domain.suffix")).andReturn(null); ++ EasyMock.expect(context.getInitParameter("knoxsso.redirect.whitelist.regex")).andReturn(null); ++ EasyMock.expect(context.getInitParameter("knoxsso.token.audiences")).andReturn(" recipient1, recipient2 "); ++ EasyMock.expect(context.getInitParameter("knoxsso.token.ttl")).andReturn(null); ++ EasyMock.expect(context.getInitParameter("knoxsso.enable.session")).andReturn(null); ++ ++ HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); ++ EasyMock.expect(request.getParameter("originalUrl")).andReturn("http://localhost:9080/service"); ++ EasyMock.expect(request.getParameterMap()).andReturn(Collections.<String,String[]>emptyMap()); ++ EasyMock.expect(request.getServletContext()).andReturn(context).anyTimes(); ++ ++ Principal principal = EasyMock.createNiceMock(Principal.class); ++ EasyMock.expect(principal.getName()).andReturn("alice").anyTimes(); ++ EasyMock.expect(request.getUserPrincipal()).andReturn(principal).anyTimes(); + + GatewayServices services = EasyMock.createNiceMock(GatewayServices.class); + EasyMock.expect(context.getAttribute(GatewayServices.GATEWAY_SERVICES_ATTRIBUTE)).andReturn(services); + + JWTokenAuthority authority = new TestJWTokenAuthority(publicKey, privateKey); + EasyMock.expect(services.getService(GatewayServices.TOKEN_SERVICE)).andReturn(authority); + + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class); + CookieResponseWrapper responseWrapper = new CookieResponseWrapper(response, outputStream); + + EasyMock.replay(principal, services, context, request); + + WebSSOResource webSSOResponse = new WebSSOResource(); + webSSOResponse.request = request; + webSSOResponse.response = responseWrapper; + webSSOResponse.context = context; + webSSOResponse.init(); + + // Issue a token + webSSOResponse.doGet(); + + // Check the cookie + Cookie cookie = responseWrapper.getCookie("hadoop-jwt"); + assertNotNull(cookie); + + JWTToken parsedToken = new JWTToken(cookie.getValue()); + assertEquals("alice", parsedToken.getSubject()); + assertTrue(authority.verifyToken(parsedToken)); + + // Verify the audiences + List<String> audiences = Arrays.asList(parsedToken.getAudienceClaims()); + assertEquals(2, audiences.size()); + assertTrue(audiences.contains("recipient1")); + assertTrue(audiences.contains("recipient2")); + } + + /** + * A wrapper for HttpServletResponseWrapper to store the cookies + */ + private static class CookieResponseWrapper extends HttpServletResponseWrapper { + + private ServletOutputStream outputStream; + private Map<String, Cookie> cookies = new HashMap<>(); + + public CookieResponseWrapper(HttpServletResponse response) { + super(response); + } + + public CookieResponseWrapper(HttpServletResponse response, ServletOutputStream outputStream) { + super(response); + this.outputStream = outputStream; + } + + @Override + public ServletOutputStream getOutputStream() { + return outputStream; + } + + @Override + public void addCookie(Cookie cookie) { + super.addCookie(cookie); + cookies.put(cookie.getName(), cookie); + } + + public Cookie getCookie(String name) { + return cookies.get(name); + } + + } + + private static class TestJWTokenAuthority implements JWTokenAuthority { + + private RSAPublicKey publicKey; + private RSAPrivateKey privateKey; + + public TestJWTokenAuthority(RSAPublicKey publicKey, RSAPrivateKey privateKey) { + this.publicKey = publicKey; + this.privateKey = privateKey; + } + + @Override + public JWT issueToken(Subject subject, String algorithm) + throws TokenServiceException { + Principal p = (Principal) subject.getPrincipals().toArray()[0]; + return issueToken(p, algorithm); + } + + @Override + public JWT issueToken(Principal p, String algorithm) + throws TokenServiceException { + return issueToken(p, null, algorithm); + } + + @Override + public JWT issueToken(Principal p, String audience, String algorithm) + throws TokenServiceException { + return issueToken(p, audience, algorithm, -1); + } + + @Override + public boolean verifyToken(JWT token) throws TokenServiceException { + JWSVerifier verifier = new RSASSAVerifier(publicKey); + return token.verify(verifier); + } + + @Override + public JWT issueToken(Principal p, String audience, String algorithm, + long expires) throws TokenServiceException { + List<String> audiences = null; + if (audience != null) { + audiences = new ArrayList<String>(); + audiences.add(audience); + } + return issueToken(p, audiences, algorithm, expires); + } + + @Override + public JWT issueToken(Principal p, List<String> audiences, String algorithm, + long expires) throws TokenServiceException { + String[] claimArray = new String[4]; + claimArray[0] = "KNOXSSO"; + claimArray[1] = p.getName(); + claimArray[2] = null; + if (expires == -1) { + claimArray[3] = null; + } else { + claimArray[3] = String.valueOf(expires); + } + + JWTToken token = null; + if ("RS256".equals(algorithm)) { + token = new JWTToken("RS256", claimArray, audiences); + JWSSigner signer = new RSASSASigner(privateKey); + token.sign(signer); + } else { + throw new TokenServiceException("Cannot issue token - Unsupported algorithm"); + } + + return token; + } + + @Override + public JWT issueToken(Principal p, String algorithm, long expiry) + throws TokenServiceException { + return issueToken(p, Collections.<String>emptyList(), algorithm, expiry); + } + + @Override + public boolean verifyToken(JWT token, RSAPublicKey publicKey) throws TokenServiceException { + JWSVerifier verifier = new RSASSAVerifier(publicKey); + return token.verify(verifier); + } + + } + +} http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenResource.java ---------------------------------------------------------------------- diff --cc gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenResource.java index 2c77bdf,0000000..1c16ab3 mode 100644,000000..100644 --- a/gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenResource.java +++ b/gateway-service-knoxtoken/src/main/java/org/apache/knox/gateway/service/knoxtoken/TokenResource.java @@@ -1,183 -1,0 +1,218 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.knox.gateway.service.knoxtoken; + +import java.io.IOException; +import java.security.Principal; ++import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.List; + +import javax.annotation.PostConstruct; +import javax.servlet.ServletContext; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.Response; +import org.apache.knox.gateway.i18n.messages.MessagesFactory; +import org.apache.knox.gateway.services.GatewayServices; +import org.apache.knox.gateway.services.security.token.JWTokenAuthority; +import org.apache.knox.gateway.services.security.token.TokenServiceException; +import org.apache.knox.gateway.services.security.token.impl.JWT; +import org.apache.knox.gateway.util.JsonUtils; + +import static javax.ws.rs.core.MediaType.APPLICATION_JSON; +import static javax.ws.rs.core.MediaType.APPLICATION_XML; + +@Path( TokenResource.RESOURCE_PATH ) +public class TokenResource { + private static final String EXPIRES_IN = "expires_in"; + private static final String TOKEN_TYPE = "token_type"; + private static final String ACCESS_TOKEN = "access_token"; + private static final String TARGET_URL = "target_url"; + private static final String BEARER = "Bearer "; + private static final String TOKEN_TTL_PARAM = "knox.token.ttl"; + private static final String TOKEN_AUDIENCES_PARAM = "knox.token.audiences"; + private static final String TOKEN_TARGET_URL = "knox.token.target.url"; + private static final String TOKEN_CLIENT_DATA = "knox.token.client.data"; ++ private static final String TOKEN_CLIENT_CERT_REQUIRED = "knox.token.client.cert.required"; ++ private static final String TOKEN_ALLOWED_PRINCIPALS = "knox.token.allowed.principals"; + static final String RESOURCE_PATH = "knoxtoken/api/v1/token"; + private static TokenServiceMessages log = MessagesFactory.get( TokenServiceMessages.class ); + private long tokenTTL = 30000l; + private List<String> targetAudiences = new ArrayList<>(); + private String tokenTargetUrl = null; + private Map<String,Object> tokenClientDataMap = null; ++ private ArrayList<String> allowedDNs = new ArrayList<>(); ++ private boolean clientCertRequired = false; + + @Context + HttpServletRequest request; + + @Context + HttpServletResponse response; + + @Context + ServletContext context; + + @PostConstruct + public void init() { + + String audiences = context.getInitParameter(TOKEN_AUDIENCES_PARAM); + if (audiences != null) { + String[] auds = audiences.split(","); + for (int i = 0; i < auds.length; i++) { - targetAudiences.add(auds[i]); ++ targetAudiences.add(auds[i].trim()); ++ } ++ } ++ ++ String clientCert = context.getInitParameter(TOKEN_CLIENT_CERT_REQUIRED); ++ clientCertRequired = "true".equals(clientCert); ++ ++ String principals = context.getInitParameter(TOKEN_ALLOWED_PRINCIPALS); ++ if (principals != null) { ++ String[] dns = principals.split(";"); ++ for (int i = 0; i < dns.length; i++) { ++ allowedDNs.add(dns[i]); + } + } + + String ttl = context.getInitParameter(TOKEN_TTL_PARAM); + if (ttl != null) { + try { + tokenTTL = Long.parseLong(ttl); + } + catch (NumberFormatException nfe) { + log.invalidTokenTTLEncountered(ttl); + } + } + + tokenTargetUrl = context.getInitParameter(TOKEN_TARGET_URL); + + String clientData = context.getInitParameter(TOKEN_CLIENT_DATA); + if (clientData != null) { + tokenClientDataMap = new HashMap<>(); + String[] tokenClientData = clientData.split(","); + addClientDataToMap(tokenClientData, tokenClientDataMap); + } + } + + @GET + @Produces({APPLICATION_JSON, APPLICATION_XML}) + public Response doGet() { + return getAuthenticationToken(); + } + + @POST + @Produces({APPLICATION_JSON, APPLICATION_XML}) + public Response doPost() { + return getAuthenticationToken(); + } + ++ private X509Certificate extractCertificate(HttpServletRequest req) { ++ X509Certificate[] certs = (X509Certificate[]) req.getAttribute("javax.servlet.request.X509Certificate"); ++ if (null != certs && certs.length > 0) { ++ return certs[0]; ++ } ++ return null; ++ } ++ + private Response getAuthenticationToken() { ++ if (clientCertRequired) { ++ X509Certificate cert = extractCertificate(request); ++ if (cert != null) { ++ if (!allowedDNs.contains(cert.getSubjectDN().getName())) { ++ return Response.status(403).entity("{ \"Unable to get token - untrusted client cert.\" }").build(); ++ } ++ } ++ else { ++ return Response.status(403).entity("{ \"Unable to get token - client cert required.\" }").build(); ++ } ++ } + GatewayServices services = (GatewayServices) request.getServletContext() + .getAttribute(GatewayServices.GATEWAY_SERVICES_ATTRIBUTE); + + JWTokenAuthority ts = services.getService(GatewayServices.TOKEN_SERVICE); + Principal p = ((HttpServletRequest)request).getUserPrincipal(); + long expires = getExpiry(); + + try { + JWT token = null; + if (targetAudiences.isEmpty()) { + token = ts.issueToken(p, "RS256", expires); + } else { + token = ts.issueToken(p, targetAudiences, "RS256", expires); + } + + if (token != null) { + String accessToken = token.toString(); + + HashMap<String, Object> map = new HashMap<>(); + map.put(ACCESS_TOKEN, accessToken); + map.put(TOKEN_TYPE, BEARER); + map.put(EXPIRES_IN, expires); + if (tokenTargetUrl != null) { + map.put(TARGET_URL, tokenTargetUrl); + } + if (tokenClientDataMap != null) { + map.putAll(tokenClientDataMap); + } + + String jsonResponse = JsonUtils.renderAsJsonString(map); + + response.getWriter().write(jsonResponse); + return Response.ok().build(); + } + else { + return Response.serverError().build(); + } + } + catch (TokenServiceException | IOException e) { + log.unableToIssueToken(e); + } + return Response.ok().entity("{ \"Unable to acquire token.\" }").build(); + } + + void addClientDataToMap(String[] tokenClientData, + Map<String,Object> map) { + String[] kv = null; + for (int i = 0; i < tokenClientData.length; i++) { + kv = tokenClientData[i].split("="); + if (kv.length == 2) { + map.put(kv[0], kv[1]); + } + } + } + + private long getExpiry() { + long expiry = 0l; + if (tokenTTL == -1) { + expiry = -1; + } + else { + expiry = System.currentTimeMillis() + tokenTTL; + } + return expiry; + } +}