This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 0262ee53c60 [BEAM-13519] Solve race issues when the server responds 
with an error before the GrpcStateClient finishes being constructed. (#17240)
0262ee53c60 is described below

commit 0262ee53c6018d929a8a40fdf66735cc7e934951
Author: Luke Cwik <[email protected]>
AuthorDate: Mon Apr 4 14:32:41 2022 -0700

    [BEAM-13519] Solve race issues when the server responds with an error 
before the GrpcStateClient finishes being constructed. (#17240)
    
    * [BEAM-13519] Solve race issues when the server responds with an error 
before the GrpcStateClient finishes.
    
    The issue was that the InboundObserver can be invoked before 
outboundObserverFactory#outboundObserverFor returns meaning that
    the server is waiting for a response for cache.remove but 
cache.computeIfAbsent is being invoked at the same time.
    
    Another issue was that the outstandingRequests map could be updated with 
another request within GrpcStateClient during closeAndCleanup meaning that the 
CompleteableFuture would never be completed exceptionally.
    
    Passes 1000 times locally now without getting stuck or failing.
---
 .../harness/state/BeamFnStateGrpcClientCache.java  | 105 ++++++++++++++-------
 .../state/BeamFnStateGrpcClientCacheTest.java      |  83 ++++++++--------
 2 files changed, 117 insertions(+), 71 deletions(-)

diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java
index d028ef61d45..e272a98902a 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java
@@ -18,10 +18,9 @@
 package org.apache.beam.fn.harness.state;
 
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentMap;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnStateGrpc;
@@ -45,7 +44,7 @@ import org.slf4j.LoggerFactory;
 public class BeamFnStateGrpcClientCache {
   private static final Logger LOG = 
LoggerFactory.getLogger(BeamFnStateGrpcClientCache.class);
 
-  private final ConcurrentMap<ApiServiceDescriptor, BeamFnStateClient> cache;
+  private final Map<ApiServiceDescriptor, BeamFnStateClient> cache;
   private final ManagedChannelFactory channelFactory;
   private final OutboundObserverFactory outboundObserverFactory;
   private final IdGenerator idGenerator;
@@ -59,7 +58,7 @@ public class BeamFnStateGrpcClientCache {
     // This showed a 1-2% improvement in the ProcessBundleBenchmark#testState* 
benchmarks.
     this.channelFactory = channelFactory.withDirectExecutor();
     this.outboundObserverFactory = outboundObserverFactory;
-    this.cache = new ConcurrentHashMap<>();
+    this.cache = new HashMap<>();
   }
 
   /**
@@ -67,30 +66,53 @@ public class BeamFnStateGrpcClientCache {
    * {@link ApiServiceDescriptor} currently has a {@link BeamFnStateClient} 
bound to the same
    * channel.
    */
-  public BeamFnStateClient forApiServiceDescriptor(ApiServiceDescriptor 
apiServiceDescriptor)
-      throws IOException {
-    return cache.computeIfAbsent(apiServiceDescriptor, 
this::createBeamFnStateClient);
-  }
-
-  private BeamFnStateClient createBeamFnStateClient(ApiServiceDescriptor 
apiServiceDescriptor) {
-    return new GrpcStateClient(apiServiceDescriptor);
+  public synchronized BeamFnStateClient forApiServiceDescriptor(
+      ApiServiceDescriptor apiServiceDescriptor) throws IOException {
+    // We specifically are synchronized so that we only create one 
GrpcStateClient at a time
+    // preventing a race where multiple GrpcStateClient objects might be 
constructed at the same
+    // for the same ApiServiceDescriptor.
+    BeamFnStateClient rval;
+    synchronized (cache) {
+      rval = cache.get(apiServiceDescriptor);
+    }
+    if (rval == null) {
+      // We can't be synchronized on cache while constructing the 
GrpcStateClient since if the
+      // connection fails, onError may be invoked from the gRPC thread which 
will invoke
+      // closeAndCleanUp that clears the cache.
+      rval = new GrpcStateClient(apiServiceDescriptor);
+      synchronized (cache) {
+        cache.put(apiServiceDescriptor, rval);
+      }
+    }
+    return rval;
   }
 
   /** A {@link BeamFnStateClient} for a given {@link ApiServiceDescriptor}. */
   private class GrpcStateClient implements BeamFnStateClient {
+    private final Object lock = new Object();
     private final ApiServiceDescriptor apiServiceDescriptor;
-    private final ConcurrentMap<String, CompletableFuture<StateResponse>> 
outstandingRequests;
+    private final Map<String, CompletableFuture<StateResponse>> 
outstandingRequests;
     private final StreamObserver<StateRequest> outboundObserver;
     private final ManagedChannel channel;
-    private volatile RuntimeException closed;
+    private RuntimeException closed;
+    private boolean errorDuringConstruction;
 
     private GrpcStateClient(ApiServiceDescriptor apiServiceDescriptor) {
       this.apiServiceDescriptor = apiServiceDescriptor;
-      this.outstandingRequests = new ConcurrentHashMap<>();
+      this.outstandingRequests = new HashMap<>();
       this.channel = channelFactory.forDescriptor(apiServiceDescriptor);
+      this.errorDuringConstruction = false;
       this.outboundObserver =
           outboundObserverFactory.outboundObserverFor(
               BeamFnStateGrpc.newStub(channel)::state, new InboundObserver());
+      // Due to safe object publishing, the InboundObserver may invoke 
closeAndCleanUp before this
+      // constructor completes. In that case there is a race where 
outboundObserver may have not
+      // been initialized and hence we invoke onCompleted here.
+      synchronized (lock) {
+        if (errorDuringConstruction) {
+          outboundObserver.onCompleted();
+        }
+      }
     }
 
     @Override
@@ -98,7 +120,13 @@ public class BeamFnStateGrpcClientCache {
       requestBuilder.setId(idGenerator.getId());
       StateRequest request = requestBuilder.build();
       CompletableFuture<StateResponse> response = new CompletableFuture<>();
-      outstandingRequests.put(request.getId(), response);
+      synchronized (lock) {
+        if (closed != null) {
+          response.completeExceptionally(closed);
+          return response;
+        }
+        outstandingRequests.put(request.getId(), response);
+      }
 
       // If the server closes, gRPC will throw an error if onNext is called.
       LOG.debug("Sending StateRequest {}", request);
@@ -106,27 +134,33 @@ public class BeamFnStateGrpcClientCache {
       return response;
     }
 
-    private synchronized void closeAndCleanUp(RuntimeException cause) {
-      if (closed != null) {
-        return;
-      }
-      cache.remove(apiServiceDescriptor);
-      closed = cause;
-
-      // Make a copy of the map to make the view of the outstanding requests 
consistent.
-      Map<String, CompletableFuture<StateResponse>> outstandingRequestsCopy =
-          new ConcurrentHashMap<>(outstandingRequests);
+    private void closeAndCleanUp(RuntimeException cause) {
+      synchronized (lock) {
+        if (closed != null) {
+          return;
+        }
+        closed = cause;
 
-      if (outstandingRequestsCopy.isEmpty()) {
-        outboundObserver.onCompleted();
-        return;
-      }
+        synchronized (cache) {
+          cache.remove(apiServiceDescriptor);
+        }
 
-      outstandingRequests.clear();
-      LOG.error("BeamFnState failed, clearing outstanding requests {}", 
outstandingRequestsCopy);
+        if (!outstandingRequests.isEmpty()) {
+          LOG.error("BeamFnState failed, clearing outstanding requests {}", 
outstandingRequests);
+          for (CompletableFuture<StateResponse> entry : 
outstandingRequests.values()) {
+            entry.completeExceptionally(cause);
+          }
+          outstandingRequests.clear();
+        }
 
-      for (CompletableFuture<StateResponse> entry : 
outstandingRequestsCopy.values()) {
-        entry.completeExceptionally(cause);
+        // Due to safe object publishing, outboundObserver may be null since 
InboundObserver may
+        // call closeAndCleanUp before the GrpcStateClient finishes 
construction. In this case
+        // we defer invoking onCompleted to the GrpcStateClient constructor.
+        if (outboundObserver == null) {
+          errorDuringConstruction = true;
+        } else {
+          outboundObserver.onCompleted();
+        }
       }
     }
 
@@ -143,7 +177,10 @@ public class BeamFnStateGrpcClientCache {
       @Override
       public void onNext(StateResponse value) {
         LOG.debug("Received StateResponse {}", value);
-        CompletableFuture<StateResponse> responseFuture = 
outstandingRequests.remove(value.getId());
+        CompletableFuture<StateResponse> responseFuture;
+        synchronized (lock) {
+          responseFuture = outstandingRequests.remove(value.getId());
+        }
         if (responseFuture == null) {
           LOG.warn("Dropped unknown StateResponse {}", value);
           return;
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java
index 1615a59cb9a..a729755fc12 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java
@@ -28,14 +28,19 @@ import java.util.UUID;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 import java.util.concurrent.LinkedBlockingQueue;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnStateGrpc;
+import 
org.apache.beam.model.fnexecution.v1.BeamFnStateGrpc.BeamFnStateImplBase;
 import org.apache.beam.model.pipeline.v1.Endpoints;
 import org.apache.beam.sdk.fn.IdGenerators;
 import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
 import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
+import org.apache.beam.sdk.fn.test.TestExecutors;
+import org.apache.beam.sdk.fn.test.TestExecutors.TestExecutorService;
 import org.apache.beam.sdk.fn.test.TestStreams;
 import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.Server;
 import org.apache.beam.vendor.grpc.v1p43p2.io.grpc.Status;
@@ -46,7 +51,7 @@ import 
org.apache.beam.vendor.grpc.v1p43p2.io.grpc.stub.StreamObserver;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
 import org.junit.After;
 import org.junit.Before;
-import org.junit.Ignore;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -59,6 +64,8 @@ public class BeamFnStateGrpcClientCacheTest {
   private static final String TEST_ERROR = "TEST ERROR";
   private static final String SERVER_ERROR = "SERVER ERROR";
 
+  @Rule public TestExecutorService executor = 
TestExecutors.from(Executors::newCachedThreadPool);
+
   private Endpoints.ApiServiceDescriptor apiServiceDescriptor;
   private Server testServer;
   private BeamFnStateGrpcClientCache clientCache;
@@ -103,7 +110,6 @@ public class BeamFnStateGrpcClientCacheTest {
   }
 
   @Test
-  @Ignore("(BEAM-13519) Java precommit timing out")
   public void testCachingOfClient() throws Exception {
     Endpoints.ApiServiceDescriptor otherApiServiceDescriptor =
         Endpoints.ApiServiceDescriptor.newBuilder()
@@ -112,18 +118,17 @@ public class BeamFnStateGrpcClientCacheTest {
     Server testServer2 =
         InProcessServerBuilder.forName(otherApiServiceDescriptor.getUrl())
             .addService(
-                new BeamFnStateGrpc.BeamFnStateImplBase() {
+                new BeamFnStateImplBase() {
                   @Override
                   public StreamObserver<StateRequest> state(
                       StreamObserver<StateResponse> outboundObserver) {
-                    throw new IllegalStateException("Unexpected in test.");
+                    throw new RuntimeException();
                   }
                 })
             .build();
     testServer2.start();
 
     try {
-
       assertSame(
           clientCache.forApiServiceDescriptor(apiServiceDescriptor),
           clientCache.forApiServiceDescriptor(apiServiceDescriptor));
@@ -164,25 +169,27 @@ public class BeamFnStateGrpcClientCacheTest {
   }
 
   @Test
+  // The checker erroneously flags that the CompletableFuture is not being 
resolved since it is the
+  // result to Executor#submit.
+  @SuppressWarnings("FutureReturnValueIgnored")
   public void testServerErrorCausesPendingAndFutureCallsToFail() throws 
Exception {
     BeamFnStateClient client = 
clientCache.forApiServiceDescriptor(apiServiceDescriptor);
 
-    CompletableFuture<StateResponse> inflight =
-        client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS));
-
-    // Wait for the client to connect.
-    StreamObserver<StateResponse> outboundServerObserver = 
outboundServerObservers.take();
-    // Send an error from the server.
-    outboundServerObserver.onError(
-        new 
StatusRuntimeException(Status.INTERNAL.withDescription(SERVER_ERROR)));
-
-    try {
-      inflight.get();
-      fail("Expected unsuccessful response due to server error");
-    } catch (ExecutionException e) {
-      assertThat(e.toString(), containsString(SERVER_ERROR));
-    }
-
+    Future<CompletableFuture<StateResponse>> stateResponse =
+        executor.submit(() -> 
client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS)));
+    Future<Void> serverResponse =
+        executor.submit(
+            () -> {
+              // Wait for the client to connect.
+              StreamObserver<StateResponse> outboundServerObserver = 
outboundServerObservers.take();
+              // Send an error from the server.
+              outboundServerObserver.onError(
+                  new 
StatusRuntimeException(Status.INTERNAL.withDescription(SERVER_ERROR)));
+              return null;
+            });
+
+    CompletableFuture<StateResponse> inflight = stateResponse.get();
+    serverResponse.get();
     try {
       inflight.get();
       fail("Expected unsuccessful response due to server error");
@@ -192,27 +199,29 @@ public class BeamFnStateGrpcClientCacheTest {
   }
 
   @Test
+  // The checker erroneously flags that the CompletableFuture is not being 
resolved since it is the
+  // result to Executor#submit.
+  @SuppressWarnings("FutureReturnValueIgnored")
   public void testServerCompletionCausesPendingAndFutureCallsToFail() throws 
Exception {
     BeamFnStateClient client = 
clientCache.forApiServiceDescriptor(apiServiceDescriptor);
 
-    CompletableFuture<StateResponse> inflight =
-        client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS));
-
-    // Wait for the client to connect.
-    StreamObserver<StateResponse> outboundServerObserver = 
outboundServerObservers.take();
-    // Send that the server is done.
-    outboundServerObserver.onCompleted();
-
+    Future<CompletableFuture<StateResponse>> stateResponse =
+        executor.submit(() -> 
client.handle(StateRequest.newBuilder().setInstructionId(SUCCESS)));
+    Future<Void> serverResponse =
+        executor.submit(
+            () -> {
+              // Wait for the client to connect.
+              StreamObserver<StateResponse> outboundServerObserver = 
outboundServerObservers.take();
+              // Send that the server is done.
+              outboundServerObserver.onCompleted();
+              return null;
+            });
+
+    CompletableFuture<StateResponse> inflight = stateResponse.get();
+    serverResponse.get();
     try {
       inflight.get();
-      fail("Expected unsuccessful response due to server completion");
-    } catch (ExecutionException e) {
-      assertThat(e.toString(), containsString("Server hanged up"));
-    }
-
-    try {
-      inflight.get();
-      fail("Expected unsuccessful response due to server completion");
+      fail("Expected unsuccessful response due to server error");
     } catch (ExecutionException e) {
       assertThat(e.toString(), containsString("Server hanged up"));
     }

Reply via email to