This is an automated email from the ASF dual-hosted git repository.
gujiaweijoe pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/bifromq.git
The following commit(s) were added to refs/heads/main by this push:
new eb9ed9f8 Fixed potential bytebuf leaks in edge cases (#164)
eb9ed9f8 is described below
commit eb9ed9f8674e2725d5e65362564b8e805fea5cdb
Author: Yonny(Yu) Hao <[email protected]>
AuthorDate: Thu Aug 7 11:50:36 2025 +0800
Fixed potential bytebuf leaks in edge cases (#164)
1. release received bytebuf when failed to upgrade ws
2. release received bytebuf when connection rejected due to rate limiting
---
.../mqtt/handler/ConnectionRateLimitHandler.java | 33 ++++++++++++++++------
.../mqtt/handler/ws/WebSocketOnlyHandler.java | 2 ++
.../handler/ConnectionRateLimitHandlerTest.java | 22 +++++++++++++--
.../mqtt/handler/ws/WebSocketOnlyHandlerTest.java | 8 ++++--
4 files changed, 50 insertions(+), 15 deletions(-)
diff --git
a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandler.java
b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandler.java
index eb8845b1..e28d9f9b 100644
---
a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandler.java
+++
b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandler.java
@@ -21,30 +21,25 @@ package org.apache.bifromq.mqtt.handler;
import static
org.apache.bifromq.plugin.eventcollector.ThreadLocalEventPool.getLocal;
-import org.apache.bifromq.plugin.eventcollector.IEventCollector;
-import
org.apache.bifromq.plugin.eventcollector.mqttbroker.channelclosed.ChannelError;
import com.google.common.util.concurrent.RateLimiter;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
+import io.netty.util.ReferenceCountUtil;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import lombok.extern.slf4j.Slf4j;
+import org.apache.bifromq.plugin.eventcollector.IEventCollector;
+import
org.apache.bifromq.plugin.eventcollector.mqttbroker.channelclosed.ChannelError;
@Slf4j
@ChannelHandler.Sharable
public class ConnectionRateLimitHandler extends ChannelDuplexHandler {
- /**
- * Initialize the pipeline when the connection is accepted.
- */
- public interface ChannelPipelineInitializer {
- void initialize(ChannelPipeline pipeline);
- }
-
private final RateLimiter rateLimiter;
private final IEventCollector eventCollector;
private final ChannelPipelineInitializer initializer;
+ private boolean accepted = false;
public ConnectionRateLimitHandler(RateLimiter limiter,
IEventCollector eventCollector,
@@ -57,9 +52,13 @@ public class ConnectionRateLimitHandler extends
ChannelDuplexHandler {
@Override
public void channelActive(ChannelHandlerContext ctx) {
if (rateLimiter.tryAcquire()) {
+ accepted = true;
initializer.initialize(ctx.pipeline());
ctx.fireChannelActive();
+ // Remove this handler after the connection is accepted
+ ctx.pipeline().remove(this);
} else {
+ accepted = false;
log.debug("Connection dropped due to exceed limit");
eventCollector.report(getLocal(ChannelError.class)
.peerAddress(ChannelAttrs.socketAddress(ctx.channel()))
@@ -73,4 +72,20 @@ public class ConnectionRateLimitHandler extends
ChannelDuplexHandler {
}, ThreadLocalRandom.current().nextLong(100, 3000),
TimeUnit.MILLISECONDS);
}
}
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) {
+ if (!accepted) {
+ ReferenceCountUtil.release(msg);
+ return;
+ }
+ ctx.fireChannelRead(msg);
+ }
+
+ /**
+ * Initialize the pipeline when the connection is accepted.
+ */
+ public interface ChannelPipelineInitializer {
+ void initialize(ChannelPipeline pipeline);
+ }
}
diff --git
a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandler.java
b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandler.java
index 98838a78..fcd4a056 100644
---
a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandler.java
+++
b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandler.java
@@ -27,6 +27,7 @@ import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpResponseStatus;
+import io.netty.util.ReferenceCountUtil;
/**
* A simple handler that rejects all requests that are not WebSocket upgrade
requests.
@@ -46,6 +47,7 @@ public class WebSocketOnlyHandler extends
SimpleChannelInboundHandler<FullHttpRe
!req.headers().get(HttpHeaderNames.UPGRADE,
"").equalsIgnoreCase("websocket")) {
FullHttpResponse response =
new DefaultFullHttpResponse(req.protocolVersion(),
HttpResponseStatus.BAD_REQUEST);
+ ReferenceCountUtil.release(req);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
} else {
// Proceed with the pipeline setup for WebSocket.
diff --git
a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandlerTest.java
b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandlerTest.java
index 545bc067..05f618dc 100644
---
a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandlerTest.java
+++
b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandlerTest.java
@@ -14,7 +14,7 @@
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
- * under the License.
+ * under the License.
*/
package org.apache.bifromq.mqtt.handler;
@@ -24,15 +24,18 @@ import static org.mockito.Mockito.any;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
-import org.apache.bifromq.plugin.eventcollector.EventType;
-import org.apache.bifromq.plugin.eventcollector.IEventCollector;
import com.google.common.util.concurrent.RateLimiter;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.embedded.EmbeddedChannel;
import java.util.concurrent.TimeUnit;
+import org.apache.bifromq.plugin.eventcollector.EventType;
+import org.apache.bifromq.plugin.eventcollector.IEventCollector;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.testng.annotations.BeforeMethod;
@@ -64,6 +67,8 @@ public class ConnectionRateLimitHandlerTest {
verify(initializer).initialize(channel.pipeline());
assertTrue(channel.isActive());
+ // After initialization, the handler should be removed
+ assertFalse(channel.pipeline().toMap().containsValue(handler));
}
@Test
@@ -78,4 +83,15 @@ public class ConnectionRateLimitHandlerTest {
assertFalse(channel.isActive());
verify(eventCollector).report(argThat(e -> e.type() ==
EventType.CHANNEL_ERROR));
}
+
+ @Test
+ public void testRejectedConnectionReleasesInboundByteBuf() {
+ when(rateLimiter.tryAcquire()).thenReturn(false);
+ EmbeddedChannel channel = new EmbeddedChannel(handler);
+
+ ByteBuf buf = Unpooled.buffer();
+ assertTrue(buf.refCnt() > 0);
+ channel.writeInbound(buf);
+ assertEquals(buf.refCnt(), 0);
+ }
}
diff --git
a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandlerTest.java
b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandlerTest.java
index 9c55a6a5..70d86289 100644
---
a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandlerTest.java
+++
b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandlerTest.java
@@ -37,8 +37,8 @@ import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
public class WebSocketOnlyHandlerTest {
- private EmbeddedChannel channel;
private final String websocketPath = "/mqtt";
+ private EmbeddedChannel channel;
@BeforeMethod
public void setUp() {
@@ -67,10 +67,12 @@ public class WebSocketOnlyHandlerTest {
HttpVersion.HTTP_1_1, HttpMethod.GET, "/wrongpath");
request.headers().set(HttpHeaderNames.UPGRADE, "websocket");
+ assertTrue(request.refCnt() > 0);
assertFalse(channel.writeInbound(request));
FullHttpResponse response = channel.readOutbound();
+ assertEquals(request.refCnt(), 0);
assertNotNull(response);
- assertEquals(HttpResponseStatus.BAD_REQUEST, response.status());
+ assertEquals(response.status(), HttpResponseStatus.BAD_REQUEST);
}
@Test
@@ -82,7 +84,7 @@ public class WebSocketOnlyHandlerTest {
assertFalse(channel.writeInbound(request));
FullHttpResponse response = channel.readOutbound();
assertNotNull(response);
- assertEquals(HttpResponseStatus.BAD_REQUEST, response.status());
+ assertEquals(response.status(), HttpResponseStatus.BAD_REQUEST);
}
@Test