http://git-wip-us.apache.org/repos/asf/aurora/blob/06ddaadb/commons/src/main/java/org/apache/aurora/common/thrift/Thrift.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/org/apache/aurora/common/thrift/Thrift.java b/commons/src/main/java/org/apache/aurora/common/thrift/Thrift.java new file mode 100644 index 0000000..b36b46e --- /dev/null +++ b/commons/src/main/java/org/apache/aurora/common/thrift/Thrift.java @@ -0,0 +1,390 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.aurora.common.thrift; + +import java.io.IOException; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.net.InetSocketAddress; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import org.apache.aurora.common.thrift.callers.DebugCaller; +import org.apache.thrift.async.AsyncMethodCallback; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; + +import org.apache.aurora.common.base.MorePreconditions; +import org.apache.aurora.common.net.loadbalancing.RequestTracker; +import org.apache.aurora.common.net.pool.Connection; +import org.apache.aurora.common.net.pool.ObjectPool; +import org.apache.aurora.common.quantity.Amount; +import org.apache.aurora.common.quantity.Time; +import org.apache.aurora.common.stats.StatsProvider; +import org.apache.aurora.common.thrift.callers.Caller; +import org.apache.aurora.common.thrift.callers.DeadlineCaller; +import org.apache.aurora.common.thrift.callers.RetryingCaller; +import org.apache.aurora.common.thrift.callers.StatTrackingCaller; +import org.apache.aurora.common.thrift.callers.ThriftCaller; + +/** + * A generic thrift client that handles reconnection in the case of protocol errors, automatic + * retries, call deadlines and call statistics tracking. This class aims for behavior compatible + * with the <a href="http://github.com/fauna/thrift_client">generic ruby thrift client</a>. + * + * <p>In order to enforce call deadlines for synchronous clients, this class uses an + * {@link java.util.concurrent.ExecutorService}. If a custom executor is supplied, it should throw + * a subclass of {@link RejectedExecutionException} to signal thread resource exhaustion, in which + * case the client will fail fast and propagate the event as a {@link TResourceExhaustedException}. + * + * TODO(William Farner): Before open sourcing, look into changing the current model of wrapped proxies + * to use a single proxy and wrapped functions for decorators. + * + * @author John Sirois + */ +public class Thrift<T> { + + /** + * The default thrift call configuration used if none is specified. + * + * Specifies the following settings: + * <ul> + * <li>global call timeout: 1 second + * <li>call retries: 0 + * <li>retryable exceptions: TTransportException (network exceptions including socket timeouts) + * <li>wait for connections: true + * <li>debug: false + * </ul> + */ + public static final Config DEFAULT_CONFIG = Config.builder() + .withRequestTimeout(Amount.of(1L, Time.SECONDS)) + .noRetries() + .retryOn(TTransportException.class) // if maxRetries is set non-zero + .create(); + + /** + * The default thrift call configuration used for an async client if none is specified. + * + * Specifies the following settings: + * <ul> + * <li>global call timeout: none + * <li>call retries: 0 + * <li>retryable exceptions: IOException, TTransportException + * (network exceptions but not timeouts) + * <li>wait for connections: true + * <li>debug: false + * </ul> + */ + @SuppressWarnings("unchecked") + public static final Config DEFAULT_ASYNC_CONFIG = Config.builder(DEFAULT_CONFIG) + .withRequestTimeout(Amount.of(0L, Time.SECONDS)) + .noRetries() + .retryOn(ImmutableSet.<Class<? extends Exception>>builder() + .add(IOException.class) + .add(TTransportException.class).build()) // if maxRetries is set non-zero + .create(); + + private final Config defaultConfig; + private final ExecutorService executorService; + private final ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool; + private final RequestTracker<InetSocketAddress> requestTracker; + private final String serviceName; + private final Class<T> serviceInterface; + private final Function<TTransport, T> clientFactory; + private final boolean async; + private final boolean withSsl; + + /** + * Constructs an instance with the {@link #DEFAULT_CONFIG}, cached thread pool + * {@link ExecutorService}, and synchronous calls. + * + * @see #Thrift(Config, ExecutorService, ObjectPool, RequestTracker , String, Class, Function, + * boolean, boolean) + */ + public Thrift(ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool, + RequestTracker<InetSocketAddress> requestTracker, + String serviceName, Class<T> serviceInterface, Function<TTransport, T> clientFactory) { + + this(DEFAULT_CONFIG, connectionPool, requestTracker, serviceName, serviceInterface, + clientFactory, false, false); + } + + /** + * Constructs an instance with the {@link #DEFAULT_CONFIG} and cached thread pool + * {@link ExecutorService}. + * + * @see #Thrift(Config, ExecutorService, ObjectPool, RequestTracker , String, Class, Function, + * boolean, boolean) + */ + public Thrift(ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool, + RequestTracker<InetSocketAddress> requestTracker, + String serviceName, Class<T> serviceInterface, Function<TTransport, T> clientFactory, + boolean async) { + + this(getConfig(async), connectionPool, requestTracker, serviceName, + serviceInterface, clientFactory, async, false); + } + + /** + * Constructs an instance with the {@link #DEFAULT_CONFIG} and cached thread pool + * {@link ExecutorService}. + * + * @see #Thrift(Config, ExecutorService, ObjectPool, RequestTracker , String, Class, Function, + * boolean, boolean) + */ + public Thrift(ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool, + RequestTracker<InetSocketAddress> requestTracker, + String serviceName, Class<T> serviceInterface, Function<TTransport, T> clientFactory, + boolean async, boolean ssl) { + + this(getConfig(async), connectionPool, requestTracker, serviceName, + serviceInterface, clientFactory, async, ssl); + } + + /** + * Constructs an instance with a cached thread pool {@link ExecutorService}. + * + * @see #Thrift(Config, ExecutorService, ObjectPool, RequestTracker , String, Class, Function, + * boolean, boolean) + */ + public Thrift(Config config, ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool, + RequestTracker<InetSocketAddress> requestTracker, + String serviceName, Class<T> serviceInterface, Function<TTransport, T> clientFactory, + boolean async, boolean ssl) { + + this(config, + Executors.newCachedThreadPool( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Thrift["+ serviceName +"][%d]") + .build()), + connectionPool, requestTracker, serviceName, serviceInterface, clientFactory, async, ssl); + } + + /** + * Constructs an instance with the {@link #DEFAULT_CONFIG}. + * + * @see #Thrift(Config, ExecutorService, ObjectPool, RequestTracker , String, Class, Function, + * boolean, boolean) + */ + public Thrift(ExecutorService executorService, + ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool, + RequestTracker<InetSocketAddress> requestTracker, + String serviceName, Class<T> serviceInterface, Function<TTransport, T> clientFactory, + boolean async, boolean ssl) { + + this(getConfig(async), executorService, connectionPool, requestTracker, serviceName, + serviceInterface, clientFactory, async, ssl); + } + + private static Config getConfig(boolean async) { + return async ? DEFAULT_ASYNC_CONFIG : DEFAULT_CONFIG; + } + + /** + * Constructs a new Thrift factory for creating clients that make calls to a particular thrift + * service. + * + * <p>Note that the combination of {@code config} and {@code connectionPool} need to be chosen + * with care depending on usage of the generated thrift clients. In particular, if configured + * to not wait for connections, the {@code connectionPool} ought to be warmed up with a set of + * connections or else be actively building connections in the background. + * + * <p>TODO(John Sirois): consider adding an method to ObjectPool that would allow Thrift to handle + * this case by pro-actively warming the pool. + * + * @param config the default configuration to use for all thrift calls; also the configuration all + * {@link ClientBuilder}s start with + * @param executorService for invoking calls with a specified deadline + * @param connectionPool the source for thrift connections + * @param serviceName a /vars friendly name identifying the service clients will connect to + * @param serviceInterface the thrift compiler generate interface class for the remote service + * (Iface) + * @param clientFactory a function that can generate a concrete thrift client for the given + * {@code serviceInterface} + * @param async enable asynchronous API + * @param ssl enable TLS handshaking for Thrift calls + */ + public Thrift(Config config, ExecutorService executorService, + ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool, + RequestTracker<InetSocketAddress> requestTracker, String serviceName, + Class<T> serviceInterface, Function<TTransport, T> clientFactory, boolean async, boolean ssl) { + + defaultConfig = Preconditions.checkNotNull(config); + this.executorService = Preconditions.checkNotNull(executorService); + this.connectionPool = Preconditions.checkNotNull(connectionPool); + this.requestTracker = Preconditions.checkNotNull(requestTracker); + this.serviceName = MorePreconditions.checkNotBlank(serviceName); + this.serviceInterface = checkServiceInterface(serviceInterface); + this.clientFactory = Preconditions.checkNotNull(clientFactory); + this.async = async; + this.withSsl = ssl; + } + + static <I> Class<I> checkServiceInterface(Class<I> serviceInterface) { + Preconditions.checkNotNull(serviceInterface); + Preconditions.checkArgument(serviceInterface.isInterface(), + "%s must be a thrift service interface", serviceInterface); + return serviceInterface; + } + + /** + * Closes any open connections and prepares this thrift client for graceful shutdown. Any thrift + * client proxies returned from {@link #create()} will become invalid. + */ + public void close() { + connectionPool.close(); + executorService.shutdown(); + } + + /** + * A builder class that allows modifications of call behavior to be made for a given Thrift + * client. Note that in the case of conflicting configuration calls, the last call wins. So, + * for example, the following sequence would result in all calls being subject to a 5 second + * global deadline: + * <code> + * builder.blocking().withDeadline(5, TimeUnit.SECONDS).create() + * </code> + * + * @see Config + */ + public final class ClientBuilder extends Config.AbstractBuilder<ClientBuilder> { + private ClientBuilder(Config template) { + super(template); + } + + @Override + protected ClientBuilder getThis() { + return this; + } + + /** + * Creates a new client using the built up configuration changes. + */ + public T create() { + return createClient(getConfig()); + } + } + + /** + * Creates a new thrift client builder that inherits this Thrift instance's default configuration. + * This is useful for customizing a client for a particular thrift call that makes sense to treat + * differently from the rest of the calls to a given service. + */ + public ClientBuilder builder() { + return builder(defaultConfig); + } + + /** + * Creates a new thrift client builder that inherits the given configuration. + * This is useful for customizing a client for a particular thrift call that makes sense to treat + * differently from the rest of the calls to a given service. + */ + public ClientBuilder builder(Config config) { + Preconditions.checkNotNull(config); + return new ClientBuilder(config); + } + + /** + * Creates a new client using the default configuration specified for this Thrift instance. + */ + public T create() { + return createClient(defaultConfig); + } + + private T createClient(Config config) { + StatsProvider statsProvider = config.getStatsProvider(); + + // lease/call/[invalidate]/release + boolean debug = config.isDebug(); + + Caller decorated = new ThriftCaller<T>(connectionPool, requestTracker, clientFactory, + config.getConnectTimeout(), debug); + + // [retry] + if (config.getMaxRetries() > 0) { + decorated = new RetryingCaller(decorated, async, statsProvider, serviceName, + config.getMaxRetries(), config.getRetryableExceptions(), debug); + } + + // [deadline] + if (config.getRequestTimeout().getValue() > 0) { + Preconditions.checkArgument(!async, + "Request deadlines may not be used with an asynchronous client."); + + decorated = new DeadlineCaller(decorated, async, executorService, config.getRequestTimeout()); + } + + // [debug] + if (debug) { + decorated = new DebugCaller(decorated, async); + } + + // stats + if (config.enableStats()) { + decorated = new StatTrackingCaller(decorated, async, statsProvider, serviceName); + } + + final Caller caller = decorated; + + final InvocationHandler invocationHandler = new InvocationHandler() { + @Override + public Object invoke(Object o, Method method, Object[] args) throws Throwable { + AsyncMethodCallback callback = null; + if (args != null && async) { + List<Object> argsList = Lists.newArrayList(args); + callback = extractCallback(argsList); + args = argsList.toArray(); + } + + return caller.call(method, args, callback, null); + } + }; + + @SuppressWarnings("unchecked") + T instance = (T) Proxy.newProxyInstance(serviceInterface.getClassLoader(), + new Class<?>[] {serviceInterface}, invocationHandler); + return instance; + } + + /** + * Verifies that the final argument in a list of objects is a fully-formed + * {@link AsyncMethodCallback} and extracts it, removing it from the argument list. + * + * @param args Argument list to remove the callback from. + * @return The callback extracted from {@code args}. + */ + private static AsyncMethodCallback extractCallback(List<Object> args) { + // TODO(William Farner): Check all interface methods when building the Thrift client + // and verify that last arguments are all callbacks...this saves us from checking + // each time. + + // Check that the last argument is a callback. + Preconditions.checkArgument(args.size() > 0); + Object lastArg = args.get(args.size() - 1); + Preconditions.checkArgument(lastArg instanceof AsyncMethodCallback, + "Last argument of an async thrift call is expected to be of type AsyncMethodCallback."); + + return (AsyncMethodCallback) args.remove(args.size() - 1); + } +}
http://git-wip-us.apache.org/repos/asf/aurora/blob/06ddaadb/commons/src/main/java/org/apache/aurora/common/thrift/ThriftConnectionFactory.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/org/apache/aurora/common/thrift/ThriftConnectionFactory.java b/commons/src/main/java/org/apache/aurora/common/thrift/ThriftConnectionFactory.java new file mode 100644 index 0000000..8c302d3 --- /dev/null +++ b/commons/src/main/java/org/apache/aurora/common/thrift/ThriftConnectionFactory.java @@ -0,0 +1,366 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.aurora.common.thrift; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.apache.aurora.common.base.Closure; +import org.apache.aurora.common.base.Closures; +import org.apache.aurora.common.base.MorePreconditions; +import org.apache.aurora.common.net.pool.Connection; +import org.apache.aurora.common.net.pool.ConnectionFactory; +import org.apache.aurora.common.quantity.Amount; +import org.apache.aurora.common.quantity.Time; +import org.apache.thrift.transport.TFramedTransport; +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TSocket; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +/** + * A connection factory for thrift transport connections to a given host. This connection factory + * is lazy and will only create a configured maximum number of active connections - where a + * {@link ConnectionFactory#create(Amount) created} connection that has + * not been {@link #destroy destroyed} is considered active. + * + * @author John Sirois + */ +public class ThriftConnectionFactory + implements ConnectionFactory<Connection<TTransport, InetSocketAddress>> { + + public enum TransportType { + BLOCKING, FRAMED, NONBLOCKING; + + /** + * Async clients implicitly use a framed transport, requiring the server they connect to to do + * the same. This prevents specifying a nonblocking client without a framed transport, since + * that is not compatible with thrift and would simply cause the client to blow up when making a + * request. Instead, you must explicitly say useFramedTransport(true) for any buildAsync(). + */ + public static TransportType get(boolean framedTransport, boolean nonblocking) { + if (nonblocking) { + Preconditions.checkArgument(framedTransport, + "nonblocking client requires a server running framed transport"); + return NONBLOCKING; + } + + return framedTransport ? FRAMED : BLOCKING; + } + } + + private static InetSocketAddress asEndpoint(String host, int port) { + MorePreconditions.checkNotBlank(host); + Preconditions.checkArgument(port > 0); + return InetSocketAddress.createUnresolved(host, port); + } + + private InetSocketAddress endpoint; + private final int maxConnections; + private final TransportType transportType; + private final Amount<Long, Time> socketTimeout; + private final Closure<Connection<TTransport, InetSocketAddress>> postCreateCallback; + private boolean sslTransport = false; + + private final Set<Connection<TTransport, InetSocketAddress>> activeConnections = + Sets.newSetFromMap( + Maps.<Connection<TTransport, InetSocketAddress>, Boolean>newIdentityHashMap()); + private volatile int lastActiveConnectionsSize = 0; + + private final Lock activeConnectionsWriteLock = new ReentrantLock(true); + + /** + * Creates a thrift connection factory with a plain socket (non-framed transport). + * This is the same as calling {@link #ThriftConnectionFactory(String, int, int, boolean)} with + * {@code framedTransport} set to {@code false}. + * + * @param host Host to connect to. + * @param port Port to connect on. + * @param maxConnections Maximum number of connections for this host:port. + */ + public ThriftConnectionFactory(String host, int port, int maxConnections) { + this(host, port, maxConnections, TransportType.BLOCKING); + } + + /** + * Creates a thrift connection factory. + * If {@code framedTransport} is set to {@code true}, {@link TFramedTransport} will be used, + * otherwise a raw {@link TSocket} will be used. + * + * @param host Host to connect to. + * @param port Port to connect on. + * @param maxConnections Maximum number of connections for this host:port. + * @param framedTransport Whether to use framed or blocking transport. + */ + public ThriftConnectionFactory(String host, int port, int maxConnections, + boolean framedTransport) { + + this(asEndpoint(host, port), maxConnections, TransportType.get(framedTransport, false)); + } + + /** + * Creates a thrift connection factory. + * If {@code framedTransport} is set to {@code true}, {@link TFramedTransport} will be used, + * otherwise a raw {@link TSocket} will be used. + * + * @param endpoint Endpoint to connect to. + * @param maxConnections Maximum number of connections for this host:port. + * @param framedTransport Whether to use framed or blocking transport. + */ + public ThriftConnectionFactory(InetSocketAddress endpoint, int maxConnections, + boolean framedTransport) { + + this(endpoint, maxConnections, TransportType.get(framedTransport, false)); + } + + /** + * Creates a thrift connection factory. + * If {@code framedTransport} is set to {@code true}, {@link TFramedTransport} will be used, + * otherwise a raw {@link TSocket} will be used. + * If {@code nonblocking} is set to {@code true}, {@link TNonblockingSocket} will be used, + * otherwise a raw {@link TSocket} will be used. + * Timeouts are ignored when nonblocking transport is used. + * + * @param host Host to connect to. + * @param port Port to connect on. + * @param maxConnections Maximum number of connections for this host:port. + * @param transportType Whether to use normal blocking, framed blocking, or non-blocking + * (implicitly framed) transport. + */ + public ThriftConnectionFactory(String host, int port, int maxConnections, + TransportType transportType) { + this(host, port, maxConnections, transportType, null); + } + + /** + * Creates a thrift connection factory. + * If {@code framedTransport} is set to {@code true}, {@link TFramedTransport} will be used, + * otherwise a raw {@link TSocket} will be used. + * If {@code nonblocking} is set to {@code true}, {@link TNonblockingSocket} will be used, + * otherwise a raw {@link TSocket} will be used. + * Timeouts are ignored when nonblocking transport is used. + * + * @param host Host to connect to. + * @param port Port to connect on. + * @param maxConnections Maximum number of connections for this host:port. + * @param transportType Whether to use normal blocking, framed blocking, or non-blocking + * (implicitly framed) transport. + * @param socketTimeout timeout on thrift i/o operations, or null to default to connectTimeout o + * the blocking client. + */ + public ThriftConnectionFactory(String host, int port, int maxConnections, + TransportType transportType, Amount<Long, Time> socketTimeout) { + this(asEndpoint(host, port), maxConnections, transportType, socketTimeout); + } + + public ThriftConnectionFactory(InetSocketAddress endpoint, int maxConnections, + TransportType transportType) { + this(endpoint, maxConnections, transportType, null); + } + + /** + * Creates a thrift connection factory. + * If {@code framedTransport} is set to {@code true}, {@link TFramedTransport} will be used, + * otherwise a raw {@link TSocket} will be used. + * If {@code nonblocking} is set to {@code true}, {@link TNonblockingSocket} will be used, + * otherwise a raw {@link TSocket} will be used. + * Timeouts are ignored when nonblocking transport is used. + * + * @param endpoint Endpoint to connect to. + * @param maxConnections Maximum number of connections for this host:port. + * @param transportType Whether to use normal blocking, framed blocking, or non-blocking + * (implicitly framed) transport. + * @param socketTimeout timeout on thrift i/o operations, or null to default to connectTimeout o + * the blocking client. + */ + public ThriftConnectionFactory(InetSocketAddress endpoint, int maxConnections, + TransportType transportType, Amount<Long, Time> socketTimeout) { + this(endpoint, maxConnections, transportType, socketTimeout, + Closures.<Connection<TTransport, InetSocketAddress>>noop(), false); + } + + public ThriftConnectionFactory(InetSocketAddress endpoint, int maxConnections, + TransportType transportType, Amount<Long, Time> socketTimeout, + Closure<Connection<TTransport, InetSocketAddress>> postCreateCallback, + boolean sslTransport) { + Preconditions.checkArgument(maxConnections > 0, "maxConnections must be at least 1"); + if (socketTimeout != null) { + Preconditions.checkArgument(socketTimeout.as(Time.MILLISECONDS) >= 0); + } + + this.endpoint = Preconditions.checkNotNull(endpoint); + this.maxConnections = maxConnections; + this.transportType = transportType; + this.socketTimeout = socketTimeout; + this.postCreateCallback = Preconditions.checkNotNull(postCreateCallback); + this.sslTransport = sslTransport; + } + + @Override + public boolean mightCreate() { + return lastActiveConnectionsSize < maxConnections; + } + + /** + * FIXME: shouldn't this throw TimeoutException instead of returning null + * in the timeout cases as per the ConnectionFactory.create javadoc? + */ + @Override + public Connection<TTransport, InetSocketAddress> create(Amount<Long, Time> timeout) + throws TTransportException, IOException { + + Preconditions.checkNotNull(timeout); + if (timeout.getValue() == 0) { + return create(); + } + + try { + long timeRemainingNs = timeout.as(Time.NANOSECONDS); + long start = System.nanoTime(); + if(activeConnectionsWriteLock.tryLock(timeRemainingNs, TimeUnit.NANOSECONDS)) { + try { + if (!willCreateSafe()) { + return null; + } + + timeRemainingNs -= (System.nanoTime() - start); + + return createConnection((int) TimeUnit.NANOSECONDS.toMillis(timeRemainingNs)); + } finally { + activeConnectionsWriteLock.unlock(); + } + } else { + return null; + } + } catch (InterruptedException e) { + return null; + } + } + + private Connection<TTransport, InetSocketAddress> create() + throws TTransportException, IOException { + activeConnectionsWriteLock.lock(); + try { + if (!willCreateSafe()) { + return null; + } + + return createConnection(0); + } finally { + activeConnectionsWriteLock.unlock(); + } + } + + private Connection<TTransport, InetSocketAddress> createConnection(int timeoutMillis) + throws TTransportException, IOException { + TTransport transport = createTransport(timeoutMillis); + if (transport == null) { + return null; + } + + Connection<TTransport, InetSocketAddress> connection = + new TTransportConnection(transport, endpoint); + postCreateCallback.execute(connection); + activeConnections.add(connection); + lastActiveConnectionsSize = activeConnections.size(); + return connection; + } + + private boolean willCreateSafe() { + return activeConnections.size() < maxConnections; + } + + @VisibleForTesting + TTransport createTransport(int timeoutMillis) throws TTransportException, IOException { + TSocket socket = null; + if (transportType != TransportType.NONBLOCKING) { + // can't do a nonblocking create on a blocking transport + if (timeoutMillis <= 0) { + return null; + } + + if (sslTransport) { + SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); + SSLSocket ssl_socket = (SSLSocket) factory.createSocket(endpoint.getHostName(), endpoint.getPort()); + ssl_socket.setSoTimeout(timeoutMillis); + return new TSocket(ssl_socket); + } else { + socket = new TSocket(endpoint.getHostName(), endpoint.getPort(), timeoutMillis); + } + } + + try { + switch (transportType) { + case BLOCKING: + socket.open(); + setSocketTimeout(socket); + return socket; + case FRAMED: + TFramedTransport transport = new TFramedTransport(socket); + transport.open(); + setSocketTimeout(socket); + return transport; + case NONBLOCKING: + try { + return new TNonblockingSocket(endpoint.getHostName(), endpoint.getPort()); + } catch (IOException e) { + throw new IOException("Failed to create non-blocking transport to " + endpoint, e); + } + } + } catch (TTransportException e) { + throw new TTransportException("Failed to create transport to " + endpoint, e); + } + + throw new IllegalArgumentException("unknown transport type " + transportType); + } + + private void setSocketTimeout(TSocket socket) { + if (socketTimeout != null) { + socket.setTimeout(socketTimeout.as(Time.MILLISECONDS).intValue()); + } + } + + @Override + public void destroy(Connection<TTransport, InetSocketAddress> connection) { + activeConnectionsWriteLock.lock(); + try { + boolean wasActiveConnection = activeConnections.remove(connection); + Preconditions.checkArgument(wasActiveConnection, + "connection %s not created by this factory", connection); + lastActiveConnectionsSize = activeConnections.size(); + } finally { + activeConnectionsWriteLock.unlock(); + } + + // We close the connection outside the critical section which means we may have more connections + // "active" (open) than maxConnections for a very short time + connection.close(); + } + + @Override + public String toString() { + return String.format("%s[%s]", getClass().getSimpleName(), endpoint); + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/06ddaadb/commons/src/main/java/org/apache/aurora/common/thrift/ThriftException.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/org/apache/aurora/common/thrift/ThriftException.java b/commons/src/main/java/org/apache/aurora/common/thrift/ThriftException.java new file mode 100644 index 0000000..27e9f5e --- /dev/null +++ b/commons/src/main/java/org/apache/aurora/common/thrift/ThriftException.java @@ -0,0 +1,26 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.aurora.common.thrift; + +/** + * Exception class to wrap exceptions caught during thrift calls. + */ +public class ThriftException extends Exception { + public ThriftException(String message) { + super(message); + } + public ThriftException(String message, Throwable t) { + super(message, t); + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/06ddaadb/commons/src/main/java/org/apache/aurora/common/thrift/ThriftFactory.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/org/apache/aurora/common/thrift/ThriftFactory.java b/commons/src/main/java/org/apache/aurora/common/thrift/ThriftFactory.java new file mode 100644 index 0000000..75c58e2 --- /dev/null +++ b/commons/src/main/java/org/apache/aurora/common/thrift/ThriftFactory.java @@ -0,0 +1,653 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.aurora.common.thrift; + +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.net.InetSocketAddress; +import java.util.Collection; +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Function; +import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import org.apache.thrift.async.TAsyncClient; +import org.apache.thrift.async.TAsyncClientManager; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TNonblockingTransport; +import org.apache.thrift.transport.TTransport; + +import org.apache.aurora.common.base.Closure; +import org.apache.aurora.common.base.Closures; +import org.apache.aurora.common.base.MorePreconditions; +import org.apache.aurora.common.net.loadbalancing.LeastConnectedStrategy; +import org.apache.aurora.common.net.loadbalancing.LoadBalancer; +import org.apache.aurora.common.net.loadbalancing.LoadBalancerImpl; +import org.apache.aurora.common.net.loadbalancing.LoadBalancingStrategy; +import org.apache.aurora.common.net.loadbalancing.MarkDeadStrategyWithHostCheck; +import org.apache.aurora.common.net.loadbalancing.TrafficMonitorAdapter; +import org.apache.aurora.common.net.monitoring.TrafficMonitor; +import org.apache.aurora.common.net.pool.Connection; +import org.apache.aurora.common.net.pool.ConnectionPool; +import org.apache.aurora.common.net.pool.DynamicHostSet; +import org.apache.aurora.common.net.pool.DynamicPool; +import org.apache.aurora.common.net.pool.MetaPool; +import org.apache.aurora.common.net.pool.ObjectPool; +import org.apache.aurora.common.quantity.Amount; +import org.apache.aurora.common.quantity.Time; +import org.apache.aurora.common.stats.Stats; +import org.apache.aurora.common.stats.StatsProvider; +import org.apache.aurora.common.thrift.ThriftConnectionFactory.TransportType; +import org.apache.aurora.common.util.BackoffDecider; +import org.apache.aurora.common.util.BackoffStrategy; +import org.apache.aurora.common.util.TruncatedBinaryBackoff; +import org.apache.aurora.common.util.concurrent.ForwardingExecutorService; + +/** + * A utility that provides convenience methods to build common {@link Thrift}s. + * + * The thrift factory allows you to specify parameters that define how the client connects to + * and communicates with servers, such as the transport type, connection settings, and load + * balancing. Request-level settings like sync/async and retries should be set on the + * {@link Thrift} instance that this factory will create. + * + * The factory will attempt to provide reasonable defaults to allow the caller to minimize the + * amount of necessary configuration. Currently, the default behavior includes: + * + * <ul> + * <li> A test lease/release for each host will be performed every second + * {@link #withDeadConnectionRestoreInterval(Amount)} + * <li> At most 50 connections will be established to each host + * {@link #withMaxConnectionsPerEndpoint(int)} + * <li> Unframed transport {@link #useFramedTransport(boolean)} + * <li> A load balancing strategy that will mark hosts dead and prefer least-connected hosts. + * Hosts are marked dead if the most recent connection attempt was a failure or else based on + * the windowed error rate of attempted RPCs. If the error rate for a connected host exceeds + * 20% over the last second, the host will be disabled for 2 seconds ascending up to 10 seconds + * if the elevated error rate persists. + * {@link #withLoadBalancingStrategy(LoadBalancingStrategy)} + * <li> Statistics are reported through {@link Stats} + * {@link #withStatsProvider(StatsProvider)} + * <li> A service name matching the thrift interface name {@link #withServiceName(String)} + * </ul> + * + * @author John Sirois + */ +public class ThriftFactory<T> { + private static final Amount<Long,Time> DEFAULT_DEAD_TARGET_RESTORE_INTERVAL = + Amount.of(1L, Time.SECONDS); + + private static final int DEFAULT_MAX_CONNECTIONS_PER_ENDPOINT = 50; + + private Class<T> serviceInterface; + private Function<TTransport, T> clientFactory; + private int maxConnectionsPerEndpoint; + private Amount<Long,Time> connectionRestoreInterval; + private boolean framedTransport; + private LoadBalancingStrategy<InetSocketAddress> loadBalancingStrategy = null; + private final TrafficMonitor<InetSocketAddress> monitor; + private Amount<Long,Time> socketTimeout = null; + private Closure<Connection<TTransport, InetSocketAddress>> postCreateCallback = Closures.noop(); + private StatsProvider statsProvider = Stats.STATS_PROVIDER; + private Optional<String> endpointName = Optional.absent(); + private String serviceName; + private boolean sslTransport; + + public static <T> ThriftFactory<T> create(Class<T> serviceInterface) { + return new ThriftFactory<T>(serviceInterface); + } + + /** + * Creates a default factory that will use unframed blocking transport. + * + * @param serviceInterface The interface of the thrift service to make a client for. + */ + private ThriftFactory(Class<T> serviceInterface) { + this.serviceInterface = Thrift.checkServiceInterface(serviceInterface); + this.maxConnectionsPerEndpoint = DEFAULT_MAX_CONNECTIONS_PER_ENDPOINT; + this.connectionRestoreInterval = DEFAULT_DEAD_TARGET_RESTORE_INTERVAL; + this.framedTransport = false; + this.monitor = new TrafficMonitor<InetSocketAddress>(serviceInterface.getName()); + this.serviceName = serviceInterface.getEnclosingClass().getSimpleName(); + this.sslTransport = false; + } + + private void checkBaseState() { + Preconditions.checkArgument(maxConnectionsPerEndpoint > 0, + "Must allow at least 1 connection per endpoint; %s specified", maxConnectionsPerEndpoint); + } + + public TrafficMonitor<InetSocketAddress> getMonitor() { + return monitor; + } + + /** + * Creates the thrift client, and initializes connection pools. + * + * @param backends Backends to connect to. + * @return A new thrift client. + */ + public Thrift<T> build(Set<InetSocketAddress> backends) { + checkBaseState(); + MorePreconditions.checkNotBlank(backends); + + ManagedThreadPool managedThreadPool = createManagedThreadpool(backends.size()); + LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer(); + Function<TTransport, T> clientFactory = getClientFactory(); + + ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool = + createConnectionPool(backends, loadBalancer, managedThreadPool, false); + + return new Thrift<T>(managedThreadPool, connectionPool, loadBalancer, serviceName, + serviceInterface, clientFactory, false, sslTransport); + } + + /** + * Creates a synchronous thrift client that will communicate with a dynamic host set. + * + * @param hostSet The host set to use as a backend. + * @return A thrift client. + * @throws ThriftFactoryException If an error occurred while creating the client. + */ + public Thrift<T> build(DynamicHostSet<ServiceInstance> hostSet) throws ThriftFactoryException { + checkBaseState(); + Preconditions.checkNotNull(hostSet); + + ManagedThreadPool managedThreadPool = createManagedThreadpool(1); + LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer(); + Function<TTransport, T> clientFactory = getClientFactory(); + + ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool = + createConnectionPool(hostSet, loadBalancer, managedThreadPool, false, endpointName); + + return new Thrift<T>(managedThreadPool, connectionPool, loadBalancer, serviceName, + serviceInterface, clientFactory, false, sslTransport); + } + + private ManagedThreadPool createManagedThreadpool(int initialEndpointCount) { + return new ManagedThreadPool(serviceName, initialEndpointCount, maxConnectionsPerEndpoint); + } + + /** + * A finite thread pool that monitors backend choice events to dynamically resize. This + * {@link java.util.concurrent.ExecutorService} implementation immediately rejects requests when + * there are no more available worked threads (requests are not queued). + */ + private static class ManagedThreadPool extends ForwardingExecutorService<ThreadPoolExecutor> + implements Closure<Collection<InetSocketAddress>> { + + private static final Logger LOG = Logger.getLogger(ManagedThreadPool.class.getName()); + + private static ThreadPoolExecutor createThreadPool(String serviceName, int initialSize) { + ThreadFactory threadFactory = + new ThreadFactoryBuilder() + .setNameFormat("Thrift[" +serviceName + "][%d]") + .setDaemon(true) + .build(); + return new ThreadPoolExecutor(initialSize, initialSize, 0, TimeUnit.MILLISECONDS, + new SynchronousQueue<Runnable>(), threadFactory); + } + + private final int maxConnectionsPerEndpoint; + + public ManagedThreadPool(String serviceName, int initialEndpointCount, + int maxConnectionsPerEndpoint) { + + super(createThreadPool(serviceName, initialEndpointCount * maxConnectionsPerEndpoint)); + this.maxConnectionsPerEndpoint = maxConnectionsPerEndpoint; + setRejectedExecutionHandler(initialEndpointCount); + } + + private void setRejectedExecutionHandler(int endpointCount) { + final String message = + String.format("All %d x %d connections in use", endpointCount, maxConnectionsPerEndpoint); + delegate.setRejectedExecutionHandler(new RejectedExecutionHandler() { + @Override public void rejectedExecution(Runnable runnable, ThreadPoolExecutor executor) { + throw new RejectedExecutionException(message); + } + }); + } + + @Override + public void execute(Collection<InetSocketAddress> chosenBackends) { + int previousPoolSize = delegate.getMaximumPoolSize(); + /* + * In the case of no available backends, we need to make sure we pass in a positive pool + * size to our delegate. In particular, java.util.concurrent.ThreadPoolExecutor does not + * accept zero as a valid core or max pool size. + */ + int backendCount = Math.max(chosenBackends.size(), 1); + int newPoolSize = backendCount * maxConnectionsPerEndpoint; + + if (previousPoolSize != newPoolSize) { + LOG.info(String.format("Re-sizing deadline thread pool from: %d to: %d", + previousPoolSize, newPoolSize)); + if (previousPoolSize < newPoolSize) { // Don't cross the beams! + delegate.setMaximumPoolSize(newPoolSize); + delegate.setCorePoolSize(newPoolSize); + } else { + delegate.setCorePoolSize(newPoolSize); + delegate.setMaximumPoolSize(newPoolSize); + } + setRejectedExecutionHandler(backendCount); + } + } + } + + /** + * Creates an asynchronous thrift client that will communicate with a fixed set of backends. + * + * @param backends Backends to connect to. + * @return A thrift client. + * @throws ThriftFactoryException If an error occurred while creating the client. + */ + public Thrift<T> buildAsync(Set<InetSocketAddress> backends) throws ThriftFactoryException { + checkBaseState(); + MorePreconditions.checkNotBlank(backends); + + LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer(); + Closure<Collection<InetSocketAddress>> noop = Closures.noop(); + Function<TTransport, T> asyncClientFactory = getAsyncClientFactory(); + + ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool = + createConnectionPool(backends, loadBalancer, noop, true); + + return new Thrift<T>(connectionPool, loadBalancer, + serviceName, serviceInterface, asyncClientFactory, true); + } + + /** + * Creates an asynchronous thrift client that will communicate with a dynamic host set. + * + * @param hostSet The host set to use as a backend. + * @return A thrift client. + * @throws ThriftFactoryException If an error occurred while creating the client. + */ + public Thrift<T> buildAsync(DynamicHostSet<ServiceInstance> hostSet) + throws ThriftFactoryException { + checkBaseState(); + Preconditions.checkNotNull(hostSet); + + LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer(); + Closure<Collection<InetSocketAddress>> noop = Closures.noop(); + Function<TTransport, T> asyncClientFactory = getAsyncClientFactory(); + + ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool = + createConnectionPool(hostSet, loadBalancer, noop, true, endpointName); + + return new Thrift<T>(connectionPool, loadBalancer, + serviceName, serviceInterface, asyncClientFactory, true); + } + + /** + * Prepare the client factory, which will create client class instances from transports. + * + * @return The client factory to use. + */ + private Function<TTransport, T> getClientFactory() { + return clientFactory == null ? createClientFactory(serviceInterface) : clientFactory; + } + + /** + * Prepare the async client factory, which will create client class instances from transports. + * + * @return The client factory to use. + * @throws ThriftFactoryException If there was a problem creating the factory. + */ + private Function<TTransport, T> getAsyncClientFactory() throws ThriftFactoryException { + try { + return clientFactory == null ? createAsyncClientFactory(serviceInterface) : clientFactory; + } catch (IOException e) { + throw new ThriftFactoryException("Failed to create async client factory.", e); + } + } + + private ObjectPool<Connection<TTransport, InetSocketAddress>> createConnectionPool( + Set<InetSocketAddress> backends, LoadBalancer<InetSocketAddress> loadBalancer, + Closure<Collection<InetSocketAddress>> onBackendsChosen, boolean nonblocking) { + + ImmutableMap.Builder<InetSocketAddress, ObjectPool<Connection<TTransport, InetSocketAddress>>> + backendBuilder = ImmutableMap.builder(); + for (InetSocketAddress backend : backends) { + backendBuilder.put(backend, createConnectionPool(backend, nonblocking)); + } + + return new MetaPool<TTransport, InetSocketAddress>(backendBuilder.build(), + loadBalancer, onBackendsChosen, connectionRestoreInterval); + } + + private ObjectPool<Connection<TTransport, InetSocketAddress>> createConnectionPool( + DynamicHostSet<ServiceInstance> hostSet, LoadBalancer<InetSocketAddress> loadBalancer, + Closure<Collection<InetSocketAddress>> onBackendsChosen, + final boolean nonblocking, Optional<String> serviceEndpointName) + throws ThriftFactoryException { + + Function<InetSocketAddress, ObjectPool<Connection<TTransport, InetSocketAddress>>> + endpointPoolFactory = + new Function<InetSocketAddress, ObjectPool<Connection<TTransport, InetSocketAddress>>>() { + @Override public ObjectPool<Connection<TTransport, InetSocketAddress>> apply( + InetSocketAddress endpoint) { + return createConnectionPool(endpoint, nonblocking); + } + }; + + try { + return new DynamicPool<ServiceInstance, TTransport, InetSocketAddress>(hostSet, + endpointPoolFactory, loadBalancer, onBackendsChosen, connectionRestoreInterval, + Util.getAddress(serviceEndpointName), Util.IS_ALIVE); + } catch (DynamicHostSet.MonitorException e) { + throw new ThriftFactoryException("Failed to monitor host set.", e); + } + } + + private ObjectPool<Connection<TTransport, InetSocketAddress>> createConnectionPool( + InetSocketAddress backend, boolean nonblocking) { + + ThriftConnectionFactory connectionFactory = new ThriftConnectionFactory( + backend, maxConnectionsPerEndpoint, TransportType.get(framedTransport, nonblocking), + socketTimeout, postCreateCallback, sslTransport); + + return new ConnectionPool<Connection<TTransport, InetSocketAddress>>(connectionFactory, + statsProvider); + } + + @VisibleForTesting + public ThriftFactory<T> withClientFactory(Function<TTransport, T> clientFactory) { + this.clientFactory = Preconditions.checkNotNull(clientFactory); + + return this; + } + + public ThriftFactory<T> withSslEnabled() { + this.sslTransport = true; + return this; + } + + /** + * Specifies the maximum number of connections that should be made to any single endpoint. + * + * @param maxConnectionsPerEndpoint Maximum number of connections per endpoint. + * @return A reference to the factory. + */ + public ThriftFactory<T> withMaxConnectionsPerEndpoint(int maxConnectionsPerEndpoint) { + Preconditions.checkArgument(maxConnectionsPerEndpoint > 0); + this.maxConnectionsPerEndpoint = maxConnectionsPerEndpoint; + + return this; + } + + /** + * Specifies the interval at which dead endpoint connections should be checked and revived. + * + * @param connectionRestoreInterval the time interval to check. + * @return A reference to the factory. + */ + public ThriftFactory<T> withDeadConnectionRestoreInterval( + Amount<Long, Time> connectionRestoreInterval) { + Preconditions.checkNotNull(connectionRestoreInterval); + Preconditions.checkArgument(connectionRestoreInterval.getValue() >= 0, + "A negative interval is invalid: %s", connectionRestoreInterval); + this.connectionRestoreInterval = connectionRestoreInterval; + + return this; + } + + /** + * Instructs the factory whether framed transport should be used. + * + * @param framedTransport Whether to use framed transport. + * @return A reference to the factory. + */ + public ThriftFactory<T> useFramedTransport(boolean framedTransport) { + this.framedTransport = framedTransport; + + return this; + } + + /** + * Specifies the load balancer to use when interacting with multiple backends. + * + * @param strategy Load balancing strategy. + * @return A reference to the factory. + */ + public ThriftFactory<T> withLoadBalancingStrategy( + LoadBalancingStrategy<InetSocketAddress> strategy) { + this.loadBalancingStrategy = Preconditions.checkNotNull(strategy); + + return this; + } + + private LoadBalancer<InetSocketAddress> createLoadBalancer() { + if (loadBalancingStrategy == null) { + loadBalancingStrategy = createDefaultLoadBalancingStrategy(); + } + + return LoadBalancerImpl.create(TrafficMonitorAdapter.create(loadBalancingStrategy, monitor)); + } + + private LoadBalancingStrategy<InetSocketAddress> createDefaultLoadBalancingStrategy() { + Function<InetSocketAddress, BackoffDecider> backoffFactory = + new Function<InetSocketAddress, BackoffDecider>() { + @Override public BackoffDecider apply(InetSocketAddress socket) { + BackoffStrategy backoffStrategy = new TruncatedBinaryBackoff( + Amount.of(2L, Time.SECONDS), Amount.of(10L, Time.SECONDS)); + + return BackoffDecider.builder(socket.toString()) + .withTolerateFailureRate(0.2) + .withRequestWindow(Amount.of(1L, Time.SECONDS)) + .withSeedSize(5) + .withStrategy(backoffStrategy) + .withRecoveryType(BackoffDecider.RecoveryType.FULL_CAPACITY) + .withStatsProvider(statsProvider) + .build(); + } + }; + + return new MarkDeadStrategyWithHostCheck<InetSocketAddress>( + new LeastConnectedStrategy<InetSocketAddress>(), backoffFactory); + } + + /** + * Specifies the net read/write timeout to set via SO_TIMEOUT on the thrift blocking client + * or AsyncClient.setTimeout on the thrift async client. Defaults to the connectTimeout on + * the blocking client if not set. + * + * @param socketTimeout timeout on thrift i/o operations + * @return A reference to the factory. + */ + public ThriftFactory<T> withSocketTimeout(Amount<Long, Time> socketTimeout) { + this.socketTimeout = Preconditions.checkNotNull(socketTimeout); + Preconditions.checkArgument(socketTimeout.as(Time.MILLISECONDS) >= 0); + + return this; + } + + /** + * Specifies the callback to notify when a connection has been created. The callback may + * be used to make thrift calls to the connection, but must not invalidate it. + * Defaults to a no-op closure. + * + * @param postCreateCallback function to setup new connections + * @return A reference to the factory. + */ + public ThriftFactory<T> withPostCreateCallback( + Closure<Connection<TTransport, InetSocketAddress>> postCreateCallback) { + this.postCreateCallback = Preconditions.checkNotNull(postCreateCallback); + + return this; + } + + /** + * Registers a custom stats provider to use to track various client stats. + * + * @param statsProvider the {@code StatsProvider} to use + * @return A reference to the factory. + */ + public ThriftFactory<T> withStatsProvider(StatsProvider statsProvider) { + this.statsProvider = Preconditions.checkNotNull(statsProvider); + + return this; + } + + /** + * Name to be passed to Thrift constructor, used in stats. + * + * @param serviceName string to use + * @return A reference to the factory. + */ + public ThriftFactory<T> withServiceName(String serviceName) { + this.serviceName = MorePreconditions.checkNotBlank(serviceName); + + return this; + } + + /** + * Set the end-point to use from {@link ServiceInstance#getAdditionalEndpoints()}. + * If not set, the default behavior is to use {@link ServiceInstance#getServiceEndpoint()}. + * + * @param endpointName the (optional) name of the end-point, if unset - the + * default/primary end-point is selected + * @return a reference to the factory for chaining + */ + public ThriftFactory<T> withEndpointName(String endpointName) { + this.endpointName = Optional.of(endpointName); + return this; + } + + private static <T> Function<TTransport, T> createClientFactory(Class<T> serviceInterface) { + final Constructor<? extends T> implementationConstructor = + findImplementationConstructor(serviceInterface); + + return new Function<TTransport, T>() { + @Override public T apply(TTransport transport) { + try { + return implementationConstructor.newInstance(new TBinaryProtocol(transport)); + } catch (InstantiationException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } catch (InvocationTargetException e) { + throw new RuntimeException(e); + } + } + }; + } + + private <T> Function<TTransport, T> createAsyncClientFactory( + final Class<T> serviceInterface) throws IOException { + + final TAsyncClientManager clientManager = new TAsyncClientManager(); + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override public void run() { + clientManager.stop(); + } + }); + + final Constructor<? extends T> implementationConstructor = + findAsyncImplementationConstructor(serviceInterface); + + return new Function<TTransport, T>() { + @Override public T apply(TTransport transport) { + Preconditions.checkNotNull(transport); + Preconditions.checkArgument(transport instanceof TNonblockingTransport, + "Invalid transport provided to client factory: " + transport.getClass()); + + try { + T client = implementationConstructor.newInstance(new TBinaryProtocol.Factory(), + clientManager, transport); + + if (socketTimeout != null) { + ((TAsyncClient) client).setTimeout(socketTimeout.as(Time.MILLISECONDS)); + } + + return client; + } catch (InstantiationException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } catch (InvocationTargetException e) { + throw new RuntimeException(e); + } + } + }; + } + + private static <T> Constructor<? extends T> findImplementationConstructor( + final Class<T> serviceInterface) { + Class<? extends T> implementationClass = findImplementationClass(serviceInterface); + try { + return implementationClass.getConstructor(TProtocol.class); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException("Failed to find a single argument TProtocol constructor " + + "in service client class: " + implementationClass); + } + } + + private static <T> Constructor<? extends T> findAsyncImplementationConstructor( + final Class<T> serviceInterface) { + Class<? extends T> implementationClass = findImplementationClass(serviceInterface); + try { + return implementationClass.getConstructor(TProtocolFactory.class, TAsyncClientManager.class, + TNonblockingTransport.class); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException("Failed to find expected constructor " + + "in service client class: " + implementationClass); + } + } + + @SuppressWarnings("unchecked") + private static <T> Class<? extends T> findImplementationClass(final Class<T> serviceInterface) { + try { + return (Class<? extends T>) + Iterables.find(ImmutableList.copyOf(serviceInterface.getEnclosingClass().getClasses()), + new Predicate<Class<?>>() { + @Override public boolean apply(Class<?> inner) { + return !serviceInterface.equals(inner) + && serviceInterface.isAssignableFrom(inner); + } + }); + } catch (NoSuchElementException e) { + throw new IllegalArgumentException("Could not find a sibling enclosed implementation of " + + "service interface: " + serviceInterface); + } + } + + public static class ThriftFactoryException extends Exception { + public ThriftFactoryException(String msg) { + super(msg); + } + + public ThriftFactoryException(String msg, Throwable t) { + super(msg, t); + } + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/06ddaadb/commons/src/main/java/org/apache/aurora/common/thrift/Util.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/org/apache/aurora/common/thrift/Util.java b/commons/src/main/java/org/apache/aurora/common/thrift/Util.java new file mode 100644 index 0000000..7802614 --- /dev/null +++ b/commons/src/main/java/org/apache/aurora/common/thrift/Util.java @@ -0,0 +1,231 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.aurora.common.thrift; + +import java.net.InetSocketAddress; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.annotation.Nullable; + +import com.google.common.base.Function; +import com.google.common.base.Joiner; +import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.base.Strings; +import com.google.common.collect.Lists; + +import org.apache.thrift.TBase; +import org.apache.thrift.TFieldIdEnum; +import org.apache.thrift.meta_data.FieldMetaData; + +/** + * Utility functions for thrift. + * + * @author William Farner + */ +public class Util { + + /** + * Maps a {@link ServiceInstance} to an {@link InetSocketAddress} given the {@code endpointName}. + * + * @param optionalEndpointName the name of the end-point on the service's additional end-points, + * if not set, maps to the primary service end-point + */ + public static Function<ServiceInstance, InetSocketAddress> getAddress( + final Optional<String> optionalEndpointName) { + if (!optionalEndpointName.isPresent()) { + return GET_ADDRESS; + } + + final String endpointName = optionalEndpointName.get(); + return getAddress( + new Function<ServiceInstance, Endpoint>() { + @Override public Endpoint apply(@Nullable ServiceInstance serviceInstance) { + Map<String, Endpoint> endpoints = serviceInstance.getAdditionalEndpoints(); + Preconditions.checkArgument(endpoints.containsKey(endpointName), + "Did not find end-point %s on %s", endpointName, serviceInstance); + return endpoints.get(endpointName); + } + }); + } + + private static Function<ServiceInstance, InetSocketAddress> getAddress( + final Function<ServiceInstance, Endpoint> serviceToEndpoint) { + return new Function<ServiceInstance, InetSocketAddress>() { + @Override public InetSocketAddress apply(ServiceInstance serviceInstance) { + Endpoint endpoint = serviceToEndpoint.apply(serviceInstance); + return InetSocketAddress.createUnresolved(endpoint.getHost(), endpoint.getPort()); + } + }; + } + + private static Function<ServiceInstance, Endpoint> GET_PRIMARY_ENDPOINT = + new Function<ServiceInstance, Endpoint>() { + @Override public Endpoint apply(ServiceInstance input) { + return input.getServiceEndpoint(); + } + }; + + public static Function<ServiceInstance, InetSocketAddress> GET_ADDRESS = + getAddress(GET_PRIMARY_ENDPOINT); + + public static final Predicate<ServiceInstance> IS_ALIVE = new Predicate<ServiceInstance>() { + @Override public boolean apply(ServiceInstance serviceInstance) { + switch (serviceInstance.getStatus()) { + case ALIVE: + return true; + + // We'll be optimistic here and let MTCP's ranking deal with + // unhealthy services in a WARNING state. + case WARNING: + return true; + + // Services which are just starting up, on the other hand... are much easier to just not + // send requests to. The STARTING state is useful to distinguish from WARNING or ALIVE: + // you exist in ZooKeeper, but don't yet serve traffic. + case STARTING: + default: + return false; + } + } + }; + + /** + * Pretty-prints a thrift object contents. + * + * @param t The thrift object to print. + * @return The pretty-printed version of the thrift object. + */ + public static String prettyPrint(TBase t) { + return t == null ? "null" : printTbase(t, 0); + } + + /** + * Prints an object contained in a thrift message. + * + * @param o The object to print. + * @param depth The print nesting level. + * @return The pretty-printed version of the thrift field. + */ + private static String printValue(Object o, int depth) { + if (o == null) { + return "null"; + } else if (TBase.class.isAssignableFrom(o.getClass())) { + return "\n" + printTbase((TBase) o, depth + 1); + } else if (Map.class.isAssignableFrom(o.getClass())) { + return printMap((Map) o, depth + 1); + } else if (List.class.isAssignableFrom(o.getClass())) { + return printList((List) o, depth + 1); + } else if (Set.class.isAssignableFrom(o.getClass())) { + return printSet((Set) o, depth + 1); + } else if (String.class == o.getClass()) { + return '"' + o.toString() + '"'; + } else { + return o.toString(); + } + } + + private static final String METADATA_MAP_FIELD_NAME = "metaDataMap"; + + /** + * Prints a TBase. + * + * @param t The object to print. + * @param depth The print nesting level. + * @return The pretty-printed version of the TBase. + */ + private static String printTbase(TBase t, int depth) { + List<String> fields = Lists.newArrayList(); + for (Map.Entry<? extends TFieldIdEnum, FieldMetaData> entry : + FieldMetaData.getStructMetaDataMap(t.getClass()).entrySet()) { + @SuppressWarnings("unchecked") + boolean fieldSet = t.isSet(entry.getKey()); + String strValue; + if (fieldSet) { + @SuppressWarnings("unchecked") + Object value = t.getFieldValue(entry.getKey()); + strValue = printValue(value, depth); + } else { + strValue = "not set"; + } + fields.add(tabs(depth) + entry.getValue().fieldName + ": " + strValue); + } + + return Joiner.on("\n").join(fields); + } + + /** + * Prints a map in a style that is consistent with TBase pretty printing. + * + * @param map The map to print + * @param depth The print nesting level. + * @return The pretty-printed version of the map. + */ + private static String printMap(Map<?, ?> map, int depth) { + List<String> entries = Lists.newArrayList(); + for (Map.Entry entry : map.entrySet()) { + entries.add(tabs(depth) + printValue(entry.getKey(), depth) + + " = " + printValue(entry.getValue(), depth)); + } + + return entries.isEmpty() ? "{}" + : String.format("{\n%s\n%s}", Joiner.on(",\n").join(entries), tabs(depth - 1)); + } + + /** + * Prints a list in a style that is consistent with TBase pretty printing. + * + * @param list The list to print + * @param depth The print nesting level. + * @return The pretty-printed version of the list + */ + private static String printList(List<?> list, int depth) { + List<String> entries = Lists.newArrayList(); + for (int i = 0; i < list.size(); i++) { + entries.add( + String.format("%sItem[%d] = %s", tabs(depth), i, printValue(list.get(i), depth))); + } + + return entries.isEmpty() ? "[]" + : String.format("[\n%s\n%s]", Joiner.on(",\n").join(entries), tabs(depth - 1)); + } + /** + * Prints a set in a style that is consistent with TBase pretty printing. + * + * @param set The set to print + * @param depth The print nesting level. + * @return The pretty-printed version of the set + */ + private static String printSet(Set<?> set, int depth) { + List<String> entries = Lists.newArrayList(); + for (Object item : set) { + entries.add( + String.format("%sItem = %s", tabs(depth), printValue(item, depth))); + } + + return entries.isEmpty() ? "{}" + : String.format("{\n%s\n%s}", Joiner.on(",\n").join(entries), tabs(depth - 1)); + } + + private static String tabs(int n) { + return Strings.repeat(" ", n); + } + + private Util() { + // Utility class. + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/06ddaadb/commons/src/main/java/org/apache/aurora/common/thrift/callers/Caller.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/org/apache/aurora/common/thrift/callers/Caller.java b/commons/src/main/java/org/apache/aurora/common/thrift/callers/Caller.java new file mode 100644 index 0000000..0200c49 --- /dev/null +++ b/commons/src/main/java/org/apache/aurora/common/thrift/callers/Caller.java @@ -0,0 +1,99 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.aurora.common.thrift.callers; + +import com.google.common.base.Preconditions; +import org.apache.aurora.common.quantity.Amount; +import org.apache.aurora.common.quantity.Time; +import org.apache.thrift.async.AsyncMethodCallback; + +import javax.annotation.Nullable; +import java.lang.reflect.Method; + +/** +* A caller that invokes a method on an object. +* +* @author William Farner +*/ +public interface Caller { + + /** + * Invokes a method on an object, using the given arguments. The method call may be + * asynchronous, in which case {@code callback} will be non-null. + * + * @param method The method being invoked. + * @param args The arguments to call {@code method} with. + * @param callback The callback to use if the method is asynchronous. + * @param connectTimeoutOverride Optional override for the default connection timeout. + * @return The return value from invoking the method. + * @throws Throwable Exception, as prescribed by the method's contract. + */ + public Object call(Method method, Object[] args, @Nullable AsyncMethodCallback callback, + @Nullable Amount<Long, Time> connectTimeoutOverride) throws Throwable; + + /** + * Captures the result of a request, whether synchronous or asynchronous. It should be expected + * that for every request made, exactly one of these methods will be called. + */ + static interface ResultCapture { + /** + * Called when the request completed successfully. + */ + void success(); + + /** + * Called when the request failed. + * + * @param t Throwable that was caught. Must never be null. + * @return {@code true} if a wrapped callback should be notified of the failure, + * {@code false} otherwise. + */ + boolean fail(Throwable t); + } + + /** + * A callback that adapts a {@link ResultCapture} with an {@link AsyncMethodCallback} while + * maintaining the AsyncMethodCallback interface. The wrapped callback will handle invocation + * of the underlying callback based on the return values from the ResultCapture. + */ + static class WrappedMethodCallback implements AsyncMethodCallback { + private final AsyncMethodCallback wrapped; + private final ResultCapture capture; + + private boolean callbackTriggered = false; + + public WrappedMethodCallback(AsyncMethodCallback wrapped, ResultCapture capture) { + this.wrapped = wrapped; + this.capture = capture; + } + + private void callbackTriggered() { + Preconditions.checkState(!callbackTriggered, "Each callback may only be triggered once."); + callbackTriggered = true; + } + + @Override @SuppressWarnings("unchecked") public void onComplete(Object o) { + capture.success(); + wrapped.onComplete(o); + callbackTriggered(); + } + + @Override public void onError(Exception t) { + if (capture.fail(t)) { + wrapped.onError(t); + callbackTriggered(); + } + } + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/06ddaadb/commons/src/main/java/org/apache/aurora/common/thrift/callers/CallerDecorator.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/org/apache/aurora/common/thrift/callers/CallerDecorator.java b/commons/src/main/java/org/apache/aurora/common/thrift/callers/CallerDecorator.java new file mode 100644 index 0000000..bd0a952 --- /dev/null +++ b/commons/src/main/java/org/apache/aurora/common/thrift/callers/CallerDecorator.java @@ -0,0 +1,78 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.aurora.common.thrift.callers; + +import org.apache.aurora.common.quantity.Amount; +import org.apache.aurora.common.quantity.Time; +import org.apache.thrift.async.AsyncMethodCallback; + +import javax.annotation.Nullable; +import java.lang.reflect.Method; + +/** +* A caller that decorates another caller. +* +* @author William Farner +*/ +abstract class CallerDecorator implements Caller { + private final Caller decoratedCaller; + private final boolean async; + + CallerDecorator(Caller decoratedCaller, boolean async) { + this.decoratedCaller = decoratedCaller; + this.async = async; + } + + /** + * Convenience method for invoking the method and shunting the capture into the callback if + * the call is asynchronous. + * + * @param method The method being invoked. + * @param args The arguments to call {@code method} with. + * @param callback The callback to use if the method is asynchronous. + * @param capture The result capture to notify of the call result. + * @param connectTimeoutOverride Optional override for the default connection timeout. + * @return The return value from invoking the method. + * @throws Throwable Exception, as prescribed by the method's contract. + */ + protected final Object invoke(Method method, Object[] args, + @Nullable AsyncMethodCallback callback, @Nullable final ResultCapture capture, + @Nullable Amount<Long, Time> connectTimeoutOverride) throws Throwable { + + // Swap the wrapped callback out for ours. + if (callback != null) { + callback = new WrappedMethodCallback(callback, capture); + } + + try { + Object result = decoratedCaller.call(method, args, callback, connectTimeoutOverride); + if (callback == null && capture != null) capture.success(); + + return result; + } catch (Exception t) { + // We allow this one to go to both sync and async captures. + if (callback != null) { + callback.onError(t); + return null; + } else { + if (capture != null) capture.fail(t); + throw t; + } + } + } + + boolean isAsync() { + return async; + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/06ddaadb/commons/src/main/java/org/apache/aurora/common/thrift/callers/DeadlineCaller.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/org/apache/aurora/common/thrift/callers/DeadlineCaller.java b/commons/src/main/java/org/apache/aurora/common/thrift/callers/DeadlineCaller.java new file mode 100644 index 0000000..75ed1ec --- /dev/null +++ b/commons/src/main/java/org/apache/aurora/common/thrift/callers/DeadlineCaller.java @@ -0,0 +1,93 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.aurora.common.thrift.callers; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeoutException; + +import javax.annotation.Nullable; + +import com.google.common.base.Throwables; + +import org.apache.aurora.common.thrift.TResourceExhaustedException; +import org.apache.thrift.async.AsyncMethodCallback; + +import org.apache.aurora.common.quantity.Amount; +import org.apache.aurora.common.quantity.Time; +import org.apache.aurora.common.thrift.TTimeoutException; + +/** + * A caller that imposes a time deadline on the underlying caller. If the underlying calls fail + * to meet the deadline {@link TTimeoutException} is thrown. If the executor service rejects + * execution of a task, {@link TResourceExhaustedException} is thrown. + * + * @author William Farner + */ +public class DeadlineCaller extends CallerDecorator { + private final ExecutorService executorService; + private final Amount<Long, Time> timeout; + + /** + * Creates a new deadline caller. + * + * @param decoratedCaller The caller to decorate with a deadline. + * @param async Whether the caller is asynchronous. + * @param executorService The executor service to use for performing calls. + * @param timeout The timeout by which the underlying call should complete in. + */ + public DeadlineCaller(Caller decoratedCaller, boolean async, ExecutorService executorService, + Amount<Long, Time> timeout) { + super(decoratedCaller, async); + + this.executorService = executorService; + this.timeout = timeout; + } + + @Override + public Object call(final Method method, final Object[] args, + @Nullable final AsyncMethodCallback callback, + @Nullable final Amount<Long, Time> connectTimeoutOverride) throws Throwable { + try { + Future<Object> result = executorService.submit(new Callable<Object>() { + @Override public Object call() throws Exception { + try { + return invoke(method, args, callback, null, connectTimeoutOverride); + } catch (Throwable t) { + Throwables.propagateIfInstanceOf(t, Exception.class); + throw new RuntimeException(t); + } + } + }); + + try { + return result.get(timeout.getValue(), timeout.getUnit().getTimeUnit()); + } catch (TimeoutException e) { + result.cancel(true); + throw new TTimeoutException(e); + } catch (ExecutionException e) { + throw e.getCause(); + } + } catch (RejectedExecutionException e) { + throw new TResourceExhaustedException(e); + } catch (InvocationTargetException e) { + throw e.getCause(); + } + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/06ddaadb/commons/src/main/java/org/apache/aurora/common/thrift/callers/DebugCaller.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/org/apache/aurora/common/thrift/callers/DebugCaller.java b/commons/src/main/java/org/apache/aurora/common/thrift/callers/DebugCaller.java new file mode 100644 index 0000000..aff4006 --- /dev/null +++ b/commons/src/main/java/org/apache/aurora/common/thrift/callers/DebugCaller.java @@ -0,0 +1,73 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.aurora.common.thrift.callers; + +import com.google.common.base.Joiner; + +import org.apache.aurora.common.quantity.Amount; +import org.apache.aurora.common.quantity.Time; +import org.apache.thrift.async.AsyncMethodCallback; + +import javax.annotation.Nullable; + +import java.lang.reflect.Method; +import java.util.logging.Logger; + +/** + * A caller that reports debugging information about calls. + * + * @author William Farner + */ +public class DebugCaller extends CallerDecorator { + private static final Logger LOG = Logger.getLogger(DebugCaller.class.getName()); + private static final Joiner ARG_JOINER = Joiner.on(", "); + + /** + * Creates a new debug caller. + * + * @param decoratedCaller The caller to decorate with debug information. + * @param async Whether the caller is asynchronous. + */ + public DebugCaller(Caller decoratedCaller, boolean async) { + super(decoratedCaller, async); + } + + @Override + public Object call(final Method method, final Object[] args, + @Nullable AsyncMethodCallback callback, @Nullable Amount<Long, Time> connectTimeoutOverride) + throws Throwable { + ResultCapture capture = new ResultCapture() { + @Override public void success() { + // No-op. + } + + @Override public boolean fail(Throwable t) { + StringBuilder message = new StringBuilder("Thrift call failed: "); + message.append(method.getName()).append("("); + ARG_JOINER.appendTo(message, args); + message.append(")"); + LOG.warning(message.toString()); + + return true; + } + }; + + try { + return invoke(method, args, callback, capture, connectTimeoutOverride); + } catch (Throwable t) { + capture.fail(t); + throw t; + } + } +}