This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 56dac7574cc Fix MqttIO read checkpoint logic (#36056)
56dac7574cc is described below

commit 56dac7574cca54ed6e183b430dbb4687e140828e
Author: Yi Hu <[email protected]>
AuthorDate: Mon Sep 8 11:08:25 2025 -0400

    Fix MqttIO read checkpoint logic (#36056)
    
    * Fix MqttIO read checkpoint logic
    
    * add tests
---
 .../java/org/apache/beam/sdk/io/mqtt/MqttIO.java   | 84 +++++++++++++++-------
 .../org/apache/beam/sdk/io/mqtt/MqttIOTest.java    | 59 ++++++++++++++-
 2 files changed, 115 insertions(+), 28 deletions(-)

diff --git 
a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java 
b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
index efc51362d06..78876eb6534 100644
--- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
+++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
@@ -422,20 +422,16 @@ public class MqttIO {
   static class MqttCheckpointMark implements UnboundedSource.CheckpointMark, 
Serializable {
 
     @VisibleForTesting String clientId;
-    @VisibleForTesting Instant oldestMessageTimestamp = Instant.now();
     @VisibleForTesting transient List<Message> messages = new ArrayList<>();
 
-    public MqttCheckpointMark() {}
-
-    public MqttCheckpointMark(String id) {
-      clientId = id;
+    public MqttCheckpointMark(String id, List<Message> messages) {
+      this.clientId = id;
+      this.messages = messages;
     }
 
-    public void add(Message message, Instant timestamp) {
-      if (timestamp.isBefore(oldestMessageTimestamp)) {
-        oldestMessageTimestamp = timestamp;
-      }
-      messages.add(message);
+    @VisibleForTesting
+    MqttCheckpointMark(String id) {
+      this.clientId = id;
     }
 
     @Override
@@ -448,7 +444,6 @@ public class MqttIO {
           LOG.warn("Can't ack message for client ID {}", clientId, e);
         }
       }
-      oldestMessageTimestamp = Instant.now();
       messages.clear();
     }
 
@@ -464,7 +459,6 @@ public class MqttIO {
       if (other instanceof MqttCheckpointMark) {
         MqttCheckpointMark that = (MqttCheckpointMark) other;
         return Objects.equals(this.clientId, that.clientId)
-            && Objects.equals(this.oldestMessageTimestamp, 
that.oldestMessageTimestamp)
             && Objects.deepEquals(this.messages, that.messages);
       } else {
         return false;
@@ -473,7 +467,38 @@ public class MqttIO {
 
     @Override
     public int hashCode() {
-      return Objects.hash(clientId, oldestMessageTimestamp, messages);
+      return Objects.hash(clientId, messages);
+    }
+
+    static class Preparer {
+      @VisibleForTesting String clientId;
+      @VisibleForTesting Instant oldestMessageTimestamp = Instant.now();
+      @VisibleForTesting transient List<Message> messages = new ArrayList<>();
+
+      public Preparer(MqttCheckpointMark checkpointMark) {
+        clientId = checkpointMark.clientId;
+        messages = checkpointMark.messages;
+      }
+
+      public Preparer(String id) {
+        clientId = id;
+      }
+
+      public Preparer() {}
+
+      public void add(Message message, Instant timestamp) {
+        if (timestamp.isBefore(oldestMessageTimestamp)) {
+          oldestMessageTimestamp = timestamp;
+        }
+        messages.add(message);
+      }
+
+      MqttCheckpointMark newCheckpoint() {
+        List<Message> currentMessages = messages;
+        messages = new ArrayList<>();
+        oldestMessageTimestamp = Instant.now();
+        return new MqttCheckpointMark(clientId, currentMessages);
+      }
     }
   }
 
@@ -489,16 +514,20 @@ public class MqttIO {
     @Override
     @SuppressWarnings("unchecked")
     public UnboundedReader<T> createReader(
-        PipelineOptions options, MqttCheckpointMark checkpointMark) {
+        PipelineOptions options, @Nullable MqttCheckpointMark checkpointMark) {
       final UnboundedMqttReader<T> unboundedMqttReader;
+      MqttCheckpointMark.Preparer preparer =
+          checkpointMark == null
+              ? new MqttCheckpointMark.Preparer()
+              : new MqttCheckpointMark.Preparer(checkpointMark);
       if (spec.withMetadata()) {
         unboundedMqttReader =
             new UnboundedMqttReader<>(
                 this,
-                checkpointMark,
+                preparer,
                 message -> (T) MqttRecord.of(message.getTopic(), 
message.getPayload()));
       } else {
-        unboundedMqttReader = new UnboundedMqttReader<>(this, checkpointMark);
+        unboundedMqttReader = new UnboundedMqttReader<>(this, preparer);
       }
 
       return unboundedMqttReader;
@@ -538,25 +567,26 @@ public class MqttIO {
     private BlockingConnection connection;
     private T current;
     private Instant currentTimestamp;
-    private MqttCheckpointMark checkpointMark;
+    private final MqttCheckpointMark.Preparer checkpointPreparer;
     private SerializableFunction<Message, T> extractFn;
 
-    public UnboundedMqttReader(UnboundedMqttSource<T> source, 
MqttCheckpointMark checkpointMark) {
+    public UnboundedMqttReader(
+        UnboundedMqttSource<T> source, MqttCheckpointMark.Preparer 
checkpointPreparer) {
       this.source = source;
       this.current = null;
-      if (checkpointMark != null) {
-        this.checkpointMark = checkpointMark;
+      if (checkpointPreparer != null) {
+        this.checkpointPreparer = checkpointPreparer;
       } else {
-        this.checkpointMark = new MqttCheckpointMark();
+        this.checkpointPreparer = new MqttCheckpointMark.Preparer();
       }
       this.extractFn = message -> (T) message.getPayload();
     }
 
     public UnboundedMqttReader(
         UnboundedMqttSource<T> source,
-        MqttCheckpointMark checkpointMark,
+        MqttCheckpointMark.Preparer checkpointPreparer,
         SerializableFunction<Message, T> extractFn) {
-      this(source, checkpointMark);
+      this(source, checkpointPreparer);
       this.extractFn = extractFn;
     }
 
@@ -567,7 +597,7 @@ public class MqttIO {
       try {
         client = spec.connectionConfiguration().createClient();
         LOG.debug("Reader client ID is {}", client.getClientId());
-        checkpointMark.clientId = client.getClientId().toString();
+        checkpointPreparer.clientId = client.getClientId().toString();
         connection = createConnection(client);
         connection.subscribe(
             new Topic[] {new Topic(spec.connectionConfiguration().getTopic(), 
QoS.AT_LEAST_ONCE)});
@@ -587,7 +617,7 @@ public class MqttIO {
         }
         current = this.extractFn.apply(message);
         currentTimestamp = Instant.now();
-        checkpointMark.add(message, currentTimestamp);
+        checkpointPreparer.add(message, currentTimestamp);
       } catch (Exception e) {
         throw new IOException(e);
       }
@@ -608,12 +638,12 @@ public class MqttIO {
 
     @Override
     public Instant getWatermark() {
-      return checkpointMark.oldestMessageTimestamp;
+      return checkpointPreparer.oldestMessageTimestamp;
     }
 
     @Override
     public UnboundedSource.CheckpointMark getCheckpointMark() {
-      return checkpointMark;
+      return checkpointPreparer.newCheckpoint();
     }
 
     @Override
diff --git 
a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java 
b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
index f0b4fab3953..754c88f0c6a 100644
--- 
a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
+++ 
b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
@@ -27,6 +27,7 @@ import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
@@ -50,11 +51,13 @@ import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.fusesource.hawtbuf.Buffer;
 import org.fusesource.mqtt.client.BlockingConnection;
+import org.fusesource.mqtt.client.Callback;
 import org.fusesource.mqtt.client.MQTT;
 import org.fusesource.mqtt.client.Message;
 import org.fusesource.mqtt.client.QoS;
 import org.fusesource.mqtt.client.Topic;
 import org.joda.time.Duration;
+import org.joda.time.Instant;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Ignore;
@@ -286,6 +289,61 @@ public class MqttIOTest {
     pipeline.run();
   }
 
+  private static class FakeMessage extends Message {
+
+    private int ackCount;
+
+    public FakeMessage() {
+      super(null, null, null, null);
+      this.ackCount = 0;
+    }
+
+    @Override
+    public void ack() {
+      ++ackCount;
+    }
+
+    @Override
+    public void ack(final Callback<Void> unused) {
+      ++ackCount;
+    }
+
+    public int getAckCount() {
+      return ackCount;
+    }
+  }
+
+  @Test
+  public void testReadCheckpoint() {
+    MqttIO.MqttCheckpointMark.Preparer preparer = new 
MqttIO.MqttCheckpointMark.Preparer("id");
+    ArrayList<Message> messages = new ArrayList<>();
+    for (int i = 0; i < 5; ++i) {
+      messages.add(new FakeMessage());
+    }
+    preparer.add(messages.get(0), Instant.ofEpochMilli(20));
+    preparer.add(messages.get(1), Instant.ofEpochMilli(10));
+    preparer.add(messages.get(2), Instant.ofEpochMilli(30));
+    assertEquals(Instant.ofEpochMilli(10), preparer.oldestMessageTimestamp);
+    MqttIO.MqttCheckpointMark checkpointA = preparer.newCheckpoint();
+    preparer.add(messages.get(3), Instant.ofEpochMilli(40));
+    preparer.add(messages.get(4), Instant.ofEpochMilli(50));
+    MqttIO.MqttCheckpointMark checkpointB = preparer.newCheckpoint();
+    assertTrue(
+        Arrays.stream(messages.toArray()).allMatch((m -> ((FakeMessage) 
m).getAckCount() == 0)));
+    checkpointA.finalizeCheckpoint();
+    // only messages in finalized checkpoint acked
+    assertTrue(
+        Arrays.stream(messages.subList(0, 3).toArray())
+            .allMatch((m -> ((FakeMessage) m).getAckCount() == 1)));
+    assertTrue(
+        Arrays.stream(messages.subList(3, 5).toArray())
+            .allMatch((m -> ((FakeMessage) m).getAckCount() == 0)));
+    checkpointB.finalizeCheckpoint();
+    // all messaged acked once
+    assertTrue(
+        Arrays.stream(messages.toArray()).allMatch((m -> ((FakeMessage) 
m).getAckCount() == 1)));
+  }
+
   @Test
   public void testWrite() throws Exception {
     final int numberOfTestMessages = 200;
@@ -560,7 +618,6 @@ public class MqttIOTest {
     // the number of messages of the decoded checkpoint should be zero
     assertEquals(0, cp2.messages.size());
     assertEquals(cp1.clientId, cp2.clientId);
-    assertEquals(cp1.oldestMessageTimestamp, cp2.oldestMessageTimestamp);
   }
 
   /**

Reply via email to