This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new bda3eb031b [python] with_shard should be evenly distributed for data
evolution mode (#7271)
bda3eb031b is described below
commit bda3eb031b829f2d4dce377fa64e3a4b09606c44
Author: Jingsong Lee <[email protected]>
AuthorDate: Thu Feb 12 10:50:43 2026 +0800
[python] with_shard should be evenly distributed for data evolution mode
(#7271)
---
.../read/scanner/data_evolution_split_generator.py | 281 ++++++---------------
.../pypaimon/read/scanner/split_generator.py | 12 +-
paimon-python/pypaimon/read/split.py | 54 +++-
paimon-python/pypaimon/tests/blob_table_test.py | 22 +-
.../pypaimon/tests/data_evolution_test.py | 65 +++++
paimon-python/pypaimon/write/table_update.py | 61 ++++-
6 files changed, 256 insertions(+), 239 deletions(-)
diff --git
a/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
b/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
index 241966134f..4ac154fed0 100644
--- a/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
+++ b/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
@@ -16,7 +16,7 @@ See the License for the specific language governing
permissions and
limitations under the License.
"""
from collections import defaultdict
-from typing import List, Optional, Dict, Tuple
+from typing import List, Optional, Tuple
from pypaimon.globalindex.indexed_split import IndexedSplit
from pypaimon.globalindex.range import Range
@@ -66,15 +66,11 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
slice_row_ranges = None # Row ID ranges for slice-based filtering
- if self.start_pos_of_this_subtask is not None:
+ if self.start_pos_of_this_subtask is not None or
self.idx_of_this_subtask is not None:
# Calculate Row ID range for slice-based filtering
slice_row_ranges =
self._calculate_slice_row_ranges(partitioned_files)
# Filter files by Row ID range
partitioned_files =
self._filter_files_by_row_ranges(partitioned_files, slice_row_ranges)
- elif self.idx_of_this_subtask is not None:
- partitioned_files = self._filter_by_shard(
- partitioned_files, self.idx_of_this_subtask,
self.number_of_para_subtasks
- )
def weight_func(file_list: List[DataFileMeta]) -> int:
return max(sum(f.file_size for f in file_list),
self.open_file_cost)
@@ -133,21 +129,14 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
pack = packed_files[i] if i < len(packed_files) else []
raw_convertible = all(len(sub_pack) == 1 for sub_pack in pack)
- file_paths = []
- total_file_size = 0
- total_record_count = 0
-
for data_file in file_group:
data_file.set_file_path(
self.table.table_path,
file_entries[0].partition,
file_entries[0].bucket
)
- file_paths.append(data_file.file_path)
- total_file_size += data_file.file_size
- total_record_count += data_file.row_count
- if file_paths:
+ if file_group:
# Get deletion files for this split
data_deletion_files = None
if self.deletion_files_map:
@@ -161,9 +150,6 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
files=file_group,
partition=file_entries[0].partition,
bucket=file_entries[0].bucket,
- file_paths=file_paths,
- row_count=total_record_count,
- file_size=total_file_size,
raw_convertible=raw_convertible,
data_deletion_files=data_deletion_files
)
@@ -196,13 +182,31 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
def _divide_ranges_by_position(self, sorted_ranges: List[Range]) ->
Tuple[Optional[Range], Optional[Range]]:
"""
Divide ranges by position (start_pos, end_pos) to get the Row ID range
for this slice.
+ If idx_of_this_subtask exists, divide total rows by
number_of_para_subtasks.
"""
if not sorted_ranges:
return None, None
total_row_count = sum(r.count() for r in sorted_ranges)
- start_pos = self.start_pos_of_this_subtask
- end_pos = self.end_pos_of_this_subtask
+
+ # If idx_of_this_subtask exists, calculate start_pos and end_pos based
on number_of_para_subtasks
+ if self.idx_of_this_subtask is not None:
+ # Calculate shard boundaries based on total row count
+ rows_per_task = total_row_count // self.number_of_para_subtasks
+ remainder = total_row_count % self.number_of_para_subtasks
+
+ start_pos = self.idx_of_this_subtask * rows_per_task
+ # Distribute remainder rows across first 'remainder' tasks
+ if self.idx_of_this_subtask < remainder:
+ start_pos += self.idx_of_this_subtask
+ end_pos = start_pos + rows_per_task + 1
+ else:
+ start_pos += remainder
+ end_pos = start_pos + rows_per_task
+ else:
+ # Use existing start_pos and end_pos
+ start_pos = self.start_pos_of_this_subtask
+ end_pos = self.end_pos_of_this_subtask
if start_pos >= total_row_count:
return None, None
@@ -239,201 +243,86 @@ class
DataEvolutionSplitGenerator(AbstractSplitGenerator):
def _filter_files_by_row_ranges(partitioned_files: defaultdict,
row_ranges: List[Range]) -> defaultdict:
"""
Filter files by Row ID ranges. Keep files that overlap with the given
ranges.
+ Blob files are only included if they overlap with non-blob files that
match the ranges.
"""
filtered_partitioned_files = defaultdict(list)
for key, file_entries in partitioned_files.items():
- filtered_entries = []
-
+ # Separate blob and non-blob files
+ non_blob_entries = []
+ blob_entries = []
+
for entry in file_entries:
+ if DataFileMeta.is_blob_file(entry.file.file_name):
+ blob_entries.append(entry)
+ else:
+ non_blob_entries.append(entry)
+
+ # First, filter non-blob files based on row ranges
+ filtered_non_blob_entries = []
+ non_blob_ranges = []
+ for entry in non_blob_entries:
first_row_id = entry.file.first_row_id
file_range = Range(first_row_id, first_row_id +
entry.file.row_count - 1)
-
+
# Check if file overlaps with any of the row ranges
overlaps = False
for r in row_ranges:
if r.overlaps(file_range):
overlaps = True
break
-
+
if overlaps:
- filtered_entries.append(entry)
-
+ filtered_non_blob_entries.append(entry)
+ non_blob_ranges.append(file_range)
+
+ # Then, filter blob files based on row ID range of non-blob files
+ filtered_blob_entries = []
+ non_blob_ranges = Range.sort_and_merge_overlap(non_blob_ranges,
True, True)
+ # Only keep blob files that overlap with merged non-blob ranges
+ for entry in blob_entries:
+ first_row_id = entry.file.first_row_id
+ blob_range = Range(first_row_id, first_row_id +
entry.file.row_count - 1)
+ # Check if blob file overlaps with any merged range
+ for merged_range in non_blob_ranges:
+ if merged_range.overlaps(blob_range):
+ filtered_blob_entries.append(entry)
+ break
+
+ # Combine filtered non-blob and blob files
+ filtered_entries = filtered_non_blob_entries +
filtered_blob_entries
+
if filtered_entries:
filtered_partitioned_files[key] = filtered_entries
return filtered_partitioned_files
- def _filter_by_shard(self, partitioned_files: defaultdict, sub_task_id:
int, total_tasks: int) -> defaultdict:
- list_ranges = []
- for file_entries in partitioned_files.values():
- for entry in file_entries:
- first_row_id = entry.file.first_row_id
- if first_row_id is None:
- raise ValueError("Found None first row id in files")
- # Range is inclusive [from_, to], so use row_count - 1
- list_ranges.append(Range(first_row_id, first_row_id +
entry.file.row_count - 1))
-
- sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False)
-
- start_range, end_range = self._divide_ranges(sorted_ranges,
sub_task_id, total_tasks)
- if start_range is None or end_range is None:
- return defaultdict(list)
- start_first_row_id = start_range.from_
- end_first_row_id = end_range.to
-
- filtered_partitioned_files = {
- k: [x for x in v if x.file.first_row_id >= start_first_row_id and
x.file.first_row_id <= end_first_row_id]
- for k, v in partitioned_files.items()
- }
-
- filtered_partitioned_files = {k: v for k, v in
filtered_partitioned_files.items() if v}
- return defaultdict(list, filtered_partitioned_files)
-
@staticmethod
- def _divide_ranges(
- sorted_ranges: List[Range], sub_task_id: int, total_tasks: int
- ) -> Tuple[Optional[Range], Optional[Range]]:
- if not sorted_ranges:
- return None, None
-
- num_ranges = len(sorted_ranges)
-
- # If more tasks than ranges, some tasks get nothing
- if sub_task_id >= num_ranges:
- return None, None
-
- # Calculate balanced distribution of ranges across tasks
- base_ranges_per_task = num_ranges // total_tasks
- remainder = num_ranges % total_tasks
-
- # Each of the first 'remainder' tasks gets one extra range
- if sub_task_id < remainder:
- num_ranges_for_task = base_ranges_per_task + 1
- start_idx = sub_task_id * (base_ranges_per_task + 1)
- else:
- num_ranges_for_task = base_ranges_per_task
- start_idx = (
- remainder * (base_ranges_per_task + 1) +
- (sub_task_id - remainder) * base_ranges_per_task
- )
- end_idx = start_idx + num_ranges_for_task - 1
- return sorted_ranges[start_idx], sorted_ranges[end_idx]
-
- def _split_by_row_id(self, files: List[DataFileMeta]) ->
List[List[DataFileMeta]]:
+ def _split_by_row_id(files: List[DataFileMeta]) ->
List[List[DataFileMeta]]:
"""
Split files by row ID for data evolution tables.
+ Files are grouped by their overlapping row ID ranges.
"""
- split_by_row_id = []
-
- # Filter blob files to only include those within the row ID range of
non-blob files
- sorted_files = self._filter_blob(files)
-
- # Split files by firstRowId
- last_row_id = -1
- check_row_id_start = 0
- current_split = []
-
- for file in sorted_files:
+ list_ranges = []
+ for file in files:
first_row_id = file.first_row_id
- if first_row_id is None:
- # Files without firstRowId are treated as individual splits
- split_by_row_id.append([file])
- continue
-
- if not DataFileMeta.is_blob_file(file.file_name) and first_row_id
!= last_row_id:
- if current_split:
- split_by_row_id.append(current_split)
-
- # Validate that files don't overlap
- if first_row_id < check_row_id_start:
- file_names = [f.file_name for f in sorted_files]
- raise ValueError(
- f"There are overlapping files in the split:
{file_names}, "
- f"the wrong file is: {file.file_name}"
- )
-
- current_split = []
- last_row_id = first_row_id
- check_row_id_start = first_row_id + file.row_count
+ list_ranges.append(Range(first_row_id, first_row_id +
file.row_count - 1))
- current_split.append(file)
-
- if current_split:
- split_by_row_id.append(current_split)
-
- return split_by_row_id
-
- def _compute_slice_split_file_idx_map(
- self,
- plan_start_pos: int,
- plan_end_pos: int,
- split: Split,
- file_end_pos: int
- ) -> Dict[str, Tuple[int, int]]:
- """
- Compute file index map for a split, determining which rows to read
from each file.
- For data files, the range is calculated based on the file's position
in the cumulative row space.
- For blob files (which may be rolled), the range is calculated based on
each file's first_row_id.
- """
- shard_file_idx_map = {}
-
- # First pass: data files only. Compute range and apply directly to
avoid second-pass lookup.
- current_pos = file_end_pos
- data_file_infos = []
- for file in split.files:
- if DataFileMeta.is_blob_file(file.file_name):
- continue
- file_begin_pos = current_pos
- current_pos += file.row_count
- data_file_range = self._compute_file_range(
- plan_start_pos, plan_end_pos, file_begin_pos, file.row_count
- )
- data_file_infos.append((file, data_file_range))
- if data_file_range is not None:
- shard_file_idx_map[file.file_name] = data_file_range
-
- if not data_file_infos:
- # No data file, skip this split
- shard_file_idx_map[self.NEXT_POS_KEY] = file_end_pos
- return shard_file_idx_map
+ if not list_ranges:
+ return []
- next_pos = current_pos
+ sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False)
- # Second pass: only blob files (data files already in
shard_file_idx_map from first pass)
- for file in split.files:
- if not DataFileMeta.is_blob_file(file.file_name):
- continue
- blob_first_row_id = file.first_row_id if file.first_row_id is not
None else 0
- data_file_range = None
- data_file_first_row_id = None
- for df, fr in data_file_infos:
- df_first = df.first_row_id if df.first_row_id is not None else 0
- if df_first <= blob_first_row_id < df_first + df.row_count:
- data_file_range = fr
- data_file_first_row_id = df_first
+ range_to_files = {}
+ for file in files:
+ first_row_id = file.first_row_id
+ file_range = Range(first_row_id, first_row_id + file.row_count - 1)
+ for r in sorted_ranges:
+ if r.overlaps(file_range):
+ range_to_files.setdefault(r, []).append(file)
break
- if data_file_range is None:
- continue
- if data_file_range == (-1, -1):
- shard_file_idx_map[file.file_name] = (-1, -1)
- continue
- blob_rel_start = blob_first_row_id - data_file_first_row_id
- blob_rel_end = blob_rel_start + file.row_count
- shard_start, shard_end = data_file_range
- intersect_start = max(blob_rel_start, shard_start)
- intersect_end = min(blob_rel_end, shard_end)
- if intersect_start >= intersect_end:
- shard_file_idx_map[file.file_name] = (-1, -1)
- elif intersect_start == blob_rel_start and intersect_end ==
blob_rel_end:
- pass
- else:
- local_start = intersect_start - blob_rel_start
- local_end = intersect_end - blob_rel_start
- shard_file_idx_map[file.file_name] = (local_start, local_end)
- shard_file_idx_map[self.NEXT_POS_KEY] = next_pos
- return shard_file_idx_map
+ return list(range_to_files.values())
def _wrap_to_indexed_splits(self, splits: List[Split], row_ranges:
List[Range]) -> List[Split]:
"""
@@ -478,25 +367,3 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
indexed_splits.append(IndexedSplit(split, expected, scores))
return indexed_splits
-
- @staticmethod
- def _filter_blob(files: List[DataFileMeta]) -> List[DataFileMeta]:
- """
- Filter blob files to only include those within row ID range of
non-blob files.
- """
- result = []
- row_id_start = -1
- row_id_end = -1
-
- for file in files:
- if not DataFileMeta.is_blob_file(file.file_name):
- if file.first_row_id is not None:
- row_id_start = file.first_row_id
- row_id_end = file.first_row_id + file.row_count
- result.append(file)
- else:
- if file.first_row_id is not None and row_id_start != -1:
- if row_id_start <= file.first_row_id < row_id_end:
- result.append(file)
-
- return result
diff --git a/paimon-python/pypaimon/read/scanner/split_generator.py
b/paimon-python/pypaimon/read/scanner/split_generator.py
index f4f2cebbe7..f11240a8a9 100644
--- a/paimon-python/pypaimon/read/scanner/split_generator.py
+++ b/paimon-python/pypaimon/read/scanner/split_generator.py
@@ -98,21 +98,14 @@ class AbstractSplitGenerator(ABC):
else:
raw_convertible = True
- file_paths = []
- total_file_size = 0
- total_record_count = 0
-
for data_file in file_group:
data_file.set_file_path(
self.table.table_path,
file_entries[0].partition,
file_entries[0].bucket
)
- file_paths.append(data_file.file_path)
- total_file_size += data_file.file_size
- total_record_count += data_file.row_count
- if file_paths:
+ if file_group:
# Get deletion files for this split
data_deletion_files = None
if self.deletion_files_map:
@@ -126,9 +119,6 @@ class AbstractSplitGenerator(ABC):
files=file_group,
partition=file_entries[0].partition,
bucket=file_entries[0].bucket,
- file_paths=file_paths,
- row_count=total_record_count,
- file_size=total_file_size,
raw_convertible=raw_convertible,
data_deletion_files=data_deletion_files
)
diff --git a/paimon-python/pypaimon/read/split.py
b/paimon-python/pypaimon/read/split.py
index 12d20c0947..5bd63d8f52 100644
--- a/paimon-python/pypaimon/read/split.py
+++ b/paimon-python/pypaimon/read/split.py
@@ -17,7 +17,7 @@
################################################################################
from abc import ABC, abstractmethod
-from typing import List, Optional
+from typing import List, Optional, Callable
from pypaimon.manifest.schema.data_file_meta import DataFileMeta
from pypaimon.table.row.generic_row import GenericRow
@@ -77,18 +77,12 @@ class DataSplit(Split):
files: List[DataFileMeta],
partition: GenericRow,
bucket: int,
- file_paths: List[str],
- row_count: int,
- file_size: int,
raw_convertible: bool = False,
data_deletion_files: Optional[List[DeletionFile]] = None
):
self._files = files
self._partition = partition
self._bucket = bucket
- self._file_paths = file_paths
- self._row_count = row_count
- self._file_size = file_size
self.raw_convertible = raw_convertible
self.data_deletion_files = data_deletion_files
@@ -96,6 +90,40 @@ class DataSplit(Split):
def files(self) -> List[DataFileMeta]:
return self._files
+ def filter_file(self, func: Callable[[DataFileMeta], bool]) ->
Optional['DataSplit']:
+ """
+ Filter files based on a predicate function and create a new DataSplit.
+
+ Args:
+ func: A function that takes a DataFileMeta and returns True if the
file should be kept
+
+ Returns:
+ A new DataSplit with filtered files, adjusted data_deletion_files
+ """
+ # Filter files based on the predicate
+ filtered_files = [f for f in self._files if func(f)]
+
+ # If no files match, return None
+ if not filtered_files:
+ return None
+
+ # Find indices of filtered files to adjust data_deletion_files
+ filtered_indices = [i for i, f in enumerate(self._files) if func(f)]
+
+ # Filter data_deletion_files to match filtered files
+ filtered_data_deletion_files = None
+ if self.data_deletion_files is not None:
+ filtered_data_deletion_files = [self.data_deletion_files[i] for i
in filtered_indices]
+
+ # Create new DataSplit with filtered data
+ return DataSplit(
+ files=filtered_files,
+ partition=self._partition,
+ bucket=self._bucket,
+ raw_convertible=self.raw_convertible,
+ data_deletion_files=filtered_data_deletion_files
+ )
+
@property
def partition(self) -> GenericRow:
return self._partition
@@ -106,18 +134,18 @@ class DataSplit(Split):
@property
def row_count(self) -> int:
- return self._row_count
+ """Calculate total row count from all files."""
+ return sum(f.row_count for f in self._files)
@property
def file_size(self) -> int:
- return self._file_size
+ """Calculate total file size from all files."""
+ return sum(f.file_size for f in self._files)
@property
def file_paths(self) -> List[str]:
- return self._file_paths
-
- def set_row_count(self, row_count: int) -> None:
- self._row_count = row_count
+ """Get file paths from all files."""
+ return [f.file_path for f in self._files if f.file_path is not None]
def merged_row_count(self) -> Optional[int]:
"""
diff --git a/paimon-python/pypaimon/tests/blob_table_test.py
b/paimon-python/pypaimon/tests/blob_table_test.py
index 7670ec9447..b1236d764b 100755
--- a/paimon-python/pypaimon/tests/blob_table_test.py
+++ b/paimon-python/pypaimon/tests/blob_table_test.py
@@ -1455,7 +1455,7 @@ class DataBlobWriterTest(unittest.TestCase):
splits = table_scan.plan().splits()
result = table_read.to_arrow(splits)
- self.assertEqual(sum([s._row_count for s in splits]), 40 * 2)
+ self.assertEqual(sum([s.row_count for s in splits]), 40 * 2)
# Verify the data
self.assertEqual(result.num_rows, 40, "Should have 40 rows")
@@ -2204,12 +2204,12 @@ class DataBlobWriterTest(unittest.TestCase):
result = table_read.to_arrow(table_scan.plan().splits())
# Verify the data
- self.assertEqual(result.num_rows, 80, "Should have 54 rows")
+ self.assertEqual(result.num_rows, 54, "Should have 54 rows")
self.assertEqual(result.num_columns, 4, "Should have 4 columns")
# Verify blob data integrity
blob_data = result.column('large_blob').to_pylist()
- self.assertEqual(len(blob_data), 80, "Should have 54 blob records")
+ self.assertEqual(len(blob_data), 54, "Should have 54 blob records")
# Verify each blob
for i, blob in enumerate(blob_data):
self.assertEqual(len(blob), len(large_blob_data), f"Blob {i + 1}
should be {large_blob_size:,} bytes")
@@ -2500,8 +2500,18 @@ class DataBlobWriterTest(unittest.TestCase):
result = table_read.to_arrow(splits)
# Verify the data was read back correctly
- # Just one file, so split 0 occupied the whole records
- self.assertEqual(result.num_rows, 5, "Should have 2 rows")
+ self.assertEqual(result.num_rows, 3, "Should have 3 rows")
+ self.assertEqual(result.num_columns, 3, "Should have 3 columns")
+
+ # Read data back using table API
+ read_builder = table.new_read_builder()
+ table_scan = read_builder.new_scan().with_shard(1, 2)
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ result = table_read.to_arrow(splits)
+
+ # Verify the data was read back correctly
+ self.assertEqual(result.num_rows, 2, "Should have 2 rows")
self.assertEqual(result.num_columns, 3, "Should have 3 columns")
def test_blob_write_read_large_data_volume_rolling_with_shard(self):
@@ -2770,7 +2780,7 @@ class DataBlobWriterTest(unittest.TestCase):
table_read = read_builder.new_read()
splits = table_scan.plan().splits()
- total_split_row_count = sum([s._row_count for s in splits])
+ total_split_row_count = sum([s.row_count for s in splits])
self.assertEqual(total_split_row_count, num_rows * 2,
f"Total split row count should be {num_rows}, got
{total_split_row_count}")
diff --git a/paimon-python/pypaimon/tests/data_evolution_test.py
b/paimon-python/pypaimon/tests/data_evolution_test.py
index cfb09a0caf..9f95672a19 100644
--- a/paimon-python/pypaimon/tests/data_evolution_test.py
+++ b/paimon-python/pypaimon/tests/data_evolution_test.py
@@ -24,6 +24,7 @@ import pyarrow as pa
from pypaimon import CatalogFactory, Schema
from pypaimon.manifest.manifest_list_manager import ManifestListManager
+from pypaimon.read.read_builder import ReadBuilder
from pypaimon.snapshot.snapshot_manager import SnapshotManager
@@ -238,6 +239,70 @@ class DataEvolutionTest(unittest.TestCase):
% len(result_oob),
)
+ def test_with_slice_partitioned_table(self):
+ pa_schema = pa.schema([
+ ("pt", pa.int64()),
+ ("b", pa.int32()),
+ ("c", pa.int32()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ pa_schema,
+ partition_keys=["pt"],
+ options={
+ "row-tracking.enabled": "true",
+ "data-evolution.enabled": "true",
+ "source.split.target-size": "512m",
+ },
+ )
+ table_name = "default.test_with_slice_partitioned_table"
+ self.catalog.create_table(table_name, schema, ignore_if_exists=True)
+ table = self.catalog.get_table(table_name)
+
+ for batch in [
+ {"pt": [1, 1], "b": [10, 20], "c": [100, 200]},
+ {"pt": [2, 2], "b": [1011, 2011], "c": [1001, 2001]},
+ {"pt": [2, 2], "b": [-10, -20], "c": [-100, -200]},
+ ]:
+ wb = table.new_batch_write_builder()
+ tw = wb.new_write()
+ tc = wb.new_commit()
+ tw.write_arrow(pa.Table.from_pydict(batch, schema=pa_schema))
+ tc.commit(tw.prepare_commit())
+ tw.close()
+ tc.close()
+
+ rb: ReadBuilder = table.new_read_builder()
+ full_splits = rb.new_scan().plan().splits()
+ full_result = rb.new_read().to_pandas(full_splits)
+ self.assertEqual(
+ len(full_result),
+ 6,
+ "Full scan should return 6 rows",
+ )
+
+ predicate_builder = rb.new_predicate_builder()
+ rb.with_filter(predicate_builder.equal("pt", 2))
+
+ # 0 to 2
+ scan_oob = rb.new_scan().with_slice(0, 2)
+ splits_oob = scan_oob.plan().splits()
+ result_oob = rb.new_read().to_pandas(splits_oob)
+ self.assertEqual(
+ sorted(result_oob["b"].tolist()),
+ [1011, 2011],
+ "Full set b mismatch",
+ )
+
+ # 2 to 4
+ scan_oob = rb.new_scan().with_slice(2, 4)
+ splits_oob = scan_oob.plan().splits()
+ result_oob = rb.new_read().to_pandas(splits_oob)
+ self.assertEqual(
+ sorted(result_oob["b"].tolist()),
+ [-20, -10],
+ "Full set b mismatch",
+ )
+
def test_multiple_appends(self):
simple_pa_schema = pa.schema([
('f0', pa.int32()),
diff --git a/paimon-python/pypaimon/write/table_update.py
b/paimon-python/pypaimon/write/table_update.py
index 8e3f91bde4..747085a980 100644
--- a/paimon-python/pypaimon/write/table_update.py
+++ b/paimon-python/pypaimon/write/table_update.py
@@ -31,6 +31,61 @@ from pypaimon.write.writer.data_writer import DataWriter
from pypaimon.write.writer.append_only_data_writer import AppendOnlyDataWriter
+def _filter_by_whole_file_shard(splits: List[DataSplit], sub_task_id: int,
total_tasks: int) -> List[DataSplit]:
+ list_ranges = []
+ for split in splits:
+ for file in split.files:
+ first_row_id = file.first_row_id
+ list_ranges.append(Range(first_row_id, first_row_id +
file.row_count - 1))
+
+ sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False)
+
+ start_range, end_range = _divide_ranges(sorted_ranges, sub_task_id,
total_tasks)
+ if start_range is None or end_range is None:
+ return []
+ start_first_row_id = start_range.from_
+ end_first_row_id = end_range.to
+
+ def filter_data_file(f: DataFileMeta) -> bool:
+ return start_first_row_id <= f.first_row_id <= end_first_row_id
+
+ filtered_splits = []
+
+ for split in splits:
+ split = split.filter_file(filter_data_file)
+ if split is not None:
+ filtered_splits.append(split)
+
+ return filtered_splits
+
+
+def _divide_ranges(
+ sorted_ranges: List[Range], sub_task_id: int, total_tasks: int
+) -> Tuple[Optional[Range], Optional[Range]]:
+ if not sorted_ranges:
+ return None, None
+
+ num_ranges = len(sorted_ranges)
+
+ # If more tasks than ranges, some tasks get nothing
+ if sub_task_id >= num_ranges:
+ return None, None
+
+ # Calculate balanced distribution of ranges across tasks
+ base_ranges_per_task = num_ranges // total_tasks
+ remainder = num_ranges % total_tasks
+
+ # Each of the first 'remainder' tasks gets one extra range
+ if sub_task_id < remainder:
+ num_ranges_for_task = base_ranges_per_task + 1
+ start_idx = sub_task_id * (base_ranges_per_task + 1)
+ else:
+ num_ranges_for_task = base_ranges_per_task
+ start_idx = (remainder * (base_ranges_per_task + 1) + (sub_task_id -
remainder) * base_ranges_per_task)
+ end_idx = start_idx + num_ranges_for_task - 1
+ return sorted_ranges[start_idx], sorted_ranges[end_idx]
+
+
class TableUpdate:
def __init__(self, table, commit_user):
from pypaimon.table.file_store_table import FileStoreTable
@@ -97,8 +152,10 @@ class ShardTableUpdator:
self.writer: Optional[SingleWriter] = None
self.dict = defaultdict(list)
- scanner =
self.table.new_read_builder().new_scan().with_shard(shard_num,
total_shard_count)
- self.splits = scanner.plan().splits()
+ scanner = self.table.new_read_builder().new_scan()
+ splits = scanner.plan().splits()
+ splits = _filter_by_whole_file_shard(splits, shard_num,
total_shard_count)
+ self.splits = splits
self.row_ranges: List[(Tuple, Range)] = []
for split in self.splits: