This is an automated email from the ASF dual-hosted git repository.

fcsaky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-connector-aws.git


The following commit(s) were added to refs/heads/main by this push:
     new 03e2fad  [FLINK-37627] Fix restart from checkpoint/savepoint at shard 
split can cause data loss
03e2fad is described below

commit 03e2fadf3c874422215e1b09b1e57551377796f4
Author: Arun Lakshman <[email protected]>
AuthorDate: Wed Apr 9 14:37:54 2025 -0700

    [FLINK-37627] Fix restart from checkpoint/savepoint at shard split can 
cause data loss
    
    Co-authored-by: Abhi Gupta <[email protected]>
    
    Closes #198
---
 .../enumerator/KinesisStreamsSourceEnumerator.java |  26 +++--
 .../enumerator/assigner/UniformShardAssigner.java  |   3 +-
 .../source/enumerator/tracker/SplitTracker.java    |   6 ++
 .../source/reader/KinesisStreamsSourceReader.java  |  72 ++++++++++++-
 .../kinesis/source/split/KinesisShardSplit.java    |  32 +++++-
 .../source/split/KinesisShardSplitSerializer.java  | 115 +++++++++++---------
 .../assigner/UniformShardAssignerTest.java         |   4 +-
 .../reader/KinesisStreamsSourceReaderTest.java     | 120 +++++++++++++++++++++
 .../split/KinesisShardSplitSerializerTest.java     |  49 ++++++++-
 .../source/split/KinesisShardSplitTest.java        |  39 +++++++
 .../connector/kinesis/source/util/TestUtil.java    |  13 +++
 11 files changed, 412 insertions(+), 67 deletions(-)

diff --git 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumerator.java
 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumerator.java
index 0a39450..c4302d8 100644
--- 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumerator.java
+++ 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/KinesisStreamsSourceEnumerator.java
@@ -137,16 +137,26 @@ public class KinesisStreamsSourceEnumerator
         }
     }
 
-    private void handleFinishedSplits(int subtask, SplitsFinishedEvent 
splitsFinishedEvent) {
+    private void handleFinishedSplits(int subtaskId, SplitsFinishedEvent 
splitsFinishedEvent) {
         splitTracker.markAsFinished(splitsFinishedEvent.getFinishedSplitIds());
-        splitAssignment
-                .get(subtask)
-                .removeIf(
-                        split ->
-                                splitsFinishedEvent
-                                        .getFinishedSplitIds()
-                                        .contains(split.splitId()));
+        Set<KinesisShardSplit> splitsAssignment = 
splitAssignment.get(subtaskId);
+        // during recovery, splitAssignment may return null since there might 
be no split assigned
+        // to the subtask, but there might be SplitsFinishedEvent from that 
subtask.
+        // We will not do child shard assignment if that is the case since 
that might lead to child
+        // shards trying to get assigned before there being any readers.
+        if (splitsAssignment == null) {
+            LOG.info(
+                    "handleFinishedSplits called for subtask: {} which doesn't 
have any "
+                            + "assigned splits right now. This might happen 
due to job restarts. "
+                            + "Child shard discovery might be delayed until we 
have enough readers."
+                            + "Finished split ids: {}",
+                    subtaskId,
+                    splitsFinishedEvent.getFinishedSplitIds());
+            return;
+        }
 
+        splitsAssignment.removeIf(
+                split -> 
splitsFinishedEvent.getFinishedSplitIds().contains(split.splitId()));
         assignSplits();
     }
 
diff --git 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/assigner/UniformShardAssigner.java
 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/assigner/UniformShardAssigner.java
index f5d45df..d2de3ff 100644
--- 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/assigner/UniformShardAssigner.java
+++ 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/assigner/UniformShardAssigner.java
@@ -44,7 +44,8 @@ public class UniformShardAssigner implements 
KinesisShardAssigner {
     public int assign(KinesisShardSplit split, Context context) {
         Preconditions.checkArgument(
                 !context.getRegisteredReaders().isEmpty(),
-                "Expected at least one registered reader. Unable to assign 
split.");
+                "Expected at least one registered reader. Unable to assign 
split with id: %s.",
+                split.splitId());
         BigInteger hashKeyStart = new BigInteger(split.getStartingHashKey());
         BigInteger hashKeyEnd = new BigInteger(split.getEndingHashKey());
         BigInteger hashKeyMid = hashKeyStart.add(hashKeyEnd).divide(TWO);
diff --git 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/tracker/SplitTracker.java
 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/tracker/SplitTracker.java
index dcd60fb..0d38bd5 100644
--- 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/tracker/SplitTracker.java
+++ 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/enumerator/tracker/SplitTracker.java
@@ -159,6 +159,12 @@ public class SplitTracker {
         return allParentsFinished;
     }
 
+    /**
+     * Checks if split with specified id is finished.
+     *
+     * @param splitId Id of the split to check
+     * @return true if split is finished, otherwise false
+     */
     private boolean isFinished(String splitId) {
         return !knownSplits.containsKey(splitId);
     }
diff --git 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReader.java
 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReader.java
index 281d7fe..ccadba9 100644
--- 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReader.java
+++ 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReader.java
@@ -33,9 +33,14 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import software.amazon.awssdk.services.kinesis.model.Record;
 
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.NavigableMap;
+import java.util.Set;
+import java.util.TreeMap;
 
 /**
  * Coordinates the reading from assigned splits. Runs on the TaskManager.
@@ -49,6 +54,8 @@ public class KinesisStreamsSourceReader<T>
 
     private static final Logger LOG = 
LoggerFactory.getLogger(KinesisStreamsSourceReader.class);
     private final Map<String, KinesisShardMetrics> shardMetricGroupMap;
+    private final NavigableMap<Long, Set<KinesisShardSplit>> finishedSplits;
+    private long currentCheckpointId;
 
     public KinesisStreamsSourceReader(
             SingleThreadFetcherManager<Record, KinesisShardSplit> 
splitFetcherManager,
@@ -58,15 +65,67 @@ public class KinesisStreamsSourceReader<T>
             Map<String, KinesisShardMetrics> shardMetricGroupMap) {
         super(splitFetcherManager, recordEmitter, config, context);
         this.shardMetricGroupMap = shardMetricGroupMap;
+        this.finishedSplits = new TreeMap<>();
+        this.currentCheckpointId = Long.MIN_VALUE;
     }
 
     @Override
     protected void onSplitFinished(Map<String, KinesisShardSplitState> 
finishedSplitIds) {
+        if (finishedSplitIds.isEmpty()) {
+            return;
+        }
+        finishedSplits.computeIfAbsent(currentCheckpointId, k -> new 
HashSet<>());
+        finishedSplitIds.values().stream()
+                .map(
+                        finishedSplit ->
+                                new KinesisShardSplit(
+                                        finishedSplit.getStreamArn(),
+                                        finishedSplit.getShardId(),
+                                        
finishedSplit.getNextStartingPosition(),
+                                        
finishedSplit.getKinesisShardSplit().getParentShardIds(),
+                                        
finishedSplit.getKinesisShardSplit().getStartingHashKey(),
+                                        
finishedSplit.getKinesisShardSplit().getEndingHashKey(),
+                                        true))
+                .forEach(split -> 
finishedSplits.get(currentCheckpointId).add(split));
+
         context.sendSourceEventToCoordinator(
                 new SplitsFinishedEvent(new 
HashSet<>(finishedSplitIds.keySet())));
         finishedSplitIds.keySet().forEach(this::unregisterShardMetricGroup);
     }
 
+    /**
+     * At snapshot, we also store the pending finished split ids in the 
current checkpoint so that
+     * in case we have to restore the reader from state, we also send the 
finished split ids
+     * otherwise we run a risk of data loss during restarts of the source 
because of the
+     * SplitsFinishedEvent going missing.
+     *
+     * @param checkpointId the checkpoint id
+     * @return a list of finished splits
+     */
+    @Override
+    public List<KinesisShardSplit> snapshotState(long checkpointId) {
+        this.currentCheckpointId = checkpointId;
+        List<KinesisShardSplit> splits = new 
ArrayList<>(super.snapshotState(checkpointId));
+
+        if (!finishedSplits.isEmpty()) {
+            // Add all finished splits to the snapshot
+            finishedSplits.values().forEach(splits::addAll);
+        }
+
+        return splits;
+    }
+
+    /**
+     * During notifyCheckpointComplete, we should clean up the state of 
finished splits that are
+     * less than or equal to the checkpoint id.
+     *
+     * @param checkpointId the checkpoint id
+     */
+    @Override
+    public void notifyCheckpointComplete(long checkpointId) {
+        finishedSplits.headMap(checkpointId, true).clear();
+    }
+
     @Override
     protected KinesisShardSplitState initializedState(KinesisShardSplit split) 
{
         return new KinesisShardSplitState(split);
@@ -79,8 +138,17 @@ public class KinesisStreamsSourceReader<T>
 
     @Override
     public void addSplits(List<KinesisShardSplit> splits) {
-        splits.forEach(this::registerShardMetricGroup);
-        super.addSplits(splits);
+        List<KinesisShardSplit> unfinishedSplits = new ArrayList<>();
+        for (KinesisShardSplit split : splits) {
+            if (split.isFinished()) {
+                context.sendSourceEventToCoordinator(
+                        new 
SplitsFinishedEvent(Collections.singleton(split.splitId())));
+            } else {
+                unfinishedSplits.add(split);
+            }
+        }
+        unfinishedSplits.forEach(this::registerShardMetricGroup);
+        super.addSplits(unfinishedSplits);
     }
 
     @Override
diff --git 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplit.java
 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplit.java
index fb5fef5..cefa249 100644
--- 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplit.java
+++ 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplit.java
@@ -44,6 +44,7 @@ public final class KinesisShardSplit implements SourceSplit {
     private final Set<String> parentShardIds;
     private final String startingHashKey;
     private final String endingHashKey;
+    private final boolean finished;
 
     public KinesisShardSplit(
             String streamArn,
@@ -52,6 +53,24 @@ public final class KinesisShardSplit implements SourceSplit {
             Set<String> parentShardIds,
             String startingHashKey,
             String endingHashKey) {
+        this(
+                streamArn,
+                shardId,
+                startingPosition,
+                parentShardIds,
+                startingHashKey,
+                endingHashKey,
+                false);
+    }
+
+    public KinesisShardSplit(
+            String streamArn,
+            String shardId,
+            StartingPosition startingPosition,
+            Set<String> parentShardIds,
+            String startingHashKey,
+            String endingHashKey,
+            boolean finished) {
         checkNotNull(streamArn, "streamArn cannot be null");
         checkNotNull(shardId, "shardId cannot be null");
         checkNotNull(startingPosition, "startingPosition cannot be null");
@@ -65,6 +84,11 @@ public final class KinesisShardSplit implements SourceSplit {
         this.parentShardIds = new HashSet<>(parentShardIds);
         this.startingHashKey = startingHashKey;
         this.endingHashKey = endingHashKey;
+        this.finished = finished;
+    }
+
+    public boolean isFinished() {
+        return finished;
     }
 
     @Override
@@ -116,6 +140,8 @@ public final class KinesisShardSplit implements SourceSplit 
{
                 + ", endingHashKey='"
                 + endingHashKey
                 + '\''
+                + ", finished="
+                + finished
                 + '}';
     }
 
@@ -133,7 +159,8 @@ public final class KinesisShardSplit implements SourceSplit 
{
                 && Objects.equals(startingPosition, that.startingPosition)
                 && Objects.equals(parentShardIds, that.parentShardIds)
                 && Objects.equals(startingHashKey, that.startingHashKey)
-                && Objects.equals(endingHashKey, that.endingHashKey);
+                && Objects.equals(endingHashKey, that.endingHashKey)
+                && Objects.equals(finished, that.finished);
     }
 
     @Override
@@ -144,6 +171,7 @@ public final class KinesisShardSplit implements SourceSplit 
{
                 startingPosition,
                 parentShardIds,
                 startingHashKey,
-                endingHashKey);
+                endingHashKey,
+                finished);
     }
 }
diff --git 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitSerializer.java
 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitSerializer.java
index 5433a64..57adfb2 100644
--- 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitSerializer.java
+++ 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/main/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitSerializer.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
 import org.apache.flink.core.io.VersionMismatchException;
+import org.apache.flink.util.function.BiConsumerWithException;
 
 import software.amazon.awssdk.services.kinesis.model.ShardIteratorType;
 
@@ -42,8 +43,8 @@ import java.util.Set;
 @Internal
 public class KinesisShardSplitSerializer implements 
SimpleVersionedSerializer<KinesisShardSplit> {
 
-    private static final int CURRENT_VERSION = 1;
-    private static final Set<Integer> COMPATIBLE_VERSIONS = new 
HashSet<>(Arrays.asList(0, 1));
+    private static final int CURRENT_VERSION = 2;
+    private static final Set<Integer> COMPATIBLE_VERSIONS = new 
HashSet<>(Arrays.asList(0, 1, 2));
 
     @Override
     public int getVersion() {
@@ -52,66 +53,67 @@ public class KinesisShardSplitSerializer implements 
SimpleVersionedSerializer<Ki
 
     @Override
     public byte[] serialize(KinesisShardSplit split) throws IOException {
+        return serialize(split, this::serializeV2);
+    }
+
+    @VisibleForTesting
+    byte[] serialize(
+            KinesisShardSplit split,
+            BiConsumerWithException<KinesisShardSplit, DataOutputStream, 
IOException> serializer)
+            throws IOException {
+
         try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
                 DataOutputStream out = new DataOutputStream(baos)) {
 
-            out.writeUTF(split.getStreamArn());
-            out.writeUTF(split.getShardId());
-            
out.writeUTF(split.getStartingPosition().getShardIteratorType().toString());
-            if (split.getStartingPosition().getStartingMarker() == null) {
-                out.writeBoolean(false);
-            } else {
-                out.writeBoolean(true);
-                Object startingMarker = 
split.getStartingPosition().getStartingMarker();
-                out.writeBoolean(startingMarker instanceof Instant);
-                if (startingMarker instanceof Instant) {
-                    out.writeLong(((Instant) startingMarker).toEpochMilli());
-                }
-                out.writeBoolean(startingMarker instanceof String);
-                if (startingMarker instanceof String) {
-                    out.writeUTF((String) startingMarker);
-                }
-            }
-            out.writeInt(split.getParentShardIds().size());
-            for (String parentShardId : split.getParentShardIds()) {
-                out.writeUTF(parentShardId);
-            }
-            out.writeUTF(split.getStartingHashKey());
-            out.writeUTF(split.getEndingHashKey());
-
+            serializer.accept(split, out);
             out.flush();
+
             return baos.toByteArray();
         }
     }
 
-    /** This method used only to test backwards compatibility of 
deserialization logic. */
+    /** This method is used to test backwards compatibility of deserialization 
logic. */
     @VisibleForTesting
-    byte[] serializeV0(KinesisShardSplit split) throws IOException {
-        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
-                DataOutputStream out = new DataOutputStream(baos)) {
-
-            out.writeUTF(split.getStreamArn());
-            out.writeUTF(split.getShardId());
-            
out.writeUTF(split.getStartingPosition().getShardIteratorType().toString());
-            if (split.getStartingPosition().getStartingMarker() == null) {
-                out.writeBoolean(false);
-            } else {
-                out.writeBoolean(true);
-                Object startingMarker = 
split.getStartingPosition().getStartingMarker();
-                out.writeBoolean(startingMarker instanceof Instant);
-                if (startingMarker instanceof Instant) {
-                    out.writeLong(((Instant) startingMarker).toEpochMilli());
-                }
-                out.writeBoolean(startingMarker instanceof String);
-                if (startingMarker instanceof String) {
-                    out.writeUTF((String) startingMarker);
-                }
+    void serializeV0(KinesisShardSplit split, DataOutputStream out) throws 
IOException {
+        out.writeUTF(split.getStreamArn());
+        out.writeUTF(split.getShardId());
+        
out.writeUTF(split.getStartingPosition().getShardIteratorType().toString());
+        if (split.getStartingPosition().getStartingMarker() == null) {
+            out.writeBoolean(false);
+        } else {
+            out.writeBoolean(true);
+            Object startingMarker = 
split.getStartingPosition().getStartingMarker();
+            out.writeBoolean(startingMarker instanceof Instant);
+            if (startingMarker instanceof Instant) {
+                out.writeLong(((Instant) startingMarker).toEpochMilli());
+            }
+            out.writeBoolean(startingMarker instanceof String);
+            if (startingMarker instanceof String) {
+                out.writeUTF((String) startingMarker);
             }
-            out.flush();
-            return baos.toByteArray();
         }
     }
 
+    /** This method is used to test backwards compatibility of deserialization 
logic. */
+    @VisibleForTesting
+    void serializeV1(KinesisShardSplit split, DataOutputStream out) throws 
IOException {
+        serializeV0(split, out);
+
+        out.writeInt(split.getParentShardIds().size());
+        for (String parentShardId : split.getParentShardIds()) {
+            out.writeUTF(parentShardId);
+        }
+        out.writeUTF(split.getStartingHashKey());
+        out.writeUTF(split.getEndingHashKey());
+    }
+
+    @VisibleForTesting
+    void serializeV2(KinesisShardSplit split, DataOutputStream out) throws 
IOException {
+        serializeV1(split, out);
+
+        out.writeBoolean(split.isFinished());
+    }
+
     @Override
     public KinesisShardSplit deserialize(int version, byte[] serialized) 
throws IOException {
         try (ByteArrayInputStream bais = new ByteArrayInputStream(serialized);
@@ -140,7 +142,8 @@ public class KinesisShardSplitSerializer implements 
SimpleVersionedSerializer<Ki
             }
 
             Set<String> parentShardIds = new HashSet<>();
-            if (version == CURRENT_VERSION) {
+            // parentShardIds was added in V1
+            if (version >= 1) {
                 int parentShardCount = in.readInt();
                 for (int i = 0; i < parentShardCount; i++) {
                     parentShardIds.add(in.readUTF());
@@ -149,7 +152,8 @@ public class KinesisShardSplitSerializer implements 
SimpleVersionedSerializer<Ki
 
             String startingHashKey;
             String endingHashKey;
-            if (version == CURRENT_VERSION) {
+            // startingHashKey and endingHashKey were added in V1
+            if (version >= 1) {
                 startingHashKey = in.readUTF();
                 endingHashKey = in.readUTF();
             } else {
@@ -157,13 +161,20 @@ public class KinesisShardSplitSerializer implements 
SimpleVersionedSerializer<Ki
                 endingHashKey = "0";
             }
 
+            boolean finished = false;
+            // isFinished was added in V2
+            if (version >= 2) {
+                finished = in.readBoolean();
+            }
+
             return new KinesisShardSplit(
                     streamArn,
                     shardId,
                     new StartingPosition(shardIteratorType, startingMarker),
                     parentShardIds,
                     startingHashKey,
-                    endingHashKey);
+                    endingHashKey,
+                    finished);
         }
     }
 }
diff --git 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/enumerator/assigner/UniformShardAssignerTest.java
 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/enumerator/assigner/UniformShardAssignerTest.java
index 338b722..dcd873d 100644
--- 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/enumerator/assigner/UniformShardAssignerTest.java
+++ 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/enumerator/assigner/UniformShardAssignerTest.java
@@ -101,7 +101,9 @@ class UniformShardAssignerTest {
         assertThatExceptionOfType(IllegalArgumentException.class)
                 .isThrownBy(() -> assigner.assign(split, assignerContext))
                 .withMessageContaining(
-                        "Expected at least one registered reader. Unable to 
assign split.");
+                        String.format(
+                                "Expected at least one registered reader. 
Unable to assign split with id: %s.",
+                                split.splitId()));
     }
 
     private void createReaderWithAssignedSplits(
diff --git 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReaderTest.java
 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReaderTest.java
index dfdc584..6118638 100644
--- 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReaderTest.java
+++ 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/reader/KinesisStreamsSourceReaderTest.java
@@ -36,14 +36,18 @@ import org.apache.flink.metrics.testutils.MetricListener;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Supplier;
+import java.util.stream.Collectors;
 
 import static 
org.apache.flink.connector.kinesis.source.util.KinesisStreamProxyProvider.getTestStreamProxy;
+import static 
org.apache.flink.connector.kinesis.source.util.TestUtil.getFinishedTestSplit;
 import static 
org.apache.flink.connector.kinesis.source.util.TestUtil.getTestSplit;
 import static 
org.apache.flink.connector.kinesis.source.util.TestUtil.getTestSplitState;
 import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
@@ -147,4 +151,120 @@ class KinesisStreamsSourceReaderTest {
         TestUtil.assertMillisBehindLatest(
                 split, TestUtil.MILLIS_BEHIND_LATEST_TEST_VALUE, 
metricListener);
     }
+
+    @Test
+    void testSnapshotStateWithFinishedSplits() throws Exception {
+        // Create and add a split
+        KinesisShardSplit split = getTestSplit();
+        List<KinesisShardSplit> splits = Collections.singletonList(split);
+        sourceReader.addSplits(splits);
+
+        // Set checkpoint ID by taking initial snapshot
+        List<KinesisShardSplit> initialSnapshot = 
sourceReader.snapshotState(1L);
+        assertThat(initialSnapshot).hasSize(1).containsExactly(split);
+
+        // Simulate split finishing
+        Map<String, KinesisShardSplitState> finishedSplits = new HashMap<>();
+        finishedSplits.put(split.splitId(), new KinesisShardSplitState(split));
+        sourceReader.onSplitFinished(finishedSplits);
+
+        // Take another snapshot
+        List<KinesisShardSplit> snapshotSplits = 
sourceReader.snapshotState(2L);
+        List<KinesisShardSplit> snapshotFinishedSplits =
+                snapshotSplits.stream()
+                        .filter(KinesisShardSplit::isFinished)
+                        .collect(Collectors.toList());
+        // Verify we have 2 splits - the original split and the finished split
+        assertThat(snapshotSplits).hasSize(2);
+        assertThat(snapshotFinishedSplits)
+                .hasSize(1)
+                .allSatisfy(
+                        s -> {
+                            assertThat(s.splitId()).isEqualTo(split.splitId());
+                        });
+    }
+
+    @Test
+    void testAddSplitsWithStateRestoration() throws Exception {
+        KinesisShardSplit finishedSplit1 = 
getFinishedTestSplit("finished-split-1");
+        KinesisShardSplit finishedSplit2 = 
getFinishedTestSplit("finished-split-2");
+
+        // Create active split
+        KinesisShardSplit activeSplit = getTestSplit();
+
+        List<KinesisShardSplit> allSplits =
+                Arrays.asList(finishedSplit1, finishedSplit2, activeSplit);
+
+        // Clear any previous events
+        testingReaderContext.clearSentEvents();
+
+        // Add splits
+        sourceReader.addSplits(allSplits);
+
+        // Verify finished events were sent
+        List<SourceEvent> events = testingReaderContext.getSentEvents();
+        assertThat(events)
+                .hasSize(2)
+                .allMatch(e -> e instanceof SplitsFinishedEvent)
+                .satisfiesExactlyInAnyOrder(
+                        e ->
+                                assertThat(((SplitsFinishedEvent) 
e).getFinishedSplitIds())
+                                        .containsExactly("finished-split-1"),
+                        e ->
+                                assertThat(((SplitsFinishedEvent) 
e).getFinishedSplitIds())
+                                        .containsExactly("finished-split-2"));
+
+        // Verify metrics registered only for active split
+        
assertThat(shardMetricGroupMap).hasSize(1).containsKey(activeSplit.splitId());
+    }
+
+    @Test
+    void testNotifyCheckpointCompleteRemovesFinishedSplits() throws Exception {
+        KinesisShardSplit split = getTestSplit();
+        List<KinesisShardSplit> splits = Collections.singletonList(split);
+
+        sourceReader.addSplits(splits);
+
+        // Simulate splits finishing at different checkpoints
+        Map<String, KinesisShardSplitState> finishedSplits1 = new HashMap<>();
+        KinesisShardSplit finishedSplit1 = getFinishedTestSplit("split-1");
+        finishedSplits1.put("split-1", new 
KinesisShardSplitState(finishedSplit1));
+        sourceReader.snapshotState(1L); // Set checkpoint ID
+        sourceReader.onSplitFinished(finishedSplits1);
+
+        Map<String, KinesisShardSplitState> finishedSplits2 = new HashMap<>();
+        KinesisShardSplit finishedSplit2 = getFinishedTestSplit("split-2");
+        finishedSplits2.put("split-2", new 
KinesisShardSplitState(finishedSplit2));
+        sourceReader.snapshotState(2L); // Set checkpoint ID
+        sourceReader.onSplitFinished(finishedSplits2);
+
+        // Take snapshot to verify initial state
+        List<KinesisShardSplit> snapshotSplits = 
sourceReader.snapshotState(3L);
+
+        assertThat(snapshotSplits).hasSize(3);
+        assertThat(
+                        snapshotSplits.stream()
+                                .filter(KinesisShardSplit::isFinished)
+                                .map(KinesisShardSplit::splitId)
+                                .collect(Collectors.toList()))
+                .hasSize(2)
+                .containsExactlyInAnyOrder("split-1", "split-2");
+
+        // Complete checkpoint 1
+        sourceReader.notifyCheckpointComplete(1L);
+
+        // Take another snapshot to verify state after completion
+        snapshotSplits = sourceReader.snapshotState(4L);
+
+        // Verify checkpoint 1 splits were removed
+        assertThat(snapshotSplits).hasSize(2);
+        assertThat(
+                        snapshotSplits.stream()
+                                .filter(KinesisShardSplit::isFinished)
+                                .map(KinesisShardSplit::splitId)
+                                .collect(Collectors.toList()))
+                .hasSize(1)
+                .first()
+                .isEqualTo("split-2");
+    }
 }
diff --git 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitSerializerTest.java
 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitSerializerTest.java
index 9b0662b..1aa1014 100644
--- 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitSerializerTest.java
+++ 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitSerializerTest.java
@@ -81,7 +81,7 @@ class KinesisShardSplitSerializerTest {
                         STARTING_HASH_KEY_TEST_VALUE,
                         ENDING_HASH_KEY_TEST_VALUE);
 
-        byte[] oldSerializedState = serializer.serializeV0(initialSplit);
+        byte[] oldSerializedState = serializer.serialize(initialSplit, 
serializer::serializeV0);
         KinesisShardSplit deserializedSplit = serializer.deserialize(0, 
oldSerializedState);
 
         assertThat(deserializedSplit)
@@ -117,6 +117,53 @@ class KinesisShardSplitSerializerTest {
                 
.withMessageContaining(String.valueOf(wrongVersionSerializer.getVersion()));
     }
 
+    @Test
+    void testSerializeAndDeserializeWithFinishedSplits() throws Exception {
+        final KinesisShardSplit initialSplit =
+                new KinesisShardSplit(
+                        STREAM_ARN,
+                        generateShardId(10),
+                        
StartingPosition.continueFromSequenceNumber("some-sequence-number"),
+                        new HashSet<>(Arrays.asList(generateShardId(2), 
generateShardId(5))),
+                        STARTING_HASH_KEY_TEST_VALUE,
+                        ENDING_HASH_KEY_TEST_VALUE,
+                        true);
+
+        KinesisShardSplitSerializer serializer = new 
KinesisShardSplitSerializer();
+
+        byte[] serialized = serializer.serialize(initialSplit);
+        KinesisShardSplit deserializedSplit =
+                serializer.deserialize(serializer.getVersion(), serialized);
+
+        
assertThat(deserializedSplit).usingRecursiveComparison().isEqualTo(initialSplit);
+        assertThat(deserializedSplit.isFinished()).isTrue();
+    }
+
+    @Test
+    void testDeserializeVersion1() throws Exception {
+        final KinesisShardSplitSerializer serializer = new 
KinesisShardSplitSerializer();
+
+        final KinesisShardSplit initialSplit =
+                new KinesisShardSplit(
+                        STREAM_ARN,
+                        generateShardId(10),
+                        
StartingPosition.continueFromSequenceNumber("some-sequence-number"),
+                        new HashSet<>(Arrays.asList(generateShardId(2), 
generateShardId(5))),
+                        STARTING_HASH_KEY_TEST_VALUE,
+                        ENDING_HASH_KEY_TEST_VALUE);
+
+        byte[] oldSerializedState = serializer.serialize(initialSplit, 
serializer::serializeV1);
+        KinesisShardSplit deserializedSplit = serializer.deserialize(1, 
oldSerializedState);
+
+        assertThat(deserializedSplit)
+                .usingRecursiveComparison(
+                        RecursiveComparisonConfiguration.builder()
+                                .withIgnoredFields("finished")
+                                .build())
+                .isEqualTo(initialSplit);
+        assertThat(deserializedSplit.isFinished()).isEqualTo(false);
+    }
+
     private static class WrongVersionSerializer extends 
KinesisShardSplitSerializer {
         @Override
         public int getVersion() {
diff --git 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitTest.java
 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitTest.java
index d00ea18..16dd517 100644
--- 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitTest.java
+++ 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/split/KinesisShardSplitTest.java
@@ -22,10 +22,14 @@ import nl.jqno.equalsverifier.EqualsVerifier;
 import org.junit.jupiter.api.Test;
 
 import java.util.Collections;
+import java.util.HashSet;
+import java.util.NavigableMap;
 import java.util.Set;
+import java.util.TreeMap;
 
 import static 
org.apache.flink.connector.kinesis.source.util.TestUtil.ENDING_HASH_KEY_TEST_VALUE;
 import static 
org.apache.flink.connector.kinesis.source.util.TestUtil.STARTING_HASH_KEY_TEST_VALUE;
+import static org.assertj.core.api.Assertions.assertThat;
 import static 
org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType;
 
 class KinesisShardSplitTest {
@@ -130,4 +134,39 @@ class KinesisShardSplitTest {
     void testEquals() {
         EqualsVerifier.forClass(KinesisShardSplit.class).verify();
     }
+
+    @Test
+    void testFinishedSplitsMapConstructor() {
+        NavigableMap<Long, Set<String>> finishedSplitsMap = new TreeMap<>();
+        Set<String> splits = new HashSet<>();
+        splits.add("split1");
+        splits.add("split2");
+        finishedSplitsMap.put(1L, splits);
+
+        KinesisShardSplit split =
+                new KinesisShardSplit(
+                        STREAM_ARN,
+                        SHARD_ID,
+                        STARTING_POSITION,
+                        PARENT_SHARD_IDS,
+                        STARTING_HASH_KEY_TEST_VALUE,
+                        ENDING_HASH_KEY_TEST_VALUE,
+                        true);
+
+        assertThat(split.isFinished()).isTrue();
+    }
+
+    @Test
+    void testFinishedSplitsMapDefaultEmpty() {
+        KinesisShardSplit split =
+                new KinesisShardSplit(
+                        STREAM_ARN,
+                        SHARD_ID,
+                        STARTING_POSITION,
+                        PARENT_SHARD_IDS,
+                        STARTING_HASH_KEY_TEST_VALUE,
+                        ENDING_HASH_KEY_TEST_VALUE);
+
+        assertThat(split.isFinished()).isFalse();
+    }
 }
diff --git 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/util/TestUtil.java
 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/util/TestUtil.java
index 035d449..bdf7061 100644
--- 
a/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/util/TestUtil.java
+++ 
b/flink-connector-aws/flink-connector-aws-kinesis-streams/src/test/java/org/apache/flink/connector/kinesis/source/util/TestUtil.java
@@ -35,7 +35,9 @@ import software.amazon.awssdk.services.kinesis.model.Shard;
 
 import java.math.BigInteger;
 import java.time.Instant;
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.Optional;
 import java.util.Set;
 import java.util.stream.IntStream;
@@ -139,6 +141,17 @@ public class TestUtil {
                 ENDING_HASH_KEY_TEST_VALUE);
     }
 
+    public static KinesisShardSplit getFinishedTestSplit(String shardId) {
+        return new KinesisShardSplit(
+                STREAM_ARN,
+                shardId,
+                StartingPosition.fromStart(),
+                new HashSet<>(Arrays.asList(generateShardId(2), 
generateShardId(5))),
+                STARTING_HASH_KEY_TEST_VALUE,
+                ENDING_HASH_KEY_TEST_VALUE,
+                true);
+    }
+
     public static KinesisShardSplit getTestSplit(
             BigInteger startingHashKey, BigInteger endingHashKey) {
         return new KinesisShardSplit(


Reply via email to