This is an automated email from the ASF dual-hosted git repository.

arvid pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new b662ed1  [FLINK-23182][connectors/rabbitmq] Fix connection leak in 
RMQSource
b662ed1 is described below

commit b662ed19406c8f552c916b6840c8d62fad64b77c
Author: Michal Ciesielczyk <michal.ciesielc...@deep.bi>
AuthorDate: Wed Jun 30 00:23:42 2021 +0200

    [FLINK-23182][connectors/rabbitmq] Fix connection leak in RMQSource
---
 .../streaming/connectors/rabbitmq/RMQSource.java   | 49 ++++++++++----------
 .../connectors/rabbitmq/RMQSourceTest.java         | 53 ++++++++++++++++++++++
 2 files changed, 76 insertions(+), 26 deletions(-)

diff --git 
a/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java
 
b/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java
index 3273dfd..b3444f1 100644
--- 
a/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java
+++ 
b/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java
@@ -28,6 +28,8 @@ import 
org.apache.flink.streaming.api.functions.source.MessageAcknowledgingSourc
 import 
org.apache.flink.streaming.api.functions.source.MultipleIdsMessageAcknowledgingSourceBase;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import 
org.apache.flink.streaming.connectors.rabbitmq.common.RMQConnectionConfig;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.Preconditions;
 
 import com.rabbitmq.client.AMQP;
@@ -257,6 +259,7 @@ public class RMQSource<OUT> extends 
MultipleIdsMessageAcknowledgingSourceBase<OU
             channel.basicConsume(queueName, autoAck, consumer);
 
         } catch (IOException e) {
+            IOUtils.closeAllQuietly(channel, connection);
             throw new RuntimeException(
                     "Cannot create RMQ connection with "
                             + queueName
@@ -273,44 +276,38 @@ public class RMQSource<OUT> extends 
MultipleIdsMessageAcknowledgingSourceBase<OU
     @Override
     public void close() throws Exception {
         super.close();
+        Exception exception = null;
 
         try {
             if (consumer != null && channel != null) {
                 channel.basicCancel(consumer.getConsumerTag());
             }
         } catch (IOException e) {
-            throw new RuntimeException(
-                    "Error while cancelling RMQ consumer on "
-                            + queueName
-                            + " at "
-                            + rmqConnectionConfig.getHost(),
-                    e);
+            exception =
+                    new RuntimeException(
+                            "Error while cancelling RMQ consumer on "
+                                    + queueName
+                                    + " at "
+                                    + rmqConnectionConfig.getHost(),
+                            e);
         }
 
         try {
-            if (channel != null) {
-                channel.close();
-            }
+            IOUtils.closeAll(channel, connection);
         } catch (IOException e) {
-            throw new RuntimeException(
-                    "Error while closing RMQ channel with "
-                            + queueName
-                            + " at "
-                            + rmqConnectionConfig.getHost(),
-                    e);
+            exception =
+                    ExceptionUtils.firstOrSuppressed(
+                            new RuntimeException(
+                                    "Error while closing RMQ source with "
+                                            + queueName
+                                            + " at "
+                                            + rmqConnectionConfig.getHost(),
+                                    e),
+                            exception);
         }
 
-        try {
-            if (connection != null) {
-                connection.close();
-            }
-        } catch (IOException e) {
-            throw new RuntimeException(
-                    "Error while closing RMQ connection with "
-                            + queueName
-                            + " at "
-                            + rmqConnectionConfig.getHost(),
-                    e);
+        if (exception != null) {
+            throw exception;
         }
     }
 
diff --git 
a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
 
b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
index d6b4807..c30bc93 100644
--- 
a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
+++ 
b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
@@ -63,6 +63,7 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
@@ -158,6 +159,25 @@ public class RMQSourceTest {
     }
 
     @Test
+    public void testResourceCleanupOnOpenFailure() throws Exception {
+        RMQConnectionConfig connectionConfig = 
Mockito.mock(RMQConnectionConfig.class);
+        ConnectionFactory connectionFactory = 
Mockito.mock(ConnectionFactory.class);
+        Connection connection = Mockito.mock(Connection.class);
+        
Mockito.when(connectionConfig.getConnectionFactory()).thenReturn(connectionFactory);
+        Mockito.when(connectionConfig.getHost()).thenReturn("hostDummy");
+        Mockito.when(connectionFactory.newConnection()).thenReturn(connection);
+        Mockito.when(connection.createChannel()).thenThrow(new IOException());
+
+        RMQSource<String> rmqSource =
+                new RMQSource<>(
+                        connectionConfig, "queueDummy", true, new 
StringDeserializationScheme());
+        RuntimeException ex =
+                assertThrows(RuntimeException.class, () -> rmqSource.open(new 
Configuration()));
+        assertEquals("Cannot create RMQ connection with queueDummy at 
hostDummy", ex.getMessage());
+        Mockito.verify(rmqSource.connection, Mockito.atLeastOnce()).close();
+    }
+
+    @Test
     public void testOpenCallDeclaresQueueInStandardMode() throws Exception {
         FunctionInitializationContext mockContext = getMockContext();
 
@@ -178,6 +198,39 @@ public class RMQSourceTest {
     }
 
     @Test
+    public void testResourceCleanupOnClose() throws Exception {
+        FunctionInitializationContext mockContext = getMockContext();
+
+        RMQConnectionConfig connectionConfig = 
Mockito.mock(RMQConnectionConfig.class);
+        ConnectionFactory connectionFactory = 
Mockito.mock(ConnectionFactory.class);
+        Connection connection = Mockito.mock(Connection.class);
+        Channel channel = Mockito.mock(Channel.class);
+
+        
Mockito.when(connectionConfig.getConnectionFactory()).thenReturn(connectionFactory);
+        Mockito.when(connectionFactory.newConnection()).thenReturn(connection);
+        Mockito.when(connectionConfig.getHost()).thenReturn("hostDummy");
+        Mockito.when(connection.createChannel()).thenReturn(channel);
+        Mockito.doThrow(new IOException("Consumer cancel 
error")).when(channel).basicCancel(any());
+        Mockito.doThrow(new IOException("Channel 
error")).when(channel).close();
+        Mockito.doThrow(new IOException("Connection 
error")).when(connection).close();
+
+        RMQSource<String> rmqSource = new 
RMQMockedRuntimeTestSource(connectionConfig);
+        rmqSource.initializeState(mockContext);
+        rmqSource.open(new Configuration());
+
+        Exception ex = assertThrows(RuntimeException.class, rmqSource::close);
+        assertEquals(
+                "Error while cancelling RMQ consumer on queueDummy at 
hostDummy", ex.getMessage());
+        assertEquals(1, ex.getSuppressed().length);
+        assertEquals(
+                "Error while closing RMQ source with queueDummy at hostDummy",
+                ex.getSuppressed()[0].getMessage());
+        Mockito.verify(rmqSource.channel, 
Mockito.atLeastOnce()).basicCancel(any());
+        Mockito.verify(rmqSource.channel, Mockito.atLeastOnce()).close();
+        Mockito.verify(rmqSource.connection, Mockito.atLeastOnce()).close();
+    }
+
+    @Test
     public void testCheckpointing() throws Exception {
         source.autoAck = false;
 

Reply via email to