This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new aa0631a  ARROW-5877: [FlightRPC] Fix Python<->Java auth issues
aa0631a is described below

commit aa0631a95a3b92c4e14b9089656b073042ca17b8
Author: David Li <li.david...@gmail.com>
AuthorDate: Wed Jul 10 14:40:19 2019 -0500

    ARROW-5877: [FlightRPC] Fix Python<->Java auth issues
    
    - Java servers before didn't actually wait for the Handshake RPC to complete
    - Java servers didn't interrupt auth handlers if the client sent an error
    - Python/C++ clients didn't explicitly finish their end of the connection
    
    Together, this led to the 'hanging forever' issue @rymurr saw.
    
    I've left some TODOs as I would like to raise Flight-specific exceptions 
(which I'm working on in parallel).
    
    Travis: https://travis-ci.com/lihalite/arrow/builds/118503572
    AppVeyor: https://ci.appveyor.com/project/lihalite/arrow/builds/25858510
    
    Author: David Li <li.david...@gmail.com>
    
    Closes #4838 from lihalite/arrow-5877 and squashes the following commits:
    
    fc35d190c <David Li> Wait for authentication to complete server-side
---
 cpp/src/arrow/flight/client.cc                     |  5 ++
 .../arrow/flight/auth/ClientAuthWrapper.java       | 46 ++++++++---
 .../arrow/flight/auth/ServerAuthHandler.java       |  2 +
 .../arrow/flight/auth/ServerAuthWrapper.java       |  2 +
 .../java/org/apache/arrow/flight/TestAuth.java     | 94 ++++++++++++++++++++++
 .../auth/{TestAuth.java => TestBasicAuth.java}     |  2 +-
 6 files changed, 141 insertions(+), 10 deletions(-)

diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index c508dca..cb4b92d 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -430,7 +430,12 @@ class FlightClient::FlightClientImpl {
     GrpcClientAuthSender outgoing{stream};
     GrpcClientAuthReader incoming{stream};
     RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming));
+    // Explicitly close our side of the connection
+    bool finished_writes = stream->WritesDone();
     RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish()));
+    if (!finished_writes) {
+      return Status::UnknownError("Could not finish writing before closing");
+    }
     return Status::OK();
   }
 
diff --git 
a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java 
b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
index f916c92..9b80340 100644
--- 
a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
+++ 
b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
@@ -18,6 +18,9 @@
 package org.apache.arrow.flight.auth;
 
 import java.util.Iterator;
+import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.LinkedBlockingQueue;
 
 import org.apache.arrow.flight.auth.ClientAuthHandler.ClientAuthSender;
@@ -25,9 +28,9 @@ import org.apache.arrow.flight.impl.Flight.HandshakeRequest;
 import org.apache.arrow.flight.impl.Flight.HandshakeResponse;
 import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub;
 
-import com.google.common.base.Throwables;
 import com.google.protobuf.ByteString;
 
+import io.grpc.StatusRuntimeException;
 import io.grpc.stub.StreamObserver;
 
 /**
@@ -45,7 +48,17 @@ public class ClientAuthWrapper {
     AuthObserver observer = new AuthObserver();
     observer.responseObserver = stub.handshake(observer);
     authHandler.authenticate(observer.sender, observer.iter);
-    observer.responseObserver.onCompleted();
+    if (!observer.sender.errored) {
+      observer.responseObserver.onCompleted();
+    }
+    try {
+      if (!observer.completed.get()) {
+        // TODO: ARROW-5681
+        throw new RuntimeException("Unauthenticated");
+      }
+    } catch (InterruptedException | ExecutionException e) {
+      throw new RuntimeException(e);
+    }
   }
 
   private static class AuthObserver implements 
StreamObserver<HandshakeResponse> {
@@ -53,11 +66,11 @@ public class ClientAuthWrapper {
     private volatile StreamObserver<HandshakeRequest> responseObserver;
     private final LinkedBlockingQueue<byte[]> messages = new 
LinkedBlockingQueue<>();
     private final AuthSender sender = new AuthSender();
-    private volatile boolean completed = false;
-    private Throwable ex = null;
+    private CompletableFuture<Boolean> completed;
 
     public AuthObserver() {
       super();
+      completed = new CompletableFuture<>();
     }
 
     @Override
@@ -72,7 +85,7 @@ public class ClientAuthWrapper {
 
       @Override
       public byte[] next() {
-        while (ex == null && (!completed || !messages.isEmpty())) {
+        while (!completed.isDone() || !messages.isEmpty()) {
           byte[] bytes = messages.poll();
           if (bytes == null) {
             // busy wait.
@@ -82,8 +95,19 @@ public class ClientAuthWrapper {
           }
         }
 
-        if (ex != null) {
-          throw Throwables.propagate(ex);
+        if (completed.isCompletedExceptionally()) {
+          // Preserve prior exception behavior
+          // TODO: with ARROW-5681, throw an appropriate Flight exception if 
gRPC raised an exception
+          try {
+            completed.get();
+          } catch (InterruptedException e) {
+            throw new RuntimeException(e);
+          } catch (ExecutionException e) {
+            if (e.getCause() instanceof StatusRuntimeException) {
+              throw (StatusRuntimeException) e.getCause();
+            }
+            throw new RuntimeException(e);
+          }
         }
 
         throw new IllegalStateException("You attempted to retrieve messages 
after there were none.");
@@ -97,11 +121,13 @@ public class ClientAuthWrapper {
 
     @Override
     public void onError(Throwable t) {
-      ex = t;
+      completed.completeExceptionally(t);
     }
 
     private class AuthSender implements ClientAuthSender {
 
+      private boolean errored = false;
+
       @Override
       public void send(byte[] payload) {
         responseObserver.onNext(HandshakeRequest.newBuilder()
@@ -111,6 +137,8 @@ public class ClientAuthWrapper {
 
       @Override
       public void onError(String message, Throwable cause) {
+        this.errored = true;
+        Objects.requireNonNull(cause);
         responseObserver.onError(cause);
       }
 
@@ -118,7 +146,7 @@ public class ClientAuthWrapper {
 
     @Override
     public void onCompleted() {
-      completed = true;
+      completed.complete(true);
     }
   }
 
diff --git 
a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java 
b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java
index a19126b..0507d3b 100644
--- 
a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java
+++ 
b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java
@@ -36,6 +36,8 @@ public interface ServerAuthHandler {
   /**
    * Handle the initial handshake with the client.
    *
+   * @param outgoing A writer to send messages to the client.
+   * @param incoming An iterator of messages from the client.
    * @return true if client is authenticated, false otherwise.
    */
   boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming);
diff --git 
a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java 
b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
index f0c5dae..a3c698b 100644
--- 
a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
+++ 
b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
@@ -58,6 +58,7 @@ public class ServerAuthWrapper {
 
         responseObserver.onError(Status.PERMISSION_DENIED.asException());
       } catch (Exception ex) {
+        ex.printStackTrace();
         responseObserver.onError(ex);
       }
     };
@@ -109,6 +110,7 @@ public class ServerAuthWrapper {
 
     @Override
     public void onError(Throwable t) {
+      completed = true;
       while (future == null) {/* busy wait */}
       future.cancel(true);
     }
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java 
b/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java
new file mode 100644
index 0000000..bfaf660
--- /dev/null
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Iterator;
+import java.util.Optional;
+
+import org.apache.arrow.flight.auth.ClientAuthHandler;
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+
+import org.junit.Test;
+
+public class TestAuth {
+
+  /** An auth handler that does not send messages should not block the server 
forever. */
+  @Test(expected = RuntimeException.class)
+  public void noMessages() throws Exception {
+    try (final BufferAllocator allocator = new 
RootAllocator(Integer.MAX_VALUE);
+        final FlightServer s = FlightTestUtil
+            .getStartedServer(
+                location -> FlightServer.builder(allocator, location, new 
NoOpFlightProducer()).authHandler(
+                    new OneshotAuthHandler()).build());
+        final FlightClient client = FlightClient.builder(allocator, 
s.getLocation()).build()) {
+      client.authenticate(new ClientAuthHandler() {
+        @Override
+        public void authenticate(ClientAuthSender outgoing, Iterator<byte[]> 
incoming) {
+        }
+
+        @Override
+        public byte[] getCallToken() {
+          return new byte[0];
+        }
+      });
+    }
+  }
+
+  /** An auth handler that sends an error should not block the server forever. 
*/
+  @Test(expected = RuntimeException.class)
+  public void clientError() throws Exception {
+    try (final BufferAllocator allocator = new 
RootAllocator(Integer.MAX_VALUE);
+        final FlightServer s = FlightTestUtil
+            .getStartedServer(
+                location -> FlightServer.builder(allocator, location, new 
NoOpFlightProducer()).authHandler(
+                    new OneshotAuthHandler()).build());
+        final FlightClient client = FlightClient.builder(allocator, 
s.getLocation()).build()) {
+      client.authenticate(new ClientAuthHandler() {
+        @Override
+        public void authenticate(ClientAuthSender outgoing, Iterator<byte[]> 
incoming) {
+          outgoing.send(new byte[0]);
+          // Ensure the server-side runs
+          incoming.next();
+          outgoing.onError("test", new RuntimeException("test"));
+        }
+
+        @Override
+        public byte[] getCallToken() {
+          return new byte[0];
+        }
+      });
+    }
+  }
+
+  private static class OneshotAuthHandler implements ServerAuthHandler {
+
+    @Override
+    public Optional<String> isValid(byte[] token) {
+      return Optional.of("test");
+    }
+
+    @Override
+    public boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> 
incoming) {
+      incoming.next();
+      outgoing.send(new byte[0]);
+      return false;
+    }
+  }
+}
diff --git 
a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java 
b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java
similarity index 99%
rename from java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
rename to 
java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java
index 54bbadb..9fe6b04 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestAuth.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java
@@ -48,7 +48,7 @@ import com.google.common.collect.ImmutableList;
 
 import io.grpc.StatusRuntimeException;
 
-public class TestAuth {
+public class TestBasicAuth {
   final String PERMISSION_DENIED = "PERMISSION_DENIED";
 
   private static final String USERNAME = "flight";

Reply via email to