This is an automated email from the ASF dual-hosted git repository.
shenghang pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new 75bc71beb8 [Fix[Connector-V2][Hbase] Avoid duplicate split assignment
on restore (#10310)
75bc71beb8 is described below
commit 75bc71beb8be0a7135c5c781c4fe9402428de681
Author: yzeng1618 <[email protected]>
AuthorDate: Sun Jan 11 20:48:14 2026 +0800
[Fix[Connector-V2][Hbase] Avoid duplicate split assignment on restore
(#10310)
Co-authored-by: zengyi <[email protected]>
---
.../hbase/source/HbaseSourceSplitEnumerator.java | 52 +++++++++++--
.../source/HbaseSourceSplitEnumeratorTest.java | 86 ++++++++++++++++++++++
2 files changed, 133 insertions(+), 5 deletions(-)
diff --git
a/seatunnel-connectors-v2/connector-hbase/src/main/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumerator.java
b/seatunnel-connectors-v2/connector-hbase/src/main/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumerator.java
index 73b1d6862a..54306ef6ec 100644
---
a/seatunnel-connectors-v2/connector-hbase/src/main/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumerator.java
+++
b/seatunnel-connectors-v2/connector-hbase/src/main/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumerator.java
@@ -50,6 +50,9 @@ public class HbaseSourceSplitEnumerator
/** The splits that have not assigned */
private Set<HbaseSourceSplit> pendingSplit;
+ /** Whether the pending splits have been initialized */
+ private boolean initialized = false;
+
private HbaseParameters hbaseParameters;
private HbaseClient hbaseClient;
@@ -71,23 +74,40 @@ public class HbaseSourceSplitEnumerator
Context<HbaseSourceSplit> context,
HbaseParameters hbaseParameters,
HbaseClient hbaseClient) {
- this(context, hbaseParameters, new HashSet<>());
- this.hbaseClient = hbaseClient;
+ this(context, hbaseParameters, new HashSet<>(), hbaseClient);
+ }
+
+ @VisibleForTesting
+ public HbaseSourceSplitEnumerator(
+ Context<HbaseSourceSplit> context,
+ HbaseParameters hbaseParameters,
+ HbaseSourceState sourceState,
+ HbaseClient hbaseClient) {
+ this(context, hbaseParameters, sourceState.getAssignedSplits(),
hbaseClient);
}
private HbaseSourceSplitEnumerator(
Context<HbaseSourceSplit> context,
HbaseParameters hbaseParameters,
Set<HbaseSourceSplit> assignedSplit) {
+ this(context, hbaseParameters, assignedSplit,
HbaseClient.createInstance(hbaseParameters));
+ }
+
+ private HbaseSourceSplitEnumerator(
+ Context<HbaseSourceSplit> context,
+ HbaseParameters hbaseParameters,
+ Set<HbaseSourceSplit> assignedSplit,
+ HbaseClient hbaseClient) {
this.context = context;
this.hbaseParameters = hbaseParameters;
this.assignedSplit = assignedSplit;
- this.hbaseClient = HbaseClient.createInstance(hbaseParameters);
+ this.hbaseClient = hbaseClient;
}
@Override
public void open() {
this.pendingSplit = new HashSet<>();
+ this.initialized = false;
}
@Override
@@ -110,7 +130,9 @@ public class HbaseSourceSplitEnumerator
public void addSplitsBack(List<HbaseSourceSplit> splits, int subtaskId) {
if (!splits.isEmpty()) {
pendingSplit.addAll(splits);
- assignSplit(subtaskId);
+ if (context.registeredReaders().contains(subtaskId)) {
+ assignSplit(subtaskId);
+ }
}
}
@@ -121,10 +143,30 @@ public class HbaseSourceSplitEnumerator
@Override
public void registerReader(int subtaskId) {
- pendingSplit = getTableSplits();
+ initializePendingSplits();
assignSplit(subtaskId);
}
+ private void initializePendingSplits() {
+ if (initialized) {
+ return;
+ }
+ Set<HbaseSourceSplit> tableSplits = getTableSplits();
+ Set<String> existedSplitIds =
+
pendingSplit.stream().map(HbaseSourceSplit::splitId).collect(Collectors.toSet());
+ if (!assignedSplit.isEmpty()) {
+ existedSplitIds.addAll(
+ assignedSplit.stream()
+ .map(HbaseSourceSplit::splitId)
+ .collect(Collectors.toSet()));
+ }
+ pendingSplit.addAll(
+ tableSplits.stream()
+ .filter(split ->
!existedSplitIds.contains(split.splitId()))
+ .collect(Collectors.toSet()));
+ initialized = true;
+ }
+
@Override
public HbaseSourceState snapshotState(long checkpointId) throws Exception {
return new HbaseSourceState(assignedSplit);
diff --git
a/seatunnel-connectors-v2/connector-hbase/src/test/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumeratorTest.java
b/seatunnel-connectors-v2/connector-hbase/src/test/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumeratorTest.java
index fd5eb0cceb..0fffeec0cc 100644
---
a/seatunnel-connectors-v2/connector-hbase/src/test/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumeratorTest.java
+++
b/seatunnel-connectors-v2/connector-hbase/src/test/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumeratorTest.java
@@ -27,16 +27,26 @@ import org.apache.hadoop.hbase.util.Bytes;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
import java.util.Set;
+import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class HbaseSourceSplitEnumeratorTest {
@@ -376,4 +386,80 @@ public class HbaseSourceSplitEnumeratorTest {
assertArrayEquals(Bytes.toBytes("row100"), split.getStartRow());
assertArrayEquals(Bytes.toBytes("row200"), split.getEndRow());
}
+
+ @Test
+ void testRestoreOnlyAssignReturnedSplits() throws Exception {
+ when(context.currentParallelism()).thenReturn(1);
+ when(context.registeredReaders()).thenReturn(Collections.emptySet());
+
+ byte[][] startKeys = {
+ HConstants.EMPTY_BYTE_ARRAY, Bytes.toBytes("row100"),
Bytes.toBytes("row200")
+ };
+ byte[][] endKeys = {
+ Bytes.toBytes("row100"), Bytes.toBytes("row200"),
HConstants.EMPTY_BYTE_ARRAY
+ };
+ when(regionLocator.getStartKeys()).thenReturn(startKeys);
+ when(regionLocator.getEndKeys()).thenReturn(endKeys);
+
+ Set<HbaseSourceSplit> assignedSplits = new HashSet<>();
+ assignedSplits.add(new HbaseSourceSplit(0, startKeys[0], endKeys[0]));
+ assignedSplits.add(new HbaseSourceSplit(1, startKeys[1], endKeys[1]));
+ assignedSplits.add(new HbaseSourceSplit(2, startKeys[2], endKeys[2]));
+
+ HbaseSourceSplitEnumerator restoredEnumerator =
+ new HbaseSourceSplitEnumerator(
+ context,
+ hbaseParameters,
+ new HbaseSourceState(assignedSplits),
+ hbaseClient);
+
+ restoredEnumerator.open();
+
+ List<HbaseSourceSplit> returnedSplits =
+ Arrays.asList(
+ new HbaseSourceSplit(1, startKeys[1], endKeys[1]),
+ new HbaseSourceSplit(2, startKeys[2], endKeys[2]));
+ restoredEnumerator.addSplitsBack(returnedSplits, 0);
+
+ ArgumentCaptor<List<HbaseSourceSplit>> assignedCaptor =
ArgumentCaptor.forClass(List.class);
+ restoredEnumerator.registerReader(0);
+
+ verify(context, times(1)).assignSplit(eq(0), assignedCaptor.capture());
+ Set<String> assignedSplitIds =
+ assignedCaptor.getValue().stream()
+ .map(HbaseSourceSplit::splitId)
+ .collect(Collectors.toSet());
+ assertEquals(2, assignedSplitIds.size());
+ assertTrue(assignedSplitIds.contains("hbase_source_split_1"));
+ assertTrue(assignedSplitIds.contains("hbase_source_split_2"));
+ assertFalse(assignedSplitIds.contains("hbase_source_split_0"));
+ }
+
+ @Test
+ void
testRegisterReaderInitializePendingSplitOnlyOnceWhenParallelismMoreThanOne()
+ throws Exception {
+ when(context.currentParallelism()).thenReturn(2);
+
+ byte[][] startKeys = {
+ HConstants.EMPTY_BYTE_ARRAY,
+ Bytes.toBytes("row100"),
+ Bytes.toBytes("row200"),
+ Bytes.toBytes("row300")
+ };
+ byte[][] endKeys = {
+ Bytes.toBytes("row100"),
+ Bytes.toBytes("row200"),
+ Bytes.toBytes("row300"),
+ HConstants.EMPTY_BYTE_ARRAY
+ };
+ when(regionLocator.getStartKeys()).thenReturn(startKeys);
+ when(regionLocator.getEndKeys()).thenReturn(endKeys);
+
+ enumerator.open();
+ enumerator.registerReader(0);
+ enumerator.registerReader(1);
+
+ verify(hbaseClient, times(1)).getRegionLocator("test_table");
+ assertEquals(0, enumerator.currentUnassignedSplitSize());
+ }
}