This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 265f9084c0 [python] Enable update data to compute a New Column (scan +
rewrite with shards) (#7148)
265f9084c0 is described below
commit 265f9084c0155889ad10567cdedd4e6d8414a12c
Author: YeJunHao <[email protected]>
AuthorDate: Fri Jan 30 11:22:33 2026 +0800
[python] Enable update data to compute a New Column (scan + rewrite with
shards) (#7148)
---
docs/content/pypaimon/data-evolution.md | 118 ++++++++
paimon-python/pypaimon/globalindex/range.py | 5 +-
.../read/scanner/data_evolution_split_generator.py | 79 +++--
paimon-python/pypaimon/tests/blob_table_test.py | 73 ++---
.../pypaimon/tests/shard_table_updator_test.py | 320 +++++++++++++++++++++
paimon-python/pypaimon/write/table_update.py | 155 +++++++++-
6 files changed, 693 insertions(+), 57 deletions(-)
diff --git a/docs/content/pypaimon/data-evolution.md
b/docs/content/pypaimon/data-evolution.md
index bd0bde03a4..2268090f73 100644
--- a/docs/content/pypaimon/data-evolution.md
+++ b/docs/content/pypaimon/data-evolution.md
@@ -29,6 +29,13 @@ under the License.
PyPaimon for Data Evolution mode. See [Data Evolution]({{< ref
"append-table/data-evolution" >}}).
+## Prerequisites
+
+To use partial updates / data evolution, enable both options when creating the
table:
+
+- **`row-tracking.enabled`**: `true`
+- **`data-evolution.enabled`**: `true`
+
## Update Columns By Row ID
You can create `TableUpdate.update_by_arrow_with_row_id` to update columns to
data evolution tables.
@@ -36,7 +43,19 @@ You can create `TableUpdate.update_by_arrow_with_row_id` to
update columns to da
The input data should include the `_ROW_ID` column, update operation will
automatically sort and match each `_ROW_ID` to
its corresponding `first_row_id`, then groups rows with the same
`first_row_id` and writes them to a separate file.
+### Requirements for `_ROW_ID` updates
+
+- **All rows required**: the input table must contain **exactly the full table
row count** (one row per existing row).
+- **Row id coverage**: after sorting by `_ROW_ID`, it must be **0..N-1** (no
duplicates, no gaps).
+- **Update columns only**: include `_ROW_ID` plus the columns you want to
update (partial schema is OK).
+
```python
+import pyarrow as pa
+from pypaimon import CatalogFactory, Schema
+
+catalog = CatalogFactory.create({'warehouse': '/tmp/warehouse'})
+catalog.create_database('default', False)
+
simple_pa_schema = pa.schema([
('f0', pa.int8()),
('f1', pa.int16()),
@@ -78,3 +97,102 @@ table_commit.close()
# 'f0': [5, 6],
# 'f1': [-1001, 1002]
```
+
+## Compute a New Column (scan + rewrite with shards)
+
+If you want to **compute a derived column** (or **update an existing column
based on other columns**) without providing
+`_ROW_ID`, you can use the shard scan + rewrite workflow:
+
+- Read only the columns you need (projection)
+- Compute the new values in the same row order
+- Write only the updated columns back
+- Commit per shard
+
+This is useful for backfilling a newly added column, or recomputing a column
from other columns.
+
+### Example: compute `d = c + b - a`
+
+```python
+import pyarrow as pa
+from pypaimon import CatalogFactory, Schema
+
+catalog = CatalogFactory.create({'warehouse': '/tmp/warehouse'})
+catalog.create_database('default', False)
+
+table_schema = pa.schema([
+ ('a', pa.int32()),
+ ('b', pa.int32()),
+ ('c', pa.int32()),
+ ('d', pa.int32()),
+])
+
+schema = Schema.from_pyarrow_schema(
+ table_schema,
+ options={'row-tracking.enabled': 'true', 'data-evolution.enabled': 'true'},
+)
+catalog.create_table('default.t', schema, False)
+table = catalog.get_table('default.t')
+
+# write initial data (a, b, c only)
+write_builder = table.new_batch_write_builder()
+write = write_builder.new_write().with_write_type(['a', 'b', 'c'])
+commit = write_builder.new_commit()
+write.write_arrow(pa.Table.from_pydict({'a': [1, 2], 'b': [10, 20], 'c': [100,
200]}))
+commit.commit(write.prepare_commit())
+write.close()
+commit.close()
+
+# shard update: read (a, b, c), write only (d)
+update = write_builder.new_update()
+update.with_read_projection(['a', 'b', 'c'])
+update.with_update_type(['d'])
+
+shard_idx = 0
+num_shards = 1
+upd = update.new_shard_updator(shard_idx, num_shards)
+reader = upd.arrow_reader()
+
+for batch in iter(reader.read_next_batch, None):
+ a = batch.column('a').to_pylist()
+ b = batch.column('b').to_pylist()
+ c = batch.column('c').to_pylist()
+ d = [ci + bi - ai for ai, bi, ci in zip(a, b, c)]
+
+ upd.update_by_arrow_batch(
+ pa.RecordBatch.from_pydict({'d': d}, schema=pa.schema([('d',
pa.int32())]))
+ )
+
+commit_messages = upd.prepare_commit()
+commit = write_builder.new_commit()
+commit.commit(commit_messages)
+commit.close()
+```
+
+### Example: update an existing column `c = b - a`
+
+```python
+update = write_builder.new_update()
+update.with_read_projection(['a', 'b'])
+update.with_update_type(['c'])
+
+upd = update.new_shard_updator(0, 1)
+reader = upd.arrow_reader()
+for batch in iter(reader.read_next_batch, None):
+ a = batch.column('a').to_pylist()
+ b = batch.column('b').to_pylist()
+ c = [bi - ai for ai, bi in zip(a, b)]
+ upd.update_by_arrow_batch(
+ pa.RecordBatch.from_pydict({'c': c}, schema=pa.schema([('c',
pa.int32())]))
+ )
+
+commit_messages = upd.prepare_commit()
+commit = write_builder.new_commit()
+commit.commit(commit_messages)
+commit.close()
+```
+
+### Notes
+
+- **Row order matters**: the batches you write must have the **same number of
rows** as the batches you read, in the
+ same order for that shard.
+- **Parallelism**: run multiple shards by calling
`new_shard_updator(shard_idx, num_shards)` for each shard.
diff --git a/paimon-python/pypaimon/globalindex/range.py
b/paimon-python/pypaimon/globalindex/range.py
index f27a637ae7..19b9b40e94 100644
--- a/paimon-python/pypaimon/globalindex/range.py
+++ b/paimon-python/pypaimon/globalindex/range.py
@@ -154,7 +154,7 @@ class Range:
return result
@staticmethod
- def sort_and_merge_overlap(ranges: List['Range'], merge: bool = True) ->
List['Range']:
+ def sort_and_merge_overlap(ranges: List['Range'], merge: bool = True,
adjacent: bool = True) -> List['Range']:
"""
Sort ranges and optionally merge overlapping ones.
"""
@@ -166,10 +166,11 @@ class Range:
if not merge:
return sorted_ranges
+ adjacent_value = 1 if adjacent else 0
result = [sorted_ranges[0]]
for r in sorted_ranges[1:]:
last = result[-1]
- if r.from_ <= last.to + 1:
+ if r.from_ <= last.to + adjacent_value:
# Merge with last range
result[-1] = Range(last.from_, max(last.to, r.to))
else:
diff --git
a/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
b/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
index ffb0ffafdf..f52efb7ae1 100644
--- a/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
+++ b/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
@@ -77,8 +77,9 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
self.end_pos_of_this_subtask
)
elif self.idx_of_this_subtask is not None:
- # shard data range: [plan_start_pos, plan_end_pos)
- partitioned_files, plan_start_pos, plan_end_pos =
self._filter_by_shard(partitioned_files)
+ partitioned_files = self._filter_by_shard(
+ partitioned_files, self.idx_of_this_subtask,
self.number_of_para_subtasks
+ )
def weight_func(file_list: List[DataFileMeta]) -> int:
return max(sum(f.file_size for f in file_list),
self.open_file_cost)
@@ -108,9 +109,8 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
flatten_packed_files, packed_files, sorted_entries_list
)
- if self.start_pos_of_this_subtask is not None or
self.idx_of_this_subtask is not None:
+ if self.start_pos_of_this_subtask is not None:
splits = self._wrap_to_sliced_splits(splits, plan_start_pos,
plan_end_pos)
-
# Wrap splits with IndexedSplit if row_ranges is provided
if self.row_ranges:
splits = self._wrap_to_indexed_splits(splits)
@@ -242,22 +242,61 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
return filtered_partitioned_files, plan_start_pos, plan_end_pos
- def _filter_by_shard(self, partitioned_files: defaultdict) -> tuple:
- """
- Filter file entries by shard for data evolution tables.
- """
- # Calculate total rows (excluding blob files)
- total_row = sum(
- entry.file.row_count
- for file_entries in partitioned_files.values()
- for entry in file_entries
- if not self._is_blob_file(entry.file.file_name)
- )
-
- # Calculate shard range using shared helper
- start_pos, end_pos = self._compute_shard_range(total_row)
-
- return self._filter_by_row_range(partitioned_files, start_pos, end_pos)
+ def _filter_by_shard(self, partitioned_files: defaultdict, sub_task_id:
int, total_tasks: int) -> defaultdict:
+ list_ranges = []
+ for file_entries in partitioned_files.values():
+ for entry in file_entries:
+ first_row_id = entry.file.first_row_id
+ if first_row_id is None:
+ raise ValueError("Found None first row id in files")
+ # Range is inclusive [from_, to], so use row_count - 1
+ list_ranges.append(Range(first_row_id, first_row_id +
entry.file.row_count - 1))
+
+ sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False)
+
+ start_range, end_range = self._divide_ranges(sorted_ranges,
sub_task_id, total_tasks)
+ if start_range is None or end_range is None:
+ return defaultdict(list)
+ start_first_row_id = start_range.from_
+ end_first_row_id = end_range.to
+
+ filtered_partitioned_files = {
+ k: [x for x in v if x.file.first_row_id >= start_first_row_id and
x.file.first_row_id <= end_first_row_id]
+ for k, v in partitioned_files.items()
+ }
+
+ filtered_partitioned_files = {k: v for k, v in
filtered_partitioned_files.items() if v}
+ return defaultdict(list, filtered_partitioned_files)
+
+ @staticmethod
+ def _divide_ranges(
+ sorted_ranges: List[Range], sub_task_id: int, total_tasks: int
+ ) -> Tuple[Optional[Range], Optional[Range]]:
+ if not sorted_ranges:
+ return None, None
+
+ num_ranges = len(sorted_ranges)
+
+ # If more tasks than ranges, some tasks get nothing
+ if sub_task_id >= num_ranges:
+ return None, None
+
+ # Calculate balanced distribution of ranges across tasks
+ base_ranges_per_task = num_ranges // total_tasks
+ remainder = num_ranges % total_tasks
+
+ # Each of the first 'remainder' tasks gets one extra range
+ if sub_task_id < remainder:
+ num_ranges_for_task = base_ranges_per_task + 1
+ start_idx = sub_task_id * (base_ranges_per_task + 1)
+ else:
+ num_ranges_for_task = base_ranges_per_task
+ start_idx = (
+ remainder * (base_ranges_per_task + 1) +
+ (sub_task_id - remainder) * base_ranges_per_task
+ )
+ end_idx = start_idx + num_ranges_for_task - 1
+ return sorted_ranges[start_idx], sorted_ranges[end_idx]
def _split_by_row_id(self, files: List[DataFileMeta]) ->
List[List[DataFileMeta]]:
"""
diff --git a/paimon-python/pypaimon/tests/blob_table_test.py
b/paimon-python/pypaimon/tests/blob_table_test.py
index 88d2626bb6..7670ec9447 100755
--- a/paimon-python/pypaimon/tests/blob_table_test.py
+++ b/paimon-python/pypaimon/tests/blob_table_test.py
@@ -2204,12 +2204,12 @@ class DataBlobWriterTest(unittest.TestCase):
result = table_read.to_arrow(table_scan.plan().splits())
# Verify the data
- self.assertEqual(result.num_rows, 54, "Should have 54 rows")
+ self.assertEqual(result.num_rows, 80, "Should have 54 rows")
self.assertEqual(result.num_columns, 4, "Should have 4 columns")
# Verify blob data integrity
blob_data = result.column('large_blob').to_pylist()
- self.assertEqual(len(blob_data), 54, "Should have 54 blob records")
+ self.assertEqual(len(blob_data), 80, "Should have 54 blob records")
# Verify each blob
for i, blob in enumerate(blob_data):
self.assertEqual(len(blob), len(large_blob_data), f"Blob {i + 1}
should be {large_blob_size:,} bytes")
@@ -2264,21 +2264,22 @@ class DataBlobWriterTest(unittest.TestCase):
actual_size = len(large_blob_data)
print(f"Created blob data: {actual_size:,} bytes ({actual_size / (1024
* 1024):.2f} MB)")
- write_builder = table.new_batch_write_builder()
- writer = write_builder.new_write()
# Write 30 records
- for record_id in range(30):
- test_data = pa.Table.from_pydict({
- 'id': [record_id], # Unique ID for each row
- 'metadata': [f'Large blob batch {record_id + 1}'],
- 'large_blob': [struct.pack('<I', record_id) + large_blob_data]
- }, schema=pa_schema)
- writer.write_arrow(test_data)
+ for i in range(3):
+ write_builder = table.new_batch_write_builder()
+ writer = write_builder.new_write()
+ for record_id in range(10):
+ test_data = pa.Table.from_pydict({
+ 'id': [record_id * i], # Unique ID for each row
+ 'metadata': [f'Large blob batch {record_id + 1}'],
+ 'large_blob': [struct.pack('<I', record_id) +
large_blob_data]
+ }, schema=pa_schema)
+ writer.write_arrow(test_data)
- commit_messages = writer.prepare_commit()
- commit = write_builder.new_commit()
- commit.commit(commit_messages)
- writer.close()
+ commit_messages = writer.prepare_commit()
+ commit = write_builder.new_commit()
+ commit.commit(commit_messages)
+ writer.close()
# Read data back
read_builder = table.new_read_builder()
@@ -2304,7 +2305,7 @@ class DataBlobWriterTest(unittest.TestCase):
actual2 = table_read.to_arrow(splits2)
splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits()
actual3 = table_read.to_arrow(splits3)
- actual = pa.concat_tables([actual1, actual2, actual3]).sort_by('id')
+ actual = pa.concat_tables([actual1, actual2, actual3])
# Verify the data
self.assertEqual(actual.num_rows, 30, "Should have 30 rows")
@@ -2337,22 +2338,25 @@ class DataBlobWriterTest(unittest.TestCase):
repetitions = large_blob_size // pattern_size
large_blob_data = blob_pattern * repetitions
- num_row = 20000
- write_builder = table.new_batch_write_builder()
- writer = write_builder.new_write()
- expected = pa.Table.from_pydict({
- 'id': [1] * num_row,
- 'batch_id': [11] * num_row,
- 'metadata': [f'Large blob batch {11}'] * num_row,
- 'large_blob': [i.to_bytes(2, byteorder='little') + large_blob_data
for i in range(num_row)]
- }, schema=pa_schema)
- writer.write_arrow(expected)
+ for i in range(3):
+ num_row = 6666
+ if i == 0:
+ num_row += 1
+ write_builder = table.new_batch_write_builder()
+ writer = write_builder.new_write()
+ expected = pa.Table.from_pydict({
+ 'id': [1] * num_row,
+ 'batch_id': [11] * num_row,
+ 'metadata': [f'Large blob batch {11}'] * num_row,
+ 'large_blob': [i.to_bytes(2, byteorder='little') +
large_blob_data for i in range(num_row)]
+ }, schema=pa_schema)
+ writer.write_arrow(expected)
- # Commit all data at once
- commit_messages = writer.prepare_commit()
- commit = write_builder.new_commit()
- commit.commit(commit_messages)
- writer.close()
+ # Commit all data at once
+ commit_messages = writer.prepare_commit()
+ commit = write_builder.new_commit()
+ commit.commit(commit_messages)
+ writer.close()
# Read data back
read_builder = table.new_read_builder()
@@ -2364,7 +2368,7 @@ class DataBlobWriterTest(unittest.TestCase):
self.assertEqual(6666, result.num_rows)
self.assertEqual(4, result.num_columns)
- self.assertEqual(expected.slice(13334, 6666), result)
+ self.assertEqual(expected, result)
splits = read_builder.new_scan().plan().splits()
expected = table_read.to_arrow(splits)
splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits()
@@ -2490,13 +2494,14 @@ class DataBlobWriterTest(unittest.TestCase):
# Read data back using table API
read_builder = table.new_read_builder()
- table_scan = read_builder.new_scan().with_shard(1, 2)
+ table_scan = read_builder.new_scan().with_shard(0, 2)
table_read = read_builder.new_read()
splits = table_scan.plan().splits()
result = table_read.to_arrow(splits)
# Verify the data was read back correctly
- self.assertEqual(result.num_rows, 2, "Should have 2 rows")
+ # Just one file, so split 0 occupied the whole records
+ self.assertEqual(result.num_rows, 5, "Should have 2 rows")
self.assertEqual(result.num_columns, 3, "Should have 3 columns")
def test_blob_write_read_large_data_volume_rolling_with_shard(self):
diff --git a/paimon-python/pypaimon/tests/shard_table_updator_test.py
b/paimon-python/pypaimon/tests/shard_table_updator_test.py
new file mode 100644
index 0000000000..641f545b47
--- /dev/null
+++ b/paimon-python/pypaimon/tests/shard_table_updator_test.py
@@ -0,0 +1,320 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import tempfile
+import unittest
+
+import pyarrow as pa
+
+from pypaimon import CatalogFactory, Schema
+
+
+class ShardTableUpdatorTest(unittest.TestCase):
+ """Tests for ShardTableUpdator partial column updates in data-evolution
mode."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog = CatalogFactory.create({
+ 'warehouse': cls.warehouse
+ })
+ cls.catalog.create_database('default', False)
+ cls.table_count = 0
+
+ def _create_unique_table_name(self, prefix='test'):
+ ShardTableUpdatorTest.table_count += 1
+ return f'default.{prefix}_{ShardTableUpdatorTest.table_count}'
+
+ def test_compute_column_d_equals_c_plus_b_minus_a(self):
+ """
+ Test: Create a table with columns a, b, c, d.
+ Write initial data for a, b, c.
+ Use ShardTableUpdator to compute d = c + b - a and fill in the d
column.
+ """
+ # Step 1: Create table with a, b, c, d columns (all int32)
+ table_schema = pa.schema([
+ ('a', pa.int32()),
+ ('b', pa.int32()),
+ ('c', pa.int32()),
+ ('d', pa.int32()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ table_schema,
+ options={'row-tracking.enabled': 'true', 'data-evolution.enabled':
'true'}
+ )
+ name = self._create_unique_table_name()
+ self.catalog.create_table(name, schema, False)
+ table = self.catalog.get_table(name)
+
+ # Step 2: Write initial data for a, b, c columns only
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write().with_write_type(['a', 'b',
'c'])
+ table_commit = write_builder.new_commit()
+
+ init_data = pa.Table.from_pydict({
+ 'a': [1, 2, 3, 4, 5],
+ 'b': [10, 20, 30, 40, 50],
+ 'c': [100, 200, 300, 400, 500],
+ }, schema=pa.schema([('a', pa.int32()), ('b', pa.int32()), ('c',
pa.int32())]))
+
+ table_write.write_arrow(init_data)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # Step 3: Use ShardTableUpdator to compute d = c + b - a
+ table_update = write_builder.new_update()
+ table_update.with_read_projection(['a', 'b', 'c'])
+ table_update.with_update_type(['d'])
+
+ shard_updator = table_update.new_shard_updator(0, 1)
+
+ # Read data using arrow_reader
+ reader = shard_updator.arrow_reader()
+
+ for batch in iter(reader.read_next_batch, None):
+ # Compute d = c + b - a
+ a_values = batch.column('a').to_pylist()
+ b_values = batch.column('b').to_pylist()
+ c_values = batch.column('c').to_pylist()
+
+ d_values = [c + b - a for a, b, c in zip(a_values, b_values,
c_values)]
+
+ # Create batch with d column
+ new_batch = pa.RecordBatch.from_pydict({
+ 'd': d_values,
+ }, schema=pa.schema([('d', pa.int32())]))
+
+ # Write d column
+ shard_updator.update_by_arrow_batch(new_batch)
+
+ # Prepare and commit
+ commit_messages = shard_updator.prepare_commit()
+ commit = write_builder.new_commit()
+ commit.commit(commit_messages)
+ commit.close()
+
+ # Step 4: Verify the result
+ read_builder = table.new_read_builder()
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ actual = table_read.to_arrow(table_scan.plan().splits())
+
+ # Expected values:
+ # Row 0: d = 100 + 10 - 1 = 109
+ # Row 1: d = 200 + 20 - 2 = 218
+ # Row 2: d = 300 + 30 - 3 = 327
+ # Row 3: d = 400 + 40 - 4 = 436
+ # Row 4: d = 500 + 50 - 5 = 545
+ expected = pa.Table.from_pydict({
+ 'a': [1, 2, 3, 4, 5],
+ 'b': [10, 20, 30, 40, 50],
+ 'c': [100, 200, 300, 400, 500],
+ 'd': [109, 218, 327, 436, 545],
+ }, schema=table_schema)
+
+ print("\n=== Actual Data ===")
+ print(actual.to_pandas())
+ print("\n=== Expected Data ===")
+ print(expected.to_pandas())
+
+ self.assertEqual(actual, expected)
+ print("\n✅ Test passed! Column d = c + b - a computed correctly!")
+
+ def test_compute_column_d_equals_c_plus_b_minus_a2(self):
+ """
+ Test: Create a table with columns a, b, c, d.
+ Write initial data for a, b, c.
+ Use ShardTableUpdator to compute d = c + b - a and fill in the d
column.
+ """
+ # Step 1: Create table with a, b, c, d columns (all int32)
+ table_schema = pa.schema([
+ ('a', pa.int32()),
+ ('b', pa.int32()),
+ ('c', pa.int32()),
+ ('d', pa.int32()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ table_schema,
+ options={'row-tracking.enabled': 'true', 'data-evolution.enabled':
'true'}
+ )
+ name = self._create_unique_table_name()
+ self.catalog.create_table(name, schema, False)
+ table = self.catalog.get_table(name)
+
+ # Step 2: Write initial data for a, b, c columns only
+ for i in range(1000):
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write().with_write_type(['a', 'b',
'c'])
+ table_commit = write_builder.new_commit()
+
+ init_data = pa.Table.from_pydict({
+ 'a': [1, 2, 3, 4, 5],
+ 'b': [10, 20, 30, 40, 50],
+ 'c': [100, 200, 300, 400, 500],
+ }, schema=pa.schema([('a', pa.int32()), ('b', pa.int32()), ('c',
pa.int32())]))
+
+ table_write.write_arrow(init_data)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # Step 3: Use ShardTableUpdator to compute d = c + b - a
+ table_update = write_builder.new_update()
+ table_update.with_read_projection(['a', 'b', 'c'])
+ table_update.with_update_type(['d'])
+
+ for i in range(10):
+ d_all_values = []
+ shard_updator = table_update.new_shard_updator(i, 10)
+
+ # Read data using arrow_reader
+ reader = shard_updator.arrow_reader()
+
+ for batch in iter(reader.read_next_batch, None):
+ # Compute d = c + b - a
+ a_values = batch.column('a').to_pylist()
+ b_values = batch.column('b').to_pylist()
+ c_values = batch.column('c').to_pylist()
+
+ d_values = [c + b - a for a, b, c in zip(a_values, b_values,
c_values)]
+ d_all_values.extend(d_values)
+
+ # Concatenate all computed values and update once for this shard
+ new_batch = pa.RecordBatch.from_pydict(
+ {'d': d_all_values},
+ schema=pa.schema([('d', pa.int32())]),
+ )
+ shard_updator.update_by_arrow_batch(new_batch)
+
+ # Prepare and commit
+ commit_messages = shard_updator.prepare_commit()
+ commit = write_builder.new_commit()
+ commit.commit(commit_messages)
+ commit.close()
+
+ # Step 4: Verify the result
+ read_builder = table.new_read_builder()
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ actual = table_read.to_arrow(table_scan.plan().splits())
+
+ # Expected values:
+ # Row 0: d = 100 + 10 - 1 = 109
+ # Row 1: d = 200 + 20 - 2 = 218
+ # Row 2: d = 300 + 30 - 3 = 327
+ # Row 3: d = 400 + 40 - 4 = 436
+ # Row 4: d = 500 + 50 - 5 = 545
+ expected = pa.Table.from_pydict({
+ 'a': [1, 2, 3, 4, 5] * 1000,
+ 'b': [10, 20, 30, 40, 50] * 1000,
+ 'c': [100, 200, 300, 400, 500] * 1000,
+ 'd': [109, 218, 327, 436, 545] * 1000,
+ }, schema=table_schema)
+
+ print("\n=== Actual Data ===")
+ print(actual.to_pandas())
+ print("\n=== Expected Data ===")
+ print(expected.to_pandas())
+
+ self.assertEqual(actual, expected)
+ print("\n✅ Test passed! Column d = c + b - a computed correctly!")
+
+ def test_compute_column_with_existing_column(self):
+ table_schema = pa.schema([
+ ('a', pa.int32()),
+ ('b', pa.int32()),
+ ('c', pa.int32()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ table_schema,
+ options={'row-tracking.enabled': 'true', 'data-evolution.enabled':
'true'}
+ )
+ name = self._create_unique_table_name()
+ self.catalog.create_table(name, schema, False)
+ table = self.catalog.get_table(name)
+
+ # Step 2: Write initial data for a, b, c columns only
+ for i in range(1000):
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write().with_write_type(['a', 'b',
'c'])
+ table_commit = write_builder.new_commit()
+
+ init_data = pa.Table.from_pydict({
+ 'a': [1, 2, 3, 4, 5],
+ 'b': [10, 20, 30, 40, 50],
+ 'c': [100, 200, 300, 400, 500],
+ }, schema=pa.schema([('a', pa.int32()), ('b', pa.int32()), ('c',
pa.int32())]))
+
+ table_write.write_arrow(init_data)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # Step 3: Use ShardTableUpdator to compute d = c + b - a
+ table_update = write_builder.new_update()
+ table_update.with_read_projection(['a', 'b'])
+ table_update.with_update_type(['c'])
+
+ for i in range(10):
+ shard_updator = table_update.new_shard_updator(i, 10)
+
+ # Read data using arrow_reader
+ reader = shard_updator.arrow_reader()
+
+ for batch in iter(reader.read_next_batch, None):
+ a_values = batch.column('a').to_pylist()
+ b_values = batch.column('b').to_pylist()
+
+ c_values = [b - a for a, b in zip(a_values, b_values)]
+
+ new_batch = pa.RecordBatch.from_pydict({
+ 'c': c_values,
+ }, schema=pa.schema([('c', pa.int32())]))
+
+ shard_updator.update_by_arrow_batch(new_batch)
+
+ # Prepare and commit
+ commit_messages = shard_updator.prepare_commit()
+ commit = write_builder.new_commit()
+ commit.commit(commit_messages)
+ commit.close()
+
+ # Step 4: Verify the result
+ read_builder = table.new_read_builder()
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ actual = table_read.to_arrow(table_scan.plan().splits())
+
+ expected = pa.Table.from_pydict({
+ 'a': [1, 2, 3, 4, 5] * 1000,
+ 'b': [10, 20, 30, 40, 50] * 1000,
+ 'c': [9, 18, 27, 36, 45] * 1000,
+ }, schema=table_schema)
+
+ print("\n=== Actual Data ===")
+ print(actual.to_pandas())
+ print("\n=== Expected Data ===")
+ print(expected.to_pandas())
+
+ self.assertEqual(actual, expected)
+ print("\n✅ Test passed! Column d = c + b - a computed correctly!")
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paimon-python/pypaimon/write/table_update.py
b/paimon-python/pypaimon/write/table_update.py
index baaf3f18e3..8e3f91bde4 100644
--- a/paimon-python/pypaimon/write/table_update.py
+++ b/paimon-python/pypaimon/write/table_update.py
@@ -15,12 +15,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-from typing import List
+from collections import defaultdict
+from typing import List, Optional, Tuple
+import pyarrow
import pyarrow as pa
+from pypaimon.common.memory_size import MemorySize
+from pypaimon.globalindex import Range
+from pypaimon.manifest.schema.data_file_meta import DataFileMeta
+from pypaimon.read.split import DataSplit
from pypaimon.write.commit_message import CommitMessage
from pypaimon.write.table_update_by_row_id import TableUpdateByRowId
+from pypaimon.write.writer.data_writer import DataWriter
+from pypaimon.write.writer.append_only_data_writer import AppendOnlyDataWriter
class TableUpdate:
@@ -30,6 +38,7 @@ class TableUpdate:
self.table: FileStoreTable = table
self.commit_user = commit_user
self.update_cols = None
+ self.projection = None
def with_update_type(self, update_cols: List[str]):
for col in update_cols:
@@ -40,7 +49,151 @@ class TableUpdate:
self.update_cols = update_cols
return self
+ def with_read_projection(self, projection: List[str]):
+ self.projection = projection
+
+ def new_shard_updator(self, shard_num: int, total_shard_count: int):
+ """Create a shard updater for scan+rewrite style updates.
+
+ Args:
+ shard_num: Index of this shard/subtask.
+ total_shard_count: Total number of shards/subtasks.
+ """
+ return ShardTableUpdator(
+ self.table,
+ self.projection,
+ self.update_cols,
+ self.commit_user,
+ shard_num,
+ total_shard_count,
+ )
+
def update_by_arrow_with_row_id(self, table: pa.Table) ->
List[CommitMessage]:
update_by_row_id = TableUpdateByRowId(self.table, self.commit_user)
update_by_row_id.update_columns(table, self.update_cols)
return update_by_row_id.commit_messages
+
+
+class ShardTableUpdator:
+
+ def __init__(
+ self,
+ table,
+ projection: Optional[List[str]],
+ write_cols: List[str],
+ commit_user,
+ shard_num: int,
+ total_shard_count: int,
+ ):
+ from pypaimon.table.file_store_table import FileStoreTable
+ self.table: FileStoreTable = table
+ self.projection = projection
+ self.write_cols = write_cols
+ self.commit_user = commit_user
+ self.total_shard_count = total_shard_count
+ self.shard_num = shard_num
+
+ self.write_pos = 0
+ self.writer: Optional[SingleWriter] = None
+ self.dict = defaultdict(list)
+
+ scanner =
self.table.new_read_builder().new_scan().with_shard(shard_num,
total_shard_count)
+ self.splits = scanner.plan().splits()
+
+ self.row_ranges: List[(Tuple, Range)] = []
+ for split in self.splits:
+ if not isinstance(split, DataSplit):
+ raise ValueError(f"Split {split} is not DataSplit.")
+ files = split.files
+ ranges = self.compute_from_files(files)
+ for row_range in ranges:
+ self.row_ranges.append((tuple(split.partition.values),
row_range))
+
+ @staticmethod
+ def compute_from_files(files: List[DataFileMeta]) -> List[Range]:
+ ranges = []
+ for file in files:
+ ranges.append(Range(file.first_row_id, file.first_row_id +
file.row_count - 1))
+
+ return Range.sort_and_merge_overlap(ranges, True, False)
+
+ def arrow_reader(self) -> pyarrow.ipc.RecordBatchReader:
+ read_builder = self.table.new_read_builder()
+ read_builder.with_projection(self.projection)
+ return read_builder.new_read().to_arrow_batch_reader(self.splits)
+
+ def prepare_commit(self) -> List[CommitMessage]:
+ commit_messages = []
+ for (partition, files) in self.dict.items():
+ commit_messages.append(CommitMessage(partition, 0, files))
+ return commit_messages
+
+ def update_by_arrow_batch(self, data: pa.RecordBatch):
+ self._init_writer()
+
+ capacity = self.writer.capacity()
+ if capacity <= 0:
+ raise RuntimeError("Writer has no remaining capacity.")
+
+ # Split the batch across writers.
+ first, rest = (data, None) if capacity >= data.num_rows else
(data.slice(0, capacity), data.slice(capacity))
+
+ self.writer.write(first)
+ if self.writer.capacity() == 0:
+ self.dict[self.writer.partition()].append(self.writer.end())
+ self.writer = None
+
+ if rest is not None:
+ if self.writer is not None:
+ raise RuntimeError("Should not get here, rest and current
writer exist in the same time.")
+ self.update_by_arrow_batch(rest)
+
+ def _init_writer(self):
+ if self.writer is None:
+ if self.write_pos >= len(self.row_ranges):
+ raise RuntimeError(
+ "No more row ranges to write. "
+ "Ensure you write exactly the same number of rows as read
from this shard."
+ )
+ item = self.row_ranges[self.write_pos]
+ self.write_pos += 1
+ partition = item[0]
+ row_range = item[1]
+ writer = AppendOnlyDataWriter(self.table, partition, 0, 0,
self.table.options, self.write_cols)
+ writer.target_file_size =
MemorySize.of_mebi_bytes(999999999).get_bytes()
+ self.writer = SingleWriter(writer, partition, row_range.from_,
row_range.to - row_range.from_ + 1)
+
+
+class SingleWriter:
+
+ def __init__(self, writer: DataWriter, partition, first_row_id: int,
row_count: int):
+ self.writer: DataWriter = writer
+ self._partition = partition
+ self.first_row_id = first_row_id
+ self.row_count = row_count
+ self.written_records_count = 0
+
+ def capacity(self) -> int:
+ return self.row_count - self.written_records_count
+
+ def write(self, data: pa.RecordBatch):
+ if data.num_rows > self.capacity():
+ raise Exception("Data num size exceeds capacity.")
+ self.written_records_count += data.num_rows
+ self.writer.write(data)
+ return
+
+ def partition(self) -> Tuple:
+ return self._partition
+
+ def end(self) -> DataFileMeta:
+ if self.capacity() != 0:
+ raise Exception("There still capacity left in the writer.")
+ files = self.writer.prepare_commit()
+ if len(files) != 1:
+ raise Exception("Should have one file.")
+ file = files[0]
+ if file.row_count != self.row_count:
+ raise Exception("File row count mismatch.")
+ file = file.assign_first_row_id(self.first_row_id)
+ return file