This is an automated email from the ASF dual-hosted git repository. lzljs3620320 pushed a commit to branch release-1.3 in repository https://gitbox.apache.org/repos/asf/paimon.git
commit 909a165964f89be1f2614a0ef64dc9449eaf9684 Author: umi <[email protected]> AuthorDate: Tue Sep 23 11:28:04 2025 +0800 [python] Support reading data by splitting according to rows (#6274) --- .../pypaimon/manifest/manifest_file_manager.py | 4 +- paimon-python/pypaimon/read/plan.py | 6 +- .../pypaimon/read/reader/concat_batch_reader.py | 27 ++ paimon-python/pypaimon/read/split.py | 2 + paimon-python/pypaimon/read/split_read.py | 7 +- paimon-python/pypaimon/read/table_scan.py | 181 ++++++--- .../pypaimon/tests/py36/ao_simple_test.py | 332 ++++++++++++++++ .../pypaimon/tests/rest/rest_simple_test.py | 419 +++++++++++++++++++-- paimon-python/pypaimon/write/file_store_commit.py | 2 +- 9 files changed, 898 insertions(+), 82 deletions(-) diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py b/paimon-python/pypaimon/manifest/manifest_file_manager.py index aec8bc7ed0..f4b0ab0be3 100644 --- a/paimon-python/pypaimon/manifest/manifest_file_manager.py +++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py @@ -40,7 +40,7 @@ class ManifestFileManager: self.primary_key_fields = self.table.table_schema.get_primary_key_fields() self.trimmed_primary_key_fields = self.table.table_schema.get_trimmed_primary_key_fields() - def read(self, manifest_file_name: str, shard_filter=None) -> List[ManifestEntry]: + def read(self, manifest_file_name: str, bucket_filter=None) -> List[ManifestEntry]: manifest_file_path = self.manifest_path / manifest_file_name entries = [] @@ -97,7 +97,7 @@ class ManifestFileManager: total_buckets=record['_TOTAL_BUCKETS'], file=file_meta ) - if shard_filter is not None and not shard_filter(entry): + if bucket_filter is not None and not bucket_filter(entry): continue entries.append(entry) return entries diff --git a/paimon-python/pypaimon/read/plan.py b/paimon-python/pypaimon/read/plan.py index 9a65fd6f12..8c69a41a9b 100644 --- a/paimon-python/pypaimon/read/plan.py +++ b/paimon-python/pypaimon/read/plan.py @@ -19,18 +19,14 @@ from dataclasses import dataclass from typing import List -from pypaimon.manifest.schema.manifest_entry import ManifestEntry + from pypaimon.read.split import Split @dataclass class Plan: """Implementation of Plan for native Python reading.""" - _files: List[ManifestEntry] _splits: List[Split] - def files(self) -> List[ManifestEntry]: - return self._files - def splits(self) -> List[Split]: return self._splits diff --git a/paimon-python/pypaimon/read/reader/concat_batch_reader.py b/paimon-python/pypaimon/read/reader/concat_batch_reader.py index 76b9f10c71..a5a596e1ea 100644 --- a/paimon-python/pypaimon/read/reader/concat_batch_reader.py +++ b/paimon-python/pypaimon/read/reader/concat_batch_reader.py @@ -49,3 +49,30 @@ class ConcatBatchReader(RecordBatchReader): self.current_reader.close() self.current_reader = None self.queue.clear() + + +class ShardBatchReader(ConcatBatchReader): + + def __init__(self, readers, split_start_row, split_end_row): + super().__init__(readers) + self.split_start_row = split_start_row + self.split_end_row = split_end_row + self.cur_end = 0 + + def read_arrow_batch(self) -> Optional[RecordBatch]: + batch = super().read_arrow_batch() + if batch is None: + return None + if self.split_start_row is not None or self.split_end_row is not None: + cur_begin = self.cur_end # begin idx of current batch based on the split + self.cur_end += batch.num_rows + # shard the first batch and the last batch + if self.split_start_row <= cur_begin < self.cur_end <= self.split_end_row: + return batch + elif cur_begin <= self.split_start_row < self.cur_end: + return batch.slice(self.split_start_row - cur_begin, + min(self.split_end_row, self.cur_end) - self.split_start_row) + elif cur_begin < self.split_end_row <= self.cur_end: + return batch.slice(0, self.split_end_row - cur_begin) + else: + return batch diff --git a/paimon-python/pypaimon/read/split.py b/paimon-python/pypaimon/read/split.py index 9b802d9880..f1ab5f3a5b 100644 --- a/paimon-python/pypaimon/read/split.py +++ b/paimon-python/pypaimon/read/split.py @@ -32,6 +32,8 @@ class Split: _file_paths: List[str] _row_count: int _file_size: int + split_start_row: int = None + split_end_row: int = None raw_convertible: bool = False @property diff --git a/paimon-python/pypaimon/read/split_read.py b/paimon-python/pypaimon/read/split_read.py index 1fe0a89d0e..d9e9939c11 100644 --- a/paimon-python/pypaimon/read/split_read.py +++ b/paimon-python/pypaimon/read/split_read.py @@ -24,7 +24,7 @@ from typing import List, Optional, Tuple from pypaimon.common.predicate import Predicate from pypaimon.read.interval_partition import IntervalPartition, SortedRun from pypaimon.read.partition_info import PartitionInfo -from pypaimon.read.reader.concat_batch_reader import ConcatBatchReader +from pypaimon.read.reader.concat_batch_reader import ConcatBatchReader, ShardBatchReader from pypaimon.read.reader.concat_record_reader import ConcatRecordReader from pypaimon.read.reader.data_file_record_reader import DataFileBatchReader from pypaimon.read.reader.drop_delete_reader import DropDeleteRecordReader @@ -249,7 +249,10 @@ class RawFileSplitRead(SplitRead): if not data_readers: return EmptyFileRecordReader() - concat_reader = ConcatBatchReader(data_readers) + if self.split.split_start_row is not None: + concat_reader = ShardBatchReader(data_readers, self.split.split_start_row, self.split.split_end_row) + else: + concat_reader = ConcatBatchReader(data_readers) # if the table is appendonly table, we don't need extra filter, all predicates has pushed down if self.table.is_primary_key_table and self.predicate: return FilterRecordReader(concat_reader, self.predicate) diff --git a/paimon-python/pypaimon/read/table_scan.py b/paimon-python/pypaimon/read/table_scan.py index 0b9f97db4f..6a6ab9f3f8 100644 --- a/paimon-python/pypaimon/read/table_scan.py +++ b/paimon-python/pypaimon/read/table_scan.py @@ -33,7 +33,6 @@ from pypaimon.read.split import Split from pypaimon.schema.data_types import DataField from pypaimon.snapshot.snapshot_manager import SnapshotManager from pypaimon.table.bucket_mode import BucketMode -from pypaimon.write.row_key_extractor import FixedBucketRowKeyExtractor class TableScan: @@ -71,9 +70,21 @@ class TableScan: self.table.options.get('bucket', -1)) == BucketMode.POSTPONE_BUCKET.value else False def plan(self) -> Plan: + file_entries = self.plan_files() + if not file_entries: + return Plan([]) + if self.table.is_primary_key_table: + splits = self._create_primary_key_splits(file_entries) + else: + splits = self._create_append_only_splits(file_entries) + + splits = self._apply_push_down_limit(splits) + return Plan(splits) + + def plan_files(self) -> List[ManifestEntry]: latest_snapshot = self.snapshot_manager.get_latest_snapshot() if not latest_snapshot: - return Plan([], []) + return [] manifest_files = self.manifest_list_manager.read_all(latest_snapshot) deleted_entries = set() @@ -92,40 +103,98 @@ class TableScan: entry for entry in added_entries if (tuple(entry.partition.values), entry.bucket, entry.file.file_name) not in deleted_entries ] - if self.predicate: file_entries = self._filter_by_predicate(file_entries) - - partitioned_split = defaultdict(list) - for entry in file_entries: - partitioned_split[(tuple(entry.partition.values), entry.bucket)].append(entry) - - splits = [] - for key, values in partitioned_split.items(): - if self.table.is_primary_key_table: - splits += self._create_primary_key_splits(values) - else: - splits += self._create_append_only_splits(values) - - splits = self._apply_push_down_limit(splits) - - return Plan(file_entries, splits) + return file_entries def with_shard(self, idx_of_this_subtask, number_of_para_subtasks) -> 'TableScan': + if idx_of_this_subtask >= number_of_para_subtasks: + raise Exception("idx_of_this_subtask must be less than number_of_para_subtasks") self.idx_of_this_subtask = idx_of_this_subtask self.number_of_para_subtasks = number_of_para_subtasks return self + def _append_only_filter_by_shard(self, partitioned_files: defaultdict) -> (defaultdict, int, int): + total_row = 0 + # Sort by file creation time to ensure consistent sharding + for key, file_entries in partitioned_files.items(): + for entry in file_entries: + total_row += entry.file.row_count + + # Calculate number of rows this shard should process + # Last shard handles all remaining rows (handles non-divisible cases) + if self.idx_of_this_subtask == self.number_of_para_subtasks - 1: + num_row = total_row - total_row // self.number_of_para_subtasks * self.idx_of_this_subtask + else: + num_row = total_row // self.number_of_para_subtasks + # Calculate start row and end row position for current shard in all data + start_row = self.idx_of_this_subtask * (total_row // self.number_of_para_subtasks) + end_row = start_row + num_row + + plan_start_row = 0 + plan_end_row = 0 + entry_end_row = 0 # end row position of current file in all data + splits_start_row = 0 + filtered_partitioned_files = defaultdict(list) + # Iterate through all file entries to find files that overlap with current shard range + for key, file_entries in partitioned_files.items(): + filtered_entries = [] + for entry in file_entries: + entry_begin_row = entry_end_row # Starting row position of current file in all data + entry_end_row += entry.file.row_count # Update to row position after current file + + # If current file is completely after shard range, stop iteration + if entry_begin_row >= end_row: + break + # If current file is completely before shard range, skip it + if entry_end_row <= start_row: + continue + if entry_begin_row <= start_row < entry_end_row: + splits_start_row = entry_begin_row + plan_start_row = start_row - entry_begin_row + # If shard end position is within current file, record relative end position + if entry_begin_row < end_row <= entry_end_row: + plan_end_row = end_row - splits_start_row + # Add files that overlap with shard range to result + filtered_entries.append(entry) + if filtered_entries: + filtered_partitioned_files[key] = filtered_entries + + return filtered_partitioned_files, plan_start_row, plan_end_row + + def _compute_split_start_end_row(self, splits: List[Split], plan_start_row, plan_end_row): + file_end_row = 0 # end row position of current file in all data + for split in splits: + files = split.files + split_start_row = file_end_row + # Iterate through all file entries to find files that overlap with current shard range + for file in files: + file_begin_row = file_end_row # Starting row position of current file in all data + file_end_row += file.row_count # Update to row position after current file + + # If shard start position is within current file, record actual start position and relative offset + if file_begin_row <= plan_start_row < file_end_row: + split.split_start_row = plan_start_row - file_begin_row + + # If shard end position is within current file, record relative end position + if file_begin_row < plan_end_row <= file_end_row: + split.split_end_row = plan_end_row - split_start_row + if split.split_start_row is None: + split.split_start_row = 0 + if split.split_end_row is None: + split.split_end_row = split.row_count + + def _primary_key_filter_by_shard(self, file_entries: List[ManifestEntry]) -> List[ManifestEntry]: + filtered_entries = [] + for entry in file_entries: + if entry.bucket % self.number_of_para_subtasks == self.idx_of_this_subtask: + filtered_entries.append(entry) + return filtered_entries + def _bucket_filter(self, entry: Optional[ManifestEntry]) -> bool: bucket = entry.bucket if self.only_read_real_buckets and bucket < 0: return False - if self.idx_of_this_subtask is not None: - if self.table.is_primary_key_table: - return bucket % self.number_of_para_subtasks == self.idx_of_this_subtask - else: - file = entry.file.file_name - return FixedBucketRowKeyExtractor.hash(file) % self.number_of_para_subtasks == self.idx_of_this_subtask return True def _apply_push_down_limit(self, splits: List[Split]) -> List[Split]: @@ -185,38 +254,60 @@ class TableScan: }) def _create_append_only_splits(self, file_entries: List[ManifestEntry]) -> List['Split']: - if not file_entries: - return [] + partitioned_files = defaultdict(list) + for entry in file_entries: + partitioned_files[(tuple(entry.partition.values), entry.bucket)].append(entry) - data_files: List[DataFileMeta] = [e.file for e in file_entries] + if self.idx_of_this_subtask is not None: + partitioned_files, plan_start_row, plan_end_row = self._append_only_filter_by_shard(partitioned_files) def weight_func(f: DataFileMeta) -> int: return max(f.file_size, self.open_file_cost) - packed_files: List[List[DataFileMeta]] = self._pack_for_ordered(data_files, weight_func, self.target_split_size) - return self._build_split_from_pack(packed_files, file_entries, False) + splits = [] + for key, file_entries in partitioned_files.items(): + if not file_entries: + return [] - def _create_primary_key_splits(self, file_entries: List[ManifestEntry]) -> List['Split']: - if not file_entries: - return [] + data_files: List[DataFileMeta] = [e.file for e in file_entries] - data_files: List[DataFileMeta] = [e.file for e in file_entries] - partition_sort_runs: List[List[SortedRun]] = IntervalPartition(data_files).partition() - sections: List[List[DataFileMeta]] = [ - [file for s in sl for file in s.files] - for sl in partition_sort_runs - ] + packed_files: List[List[DataFileMeta]] = self._pack_for_ordered(data_files, weight_func, + self.target_split_size) + splits += self._build_split_from_pack(packed_files, file_entries, False) + if self.idx_of_this_subtask is not None: + self._compute_split_start_end_row(splits, plan_start_row, plan_end_row) + return splits + + def _create_primary_key_splits(self, file_entries: List[ManifestEntry]) -> List['Split']: + if self.idx_of_this_subtask is not None: + file_entries = self._primary_key_filter_by_shard(file_entries) + partitioned_files = defaultdict(list) + for entry in file_entries: + partitioned_files[(tuple(entry.partition.values), entry.bucket)].append(entry) def weight_func(fl: List[DataFileMeta]) -> int: return max(sum(f.file_size for f in fl), self.open_file_cost) - packed_files: List[List[List[DataFileMeta]]] = self._pack_for_ordered(sections, weight_func, - self.target_split_size) - flatten_packed_files: List[List[DataFileMeta]] = [ - [file for sub_pack in pack for file in sub_pack] - for pack in packed_files - ] - return self._build_split_from_pack(flatten_packed_files, file_entries, True) + splits = [] + for key, file_entries in partitioned_files.items(): + if not file_entries: + return [] + + data_files: List[DataFileMeta] = [e.file for e in file_entries] + partition_sort_runs: List[List[SortedRun]] = IntervalPartition(data_files).partition() + sections: List[List[DataFileMeta]] = [ + [file for s in sl for file in s.files] + for sl in partition_sort_runs + ] + + packed_files: List[List[List[DataFileMeta]]] = self._pack_for_ordered(sections, weight_func, + self.target_split_size) + flatten_packed_files: List[List[DataFileMeta]] = [ + [file for sub_pack in pack for file in sub_pack] + for pack in packed_files + ] + splits += self._build_split_from_pack(flatten_packed_files, file_entries, True) + return splits def _build_split_from_pack(self, packed_files, file_entries, for_primary_key_split: bool) -> List['Split']: splits = [] diff --git a/paimon-python/pypaimon/tests/py36/ao_simple_test.py b/paimon-python/pypaimon/tests/py36/ao_simple_test.py new file mode 100644 index 0000000000..17ebf58be7 --- /dev/null +++ b/paimon-python/pypaimon/tests/py36/ao_simple_test.py @@ -0,0 +1,332 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import pyarrow as pa + +from pypaimon import Schema +from pypaimon.tests.py36.pyarrow_compat import table_sort_by +from pypaimon.tests.rest.rest_base_test import RESTBaseTest + + +class AOSimpleTest(RESTBaseTest): + def setUp(self): + super().setUp() + self.pa_schema = pa.schema([ + ('user_id', pa.int64()), + ('item_id', pa.int64()), + ('behavior', pa.string()), + ('dt', pa.string()), + ]) + self.data = { + 'user_id': [2, 4, 6, 8, 10], + 'item_id': [1001, 1002, 1003, 1004, 1005], + 'behavior': ['a', 'b', 'c', 'd', 'e'], + 'dt': ['2000-10-10', '2025-08-10', '2025-08-11', '2025-08-12', '2025-08-13'] + } + self.expected = pa.Table.from_pydict(self.data, schema=self.pa_schema) + + def test_with_shard_ao_unaware_bucket(self): + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_with_shard_ao_unaware_bucket', schema, False) + table = self.rest_catalog.get_table('default.test_with_shard_ao_unaware_bucket') + write_builder = table.new_batch_write_builder() + # first write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014], + 'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + # second write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8, 18], + 'item_id': [1005, 1006, 1007, 1008, 1018], + 'behavior': ['e', 'f', 'g', 'h', 'z'], + 'dt': ['p2', 'p1', 'p2', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().with_shard(2, 3).plan().splits() + actual = table_sort_by(table_read.to_arrow(splits), 'user_id') + expected = pa.Table.from_pydict({ + 'user_id': [5, 7, 7, 8, 9, 11, 13], + 'item_id': [1005, 1007, 1007, 1008, 1009, 1011, 1013], + 'behavior': ['e', 'f', 'g', 'h', 'h', 'j', 'l'], + 'dt': ['p2', 'p2', 'p2', 'p2', 'p2', 'p2', 'p2'], + }, schema=self.pa_schema) + self.assertEqual(actual, expected) + + # Get the three actual tables + splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits() + actual1 = table_sort_by(table_read.to_arrow(splits1), 'user_id') + splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits() + actual2 = table_sort_by(table_read.to_arrow(splits2), 'user_id') + splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits() + actual3 = table_sort_by(table_read.to_arrow(splits3), 'user_id') + + # Concatenate the three tables + actual = table_sort_by(pa.concat_tables([actual1, actual2, actual3]), 'user_id') + expected = table_sort_by(self._read_test_table(read_builder), 'user_id') + self.assertEqual(actual, expected) + + def test_with_shard_ao_fixed_bucket(self): + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt'], + options={'bucket': '5', 'bucket-key': 'item_id'}) + self.rest_catalog.create_table('default.test_with_slice_ao_fixed_bucket', schema, False) + table = self.rest_catalog.get_table('default.test_with_slice_ao_fixed_bucket') + write_builder = table.new_batch_write_builder() + # first write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014], + 'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + # second write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8], + 'item_id': [1005, 1006, 1007, 1008], + 'behavior': ['e', 'f', 'g', 'h'], + 'dt': ['p2', 'p1', 'p2', 'p2'], + } + pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().with_shard(0, 3).plan().splits() + actual = table_sort_by(table_read.to_arrow(splits), 'user_id') + expected = pa.Table.from_pydict({ + 'user_id': [1, 2, 3, 5, 8, 12], + 'item_id': [1001, 1002, 1003, 1005, 1008, 1012], + 'behavior': ['a', 'b', 'c', 'd', 'g', 'k'], + 'dt': ['p1', 'p1', 'p2', 'p2', 'p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual, expected) + + # Get the three actual tables + splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits() + actual1 = table_sort_by(table_read.to_arrow(splits1), 'user_id') + splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits() + actual2 = table_sort_by(table_read.to_arrow(splits2), 'user_id') + splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits() + actual3 = table_sort_by(table_read.to_arrow(splits3), 'user_id') + + # Concatenate the three tables + actual = table_sort_by(pa.concat_tables([actual1, actual2, actual3]), 'user_id') + expected = table_sort_by(self._read_test_table(read_builder), 'user_id') + self.assertEqual(actual, expected) + + def test_shard_single_partition(self): + """Test sharding with single partition - tests _filter_by_shard with simple data""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_shard_single_partition', schema, False) + table = self.rest_catalog.get_table('default.test_shard_single_partition') + write_builder = table.new_batch_write_builder() + + # Write data with single partition + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data = { + 'user_id': [1, 2, 3, 4, 5, 6], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006], + 'behavior': ['a', 'b', 'c', 'd', 'e', 'f'], + 'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1'], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + + # Test first shard (0, 2) - should get first 3 rows + splits = read_builder.new_scan().with_shard(0, 2).plan().splits() + actual = table_sort_by(table_read.to_arrow(splits), 'user_id') + expected = pa.Table.from_pydict({ + 'user_id': [1, 2, 3], + 'item_id': [1001, 1002, 1003], + 'behavior': ['a', 'b', 'c'], + 'dt': ['p1', 'p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual, expected) + + # Test second shard (1, 2) - should get last 3 rows + splits = read_builder.new_scan().with_shard(1, 2).plan().splits() + actual = table_sort_by(table_read.to_arrow(splits), 'user_id') + expected = pa.Table.from_pydict({ + 'user_id': [4, 5, 6], + 'item_id': [1004, 1005, 1006], + 'behavior': ['d', 'e', 'f'], + 'dt': ['p1', 'p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual, expected) + + def test_shard_uneven_distribution(self): + """Test sharding with uneven row distribution across shards""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_shard_uneven', schema, False) + table = self.rest_catalog.get_table('default.test_shard_uneven') + write_builder = table.new_batch_write_builder() + + # Write data with 7 rows (not evenly divisible by 3) + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data = { + 'user_id': [1, 2, 3, 4, 5, 6, 7], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007], + 'behavior': ['a', 'b', 'c', 'd', 'e', 'f', 'g'], + 'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1', 'p1'], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + + # Test sharding into 3 parts: 2, 2, 3 rows + splits = read_builder.new_scan().with_shard(0, 3).plan().splits() + actual1 = table_sort_by(table_read.to_arrow(splits), 'user_id') + expected1 = pa.Table.from_pydict({ + 'user_id': [1, 2], + 'item_id': [1001, 1002], + 'behavior': ['a', 'b'], + 'dt': ['p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual1, expected1) + + splits = read_builder.new_scan().with_shard(1, 3).plan().splits() + actual2 = table_sort_by(table_read.to_arrow(splits), 'user_id') + expected2 = pa.Table.from_pydict({ + 'user_id': [3, 4], + 'item_id': [1003, 1004], + 'behavior': ['c', 'd'], + 'dt': ['p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual2, expected2) + + splits = read_builder.new_scan().with_shard(2, 3).plan().splits() + actual3 = table_sort_by(table_read.to_arrow(splits), 'user_id') + expected3 = pa.Table.from_pydict({ + 'user_id': [5, 6, 7], + 'item_id': [1005, 1006, 1007], + 'behavior': ['e', 'f', 'g'], + 'dt': ['p1', 'p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual3, expected3) + + def test_shard_many_small_shards(self): + """Test sharding with many small shards""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_shard_many_small', schema, False) + table = self.rest_catalog.get_table('default.test_shard_many_small') + write_builder = table.new_batch_write_builder() + + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data = { + 'user_id': [1, 2, 3, 4, 5, 6], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006], + 'behavior': ['a', 'b', 'c', 'd', 'e', 'f'], + 'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1'], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + + # Test with 6 shards (one row per shard) + for i in range(6): + splits = read_builder.new_scan().with_shard(i, 6).plan().splits() + actual = table_read.to_arrow(splits) + self.assertEqual(len(actual), 1) + self.assertEqual(actual['user_id'][0].as_py(), i + 1) + + def test_shard_boundary_conditions(self): + """Test sharding boundary conditions with edge cases""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_shard_boundary', schema, False) + table = self.rest_catalog.get_table('default.test_shard_boundary') + write_builder = table.new_batch_write_builder() + + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data = { + 'user_id': [1, 2, 3, 4, 5], + 'item_id': [1001, 1002, 1003, 1004, 1005], + 'behavior': ['a', 'b', 'c', 'd', 'e'], + 'dt': ['p1', 'p1', 'p1', 'p1', 'p1'], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + + # Test first shard (0, 4) - should get 1 row (5//4 = 1) + splits = read_builder.new_scan().with_shard(0, 4).plan().splits() + actual = table_read.to_arrow(splits) + self.assertEqual(len(actual), 1) + + # Test middle shard (1, 4) - should get 1 row + splits = read_builder.new_scan().with_shard(1, 4).plan().splits() + actual = table_read.to_arrow(splits) + self.assertEqual(len(actual), 1) + + # Test last shard (3, 4) - should get 2 rows (remainder goes to last shard) + splits = read_builder.new_scan().with_shard(3, 4).plan().splits() + actual = table_read.to_arrow(splits) + self.assertEqual(len(actual), 2) diff --git a/paimon-python/pypaimon/tests/rest/rest_simple_test.py b/paimon-python/pypaimon/tests/rest/rest_simple_test.py index 95a20345b0..03685c317a 100644 --- a/paimon-python/pypaimon/tests/rest/rest_simple_test.py +++ b/paimon-python/pypaimon/tests/rest/rest_simple_test.py @@ -22,9 +22,7 @@ import pyarrow as pa from pypaimon import Schema from pypaimon.tests.rest.rest_base_test import RESTBaseTest -from pypaimon.write.row_key_extractor import (DynamicBucketRowKeyExtractor, - FixedBucketRowKeyExtractor, - UnawareBucketRowKeyExtractor) +from pypaimon.write.row_key_extractor import FixedBucketRowKeyExtractor, DynamicBucketRowKeyExtractor class RESTSimpleTest(RESTBaseTest): @@ -45,58 +43,425 @@ class RESTSimpleTest(RESTBaseTest): self.expected = pa.Table.from_pydict(self.data, schema=self.pa_schema) def test_with_shard_ao_unaware_bucket(self): - schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['user_id']) - self.rest_catalog.create_table('default.test_with_shard', schema, False) - table = self.rest_catalog.get_table('default.test_with_shard') + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_with_shard_ao_unaware_bucket', schema, False) + table = self.rest_catalog.get_table('default.test_with_shard_ao_unaware_bucket') + write_builder = table.new_batch_write_builder() + # first write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014], + 'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + # second write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8, 18], + 'item_id': [1005, 1006, 1007, 1008, 1018], + 'behavior': ['e', 'f', 'g', 'h', 'z'], + 'dt': ['p2', 'p1', 'p2', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().with_shard(2, 3).plan().splits() + actual = table_read.to_arrow(splits).sort_by('user_id') + expected = pa.Table.from_pydict({ + 'user_id': [5, 7, 7, 8, 9, 11, 13], + 'item_id': [1005, 1007, 1007, 1008, 1009, 1011, 1013], + 'behavior': ['e', 'f', 'g', 'h', 'h', 'j', 'l'], + 'dt': ['p2', 'p2', 'p2', 'p2', 'p2', 'p2', 'p2'], + }, schema=self.pa_schema) + self.assertEqual(actual, expected) + + # Get the three actual tables + splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits() + actual1 = table_read.to_arrow(splits1).sort_by('user_id') + splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits() + actual2 = table_read.to_arrow(splits2).sort_by('user_id') + splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits() + actual3 = table_read.to_arrow(splits3).sort_by('user_id') + + # Concatenate the three tables + actual = pa.concat_tables([actual1, actual2, actual3]).sort_by('user_id') + expected = self._read_test_table(read_builder).sort_by('user_id') + self.assertEqual(actual, expected) + def test_with_shard_ao_fixed_bucket(self): + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt'], + options={'bucket': '5', 'bucket-key': 'item_id'}) + self.rest_catalog.create_table('default.test_with_slice_ao_fixed_bucket', schema, False) + table = self.rest_catalog.get_table('default.test_with_slice_ao_fixed_bucket') write_builder = table.new_batch_write_builder() + # first write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014], + 'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + # second write table_write = write_builder.new_write() table_commit = write_builder.new_commit() - self.assertIsInstance(table_write.row_key_extractor, UnawareBucketRowKeyExtractor) + data2 = { + 'user_id': [5, 6, 7, 8], + 'item_id': [1005, 1006, 1007, 1008], + 'behavior': ['e', 'f', 'g', 'h'], + 'dt': ['p2', 'p1', 'p2', 'p2'], + } + pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() - pa_table = pa.Table.from_pydict(self.data, schema=self.pa_schema) + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().with_shard(0, 3).plan().splits() + actual = table_read.to_arrow(splits).sort_by('user_id') + expected = pa.Table.from_pydict({ + 'user_id': [1, 2, 3, 5, 8, 12], + 'item_id': [1001, 1002, 1003, 1005, 1008, 1012], + 'behavior': ['a', 'b', 'c', 'd', 'g', 'k'], + 'dt': ['p1', 'p1', 'p2', 'p2', 'p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual, expected) + + # Get the three actual tables + splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits() + actual1 = table_read.to_arrow(splits1).sort_by('user_id') + splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits() + actual2 = table_read.to_arrow(splits2).sort_by('user_id') + splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits() + actual3 = table_read.to_arrow(splits3).sort_by('user_id') + + # Concatenate the three tables + actual = pa.concat_tables([actual1, actual2, actual3]).sort_by('user_id') + expected = self._read_test_table(read_builder).sort_by('user_id') + self.assertEqual(actual, expected) + + def test_shard_single_partition(self): + """Test sharding with single partition - tests _filter_by_shard with simple data""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_shard_single_partition', schema, False) + table = self.rest_catalog.get_table('default.test_shard_single_partition') + write_builder = table.new_batch_write_builder() + + # Write data with single partition + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data = { + 'user_id': [1, 2, 3, 4, 5, 6], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006], + 'behavior': ['a', 'b', 'c', 'd', 'e', 'f'], + 'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1'], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) table_write.write_arrow(pa_table) table_commit.commit(table_write.prepare_commit()) table_write.close() table_commit.close() - splits = [] read_builder = table.new_read_builder() - splits.extend(read_builder.new_scan().with_shard(0, 3).plan().splits()) - splits.extend(read_builder.new_scan().with_shard(1, 3).plan().splits()) - splits.extend(read_builder.new_scan().with_shard(2, 3).plan().splits()) + table_read = read_builder.new_read() + # Test first shard (0, 2) - should get first 3 rows + plan = read_builder.new_scan().with_shard(0, 2).plan() + actual = table_read.to_arrow(plan.splits()).sort_by('user_id') + expected = pa.Table.from_pydict({ + 'user_id': [1, 2, 3], + 'item_id': [1001, 1002, 1003], + 'behavior': ['a', 'b', 'c'], + 'dt': ['p1', 'p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual, expected) + + # Test second shard (1, 2) - should get last 3 rows + plan = read_builder.new_scan().with_shard(1, 2).plan() + actual = table_read.to_arrow(plan.splits()).sort_by('user_id') + expected = pa.Table.from_pydict({ + 'user_id': [4, 5, 6], + 'item_id': [1004, 1005, 1006], + 'behavior': ['d', 'e', 'f'], + 'dt': ['p1', 'p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual, expected) + + def test_shard_uneven_distribution(self): + """Test sharding with uneven row distribution across shards""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_shard_uneven', schema, False) + table = self.rest_catalog.get_table('default.test_shard_uneven') + write_builder = table.new_batch_write_builder() + + # Write data with 7 rows (not evenly divisible by 3) + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data = { + 'user_id': [1, 2, 3, 4, 5, 6, 7], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007], + 'behavior': ['a', 'b', 'c', 'd', 'e', 'f', 'g'], + 'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1', 'p1'], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() table_read = read_builder.new_read() - actual = table_read.to_arrow(splits) - self.assertEqual(actual.sort_by('user_id'), self.expected) + # Test sharding into 3 parts: 2, 2, 3 rows + plan1 = read_builder.new_scan().with_shard(0, 3).plan() + actual1 = table_read.to_arrow(plan1.splits()).sort_by('user_id') + expected1 = pa.Table.from_pydict({ + 'user_id': [1, 2], + 'item_id': [1001, 1002], + 'behavior': ['a', 'b'], + 'dt': ['p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual1, expected1) + + plan2 = read_builder.new_scan().with_shard(1, 3).plan() + actual2 = table_read.to_arrow(plan2.splits()).sort_by('user_id') + expected2 = pa.Table.from_pydict({ + 'user_id': [3, 4], + 'item_id': [1003, 1004], + 'behavior': ['c', 'd'], + 'dt': ['p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual2, expected2) + + plan3 = read_builder.new_scan().with_shard(2, 3).plan() + actual3 = table_read.to_arrow(plan3.splits()).sort_by('user_id') + expected3 = pa.Table.from_pydict({ + 'user_id': [5, 6, 7], + 'item_id': [1005, 1006, 1007], + 'behavior': ['e', 'f', 'g'], + 'dt': ['p1', 'p1', 'p1'], + }, schema=self.pa_schema) + self.assertEqual(actual3, expected3) + + def test_shard_single_shard(self): + """Test sharding with only one shard - should return all data""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_shard_single', schema, False) + table = self.rest_catalog.get_table('default.test_shard_single') + write_builder = table.new_batch_write_builder() - def test_with_shard_ao_fixed_bucket(self): - schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['user_id'], - options={'bucket': '5', 'bucket-key': 'item_id'}) - self.rest_catalog.create_table('default.test_with_shard', schema, False) - table = self.rest_catalog.get_table('default.test_with_shard') + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data = { + 'user_id': [1, 2, 3, 4], + 'item_id': [1001, 1002, 1003, 1004], + 'behavior': ['a', 'b', 'c', 'd'], + 'dt': ['p1', 'p1', 'p2', 'p2'], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + + # Test single shard (0, 1) - should get all data + plan = read_builder.new_scan().with_shard(0, 1).plan() + actual = table_read.to_arrow(plan.splits()).sort_by('user_id') + expected = pa.Table.from_pydict(data, schema=self.pa_schema) + self.assertEqual(actual, expected) + + def test_shard_many_small_shards(self): + """Test sharding with many small shards""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_shard_many_small', schema, False) + table = self.rest_catalog.get_table('default.test_shard_many_small') write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() table_commit = write_builder.new_commit() - self.assertIsInstance(table_write.row_key_extractor, FixedBucketRowKeyExtractor) + data = { + 'user_id': [1, 2, 3, 4, 5, 6], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006], + 'behavior': ['a', 'b', 'c', 'd', 'e', 'f'], + 'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1'], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() - pa_table = pa.Table.from_pydict(self.data, schema=self.pa_schema) + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + + # Test with 6 shards (one row per shard) + for i in range(6): + plan = read_builder.new_scan().with_shard(i, 6).plan() + actual = table_read.to_arrow(plan.splits()) + self.assertEqual(len(actual), 1) + self.assertEqual(actual['user_id'][0].as_py(), i + 1) + + def test_shard_boundary_conditions(self): + """Test sharding boundary conditions with edge cases""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_shard_boundary', schema, False) + table = self.rest_catalog.get_table('default.test_shard_boundary') + write_builder = table.new_batch_write_builder() + + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data = { + 'user_id': [1, 2, 3, 4, 5], + 'item_id': [1001, 1002, 1003, 1004, 1005], + 'behavior': ['a', 'b', 'c', 'd', 'e'], + 'dt': ['p1', 'p1', 'p1', 'p1', 'p1'], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) table_write.write_arrow(pa_table) table_commit.commit(table_write.prepare_commit()) table_write.close() table_commit.close() - splits = [] read_builder = table.new_read_builder() - splits.extend(read_builder.new_scan().with_shard(0, 3).plan().splits()) - splits.extend(read_builder.new_scan().with_shard(1, 3).plan().splits()) - splits.extend(read_builder.new_scan().with_shard(2, 3).plan().splits()) + table_read = read_builder.new_read() + + # Test first shard (0, 4) - should get 1 row (5//4 = 1) + plan = read_builder.new_scan().with_shard(0, 4).plan() + actual = table_read.to_arrow(plan.splits()) + self.assertEqual(len(actual), 1) + + # Test middle shard (1, 4) - should get 1 row + plan = read_builder.new_scan().with_shard(1, 4).plan() + actual = table_read.to_arrow(plan.splits()) + self.assertEqual(len(actual), 1) + + # Test last shard (3, 4) - should get 2 rows (remainder goes to last shard) + plan = read_builder.new_scan().with_shard(3, 4).plan() + actual = table_read.to_arrow(plan.splits()) + self.assertEqual(len(actual), 2) + + def test_with_shard_large_dataset(self): + """Test with_shard method using 50000 rows of data to verify performance and correctness""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt'], + options={'bucket': '5', 'bucket-key': 'item_id'}) + self.rest_catalog.create_table('default.test_with_shard_large_dataset', schema, False) + table = self.rest_catalog.get_table('default.test_with_shard_large_dataset') + write_builder = table.new_batch_write_builder() + # Generate 50000 rows of test data + num_rows = 50000 + batch_size = 5000 # Write in batches to avoid memory issues + + for batch_start in range(0, num_rows, batch_size): + batch_end = min(batch_start + batch_size, num_rows) + batch_data = { + 'user_id': list(range(batch_start + 1, batch_end + 1)), + 'item_id': [2000 + i for i in range(batch_start, batch_end)], + 'behavior': [chr(ord('a') + (i % 26)) for i in range(batch_start, batch_end)], + 'dt': [f'p{(i % 5) + 1}' for i in range(batch_start, batch_end)], + } + + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + pa_table = pa.Table.from_pydict(batch_data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() table_read = read_builder.new_read() - actual = table_read.to_arrow(splits) - self.assertEqual(actual.sort_by("user_id"), self.expected) + + # Test with 6 shards + num_shards = 6 + shard_results = [] + total_rows_from_shards = 0 + + for shard_idx in range(num_shards): + splits = read_builder.new_scan().with_shard(shard_idx, num_shards).plan().splits() + shard_result = table_read.to_arrow(splits) + shard_results.append(shard_result) + shard_rows = len(shard_result) if shard_result else 0 + total_rows_from_shards += shard_rows + print(f"Shard {shard_idx}/{num_shards}: {shard_rows} rows") + + # Verify that all shards together contain all the data + concatenated_result = pa.concat_tables(shard_results).sort_by('user_id') + + # Read all data without sharding for comparison + all_splits = read_builder.new_scan().plan().splits() + all_data = table_read.to_arrow(all_splits).sort_by('user_id') + + # Verify total row count + self.assertEqual(len(concatenated_result), len(all_data)) + self.assertEqual(len(all_data), num_rows) + self.assertEqual(total_rows_from_shards, num_rows) + + # Verify data integrity - check first and last few rows + self.assertEqual(concatenated_result['user_id'][0].as_py(), 1) + self.assertEqual(concatenated_result['user_id'][-1].as_py(), num_rows) + self.assertEqual(concatenated_result['item_id'][0].as_py(), 2000) + self.assertEqual(concatenated_result['item_id'][-1].as_py(), 2000 + num_rows - 1) + + # Verify that concatenated result equals all data + self.assertEqual(concatenated_result, all_data) + # Test with different shard configurations + # Test with 10 shards + shard_10_results = [] + for shard_idx in range(10): + splits = read_builder.new_scan().with_shard(shard_idx, 10).plan().splits() + shard_result = table_read.to_arrow(splits) + if shard_result: + shard_10_results.append(shard_result) + + if shard_10_results: + concatenated_10_shards = pa.concat_tables(shard_10_results).sort_by('user_id') + self.assertEqual(len(concatenated_10_shards), num_rows) + self.assertEqual(concatenated_10_shards, all_data) + + # Test with single shard (should return all data) + single_shard_splits = read_builder.new_scan().with_shard(0, 1).plan().splits() + single_shard_result = table_read.to_arrow(single_shard_splits).sort_by('user_id') + self.assertEqual(len(single_shard_result), num_rows) + self.assertEqual(single_shard_result, all_data) + + print(f"Successfully tested with_shard method using {num_rows} rows of data") + + def test_shard_parameter_validation(self): + """Test edge cases for parameter validation""" + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.rest_catalog.create_table('default.test_shard_validation_edge', schema, False) + table = self.rest_catalog.get_table('default.test_shard_validation_edge') + + read_builder = table.new_read_builder() + # Test invalid case with number_of_para_subtasks = 1 + with self.assertRaises(Exception) as context: + read_builder.new_scan().with_shard(1, 1).plan() + self.assertEqual(str(context.exception), "idx_of_this_subtask must be less than number_of_para_subtasks") def test_with_shard_pk_dynamic_bucket(self): schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['user_id'], primary_keys=['user_id', 'dt']) diff --git a/paimon-python/pypaimon/write/file_store_commit.py b/paimon-python/pypaimon/write/file_store_commit.py index 03c4d034d0..5920f50ad8 100644 --- a/paimon-python/pypaimon/write/file_store_commit.py +++ b/paimon-python/pypaimon/write/file_store_commit.py @@ -101,7 +101,7 @@ class FileStoreCommit: f"in {msg.partition} does not belong to this partition") commit_entries = [] - current_entries = TableScan(self.table, partition_filter, None, []).plan().files() + current_entries = TableScan(self.table, partition_filter, None, []).plan_files() for entry in current_entries: entry.kind = 1 commit_entries.append(entry)
