Github user sudheeshkatkam commented on a diff in the pull request:

    https://github.com/apache/drill/pull/578#discussion_r85853323
  
    --- Diff: 
exec/java-exec/src/main/java/org/apache/drill/exec/rpc/user/UserClient.java ---
    @@ -78,21 +101,241 @@ public void submitQuery(UserResultsListener 
resultsListener, RunQuery query) {
         send(queryResultHandler.getWrappedListener(resultsListener), 
RpcType.RUN_QUERY, query, QueryId.class);
       }
     
    -  public void connect(RpcConnectionHandler<ServerConnection> handler, 
DrillbitEndpoint endpoint,
    -                      UserProperties props, UserBitShared.UserCredentials 
credentials) {
    +  public CheckedFuture<Void, RpcException> connect(DrillbitEndpoint 
endpoint, ConnectionParameters parameters,
    +                                                   UserCredentials 
credentials) {
    +    final FutureHandler handler = new FutureHandler();
         UserToBitHandshake.Builder hsBuilder = UserToBitHandshake.newBuilder()
             .setRpcVersion(UserRpcConfig.RPC_VERSION)
             .setSupportListening(true)
             .setSupportComplexTypes(supportComplexTypes)
             .setSupportTimeout(true)
    -        .setCredentials(credentials);
    +        .setCredentials(credentials)
    +        .setProperties(parameters.serializeForServer());
    +    this.parameters = parameters;
    +
    +    
connectAsClient(queryResultHandler.getWrappedConnectionHandler(handler),
    +        hsBuilder.build(), endpoint.getAddress(), endpoint.getUserPort());
    +    return handler;
    +  }
    +
    +  /**
    +   * Check (after {@link #connect connecting}) if server requires 
authentication.
    +   *
    +   * @return true if server requires authentication
    +   */
    +  public boolean serverRequiresAuthentication() {
    +    return supportedAuthMechs != null;
    +  }
    +
    +  /**
    +   * Returns a list of supported authentication mechanism. If called 
before {@link #connect connecting},
    +   * returns null. If called after {@link #connect connecting}, returns a 
list of supported mechanisms
    +   * iff authentication is required.
    +   *
    +   * @return list of supported authentication mechanisms
    +   */
    +  public List<String> getSupportedAuthenticationMechanisms() {
    +    return supportedAuthMechs;
    +  }
    +
    +  /**
    +   * Authenticate to the server asynchronously. Returns a future that 
{@link CheckedFuture#checkedGet results}
    +   * in null if authentication succeeds, or throws a {@link SaslException} 
with relevant message if
    +   * authentication fails.
    +   *
    +   * This method uses parameters provided at {@link #connect connection 
time} and override them with the
    +   * given parameters, if any.
    +   *
    +   * @param overrides parameter overrides
    +   * @return result of authentication request
    +   */
    +  public CheckedFuture<Void, SaslException> authenticate(final 
ConnectionParameters overrides) {
    +    if (supportedAuthMechs == null) {
    +      throw new IllegalStateException("Server does not require 
authentication.");
    +    }
    +    parameters.merge(overrides);
    +
    +    final SettableFuture<Void> settableFuture = SettableFuture.create(); 
// future used in SASL exchange
    +    final CheckedFuture<Void, SaslException> future =
    +        new AbstractCheckedFuture<Void, SaslException>(settableFuture) {
    +
    +          @Override
    +          protected SaslException mapException(Exception e) {
    +            if (connection != null) {
    +              connection.close(); // to ensure connection is dropped
    +            }
    +            if (e instanceof ExecutionException) {
    +              final Throwable cause = e.getCause();
    +              if (cause instanceof SaslException) {
    +                return new SaslException("Authentication failed: " + 
cause.getMessage(), cause);
    +              }
    +            }
    +            return new SaslException("Authentication failed 
unexpectedly.", e);
    +          }
    +        };
     
    -    if (props != null) {
    -      hsBuilder.setProperties(props);
    +    final ClientAuthenticationProvider authenticationProvider;
    +    try {
    +      authenticationProvider =
    +          
UserAuthenticationUtil.getClientAuthenticationProvider(parameters, 
supportedAuthMechs);
    +    } catch (final SaslException e) {
    +      settableFuture.setException(e);
    +      return future;
         }
     
    -    
this.connectAsClient(queryResultHandler.getWrappedConnectionHandler(handler),
    -        hsBuilder.build(), endpoint.getAddress(), endpoint.getUserPort());
    +    final String providerName = authenticationProvider.name();
    +    logger.trace("Will try to login for {} mechanism.", providerName);
    +    final UserGroupInformation ugi;
    +    try {
    +      ugi = authenticationProvider.login(parameters);
    +    } catch (final SaslException e) {
    +      settableFuture.setException(e);
    +      return future;
    +    }
    +
    +    logger.trace("Will try to authenticate to server using {} mechanism.", 
providerName);
    +    try {
    +      saslClient = authenticationProvider.createSaslClient(ugi, 
parameters);
    +    } catch (final SaslException e) {
    +      settableFuture.setException(e);
    +      return future;
    +    }
    +
    +    if (saslClient == null) {
    +      settableFuture.setException(new SaslException("Cannot initiate 
authentication. Insufficient credentials?"));
    +      return future;
    +    }
    +    logger.trace("Initiating SASL exchange.");
    +
    +    try {
    +      final ByteString responseData;
    +      if (saslClient.hasInitialResponse()) {
    +        responseData = ByteString.copyFrom(evaluateChallenge(ugi, 
saslClient, new byte[0]));
    +      } else {
    +        responseData = ByteString.EMPTY;
    +      }
    +      send(new SaslChallengeHandler(ugi, settableFuture),
    +          RpcType.SASL_MESSAGE,
    +          SaslMessage.newBuilder()
    +              .setMechanism(providerName)
    +              .setStatus(SaslStatus.SASL_START)
    +              .setData(responseData)
    +              .build(),
    +          SaslMessage.class);
    +      logger.trace("Initiated SASL exchange.");
    +    } catch (final SaslException e) {
    +      settableFuture.setException(e);
    +    }
    +    return future;
    +  }
    +
    +  private static byte[] evaluateChallenge(final UserGroupInformation ugi, 
final SaslClient saslClient,
    +                                          final byte[] challenge) throws 
SaslException {
    +    try {
    +      return ugi.doAs(new PrivilegedExceptionAction<byte[]>() {
    +        @Override
    +        public byte[] run() throws Exception {
    +          return saslClient.evaluateChallenge(challenge);
    +        }
    +      });
    +    } catch (final UndeclaredThrowableException e) {
    +      final Throwable cause = e.getCause();
    +      if (cause instanceof SaslException) {
    +        throw (SaslException) cause;
    +      } else {
    +        throw new SaslException(
    +            String.format("Unexpected failure (%s)", 
saslClient.getMechanismName()), cause);
    +      }
    +    } catch (final IOException | InterruptedException e) {
    +      throw new SaslException(String.format("Unexpected failure (%s)", 
saslClient.getMechanismName()), e);
    +    }
    +  }
    +
    +  // handles SASL message exchange
    +  private class SaslChallengeHandler implements 
RpcOutcomeListener<SaslMessage> {
    +
    +    private final UserGroupInformation ugi;
    +    private final SettableFuture<Void> future;
    +
    +    public SaslChallengeHandler(UserGroupInformation ugi, 
SettableFuture<Void> future) {
    +      this.ugi = ugi;
    +      this.future = future;
    +    }
    +
    +    @Override
    +    public void failed(RpcException ex) {
    +      future.setException(new SaslException("Unexpected failure", ex));
    +    }
    +
    +    @Override
    +    public void success(SaslMessage value, ByteBuf buffer) {
    +      logger.trace("Server responded with message of type: {}", 
value.getStatus());
    +      switch (value.getStatus()) {
    +      case SASL_AUTH_IN_PROGRESS: {
    +        try {
    +          final SaslMessage.Builder response = SaslMessage.newBuilder();
    +          final byte[] responseBytes = evaluateChallenge(ugi, saslClient, 
value.getData().toByteArray());
    +          final boolean isComplete = saslClient.isComplete();
    +          logger.trace("Evaluated challenge. Completed? {}. Sending 
response to server.", isComplete);
    +          response.setData(responseBytes != null ? 
ByteString.copyFrom(responseBytes) : ByteString.EMPTY);
    +          // if isComplete, the client will get one more response from 
server
    +          response.setStatus(isComplete ? SaslStatus.SASL_AUTH_SUCCESS : 
SaslStatus.SASL_AUTH_IN_PROGRESS);
    +          send(new SaslChallengeHandler(ugi, future),
    +              connection,
    +              RpcType.SASL_MESSAGE,
    +              response.build(),
    +              SaslMessage.class,
    +              true // the connection will not be backed up at this point
    +          );
    +        } catch (Exception e) {
    +          future.setException(e);
    +        }
    +        break;
    +      }
    +      case SASL_AUTH_SUCCESS: {
    +        try {
    +          if (saslClient.isComplete()) {
    +            logger.trace("Successfully authenticated to server using {}", 
saslClient.getMechanismName());
    +            saslClient.dispose();
    --- End diff --
    
    Will update PR with the latest changes that include refactoring this 
`switch case` statement.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

Reply via email to