This is an automated email from the ASF dual-hosted git repository.
popduke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/bifromq.git
The following commit(s) were added to refs/heads/main by this push:
new 8de92fea8 Fixed potential CME thrown from signal fetch (#232)
8de92fea8 is described below
commit 8de92fea8dff5bf41a0c2efe3a3915f973742e27
Author: liaodongnian <[email protected]>
AuthorDate: Sat Mar 7 09:13:44 2026 +0800
Fixed potential CME thrown from signal fetch (#232)
---
.../bifromq/inbox/server/InboxFetchPipeline.java | 28 ++++--
.../server/InboxFetchPipelineMappingTest.java | 103 +++++++++++++++++++++
2 files changed, 124 insertions(+), 7 deletions(-)
diff --git
a/bifromq-inbox/bifromq-inbox-server/src/main/java/org/apache/bifromq/inbox/server/InboxFetchPipeline.java
b/bifromq-inbox/bifromq-inbox-server/src/main/java/org/apache/bifromq/inbox/server/InboxFetchPipeline.java
index f06e56040..7259742c9 100644
---
a/bifromq-inbox/bifromq-inbox-server/src/main/java/org/apache/bifromq/inbox/server/InboxFetchPipeline.java
+++
b/bifromq-inbox/bifromq-inbox-server/src/main/java/org/apache/bifromq/inbox/server/InboxFetchPipeline.java
@@ -25,8 +25,7 @@ import static
org.apache.bifromq.inbox.util.PipelineUtil.PIPELINE_ATTR_KEY_ID;
import io.grpc.stub.StreamObserver;
import io.reactivex.rxjava3.disposables.Disposable;
-import java.util.Collections;
-import java.util.HashSet;
+import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
@@ -90,7 +89,7 @@ final class InboxFetchPipeline extends
AckStream<InboxFetchHint, InboxFetched> i
fetchHint.getSessionId());
inboxSessionMap.computeIfAbsent(
new InboxId(fetchHint.getInboxId(),
fetchHint.getIncarnation()),
- k1 -> new
HashSet<>()).add(fetchHint.getSessionId());
+ k1 ->
ConcurrentHashMap.newKeySet()).add(fetchHint.getSessionId());
}
v.lastFetchQoS0Seq.set(
Math.max(fetchHint.getLastFetchQoS0Seq(),
v.lastFetchQoS0Seq.get()));
@@ -131,16 +130,31 @@ final class InboxFetchPipeline extends
AckStream<InboxFetchHint, InboxFetched> i
public boolean signalFetch(String inboxId, long incarnation, long now) {
log.trace("Signal fetch: tenantId={}, inboxId={}", tenantId, inboxId);
// signal fetch won't refresh expiry
- Set<Long> sessionIds = inboxSessionMap.getOrDefault(new
InboxId(inboxId, incarnation), Collections.emptySet());
- for (Long sessionId : sessionIds) {
+ InboxId inboxKey = new InboxId(inboxId, incarnation);
+ Set<Long> sessionIds = inboxSessionMap.get(inboxKey);
+ if (sessionIds == null || sessionIds.isEmpty()) {
+ return false;
+ }
+ boolean triggered = false;
+ Iterator<Long> itr = sessionIds.iterator();
+ while (itr.hasNext()) {
+ Long sessionId = itr.next();
FetchState fetchState = inboxFetchSessions.get(sessionId);
- if (fetchState != null && fetchState.signalFetchTS.get() < now) {
+ if (fetchState == null) {
+ itr.remove();
+ continue;
+ }
+ triggered = true;
+ if (fetchState.signalFetchTS.get() < now) {
fetchState.hasMore.set(true);
fetchState.signalFetchTS.set(now);
fetch(fetchState);
}
}
- return !sessionIds.isEmpty();
+ if (sessionIds.isEmpty()) {
+ inboxSessionMap.remove(inboxKey, sessionIds);
+ }
+ return triggered;
}
@Override
diff --git
a/bifromq-inbox/bifromq-inbox-server/src/test/java/org/apache/bifromq/inbox/server/InboxFetchPipelineMappingTest.java
b/bifromq-inbox/bifromq-inbox-server/src/test/java/org/apache/bifromq/inbox/server/InboxFetchPipelineMappingTest.java
index 071fab65c..cff671dfa 100644
---
a/bifromq-inbox/bifromq-inbox-server/src/test/java/org/apache/bifromq/inbox/server/InboxFetchPipelineMappingTest.java
+++
b/bifromq-inbox/bifromq-inbox-server/src/test/java/org/apache/bifromq/inbox/server/InboxFetchPipelineMappingTest.java
@@ -25,20 +25,27 @@ import static org.awaitility.Awaitility.await;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
import io.grpc.Context;
import io.grpc.stub.ServerCallStreamObserver;
import io.micrometer.core.instrument.Timer;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
+import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
import lombok.SneakyThrows;
import org.apache.bifromq.baserpc.RPCContext;
import org.apache.bifromq.baserpc.metrics.IRPCMeter;
@@ -205,12 +212,96 @@ public class InboxFetchPipelineMappingTest {
pipeline.close();
}
+ @Test
+ public void shouldCleanStaleSessionIdWhenFetchStateMissing() throws
Exception {
+ InboxFetcherRegistry registry = new InboxFetcherRegistry();
+ InboxFetchPipeline pipeline = new InboxFetchPipeline(responseObserver,
noopFetcher(), registry);
+
+ long sessionId = 4004L;
+ pipeline.onNext(hint(sessionId, 1));
+
+ Map<Long, ?> fetchSessions = fetchSessions(pipeline);
+ fetchSessions.remove(sessionId);
+
+ Map<?, Set<Long>> sessionMap = inboxSessionMap(pipeline);
+ Set<Long> sessionIds = sessionMap.values().iterator().next();
+ assertTrue(sessionIds.contains(sessionId));
+
+ boolean signalled = pipeline.signalFetch(INBOX, INCARNATION,
System.nanoTime());
+
+ assertFalse(signalled);
+ assertTrue(sessionMap.isEmpty());
+ }
+
+ @Test
+ public void shouldNotThrowWhenSignalFetchConcurrentWithSessionRemoval()
throws Exception {
+ InboxFetcherRegistry registry = new InboxFetcherRegistry();
+ CountingFetcher fetcher = new CountingFetcher();
+ InboxFetchPipeline pipeline = new InboxFetchPipeline(responseObserver,
fetcher, registry);
+
+ long sessionA = 5005L;
+ long sessionB = 6006L;
+
+ pipeline.onNext(hint(sessionA, 5));
+ pipeline.onNext(hint(sessionB, 5));
+
+ CountDownLatch latch = new CountDownLatch(1);
+ AtomicReference<Throwable> error = new AtomicReference<>();
+
+ Thread signalThread = new Thread(() -> {
+ try {
+ latch.await();
+ for (int i = 0; i < 500; i++) {
+ pipeline.signalFetch(INBOX, INCARNATION,
System.nanoTime());
+ }
+ } catch (Throwable t) {
+ error.compareAndSet(null, t);
+ }
+ });
+
+ Thread removeThread = new Thread(() -> {
+ try {
+ latch.await();
+ for (int i = 0; i < 500; i++) {
+ pipeline.onNext(hint(sessionB, -1));
+ pipeline.onNext(hint(sessionB, 5));
+ }
+ } catch (Throwable t) {
+ error.compareAndSet(null, t);
+ }
+ });
+
+ signalThread.start();
+ removeThread.start();
+ latch.countDown();
+ signalThread.join();
+ removeThread.join();
+
+ assertNull(error.get());
+ await().until(() -> fetcher.fetchCount.get() > 0);
+ pipeline.close();
+ }
+
private InboxFetched lastReceived() {
synchronized (received) {
return received.get(received.size() - 1);
}
}
+ @SuppressWarnings("unchecked")
+ private Map<Long, ?> fetchSessions(InboxFetchPipeline pipeline) throws
Exception {
+ Field field =
InboxFetchPipeline.class.getDeclaredField("inboxFetchSessions");
+ field.setAccessible(true);
+ return (Map<Long, ?>) field.get(pipeline);
+ }
+
+ @SuppressWarnings("unchecked")
+ private Map<?, Set<Long>> inboxSessionMap(InboxFetchPipeline pipeline)
throws Exception {
+ Field field =
InboxFetchPipeline.class.getDeclaredField("inboxSessionMap");
+ field.setAccessible(true);
+ return (Map<?, Set<Long>>) field.get(pipeline);
+ }
+
private static class TestFetcher implements InboxFetchPipeline.Fetcher {
private final BlockingQueue<FetchRequest> requests = new
LinkedBlockingQueue<>();
private final BlockingQueue<CompletableFuture<Fetched>> responses =
new LinkedBlockingQueue<>();
@@ -237,4 +328,16 @@ public class InboxFetchPipelineMappingTest {
future.complete(fetched);
}
}
+
+ private static class CountingFetcher implements InboxFetchPipeline.Fetcher
{
+ private final AtomicInteger fetchCount = new AtomicInteger();
+
+ @Override
+ public CompletableFuture<Fetched> fetch(FetchRequest request) {
+ fetchCount.incrementAndGet();
+ return CompletableFuture.completedFuture(Fetched.newBuilder()
+ .setResult(Fetched.Result.OK)
+ .build());
+ }
+ }
}