Repository: flume Updated Branches: refs/heads/trunk 6f6f69b8b -> a103a6771
FLUME-2574. SSL support for Thrift RPC. (Johny Rufus via Hari) Project: http://git-wip-us.apache.org/repos/asf/flume/repo Commit: http://git-wip-us.apache.org/repos/asf/flume/commit/a103a677 Tree: http://git-wip-us.apache.org/repos/asf/flume/tree/a103a677 Diff: http://git-wip-us.apache.org/repos/asf/flume/diff/a103a677 Branch: refs/heads/trunk Commit: a103a677145a43aa6fa78dfeeb34018879e24a94 Parents: 6f6f69b Author: Hari Shreedharan <[email protected]> Authored: Wed Feb 18 21:10:56 2015 -0800 Committer: Hari Shreedharan <[email protected]> Committed: Wed Feb 18 21:10:56 2015 -0800 ---------------------------------------------------------------------- .../org/apache/flume/source/ThriftSource.java | 274 +++++++++++++------ .../org/apache/flume/sink/TestThriftSink.java | 139 +++++++++- .../apache/flume/source/TestThriftSource.java | 37 +++ .../src/test/resources/keystorefile.jks | Bin 0 -> 1294 bytes .../src/test/resources/truststorefile.jks | Bin 0 -> 887 bytes .../org/apache/flume/api/ThriftRpcClient.java | 113 +++++++- .../apache/flume/api/ThriftTestingSource.java | 63 ++++- 7 files changed, 527 insertions(+), 99 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flume/blob/a103a677/flume-ng-core/src/main/java/org/apache/flume/source/ThriftSource.java ---------------------------------------------------------------------- diff --git a/flume-ng-core/src/main/java/org/apache/flume/source/ThriftSource.java b/flume-ng-core/src/main/java/org/apache/flume/source/ThriftSource.java index 551fe13..06bb604 100644 --- a/flume-ng-core/src/main/java/org/apache/flume/source/ThriftSource.java +++ b/flume-ng-core/src/main/java/org/apache/flume/source/ThriftSource.java @@ -35,18 +35,30 @@ import org.apache.flume.thrift.ThriftFlumeEvent; import org.apache.thrift.TException; import org.apache.thrift.protocol.TCompactProtocol; import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.server.TNonblockingServer; import org.apache.thrift.server.TServer; +import org.apache.thrift.server.TThreadPoolServer; import org.apache.thrift.transport.TFastFramedTransport; import org.apache.thrift.transport.TNonblockingServerSocket; import org.apache.thrift.transport.TNonblockingServerTransport; import org.apache.thrift.transport.TServerSocket; import org.apache.thrift.transport.TServerTransport; +import org.apache.thrift.transport.TSSLTransportFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLServerSocket; +import java.io.FileInputStream; import java.lang.reflect.Method; +import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.security.KeyStore; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedList; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -77,15 +89,27 @@ public class ThriftSource extends AbstractSource implements Configurable, public static final String CONFIG_PROTOCOL = "protocol"; public static final String BINARY_PROTOCOL = "binary"; public static final String COMPACT_PROTOCOL = "compact"; - + + private static final String SSL_KEY = "ssl"; + private static final String KEYSTORE_KEY = "keystore"; + private static final String KEYSTORE_PASSWORD_KEY = "keystore-password"; + private static final String KEYSTORE_TYPE_KEY = "keystore-type"; + private static final String KEYMANAGER_TYPE = "keymanager-type"; + private static final String EXCLUDE_PROTOCOLS = "exclude-protocols"; + private Integer port; private String bindAddress; private int maxThreads = 0; private SourceCounter sourceCounter; private TServer server; - private TServerTransport serverTransport; private ExecutorService servingExecutor; private String protocol; + private String keystore; + private String keystorePassword; + private String keystoreType; + private String keyManagerType; + private final List<String> excludeProtocols = new LinkedList<String>(); + private boolean enableSsl = false; @Override public void configure(Context context) { @@ -99,6 +123,7 @@ public class ThriftSource extends AbstractSource implements Configurable, try { maxThreads = context.getInteger(CONFIG_THREADS, 0); + maxThreads = (maxThreads <= 0) ? Integer.MAX_VALUE : maxThreads; } catch (NumberFormatException e) { logger.warn("Thrift source\'s \"threads\" property must specify an " + "integer value: " + context.getString(CONFIG_THREADS)); @@ -107,111 +132,58 @@ public class ThriftSource extends AbstractSource implements Configurable, if (sourceCounter == null) { sourceCounter = new SourceCounter(getName()); } - + protocol = context.getString(CONFIG_PROTOCOL); if (protocol == null) { // default is to use the compact protocol. protocol = COMPACT_PROTOCOL; - } + } Preconditions.checkArgument( (protocol.equalsIgnoreCase(BINARY_PROTOCOL) || - protocol.equalsIgnoreCase(COMPACT_PROTOCOL)), + protocol.equalsIgnoreCase(COMPACT_PROTOCOL)), "binary or compact are the only valid Thrift protocol types to " + - "choose from."); - } - - @Override - public void start() { - logger.info("Starting thrift source"); - - maxThreads = (maxThreads <= 0) ? Integer.MAX_VALUE : maxThreads; - Class<?> serverClass = null; - Class<?> argsClass = null; - TServer.AbstractServerArgs args = null; - /* - * Use reflection to determine if TThreadedSelectServer is available. If - * it is not available, use TThreadPoolServer - */ - try { - serverClass = Class.forName("org.apache.thrift" + - ".server.TThreadedSelectorServer"); + "choose from."); - argsClass = Class.forName("org.apache.thrift" + - ".server.TThreadedSelectorServer$Args"); - - // Looks like TThreadedSelectorServer is available, so continue.. - ExecutorService sourceService; - ThreadFactory threadFactory = new ThreadFactoryBuilder().setNameFormat( - "Flume Thrift IPC Thread %d").build(); - if (maxThreads == 0) { - sourceService = Executors.newCachedThreadPool(threadFactory); + enableSsl = context.getBoolean(SSL_KEY, false); + if (enableSsl) { + keystore = context.getString(KEYSTORE_KEY); + keystorePassword = context.getString(KEYSTORE_PASSWORD_KEY); + keystoreType = context.getString(KEYSTORE_TYPE_KEY, "JKS"); + keyManagerType = context.getString(KEYMANAGER_TYPE, KeyManagerFactory.getDefaultAlgorithm()); + String excludeProtocolsStr = context.getString(EXCLUDE_PROTOCOLS); + if (excludeProtocolsStr == null) { + excludeProtocols.add("SSLv3"); } else { - sourceService = Executors.newFixedThreadPool(maxThreads, threadFactory); + excludeProtocols.addAll(Arrays.asList(excludeProtocolsStr.split(" "))); + if (!excludeProtocols.contains("SSLv3")) { + excludeProtocols.add("SSLv3"); + } } - serverTransport = new TNonblockingServerSocket( - new InetSocketAddress(bindAddress, port)); - args = (TNonblockingServer.AbstractNonblockingServerArgs) argsClass - .getConstructor(TNonblockingServerTransport.class) - .newInstance(serverTransport); - Method m = argsClass.getDeclaredMethod("executorService", - ExecutorService.class); - m.invoke(args, sourceService); - } catch (ClassNotFoundException e) { - logger.info("TThreadedSelectorServer not found, " + - "using TThreadPoolServer"); + Preconditions.checkNotNull(keystore, + KEYSTORE_KEY + " must be specified when SSL is enabled"); + Preconditions.checkNotNull(keystorePassword, + KEYSTORE_PASSWORD_KEY + " must be specified when SSL is enabled"); try { - // Looks like TThreadedSelectorServer is not available, - // so create a TThreadPoolServer instead. - - serverTransport = new TServerSocket(new InetSocketAddress - (bindAddress, port)); - - serverClass = Class.forName("org.apache.thrift" + - ".server.TThreadPoolServer"); - argsClass = Class.forName("org.apache.thrift.server" + - ".TThreadPoolServer$Args"); - args = (TServer.AbstractServerArgs) argsClass - .getConstructor(TServerTransport.class) - .newInstance(serverTransport); - Method m = argsClass.getDeclaredMethod("maxWorkerThreads",int.class); - m.invoke(args, maxThreads); - } catch (ClassNotFoundException e1) { - throw new FlumeException("Cannot find TThreadSelectorServer or " + - "TThreadPoolServer. Please install a compatible version of thrift " + - "in the classpath", e1); - } catch (Throwable throwable) { - throw new FlumeException("Cannot start Thrift source.", throwable); + KeyStore ks = KeyStore.getInstance(keystoreType); + ks.load(new FileInputStream(keystore), keystorePassword.toCharArray()); + } catch (Exception ex) { + throw new FlumeException( + "Thrift source configured with invalid keystore: " + keystore, ex); } - } catch (Throwable throwable) { - throw new FlumeException("Cannot start Thrift source.", throwable); } + } - try { - if (protocol.equals(BINARY_PROTOCOL)) { - logger.info("Using TBinaryProtocol"); - args.protocolFactory(new TBinaryProtocol.Factory()); - } else { - logger.info("Using TCompactProtocol"); - args.protocolFactory(new TCompactProtocol.Factory()); - } - args.inputTransportFactory(new TFastFramedTransport.Factory()); - args.outputTransportFactory(new TFastFramedTransport.Factory()); - args.processor(new ThriftSourceProtocol - .Processor<ThriftSourceHandler>(new ThriftSourceHandler())); - /* - * Both THsHaServer and TThreadedSelectorServer allows us to pass in - * the executor service to use - unfortunately the "executorService" - * method does not exist in the parent abstract Args class, - * so use reflection to pass the executor in. - * - */ + @Override + public void start() { + logger.info("Starting thrift source"); - server = (TServer) serverClass.getConstructor(argsClass).newInstance - (args); - } catch (Throwable ex) { - throw new FlumeException("Cannot start Thrift Source.", ex); - } + // create the server + server = getTThreadedSelectorServer(); + // if in ssl mode or if SelectorServer is unavailable + if (server == null) { + server = getTThreadPoolServer(); + } servingExecutor = Executors.newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("Flume Thrift Source I/O Boss") @@ -245,6 +217,126 @@ public class ThriftSource extends AbstractSource implements Configurable, super.start(); } + private TServerTransport getSSLServerTransport() { + try { + TServerTransport transport; + TSSLTransportFactory.TSSLTransportParameters params = + new TSSLTransportFactory.TSSLTransportParameters(); + params.setKeyStore(keystore, keystorePassword, keyManagerType, keystoreType); + transport = TSSLTransportFactory.getServerSocket( + port, 120000, InetAddress.getByName(bindAddress), params); + + ServerSocket serverSock = ((TServerSocket) transport).getServerSocket(); + if (serverSock instanceof SSLServerSocket) { + SSLServerSocket sslServerSock = (SSLServerSocket) serverSock; + List<String> enabledProtocols = new ArrayList<String>(); + for (String protocol : sslServerSock.getEnabledProtocols()) { + if (!excludeProtocols.contains(protocol)) { + enabledProtocols.add(protocol); + } + } + sslServerSock.setEnabledProtocols(enabledProtocols.toArray(new String[0])); + } + return transport; + } catch (Throwable throwable) { + throw new FlumeException("Cannot start Thrift source.", throwable); + } + } + + private TServerTransport getTServerTransport() { + try { + return new TServerSocket(new InetSocketAddress + (bindAddress, port)); + } catch (Throwable throwable) { + throw new FlumeException("Cannot start Thrift source.", throwable); + } + } + + private TProtocolFactory getProtocolFactory() { + if (protocol.equals(BINARY_PROTOCOL)) { + logger.info("Using TBinaryProtocol"); + return new TBinaryProtocol.Factory(); + } else { + logger.info("Using TCompactProtocol"); + return new TCompactProtocol.Factory(); + } + } + + private TServer getTThreadedSelectorServer() { + if(enableSsl) { + return null; + } + Class<?> serverClass; + Class<?> argsClass; + TServer.AbstractServerArgs args; + try { + serverClass = Class.forName("org.apache.thrift" + + ".server.TThreadedSelectorServer"); + argsClass = Class.forName("org.apache.thrift" + + ".server.TThreadedSelectorServer$Args"); + + TServerTransport serverTransport = new TNonblockingServerSocket( + new InetSocketAddress(bindAddress, port)); + ExecutorService sourceService; + ThreadFactory threadFactory = new ThreadFactoryBuilder().setNameFormat( + "Flume Thrift IPC Thread %d").build(); + if (maxThreads == 0) { + sourceService = Executors.newCachedThreadPool(threadFactory); + } else { + sourceService = Executors.newFixedThreadPool(maxThreads, threadFactory); + } + args = (TNonblockingServer.AbstractNonblockingServerArgs) argsClass + .getConstructor(TNonblockingServerTransport.class) + .newInstance(serverTransport); + Method m = argsClass.getDeclaredMethod("executorService", + ExecutorService.class); + m.invoke(args, sourceService); + + populateServerParams(args); + + /* + * Both THsHaServer and TThreadedSelectorServer allows us to pass in + * the executor service to use - unfortunately the "executorService" + * method does not exist in the parent abstract Args class, + * so use reflection to pass the executor in. + * + */ + server = (TServer) serverClass.getConstructor(argsClass).newInstance(args); + } catch(ClassNotFoundException e) { + return null; + } catch (Throwable ex) { + throw new FlumeException("Cannot start Thrift Source.", ex); + } + return server; + } + + private TServer getTThreadPoolServer() { + TServerTransport serverTransport; + if (enableSsl) { + serverTransport = getSSLServerTransport(); + } else { + serverTransport = getTServerTransport(); + } + TThreadPoolServer.Args serverArgs = new TThreadPoolServer.Args(serverTransport); + serverArgs.maxWorkerThreads(maxThreads); + populateServerParams(serverArgs); + return new TThreadPoolServer(serverArgs); + } + + private void populateServerParams(TServer.AbstractServerArgs args) { + //populate the ProtocolFactory + args.protocolFactory(getProtocolFactory()); + + //populate the transportFactory + args.inputTransportFactory(new TFastFramedTransport.Factory()); + args.outputTransportFactory(new TFastFramedTransport.Factory()); + + // populate the Processor + args.processor(new ThriftSourceProtocol + .Processor<ThriftSourceHandler>(new ThriftSourceHandler())); + } + + @Override public void stop() { if(server != null && server.isServing()) { server.stop(); http://git-wip-us.apache.org/repos/asf/flume/blob/a103a677/flume-ng-core/src/test/java/org/apache/flume/sink/TestThriftSink.java ---------------------------------------------------------------------- diff --git a/flume-ng-core/src/test/java/org/apache/flume/sink/TestThriftSink.java b/flume-ng-core/src/test/java/org/apache/flume/sink/TestThriftSink.java index fccaede..1beec76 100644 --- a/flume-ng-core/src/test/java/org/apache/flume/sink/TestThriftSink.java +++ b/flume-ng-core/src/test/java/org/apache/flume/sink/TestThriftSink.java @@ -30,12 +30,16 @@ import org.apache.flume.api.ThriftTestingSource; import org.apache.flume.channel.MemoryChannel; import org.apache.flume.conf.Configurables; import org.apache.flume.event.EventBuilder; -import org.apache.flume.source.ThriftSource; +import org.apache.flume.lifecycle.LifecycleController; +import org.apache.flume.lifecycle.LifecycleState; + import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.TrustManagerFactory; import java.nio.charset.Charset; import java.util.Random; import java.util.concurrent.atomic.AtomicLong; @@ -195,4 +199,137 @@ public class TestThriftSink { Assert.assertEquals(Sink.Status.BACKOFF, sink.process()); sink.stop(); } + + @Test + public void testSslProcess() throws Exception { + Event event = EventBuilder.withBody("test event 1", Charsets.UTF_8); + src = new ThriftTestingSource(ThriftTestingSource.HandlerType.OK.name(), port, + ThriftRpcClient.COMPACT_PROTOCOL, "src/test/resources/keystorefile.jks", + "password", KeyManagerFactory.getDefaultAlgorithm(), "JKS"); + Context context = new Context(); + context.put("hostname", hostname); + context.put("port", String.valueOf(port)); + context.put("ssl", String.valueOf(true)); + context.put("batch-size", String.valueOf(2)); + context.put("connect-timeout", String.valueOf(2000L)); + context.put("request-timeout", String.valueOf(3000L)); + context.put("truststore", "src/test/resources/truststorefile.jks"); + context.put("truststore-password", "password"); + context.put("trustmanager-type", TrustManagerFactory.getDefaultAlgorithm()); + + Configurables.configure(sink, context); + channel.start(); + sink.start(); + Transaction transaction = channel.getTransaction(); + transaction.begin(); + for (int i = 0; i < 11; i++) { + channel.put(event); + } + transaction.commit(); + transaction.close(); + for (int i = 0; i < 6; i++) { + Sink.Status status = sink.process(); + Assert.assertEquals(Sink.Status.READY, status); + } + Assert.assertEquals(Sink.Status.BACKOFF, sink.process()); + + sink.stop(); + Assert.assertEquals(11, src.flumeEvents.size()); + Assert.assertEquals(6, src.batchCount); + Assert.assertEquals(0, src.individualCount); + } + + @Test + public void testSslSinkWithNonSslServer() throws Exception { + Event event = EventBuilder.withBody("test event 1", Charsets.UTF_8); + src = new ThriftTestingSource(ThriftTestingSource.HandlerType.OK.name(), + port, ThriftRpcClient.COMPACT_PROTOCOL); + + Context context = new Context(); + context.put("hostname", hostname); + context.put("port", String.valueOf(port)); + context.put("ssl", String.valueOf(true)); + context.put("batch-size", String.valueOf(2)); + context.put("connect-timeout", String.valueOf(2000L)); + context.put("request-timeout", String.valueOf(3000L)); + context.put("truststore", "src/test/resources/truststorefile.jks"); + context.put("truststore-password", "password"); + context.put("trustmanager-type", TrustManagerFactory.getDefaultAlgorithm()); + + Configurables.configure(sink, context); + channel.start(); + sink.start(); + Assert.assertTrue(LifecycleController.waitForOneOf(sink, + LifecycleState.START_OR_ERROR, 5000)); + Transaction transaction = channel.getTransaction(); + transaction.begin(); + for (int i = 0; i < 11; i++) { + channel.put(event); + } + transaction.commit(); + transaction.close(); + + boolean failed = false; + try { + for (int i = 0; i < 6; i++) { + Sink.Status status = sink.process(); + failed = true; + } + } catch (EventDeliveryException ex) { + // This is correct + } + + sink.stop(); + Assert.assertTrue(LifecycleController.waitForOneOf(sink, + LifecycleState.STOP_OR_ERROR, 5000)); + if (failed) { + Assert.fail("SSL-enabled sink successfully connected to a non-SSL-enabled server, that's wrong."); + } + } + + @Test + public void testSslSinkWithNonTrustedCert() throws Exception { + Event event = EventBuilder.withBody("test event 1", Charsets.UTF_8); + src = new ThriftTestingSource(ThriftTestingSource.HandlerType.OK.name(), port, + ThriftRpcClient.COMPACT_PROTOCOL, "src/test/resources/keystorefile.jks", + "password", KeyManagerFactory.getDefaultAlgorithm(), "JKS"); + + Context context = new Context(); + context.put("hostname", hostname); + context.put("port", String.valueOf(port)); + context.put("ssl", String.valueOf(true)); + context.put("batch-size", String.valueOf(2)); + context.put("connect-timeout", String.valueOf(2000L)); + context.put("request-timeout", String.valueOf(3000L)); + + Configurables.configure(sink, context); + channel.start(); + sink.start(); + Assert.assertTrue(LifecycleController.waitForOneOf(sink, + LifecycleState.START_OR_ERROR, 5000)); + Transaction transaction = channel.getTransaction(); + transaction.begin(); + for (int i = 0; i < 11; i++) { + channel.put(event); + } + transaction.commit(); + transaction.close(); + + boolean failed = false; + try { + for (int i = 0; i < 6; i++) { + Sink.Status status = sink.process(); + failed = true; + } + } catch (EventDeliveryException ex) { + // This is correct + } + + sink.stop(); + Assert.assertTrue(LifecycleController.waitForOneOf(sink, + LifecycleState.STOP_OR_ERROR, 5000)); + if (failed) { + Assert.fail("SSL-enabled sink successfully connected to a server with an untrusted certificate when it should have failed"); + } + } } http://git-wip-us.apache.org/repos/asf/flume/blob/a103a677/flume-ng-core/src/test/java/org/apache/flume/source/TestThriftSource.java ---------------------------------------------------------------------- diff --git a/flume-ng-core/src/test/java/org/apache/flume/source/TestThriftSource.java b/flume-ng-core/src/test/java/org/apache/flume/source/TestThriftSource.java index 357965f..8b9fa23 100644 --- a/flume-ng-core/src/test/java/org/apache/flume/source/TestThriftSource.java +++ b/flume-ng-core/src/test/java/org/apache/flume/source/TestThriftSource.java @@ -35,11 +35,14 @@ import org.apache.flume.channel.MemoryChannel; import org.apache.flume.channel.ReplicatingChannelSelector; import org.apache.flume.conf.Configurables; import org.apache.flume.event.EventBuilder; + import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.KeyManagerFactory; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -87,6 +90,40 @@ public class TestThriftSource { source.setChannelProcessor(new ChannelProcessor(rcs)); } + @Test + public void testAppendSSL() throws Exception { + Properties sslprops = (Properties)props.clone(); + sslprops.put("ssl", "true"); + sslprops.put("truststore", "src/test/resources/truststorefile.jks"); + sslprops.put("truststore-password", "password"); + sslprops.put("trustmanager-type", TrustManagerFactory.getDefaultAlgorithm()); + client = RpcClientFactory.getThriftInstance(sslprops); + + Context context = new Context(); + channel.configure(context); + configureSource(); + context.put(ThriftSource.CONFIG_BIND, "0.0.0.0"); + context.put(ThriftSource.CONFIG_PORT, String.valueOf(port)); + context.put("ssl", "true"); + context.put("keystore", "src/test/resources/keystorefile.jks"); + context.put("keystore-password", "password"); + context.put("keymanager-type", KeyManagerFactory.getDefaultAlgorithm()); + Configurables.configure(source, context); + source.start(); + for(int i = 0; i < 30; i++) { + client.append(EventBuilder.withBody(String.valueOf(i).getBytes())); + } + Transaction transaction = channel.getTransaction(); + transaction.begin(); + + for (int i = 0; i < 30; i++) { + Event event = channel.take(); + Assert.assertNotNull(event); + Assert.assertEquals(String.valueOf(i), new String(event.getBody())); + } + transaction.commit(); + transaction.close(); + } @Test public void testAppend() throws Exception { http://git-wip-us.apache.org/repos/asf/flume/blob/a103a677/flume-ng-core/src/test/resources/keystorefile.jks ---------------------------------------------------------------------- diff --git a/flume-ng-core/src/test/resources/keystorefile.jks b/flume-ng-core/src/test/resources/keystorefile.jks new file mode 100644 index 0000000..20ac6a8 Binary files /dev/null and b/flume-ng-core/src/test/resources/keystorefile.jks differ http://git-wip-us.apache.org/repos/asf/flume/blob/a103a677/flume-ng-core/src/test/resources/truststorefile.jks ---------------------------------------------------------------------- diff --git a/flume-ng-core/src/test/resources/truststorefile.jks b/flume-ng-core/src/test/resources/truststorefile.jks new file mode 100644 index 0000000..a98c490 Binary files /dev/null and b/flume-ng-core/src/test/resources/truststorefile.jks differ http://git-wip-us.apache.org/repos/asf/flume/blob/a103a677/flume-ng-sdk/src/main/java/org/apache/flume/api/ThriftRpcClient.java ---------------------------------------------------------------------- diff --git a/flume-ng-sdk/src/main/java/org/apache/flume/api/ThriftRpcClient.java b/flume-ng-sdk/src/main/java/org/apache/flume/api/ThriftRpcClient.java index 6382a0e..4f75a2b 100644 --- a/flume-ng-sdk/src/main/java/org/apache/flume/api/ThriftRpcClient.java +++ b/flume-ng-sdk/src/main/java/org/apache/flume/api/ThriftRpcClient.java @@ -31,7 +31,15 @@ import org.apache.thrift.transport.TSocket; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +import java.io.FileInputStream; import java.nio.ByteBuffer; +import java.security.KeyStore; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; @@ -41,6 +49,7 @@ import java.util.Properties; import java.util.Queue; import java.util.Random; import java.util.Set; +import java.util.Arrays; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -78,6 +87,15 @@ public class ThriftRpcClient extends AbstractRpcClient { private final Random random = new Random(); private String protocol; + private boolean enableSsl; + private String truststore; + private String truststorePassword; + private String truststoreType; + private String trustManagerType; + private static final String TRUSTMANAGER_TYPE = "trustmanager-type"; + private final List<String> excludeProtocols = new LinkedList<String>(); + + public ThriftRpcClient() { stateLock = new ReentrantLock(true); connState = State.INIT; @@ -311,6 +329,29 @@ public class ThriftRpcClient extends AbstractRpcClient { connectionPoolSize = RpcClientConfigurationConstants .DEFAULT_CONNECTION_POOL_SIZE; } + + enableSsl = Boolean.parseBoolean(properties.getProperty( + RpcClientConfigurationConstants.CONFIG_SSL)); + if(enableSsl) { + truststore = properties.getProperty( + RpcClientConfigurationConstants.CONFIG_TRUSTSTORE); + truststorePassword = properties.getProperty( + RpcClientConfigurationConstants.CONFIG_TRUSTSTORE_PASSWORD); + truststoreType = properties.getProperty( + RpcClientConfigurationConstants.CONFIG_TRUSTSTORE_TYPE, "JKS"); + trustManagerType = properties.getProperty( + TRUSTMANAGER_TYPE, TrustManagerFactory.getDefaultAlgorithm()); + String excludeProtocolsStr = properties.getProperty( + RpcClientConfigurationConstants.CONFIG_EXCLUDE_PROTOCOLS); + if (excludeProtocolsStr == null) { + excludeProtocols.add("SSLv3"); + } else { + excludeProtocols.addAll(Arrays.asList(excludeProtocolsStr.split(" "))); + if (!excludeProtocols.contains("SSLv3")) { + excludeProtocols.add("SSLv3"); + } + } + } connectionManager = new ConnectionPoolManager(connectionPoolSize); connState = State.READY; } catch (Throwable ex) { @@ -341,8 +382,27 @@ public class ThriftRpcClient extends AbstractRpcClient { private final int hashCode; public ClientWrapper() throws Exception{ - transport = new TFastFramedTransport(new TSocket(hostname, port)); - transport.open(); + TSocket tsocket; + if(enableSsl) { + // JDK6's factory doesn't appear to pass the protocol onto the Socket properly so we have + // to do some magic to make sure that happens. Not an issue in JDK7 + // Lifted from thrift-0.9.1 to make the SSLContext + SSLContext sslContext = createSSLContext(truststore, truststorePassword, trustManagerType, truststoreType); + + // Create the factory from it + SSLSocketFactory sslSockFactory = sslContext.getSocketFactory(); + + // Create the TSocket from that + tsocket = createSSLSocket(sslSockFactory, hostname, port, 120000, excludeProtocols); + } else { + tsocket = new TSocket(hostname, port); + } + + transport = new TFastFramedTransport(tsocket); + // The transport is already open for SSL as part of TSSLTransportFactory.getClientSocket + if(!transport.isOpen()) { + transport.open(); + } if (protocol.equals(BINARY_PROTOCOL)) { LOGGER.info("Using TBinaryProtocol"); client = new ThriftSourceProtocol.Client(new TBinaryProtocol @@ -456,4 +516,53 @@ public class ThriftRpcClient extends AbstractRpcClient { } } } + + /** + * Lifted from ACCUMULO-3318 - Lifted from TSSLTransportFactory in Thrift-0.9.1. The method to create a client socket with an SSLContextFactory object is not visibile to us. Have to use + * SslConnectionParams instead of TSSLTransportParameters because no getters exist on TSSLTransportParameters. + * + */ + private static SSLContext createSSLContext(String truststore, + String truststorePassword, String trustManagerType, + String truststoreType) throws FlumeException { + SSLContext ctx; + try { + ctx = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = null; + if (truststore != null && truststoreType != null) { + ts = KeyStore.getInstance(truststoreType); + ts.load(new FileInputStream(truststore), truststorePassword.toCharArray()); + tmf.init(ts); + } + + tmf.init(ts); + ctx.init(null, tmf.getTrustManagers(), null); + + } catch (Exception e) { + throw new FlumeException("Error creating the transport", e); + } + return ctx; + } + + private static TSocket createSSLSocket(SSLSocketFactory factory, String host, + int port, int timeout, List<String> excludeProtocols) throws FlumeException { + try { + SSLSocket socket = (SSLSocket) factory.createSocket(host, port); + socket.setSoTimeout(timeout); + + List<String> enabledProtocols = new ArrayList<String>(); + for (String protocol : socket.getEnabledProtocols()) { + if (!excludeProtocols.contains(protocol)) { + enabledProtocols.add(protocol); + } + } + socket.setEnabledProtocols(enabledProtocols.toArray(new String[0])); + return new TSocket(socket); + } catch (Exception e) { + throw new FlumeException("Could not connect to " + host + " on port " + port, e); + } + } } http://git-wip-us.apache.org/repos/asf/flume/blob/a103a677/flume-ng-sdk/src/test/java/org/apache/flume/api/ThriftTestingSource.java ---------------------------------------------------------------------- diff --git a/flume-ng-sdk/src/test/java/org/apache/flume/api/ThriftTestingSource.java b/flume-ng-sdk/src/test/java/org/apache/flume/api/ThriftTestingSource.java index 63d2fc3..70d2c1b 100644 --- a/flume-ng-sdk/src/test/java/org/apache/flume/api/ThriftTestingSource.java +++ b/flume-ng-sdk/src/test/java/org/apache/flume/api/ThriftTestingSource.java @@ -26,15 +26,19 @@ import org.apache.flume.thrift.ThriftFlumeEvent; import org.apache.flume.thrift.ThriftSourceProtocol; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; -import org.apache.thrift.protocol.TBinaryProtocol.Factory; import org.apache.thrift.protocol.TCompactProtocol; -import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.server.THsHaServer; import org.apache.thrift.server.TServer; +import org.apache.thrift.transport.TSSLTransportFactory; import org.apache.thrift.transport.TNonblockingServerSocket; +import org.apache.thrift.transport.TServerSocket; import org.apache.thrift.transport.TNonblockingServerTransport; +import org.apache.thrift.transport.TFastFramedTransport; +import org.apache.thrift.transport.TServerTransport; +import java.lang.reflect.Method; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.util.List; import java.util.Queue; @@ -184,9 +188,7 @@ public class ThriftTestingSource { } } - public ThriftTestingSource(String handlerName, int port, String protocol) throws Exception { - TNonblockingServerTransport serverTransport = new TNonblockingServerSocket(new - InetSocketAddress("0.0.0.0", port)); + private ThriftSourceProtocol.Iface getHandler(String handlerName) { ThriftSourceProtocol.Iface handler = null; if (handlerName.equals(HandlerType.OK.name())) { handler = new ThriftOKHandler(); @@ -201,6 +203,14 @@ public class ThriftTestingSource { } else if (handlerName.equals(HandlerType.ALTERNATE.name())) { handler = new ThriftAlternateHandler(); } + return handler; + } + + public ThriftTestingSource(String handlerName, int port, String protocol) throws Exception { + TNonblockingServerTransport serverTransport = new TNonblockingServerSocket(new + InetSocketAddress("0.0.0.0", port)); + ThriftSourceProtocol.Iface handler = getHandler(handlerName); + TProtocolFactory transportProtocolFactory = null; if (protocol != null && protocol == ThriftRpcClient.BINARY_PROTOCOL) { transportProtocolFactory = new TBinaryProtocol.Factory(); @@ -219,6 +229,49 @@ public class ThriftTestingSource { }); } + public ThriftTestingSource(String handlerName, int port, + String protocol, String keystore, + String keystorePassword, String keyManagerType, + String keystoreType) throws Exception { + TSSLTransportFactory.TSSLTransportParameters params = + new TSSLTransportFactory.TSSLTransportParameters(); + params.setKeyStore(keystore, keystorePassword, keyManagerType, keystoreType); + + TServerSocket serverTransport = TSSLTransportFactory.getServerSocket( + port, 10000, InetAddress.getByName("0.0.0.0"), params); + + ThriftSourceProtocol.Iface handler = getHandler(handlerName); + + Class serverClass = Class.forName("org.apache.thrift" + + ".server.TThreadPoolServer"); + Class argsClass = Class.forName("org.apache.thrift.server" + + ".TThreadPoolServer$Args"); + TServer.AbstractServerArgs args = (TServer.AbstractServerArgs) argsClass + .getConstructor(TServerTransport.class) + .newInstance(serverTransport); + Method m = argsClass.getDeclaredMethod("maxWorkerThreads", int.class); + m.invoke(args, Integer.MAX_VALUE); + TProtocolFactory transportProtocolFactory = null; + if (protocol != null && protocol == ThriftRpcClient.BINARY_PROTOCOL) { + transportProtocolFactory = new TBinaryProtocol.Factory(); + } else { + transportProtocolFactory = new TCompactProtocol.Factory(); + } + args.protocolFactory(transportProtocolFactory); + args.inputTransportFactory(new TFastFramedTransport.Factory()); + args.outputTransportFactory(new TFastFramedTransport.Factory()); + args.processor(new ThriftSourceProtocol + .Processor<ThriftSourceProtocol.Iface>(handler)); + server = (TServer) serverClass.getConstructor(argsClass).newInstance + (args); + Executors.newSingleThreadExecutor().submit(new Runnable() { + @Override + public void run() { + server.serve(); + } + }); + } + public enum HandlerType { OK, FAIL,
