This is an automated email from the ASF dual-hosted git repository.

mapohl pushed a commit to branch release-1.17
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 5a553c5ef47a9b777f077a20af96666e34fcc7f6
Author: Patrick Lucas <[email protected]>
AuthorDate: Wed Jul 12 14:56:22 2023 +0200

    [FLINK-32583][rest] Fix deadlock in RestClient
---
 .../org/apache/flink/runtime/rest/RestClient.java  |  53 ++++++++-
 .../apache/flink/runtime/rest/RestClientTest.java  | 124 +++++++++++++++++++++
 2 files changed, 175 insertions(+), 2 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java
index 2ab8fdac5792..b4d7f7786a8e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java
@@ -59,6 +59,8 @@ import 
org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelOption;
+import 
org.apache.flink.shaded.netty4.io.netty.channel.DefaultSelectStrategyFactory;
+import org.apache.flink.shaded.netty4.io.netty.channel.SelectStrategyFactory;
 import 
org.apache.flink.shaded.netty4.io.netty.channel.SimpleChannelInboundHandler;
 import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup;
 import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel;
@@ -90,6 +92,7 @@ import java.io.File;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.StringWriter;
+import java.nio.channels.spi.SelectorProvider;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.ArrayList;
@@ -101,6 +104,7 @@ import java.util.List;
 import java.util.Optional;
 import java.util.ServiceLoader;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.Executor;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -123,10 +127,24 @@ public class RestClient implements AutoCloseableAsync {
 
     private final AtomicBoolean isRunning = new AtomicBoolean(true);
 
+    // Used to track unresolved request futures in case they need to be 
resolved when the client is
+    // closed
+    private final Collection<CompletableFuture<Channel>> 
responseChannelFutures =
+            ConcurrentHashMap.newKeySet();
+
     @VisibleForTesting List<OutboundChannelHandlerFactory> 
outboundChannelHandlerFactories;
 
     public RestClient(Configuration configuration, Executor executor)
             throws ConfigurationException {
+        this(configuration, executor, DefaultSelectStrategyFactory.INSTANCE);
+    }
+
+    @VisibleForTesting
+    RestClient(
+            Configuration configuration,
+            Executor executor,
+            SelectStrategyFactory selectStrategyFactory)
+            throws ConfigurationException {
         Preconditions.checkNotNull(configuration);
         this.executor = Preconditions.checkNotNull(executor);
         this.terminationFuture = new CompletableFuture<>();
@@ -200,8 +218,16 @@ public class RestClient implements AutoCloseableAsync {
                         }
                     }
                 };
+
+        // No NioEventLoopGroup constructor available that allows passing 
nThreads, threadFactory,
+        // and selectStrategyFactory without also passing a SelectorProvider, 
so mimicking its
+        // default value seen in other constructors
         NioEventLoopGroup group =
-                new NioEventLoopGroup(1, new 
ExecutorThreadFactory("flink-rest-client-netty"));
+                new NioEventLoopGroup(
+                        1,
+                        new ExecutorThreadFactory("flink-rest-client-netty"),
+                        SelectorProvider.provider(),
+                        selectStrategyFactory);
 
         bootstrap = new Bootstrap();
         bootstrap
@@ -215,6 +241,11 @@ public class RestClient implements AutoCloseableAsync {
         LOG.debug("Rest client endpoint started.");
     }
 
+    @VisibleForTesting
+    Collection<CompletableFuture<Channel>> getResponseChannelFutures() {
+        return responseChannelFutures;
+    }
+
     @Override
     public CompletableFuture<Void> closeAsync() {
         return shutdownInternally(Time.seconds(10L));
@@ -243,6 +274,8 @@ public class RestClient implements AutoCloseableAsync {
                             .shutdownGracefully(0L, timeout.toMilliseconds(), 
TimeUnit.MILLISECONDS)
                             .addListener(
                                     finished -> {
+                                        notifyResponseFuturesOfShutdown();
+
                                         if (finished.isSuccess()) {
                                             terminationFuture.complete(null);
                                         } else {
@@ -256,6 +289,15 @@ public class RestClient implements AutoCloseableAsync {
         return terminationFuture;
     }
 
+    private void notifyResponseFuturesOfShutdown() {
+        responseChannelFutures.forEach(
+                future ->
+                        future.completeExceptionally(
+                                new IllegalStateException(
+                                        "RestClient closed before request 
completed")));
+        responseChannelFutures.clear();
+    }
+
     public <
                     M extends MessageHeaders<EmptyRequestBody, P, 
EmptyMessageParameters>,
                     P extends ResponseBody>
@@ -468,12 +510,19 @@ public class RestClient implements AutoCloseableAsync {
 
     private <P extends ResponseBody> CompletableFuture<P> submitRequest(
             String targetAddress, int targetPort, Request httpRequest, 
JavaType responseType) {
-        final ChannelFuture connectFuture = bootstrap.connect(targetAddress, 
targetPort);
+        if (!isRunning.get()) {
+            return FutureUtils.completedExceptionally(
+                    new IllegalStateException("RestClient is already closed"));
+        }
 
         final CompletableFuture<Channel> channelFuture = new 
CompletableFuture<>();
+        responseChannelFutures.add(channelFuture);
 
+        final ChannelFuture connectFuture = bootstrap.connect(targetAddress, 
targetPort);
         connectFuture.addListener(
                 (ChannelFuture future) -> {
+                    responseChannelFutures.remove(channelFuture);
+
                     if (future.isSuccess()) {
                         channelFuture.complete(future.channel());
                     } else {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestClientTest.java 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestClientTest.java
index cc99eeeca2ca..370dbacd2cfd 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestClientTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestClientTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.rest;
 
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.rest.messages.EmptyMessageParameters;
 import org.apache.flink.runtime.rest.messages.EmptyRequestBody;
 import org.apache.flink.runtime.rest.messages.EmptyResponseBody;
@@ -33,12 +34,17 @@ import org.apache.flink.util.TestLogger;
 import org.apache.flink.util.concurrent.Executors;
 import org.apache.flink.util.function.CheckedSupplier;
 
+import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
 import org.apache.flink.shaded.netty4.io.netty.channel.ConnectTimeoutException;
+import 
org.apache.flink.shaded.netty4.io.netty.channel.DefaultSelectStrategyFactory;
+import org.apache.flink.shaded.netty4.io.netty.channel.SelectStrategy;
+import org.apache.flink.shaded.netty4.io.netty.channel.SelectStrategyFactory;
 import 
org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus;
 
 import org.junit.Assert;
 import org.junit.ClassRule;
 import org.junit.Test;
+import org.junit.function.ThrowingRunnable;
 
 import java.io.IOException;
 import java.net.ServerSocket;
@@ -51,8 +57,12 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertThrows;
 
 /** Tests for {@link RestClient}. */
 public class RestClientTest extends TestLogger {
@@ -207,6 +217,120 @@ public class RestClientTest extends TestLogger {
         }
     }
 
+    /**
+     * Tests that the futures returned by {@link RestClient} fail immediately 
if the client is
+     * already closed.
+     *
+     * <p>See FLINK-32583
+     */
+    @Test
+    public void testCloseClientBeforeRequest() throws Exception {
+        try (final RestClient restClient =
+                new RestClient(new Configuration(), 
Executors.directExecutor())) {
+            restClient.close(); // Intentionally close the client prior to the 
request
+
+            CompletableFuture<?> future =
+                    restClient.sendRequest(
+                            unroutableIp,
+                            80,
+                            new TestMessageHeaders(),
+                            EmptyMessageParameters.getInstance(),
+                            EmptyRequestBody.getInstance());
+
+            // Call get() on the future with a timeout of 0s so we can test 
that the exception
+            // thrown is not a TimeoutException, which is what would be thrown 
if restClient were
+            // not already closed
+            final ThrowingRunnable getFuture = () -> future.get(0, 
TimeUnit.SECONDS);
+
+            final Throwable cause = assertThrows(ExecutionException.class, 
getFuture).getCause();
+            assertThat(cause, instanceOf(IllegalStateException.class));
+            assertThat(cause.getMessage(), equalTo("RestClient is already 
closed"));
+        }
+    }
+
+    @Test
+    public void testCloseClientWhileProcessingRequest() throws Exception {
+        // Set up a Netty SelectStrategy with latches that allow us to step 
forward through Netty's
+        // request state machine, closing the client at a particular moment
+        final OneShotLatch connectTriggered = new OneShotLatch();
+        final OneShotLatch closeTriggered = new OneShotLatch();
+        final SelectStrategy fallbackSelectStrategy =
+                DefaultSelectStrategyFactory.INSTANCE.newSelectStrategy();
+        final SelectStrategyFactory selectStrategyFactory =
+                () ->
+                        (selectSupplier, hasTasks) -> {
+                            connectTriggered.trigger();
+                            closeTriggered.await();
+
+                            return fallbackSelectStrategy.calculateStrategy(
+                                    selectSupplier, hasTasks);
+                        };
+
+        try (final RestClient restClient =
+                new RestClient(
+                        new Configuration(), Executors.directExecutor(), 
selectStrategyFactory)) {
+            // Check that client's internal collection of pending response 
futures is empty prior to
+            // the request
+            assertThat(restClient.getResponseChannelFutures(), empty());
+
+            final CompletableFuture<?> requestFuture =
+                    restClient.sendRequest(
+                            unroutableIp,
+                            80,
+                            new TestMessageHeaders(),
+                            EmptyMessageParameters.getInstance(),
+                            EmptyRequestBody.getInstance());
+
+            // Check that client's internal collection of pending response 
futures now has one
+            // entry, presumably due to the call to sendRequest
+            assertThat(restClient.getResponseChannelFutures(), hasSize(1));
+
+            // Wait for Netty to start connecting, then while it's paused in 
the SelectStrategy,
+            // close the client before unpausing Netty
+            connectTriggered.await();
+            final CompletableFuture<Void> closeFuture = 
restClient.closeAsync();
+            closeTriggered.trigger();
+
+            // Close should complete successfully
+            closeFuture.get();
+
+            final Throwable cause =
+                    assertThrows(
+                                    ExecutionException.class,
+                                    () -> requestFuture.get(0, 
TimeUnit.SECONDS))
+                            .getCause();
+            assertThat(cause, instanceOf(IllegalStateException.class));
+            assertThat(cause.getMessage(), equalTo("executor not accepting a 
task"));
+        }
+    }
+
+    @Test
+    public void testResponseChannelFuturesResolvedExceptionallyOnClose() 
throws Exception {
+        try (final RestClient restClient =
+                new RestClient(new Configuration(), 
Executors.directExecutor())) {
+            CompletableFuture<Channel> responseChannelFuture = new 
CompletableFuture<>();
+
+            // Add the future to the client's internal collection of pending 
response futures
+            restClient.getResponseChannelFutures().add(responseChannelFuture);
+
+            // Close the client, which should resolve all pending response 
futures exceptionally and
+            // clear the collection
+            restClient.close();
+
+            // Ensure the client's internal collection of pending response 
futures was cleared after
+            // close
+            assertThat(restClient.getResponseChannelFutures(), empty());
+
+            final Throwable cause =
+                    assertThrows(
+                                    ExecutionException.class,
+                                    () -> responseChannelFuture.get(0, 
TimeUnit.SECONDS))
+                            .getCause();
+            assertThat(cause, instanceOf(IllegalStateException.class));
+            assertThat(cause.getMessage(), equalTo("RestClient closed before 
request completed"));
+        }
+    }
+
     private static class TestMessageHeaders
             implements RuntimeMessageHeaders<
                     EmptyRequestBody, EmptyResponseBody, 
EmptyMessageParameters> {

Reply via email to