This is an automated email from the ASF dual-hosted git repository.
yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 56dac7574cc Fix MqttIO read checkpoint logic (#36056)
56dac7574cc is described below
commit 56dac7574cca54ed6e183b430dbb4687e140828e
Author: Yi Hu <[email protected]>
AuthorDate: Mon Sep 8 11:08:25 2025 -0400
Fix MqttIO read checkpoint logic (#36056)
* Fix MqttIO read checkpoint logic
* add tests
---
.../java/org/apache/beam/sdk/io/mqtt/MqttIO.java | 84 +++++++++++++++-------
.../org/apache/beam/sdk/io/mqtt/MqttIOTest.java | 59 ++++++++++++++-
2 files changed, 115 insertions(+), 28 deletions(-)
diff --git
a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
index efc51362d06..78876eb6534 100644
--- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
+++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
@@ -422,20 +422,16 @@ public class MqttIO {
static class MqttCheckpointMark implements UnboundedSource.CheckpointMark,
Serializable {
@VisibleForTesting String clientId;
- @VisibleForTesting Instant oldestMessageTimestamp = Instant.now();
@VisibleForTesting transient List<Message> messages = new ArrayList<>();
- public MqttCheckpointMark() {}
-
- public MqttCheckpointMark(String id) {
- clientId = id;
+ public MqttCheckpointMark(String id, List<Message> messages) {
+ this.clientId = id;
+ this.messages = messages;
}
- public void add(Message message, Instant timestamp) {
- if (timestamp.isBefore(oldestMessageTimestamp)) {
- oldestMessageTimestamp = timestamp;
- }
- messages.add(message);
+ @VisibleForTesting
+ MqttCheckpointMark(String id) {
+ this.clientId = id;
}
@Override
@@ -448,7 +444,6 @@ public class MqttIO {
LOG.warn("Can't ack message for client ID {}", clientId, e);
}
}
- oldestMessageTimestamp = Instant.now();
messages.clear();
}
@@ -464,7 +459,6 @@ public class MqttIO {
if (other instanceof MqttCheckpointMark) {
MqttCheckpointMark that = (MqttCheckpointMark) other;
return Objects.equals(this.clientId, that.clientId)
- && Objects.equals(this.oldestMessageTimestamp,
that.oldestMessageTimestamp)
&& Objects.deepEquals(this.messages, that.messages);
} else {
return false;
@@ -473,7 +467,38 @@ public class MqttIO {
@Override
public int hashCode() {
- return Objects.hash(clientId, oldestMessageTimestamp, messages);
+ return Objects.hash(clientId, messages);
+ }
+
+ static class Preparer {
+ @VisibleForTesting String clientId;
+ @VisibleForTesting Instant oldestMessageTimestamp = Instant.now();
+ @VisibleForTesting transient List<Message> messages = new ArrayList<>();
+
+ public Preparer(MqttCheckpointMark checkpointMark) {
+ clientId = checkpointMark.clientId;
+ messages = checkpointMark.messages;
+ }
+
+ public Preparer(String id) {
+ clientId = id;
+ }
+
+ public Preparer() {}
+
+ public void add(Message message, Instant timestamp) {
+ if (timestamp.isBefore(oldestMessageTimestamp)) {
+ oldestMessageTimestamp = timestamp;
+ }
+ messages.add(message);
+ }
+
+ MqttCheckpointMark newCheckpoint() {
+ List<Message> currentMessages = messages;
+ messages = new ArrayList<>();
+ oldestMessageTimestamp = Instant.now();
+ return new MqttCheckpointMark(clientId, currentMessages);
+ }
}
}
@@ -489,16 +514,20 @@ public class MqttIO {
@Override
@SuppressWarnings("unchecked")
public UnboundedReader<T> createReader(
- PipelineOptions options, MqttCheckpointMark checkpointMark) {
+ PipelineOptions options, @Nullable MqttCheckpointMark checkpointMark) {
final UnboundedMqttReader<T> unboundedMqttReader;
+ MqttCheckpointMark.Preparer preparer =
+ checkpointMark == null
+ ? new MqttCheckpointMark.Preparer()
+ : new MqttCheckpointMark.Preparer(checkpointMark);
if (spec.withMetadata()) {
unboundedMqttReader =
new UnboundedMqttReader<>(
this,
- checkpointMark,
+ preparer,
message -> (T) MqttRecord.of(message.getTopic(),
message.getPayload()));
} else {
- unboundedMqttReader = new UnboundedMqttReader<>(this, checkpointMark);
+ unboundedMqttReader = new UnboundedMqttReader<>(this, preparer);
}
return unboundedMqttReader;
@@ -538,25 +567,26 @@ public class MqttIO {
private BlockingConnection connection;
private T current;
private Instant currentTimestamp;
- private MqttCheckpointMark checkpointMark;
+ private final MqttCheckpointMark.Preparer checkpointPreparer;
private SerializableFunction<Message, T> extractFn;
- public UnboundedMqttReader(UnboundedMqttSource<T> source,
MqttCheckpointMark checkpointMark) {
+ public UnboundedMqttReader(
+ UnboundedMqttSource<T> source, MqttCheckpointMark.Preparer
checkpointPreparer) {
this.source = source;
this.current = null;
- if (checkpointMark != null) {
- this.checkpointMark = checkpointMark;
+ if (checkpointPreparer != null) {
+ this.checkpointPreparer = checkpointPreparer;
} else {
- this.checkpointMark = new MqttCheckpointMark();
+ this.checkpointPreparer = new MqttCheckpointMark.Preparer();
}
this.extractFn = message -> (T) message.getPayload();
}
public UnboundedMqttReader(
UnboundedMqttSource<T> source,
- MqttCheckpointMark checkpointMark,
+ MqttCheckpointMark.Preparer checkpointPreparer,
SerializableFunction<Message, T> extractFn) {
- this(source, checkpointMark);
+ this(source, checkpointPreparer);
this.extractFn = extractFn;
}
@@ -567,7 +597,7 @@ public class MqttIO {
try {
client = spec.connectionConfiguration().createClient();
LOG.debug("Reader client ID is {}", client.getClientId());
- checkpointMark.clientId = client.getClientId().toString();
+ checkpointPreparer.clientId = client.getClientId().toString();
connection = createConnection(client);
connection.subscribe(
new Topic[] {new Topic(spec.connectionConfiguration().getTopic(),
QoS.AT_LEAST_ONCE)});
@@ -587,7 +617,7 @@ public class MqttIO {
}
current = this.extractFn.apply(message);
currentTimestamp = Instant.now();
- checkpointMark.add(message, currentTimestamp);
+ checkpointPreparer.add(message, currentTimestamp);
} catch (Exception e) {
throw new IOException(e);
}
@@ -608,12 +638,12 @@ public class MqttIO {
@Override
public Instant getWatermark() {
- return checkpointMark.oldestMessageTimestamp;
+ return checkpointPreparer.oldestMessageTimestamp;
}
@Override
public UnboundedSource.CheckpointMark getCheckpointMark() {
- return checkpointMark;
+ return checkpointPreparer.newCheckpoint();
}
@Override
diff --git
a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
index f0b4fab3953..754c88f0c6a 100644
---
a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
+++
b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
@@ -27,6 +27,7 @@ import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
@@ -50,11 +51,13 @@ import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.fusesource.hawtbuf.Buffer;
import org.fusesource.mqtt.client.BlockingConnection;
+import org.fusesource.mqtt.client.Callback;
import org.fusesource.mqtt.client.MQTT;
import org.fusesource.mqtt.client.Message;
import org.fusesource.mqtt.client.QoS;
import org.fusesource.mqtt.client.Topic;
import org.joda.time.Duration;
+import org.joda.time.Instant;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
@@ -286,6 +289,61 @@ public class MqttIOTest {
pipeline.run();
}
+ private static class FakeMessage extends Message {
+
+ private int ackCount;
+
+ public FakeMessage() {
+ super(null, null, null, null);
+ this.ackCount = 0;
+ }
+
+ @Override
+ public void ack() {
+ ++ackCount;
+ }
+
+ @Override
+ public void ack(final Callback<Void> unused) {
+ ++ackCount;
+ }
+
+ public int getAckCount() {
+ return ackCount;
+ }
+ }
+
+ @Test
+ public void testReadCheckpoint() {
+ MqttIO.MqttCheckpointMark.Preparer preparer = new
MqttIO.MqttCheckpointMark.Preparer("id");
+ ArrayList<Message> messages = new ArrayList<>();
+ for (int i = 0; i < 5; ++i) {
+ messages.add(new FakeMessage());
+ }
+ preparer.add(messages.get(0), Instant.ofEpochMilli(20));
+ preparer.add(messages.get(1), Instant.ofEpochMilli(10));
+ preparer.add(messages.get(2), Instant.ofEpochMilli(30));
+ assertEquals(Instant.ofEpochMilli(10), preparer.oldestMessageTimestamp);
+ MqttIO.MqttCheckpointMark checkpointA = preparer.newCheckpoint();
+ preparer.add(messages.get(3), Instant.ofEpochMilli(40));
+ preparer.add(messages.get(4), Instant.ofEpochMilli(50));
+ MqttIO.MqttCheckpointMark checkpointB = preparer.newCheckpoint();
+ assertTrue(
+ Arrays.stream(messages.toArray()).allMatch((m -> ((FakeMessage)
m).getAckCount() == 0)));
+ checkpointA.finalizeCheckpoint();
+ // only messages in finalized checkpoint acked
+ assertTrue(
+ Arrays.stream(messages.subList(0, 3).toArray())
+ .allMatch((m -> ((FakeMessage) m).getAckCount() == 1)));
+ assertTrue(
+ Arrays.stream(messages.subList(3, 5).toArray())
+ .allMatch((m -> ((FakeMessage) m).getAckCount() == 0)));
+ checkpointB.finalizeCheckpoint();
+ // all messaged acked once
+ assertTrue(
+ Arrays.stream(messages.toArray()).allMatch((m -> ((FakeMessage)
m).getAckCount() == 1)));
+ }
+
@Test
public void testWrite() throws Exception {
final int numberOfTestMessages = 200;
@@ -560,7 +618,6 @@ public class MqttIOTest {
// the number of messages of the decoded checkpoint should be zero
assertEquals(0, cp2.messages.size());
assertEquals(cp1.clientId, cp2.clientId);
- assertEquals(cp1.oldestMessageTimestamp, cp2.oldestMessageTimestamp);
}
/**