Github user sudheeshkatkam commented on a diff in the pull request: https://github.com/apache/drill/pull/578#discussion_r85853323 --- Diff: exec/java-exec/src/main/java/org/apache/drill/exec/rpc/user/UserClient.java --- @@ -78,21 +101,241 @@ public void submitQuery(UserResultsListener resultsListener, RunQuery query) { send(queryResultHandler.getWrappedListener(resultsListener), RpcType.RUN_QUERY, query, QueryId.class); } - public void connect(RpcConnectionHandler<ServerConnection> handler, DrillbitEndpoint endpoint, - UserProperties props, UserBitShared.UserCredentials credentials) { + public CheckedFuture<Void, RpcException> connect(DrillbitEndpoint endpoint, ConnectionParameters parameters, + UserCredentials credentials) { + final FutureHandler handler = new FutureHandler(); UserToBitHandshake.Builder hsBuilder = UserToBitHandshake.newBuilder() .setRpcVersion(UserRpcConfig.RPC_VERSION) .setSupportListening(true) .setSupportComplexTypes(supportComplexTypes) .setSupportTimeout(true) - .setCredentials(credentials); + .setCredentials(credentials) + .setProperties(parameters.serializeForServer()); + this.parameters = parameters; + + connectAsClient(queryResultHandler.getWrappedConnectionHandler(handler), + hsBuilder.build(), endpoint.getAddress(), endpoint.getUserPort()); + return handler; + } + + /** + * Check (after {@link #connect connecting}) if server requires authentication. + * + * @return true if server requires authentication + */ + public boolean serverRequiresAuthentication() { + return supportedAuthMechs != null; + } + + /** + * Returns a list of supported authentication mechanism. If called before {@link #connect connecting}, + * returns null. If called after {@link #connect connecting}, returns a list of supported mechanisms + * iff authentication is required. + * + * @return list of supported authentication mechanisms + */ + public List<String> getSupportedAuthenticationMechanisms() { + return supportedAuthMechs; + } + + /** + * Authenticate to the server asynchronously. Returns a future that {@link CheckedFuture#checkedGet results} + * in null if authentication succeeds, or throws a {@link SaslException} with relevant message if + * authentication fails. + * + * This method uses parameters provided at {@link #connect connection time} and override them with the + * given parameters, if any. + * + * @param overrides parameter overrides + * @return result of authentication request + */ + public CheckedFuture<Void, SaslException> authenticate(final ConnectionParameters overrides) { + if (supportedAuthMechs == null) { + throw new IllegalStateException("Server does not require authentication."); + } + parameters.merge(overrides); + + final SettableFuture<Void> settableFuture = SettableFuture.create(); // future used in SASL exchange + final CheckedFuture<Void, SaslException> future = + new AbstractCheckedFuture<Void, SaslException>(settableFuture) { + + @Override + protected SaslException mapException(Exception e) { + if (connection != null) { + connection.close(); // to ensure connection is dropped + } + if (e instanceof ExecutionException) { + final Throwable cause = e.getCause(); + if (cause instanceof SaslException) { + return new SaslException("Authentication failed: " + cause.getMessage(), cause); + } + } + return new SaslException("Authentication failed unexpectedly.", e); + } + }; - if (props != null) { - hsBuilder.setProperties(props); + final ClientAuthenticationProvider authenticationProvider; + try { + authenticationProvider = + UserAuthenticationUtil.getClientAuthenticationProvider(parameters, supportedAuthMechs); + } catch (final SaslException e) { + settableFuture.setException(e); + return future; } - this.connectAsClient(queryResultHandler.getWrappedConnectionHandler(handler), - hsBuilder.build(), endpoint.getAddress(), endpoint.getUserPort()); + final String providerName = authenticationProvider.name(); + logger.trace("Will try to login for {} mechanism.", providerName); + final UserGroupInformation ugi; + try { + ugi = authenticationProvider.login(parameters); + } catch (final SaslException e) { + settableFuture.setException(e); + return future; + } + + logger.trace("Will try to authenticate to server using {} mechanism.", providerName); + try { + saslClient = authenticationProvider.createSaslClient(ugi, parameters); + } catch (final SaslException e) { + settableFuture.setException(e); + return future; + } + + if (saslClient == null) { + settableFuture.setException(new SaslException("Cannot initiate authentication. Insufficient credentials?")); + return future; + } + logger.trace("Initiating SASL exchange."); + + try { + final ByteString responseData; + if (saslClient.hasInitialResponse()) { + responseData = ByteString.copyFrom(evaluateChallenge(ugi, saslClient, new byte[0])); + } else { + responseData = ByteString.EMPTY; + } + send(new SaslChallengeHandler(ugi, settableFuture), + RpcType.SASL_MESSAGE, + SaslMessage.newBuilder() + .setMechanism(providerName) + .setStatus(SaslStatus.SASL_START) + .setData(responseData) + .build(), + SaslMessage.class); + logger.trace("Initiated SASL exchange."); + } catch (final SaslException e) { + settableFuture.setException(e); + } + return future; + } + + private static byte[] evaluateChallenge(final UserGroupInformation ugi, final SaslClient saslClient, + final byte[] challenge) throws SaslException { + try { + return ugi.doAs(new PrivilegedExceptionAction<byte[]>() { + @Override + public byte[] run() throws Exception { + return saslClient.evaluateChallenge(challenge); + } + }); + } catch (final UndeclaredThrowableException e) { + final Throwable cause = e.getCause(); + if (cause instanceof SaslException) { + throw (SaslException) cause; + } else { + throw new SaslException( + String.format("Unexpected failure (%s)", saslClient.getMechanismName()), cause); + } + } catch (final IOException | InterruptedException e) { + throw new SaslException(String.format("Unexpected failure (%s)", saslClient.getMechanismName()), e); + } + } + + // handles SASL message exchange + private class SaslChallengeHandler implements RpcOutcomeListener<SaslMessage> { + + private final UserGroupInformation ugi; + private final SettableFuture<Void> future; + + public SaslChallengeHandler(UserGroupInformation ugi, SettableFuture<Void> future) { + this.ugi = ugi; + this.future = future; + } + + @Override + public void failed(RpcException ex) { + future.setException(new SaslException("Unexpected failure", ex)); + } + + @Override + public void success(SaslMessage value, ByteBuf buffer) { + logger.trace("Server responded with message of type: {}", value.getStatus()); + switch (value.getStatus()) { + case SASL_AUTH_IN_PROGRESS: { + try { + final SaslMessage.Builder response = SaslMessage.newBuilder(); + final byte[] responseBytes = evaluateChallenge(ugi, saslClient, value.getData().toByteArray()); + final boolean isComplete = saslClient.isComplete(); + logger.trace("Evaluated challenge. Completed? {}. Sending response to server.", isComplete); + response.setData(responseBytes != null ? ByteString.copyFrom(responseBytes) : ByteString.EMPTY); + // if isComplete, the client will get one more response from server + response.setStatus(isComplete ? SaslStatus.SASL_AUTH_SUCCESS : SaslStatus.SASL_AUTH_IN_PROGRESS); + send(new SaslChallengeHandler(ugi, future), + connection, + RpcType.SASL_MESSAGE, + response.build(), + SaslMessage.class, + true // the connection will not be backed up at this point + ); + } catch (Exception e) { + future.setException(e); + } + break; + } + case SASL_AUTH_SUCCESS: { + try { + if (saslClient.isComplete()) { + logger.trace("Successfully authenticated to server using {}", saslClient.getMechanismName()); + saslClient.dispose(); --- End diff -- Will update PR with the latest changes that include refactoring this `switch case` statement.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. ---