This is an automated email from the ASF dual-hosted git repository.

gujiaweijoe pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/bifromq.git


The following commit(s) were added to refs/heads/main by this push:
     new eb9ed9f8 Fixed potential bytebuf leaks in edge cases (#164)
eb9ed9f8 is described below

commit eb9ed9f8674e2725d5e65362564b8e805fea5cdb
Author: Yonny(Yu) Hao <[email protected]>
AuthorDate: Thu Aug 7 11:50:36 2025 +0800

    Fixed potential bytebuf leaks in edge cases (#164)
    
    1. release received bytebuf when failed to upgrade ws
    2. release received bytebuf when connection rejected due to rate limiting
---
 .../mqtt/handler/ConnectionRateLimitHandler.java   | 33 ++++++++++++++++------
 .../mqtt/handler/ws/WebSocketOnlyHandler.java      |  2 ++
 .../handler/ConnectionRateLimitHandlerTest.java    | 22 +++++++++++++--
 .../mqtt/handler/ws/WebSocketOnlyHandlerTest.java  |  8 ++++--
 4 files changed, 50 insertions(+), 15 deletions(-)

diff --git 
a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandler.java
 
b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandler.java
index eb8845b1..e28d9f9b 100644
--- 
a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandler.java
+++ 
b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandler.java
@@ -21,30 +21,25 @@ package org.apache.bifromq.mqtt.handler;
 
 import static 
org.apache.bifromq.plugin.eventcollector.ThreadLocalEventPool.getLocal;
 
-import org.apache.bifromq.plugin.eventcollector.IEventCollector;
-import 
org.apache.bifromq.plugin.eventcollector.mqttbroker.channelclosed.ChannelError;
 import com.google.common.util.concurrent.RateLimiter;
 import io.netty.channel.ChannelDuplexHandler;
 import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelPipeline;
+import io.netty.util.ReferenceCountUtil;
 import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.TimeUnit;
 import lombok.extern.slf4j.Slf4j;
+import org.apache.bifromq.plugin.eventcollector.IEventCollector;
+import 
org.apache.bifromq.plugin.eventcollector.mqttbroker.channelclosed.ChannelError;
 
 @Slf4j
 @ChannelHandler.Sharable
 public class ConnectionRateLimitHandler extends ChannelDuplexHandler {
-    /**
-     * Initialize the pipeline when the connection is accepted.
-     */
-    public interface ChannelPipelineInitializer {
-        void initialize(ChannelPipeline pipeline);
-    }
-
     private final RateLimiter rateLimiter;
     private final IEventCollector eventCollector;
     private final ChannelPipelineInitializer initializer;
+    private boolean accepted = false;
 
     public ConnectionRateLimitHandler(RateLimiter limiter,
                                       IEventCollector eventCollector,
@@ -57,9 +52,13 @@ public class ConnectionRateLimitHandler extends 
ChannelDuplexHandler {
     @Override
     public void channelActive(ChannelHandlerContext ctx) {
         if (rateLimiter.tryAcquire()) {
+            accepted = true;
             initializer.initialize(ctx.pipeline());
             ctx.fireChannelActive();
+            // Remove this handler after the connection is accepted
+            ctx.pipeline().remove(this);
         } else {
+            accepted = false;
             log.debug("Connection dropped due to exceed limit");
             eventCollector.report(getLocal(ChannelError.class)
                 .peerAddress(ChannelAttrs.socketAddress(ctx.channel()))
@@ -73,4 +72,20 @@ public class ConnectionRateLimitHandler extends 
ChannelDuplexHandler {
             }, ThreadLocalRandom.current().nextLong(100, 3000), 
TimeUnit.MILLISECONDS);
         }
     }
+
+    @Override
+    public void channelRead(ChannelHandlerContext ctx, Object msg) {
+        if (!accepted) {
+            ReferenceCountUtil.release(msg);
+            return;
+        }
+        ctx.fireChannelRead(msg);
+    }
+
+    /**
+     * Initialize the pipeline when the connection is accepted.
+     */
+    public interface ChannelPipelineInitializer {
+        void initialize(ChannelPipeline pipeline);
+    }
 }
diff --git 
a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandler.java
 
b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandler.java
index 98838a78..fcd4a056 100644
--- 
a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandler.java
+++ 
b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandler.java
@@ -27,6 +27,7 @@ import io.netty.handler.codec.http.FullHttpRequest;
 import io.netty.handler.codec.http.FullHttpResponse;
 import io.netty.handler.codec.http.HttpHeaderNames;
 import io.netty.handler.codec.http.HttpResponseStatus;
+import io.netty.util.ReferenceCountUtil;
 
 /**
  * A simple handler that rejects all requests that are not WebSocket upgrade 
requests.
@@ -46,6 +47,7 @@ public class WebSocketOnlyHandler extends 
SimpleChannelInboundHandler<FullHttpRe
             !req.headers().get(HttpHeaderNames.UPGRADE, 
"").equalsIgnoreCase("websocket")) {
             FullHttpResponse response =
                 new DefaultFullHttpResponse(req.protocolVersion(), 
HttpResponseStatus.BAD_REQUEST);
+            ReferenceCountUtil.release(req);
             
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
         } else {
             // Proceed with the pipeline setup for WebSocket.
diff --git 
a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandlerTest.java
 
b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandlerTest.java
index 545bc067..05f618dc 100644
--- 
a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandlerTest.java
+++ 
b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandlerTest.java
@@ -14,7 +14,7 @@
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  * KIND, either express or implied.  See the License for the
  * specific language governing permissions and limitations
- * under the License.    
+ * under the License.
  */
 
 package org.apache.bifromq.mqtt.handler;
@@ -24,15 +24,18 @@ import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertFalse;
 import static org.testng.Assert.assertTrue;
 
-import org.apache.bifromq.plugin.eventcollector.EventType;
-import org.apache.bifromq.plugin.eventcollector.IEventCollector;
 import com.google.common.util.concurrent.RateLimiter;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelPipeline;
 import io.netty.channel.embedded.EmbeddedChannel;
 import java.util.concurrent.TimeUnit;
+import org.apache.bifromq.plugin.eventcollector.EventType;
+import org.apache.bifromq.plugin.eventcollector.IEventCollector;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 import org.testng.annotations.BeforeMethod;
@@ -64,6 +67,8 @@ public class ConnectionRateLimitHandlerTest {
 
         verify(initializer).initialize(channel.pipeline());
         assertTrue(channel.isActive());
+        // After initialization, the handler should be removed
+        assertFalse(channel.pipeline().toMap().containsValue(handler));
     }
 
     @Test
@@ -78,4 +83,15 @@ public class ConnectionRateLimitHandlerTest {
         assertFalse(channel.isActive());
         verify(eventCollector).report(argThat(e -> e.type() == 
EventType.CHANNEL_ERROR));
     }
+
+    @Test
+    public void testRejectedConnectionReleasesInboundByteBuf() {
+        when(rateLimiter.tryAcquire()).thenReturn(false);
+        EmbeddedChannel channel = new EmbeddedChannel(handler);
+
+        ByteBuf buf = Unpooled.buffer();
+        assertTrue(buf.refCnt() > 0);
+        channel.writeInbound(buf);
+        assertEquals(buf.refCnt(), 0);
+    }
 }
diff --git 
a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandlerTest.java
 
b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandlerTest.java
index 9c55a6a5..70d86289 100644
--- 
a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandlerTest.java
+++ 
b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandlerTest.java
@@ -37,8 +37,8 @@ import org.testng.annotations.BeforeMethod;
 import org.testng.annotations.Test;
 
 public class WebSocketOnlyHandlerTest {
-    private EmbeddedChannel channel;
     private final String websocketPath = "/mqtt";
+    private EmbeddedChannel channel;
 
     @BeforeMethod
     public void setUp() {
@@ -67,10 +67,12 @@ public class WebSocketOnlyHandlerTest {
             HttpVersion.HTTP_1_1, HttpMethod.GET, "/wrongpath");
         request.headers().set(HttpHeaderNames.UPGRADE, "websocket");
 
+        assertTrue(request.refCnt() > 0);
         assertFalse(channel.writeInbound(request));
         FullHttpResponse response = channel.readOutbound();
+        assertEquals(request.refCnt(), 0);
         assertNotNull(response);
-        assertEquals(HttpResponseStatus.BAD_REQUEST, response.status());
+        assertEquals(response.status(), HttpResponseStatus.BAD_REQUEST);
     }
 
     @Test
@@ -82,7 +84,7 @@ public class WebSocketOnlyHandlerTest {
         assertFalse(channel.writeInbound(request));
         FullHttpResponse response = channel.readOutbound();
         assertNotNull(response);
-        assertEquals(HttpResponseStatus.BAD_REQUEST, response.status());
+        assertEquals(response.status(), HttpResponseStatus.BAD_REQUEST);
     }
 
     @Test

Reply via email to