This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 265f9084c0 [python] Enable update data to compute a New Column (scan + 
rewrite with shards)  (#7148)
265f9084c0 is described below

commit 265f9084c0155889ad10567cdedd4e6d8414a12c
Author: YeJunHao <[email protected]>
AuthorDate: Fri Jan 30 11:22:33 2026 +0800

    [python] Enable update data to compute a New Column (scan + rewrite with 
shards)  (#7148)
---
 docs/content/pypaimon/data-evolution.md            | 118 ++++++++
 paimon-python/pypaimon/globalindex/range.py        |   5 +-
 .../read/scanner/data_evolution_split_generator.py |  79 +++--
 paimon-python/pypaimon/tests/blob_table_test.py    |  73 ++---
 .../pypaimon/tests/shard_table_updator_test.py     | 320 +++++++++++++++++++++
 paimon-python/pypaimon/write/table_update.py       | 155 +++++++++-
 6 files changed, 693 insertions(+), 57 deletions(-)

diff --git a/docs/content/pypaimon/data-evolution.md 
b/docs/content/pypaimon/data-evolution.md
index bd0bde03a4..2268090f73 100644
--- a/docs/content/pypaimon/data-evolution.md
+++ b/docs/content/pypaimon/data-evolution.md
@@ -29,6 +29,13 @@ under the License.
 
 PyPaimon for Data Evolution mode. See [Data Evolution]({{< ref 
"append-table/data-evolution" >}}).
 
+## Prerequisites
+
+To use partial updates / data evolution, enable both options when creating the 
table:
+
+- **`row-tracking.enabled`**: `true`
+- **`data-evolution.enabled`**: `true`
+
 ## Update Columns By Row ID
 
 You can create `TableUpdate.update_by_arrow_with_row_id` to update columns to 
data evolution tables.
@@ -36,7 +43,19 @@ You can create `TableUpdate.update_by_arrow_with_row_id` to 
update columns to da
 The input data should include the `_ROW_ID` column, update operation will 
automatically sort and match each `_ROW_ID` to
 its corresponding `first_row_id`, then groups rows with the same 
`first_row_id` and writes them to a separate file.
 
+### Requirements for `_ROW_ID` updates
+
+- **All rows required**: the input table must contain **exactly the full table 
row count** (one row per existing row).
+- **Row id coverage**: after sorting by `_ROW_ID`, it must be **0..N-1** (no 
duplicates, no gaps).
+- **Update columns only**: include `_ROW_ID` plus the columns you want to 
update (partial schema is OK).
+
 ```python
+import pyarrow as pa
+from pypaimon import CatalogFactory, Schema
+
+catalog = CatalogFactory.create({'warehouse': '/tmp/warehouse'})
+catalog.create_database('default', False)
+
 simple_pa_schema = pa.schema([
   ('f0', pa.int8()),
   ('f1', pa.int16()),
@@ -78,3 +97,102 @@ table_commit.close()
 #   'f0': [5, 6],
 #   'f1': [-1001, 1002]
 ```
+
+## Compute a New Column (scan + rewrite with shards)
+
+If you want to **compute a derived column** (or **update an existing column 
based on other columns**) without providing
+`_ROW_ID`, you can use the shard scan + rewrite workflow:
+
+- Read only the columns you need (projection)
+- Compute the new values in the same row order
+- Write only the updated columns back
+- Commit per shard
+
+This is useful for backfilling a newly added column, or recomputing a column 
from other columns.
+
+### Example: compute `d = c + b - a`
+
+```python
+import pyarrow as pa
+from pypaimon import CatalogFactory, Schema
+
+catalog = CatalogFactory.create({'warehouse': '/tmp/warehouse'})
+catalog.create_database('default', False)
+
+table_schema = pa.schema([
+    ('a', pa.int32()),
+    ('b', pa.int32()),
+    ('c', pa.int32()),
+    ('d', pa.int32()),
+])
+
+schema = Schema.from_pyarrow_schema(
+    table_schema,
+    options={'row-tracking.enabled': 'true', 'data-evolution.enabled': 'true'},
+)
+catalog.create_table('default.t', schema, False)
+table = catalog.get_table('default.t')
+
+# write initial data (a, b, c only)
+write_builder = table.new_batch_write_builder()
+write = write_builder.new_write().with_write_type(['a', 'b', 'c'])
+commit = write_builder.new_commit()
+write.write_arrow(pa.Table.from_pydict({'a': [1, 2], 'b': [10, 20], 'c': [100, 
200]}))
+commit.commit(write.prepare_commit())
+write.close()
+commit.close()
+
+# shard update: read (a, b, c), write only (d)
+update = write_builder.new_update()
+update.with_read_projection(['a', 'b', 'c'])
+update.with_update_type(['d'])
+
+shard_idx = 0
+num_shards = 1
+upd = update.new_shard_updator(shard_idx, num_shards)
+reader = upd.arrow_reader()
+
+for batch in iter(reader.read_next_batch, None):
+    a = batch.column('a').to_pylist()
+    b = batch.column('b').to_pylist()
+    c = batch.column('c').to_pylist()
+    d = [ci + bi - ai for ai, bi, ci in zip(a, b, c)]
+
+    upd.update_by_arrow_batch(
+        pa.RecordBatch.from_pydict({'d': d}, schema=pa.schema([('d', 
pa.int32())]))
+    )
+
+commit_messages = upd.prepare_commit()
+commit = write_builder.new_commit()
+commit.commit(commit_messages)
+commit.close()
+```
+
+### Example: update an existing column `c = b - a`
+
+```python
+update = write_builder.new_update()
+update.with_read_projection(['a', 'b'])
+update.with_update_type(['c'])
+
+upd = update.new_shard_updator(0, 1)
+reader = upd.arrow_reader()
+for batch in iter(reader.read_next_batch, None):
+    a = batch.column('a').to_pylist()
+    b = batch.column('b').to_pylist()
+    c = [bi - ai for ai, bi in zip(a, b)]
+    upd.update_by_arrow_batch(
+        pa.RecordBatch.from_pydict({'c': c}, schema=pa.schema([('c', 
pa.int32())]))
+    )
+
+commit_messages = upd.prepare_commit()
+commit = write_builder.new_commit()
+commit.commit(commit_messages)
+commit.close()
+```
+
+### Notes
+
+- **Row order matters**: the batches you write must have the **same number of 
rows** as the batches you read, in the
+  same order for that shard.
+- **Parallelism**: run multiple shards by calling 
`new_shard_updator(shard_idx, num_shards)` for each shard.
diff --git a/paimon-python/pypaimon/globalindex/range.py 
b/paimon-python/pypaimon/globalindex/range.py
index f27a637ae7..19b9b40e94 100644
--- a/paimon-python/pypaimon/globalindex/range.py
+++ b/paimon-python/pypaimon/globalindex/range.py
@@ -154,7 +154,7 @@ class Range:
         return result
 
     @staticmethod
-    def sort_and_merge_overlap(ranges: List['Range'], merge: bool = True) -> 
List['Range']:
+    def sort_and_merge_overlap(ranges: List['Range'], merge: bool = True, 
adjacent: bool = True) -> List['Range']:
         """
         Sort ranges and optionally merge overlapping ones.
         """
@@ -166,10 +166,11 @@ class Range:
         if not merge:
             return sorted_ranges
 
+        adjacent_value = 1 if adjacent else 0
         result = [sorted_ranges[0]]
         for r in sorted_ranges[1:]:
             last = result[-1]
-            if r.from_ <= last.to + 1:
+            if r.from_ <= last.to + adjacent_value:
                 # Merge with last range
                 result[-1] = Range(last.from_, max(last.to, r.to))
             else:
diff --git 
a/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py 
b/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
index ffb0ffafdf..f52efb7ae1 100644
--- a/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
+++ b/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
@@ -77,8 +77,9 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
                     self.end_pos_of_this_subtask
                 )
         elif self.idx_of_this_subtask is not None:
-            # shard data range: [plan_start_pos, plan_end_pos)
-            partitioned_files, plan_start_pos, plan_end_pos = 
self._filter_by_shard(partitioned_files)
+            partitioned_files = self._filter_by_shard(
+                partitioned_files, self.idx_of_this_subtask, 
self.number_of_para_subtasks
+            )
 
         def weight_func(file_list: List[DataFileMeta]) -> int:
             return max(sum(f.file_size for f in file_list), 
self.open_file_cost)
@@ -108,9 +109,8 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
                 flatten_packed_files, packed_files, sorted_entries_list
             )
 
-        if self.start_pos_of_this_subtask is not None or 
self.idx_of_this_subtask is not None:
+        if self.start_pos_of_this_subtask is not None:
             splits = self._wrap_to_sliced_splits(splits, plan_start_pos, 
plan_end_pos)
-
         # Wrap splits with IndexedSplit if row_ranges is provided
         if self.row_ranges:
             splits = self._wrap_to_indexed_splits(splits)
@@ -242,22 +242,61 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
 
         return filtered_partitioned_files, plan_start_pos, plan_end_pos
 
-    def _filter_by_shard(self, partitioned_files: defaultdict) -> tuple:
-        """
-        Filter file entries by shard for data evolution tables.
-        """
-        # Calculate total rows (excluding blob files)
-        total_row = sum(
-            entry.file.row_count
-            for file_entries in partitioned_files.values()
-            for entry in file_entries
-            if not self._is_blob_file(entry.file.file_name)
-        )
-
-        # Calculate shard range using shared helper
-        start_pos, end_pos = self._compute_shard_range(total_row)
-
-        return self._filter_by_row_range(partitioned_files, start_pos, end_pos)
+    def _filter_by_shard(self, partitioned_files: defaultdict, sub_task_id: 
int, total_tasks: int) -> defaultdict:
+        list_ranges = []
+        for file_entries in partitioned_files.values():
+            for entry in file_entries:
+                first_row_id = entry.file.first_row_id
+                if first_row_id is None:
+                    raise ValueError("Found None first row id in files")
+                # Range is inclusive [from_, to], so use row_count - 1
+                list_ranges.append(Range(first_row_id, first_row_id + 
entry.file.row_count - 1))
+
+        sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False)
+
+        start_range, end_range = self._divide_ranges(sorted_ranges, 
sub_task_id, total_tasks)
+        if start_range is None or end_range is None:
+            return defaultdict(list)
+        start_first_row_id = start_range.from_
+        end_first_row_id = end_range.to
+
+        filtered_partitioned_files = {
+            k: [x for x in v if x.file.first_row_id >= start_first_row_id and 
x.file.first_row_id <= end_first_row_id]
+            for k, v in partitioned_files.items()
+        }
+
+        filtered_partitioned_files = {k: v for k, v in 
filtered_partitioned_files.items() if v}
+        return defaultdict(list, filtered_partitioned_files)
+
+    @staticmethod
+    def _divide_ranges(
+        sorted_ranges: List[Range], sub_task_id: int, total_tasks: int
+    ) -> Tuple[Optional[Range], Optional[Range]]:
+        if not sorted_ranges:
+            return None, None
+
+        num_ranges = len(sorted_ranges)
+
+        # If more tasks than ranges, some tasks get nothing
+        if sub_task_id >= num_ranges:
+            return None, None
+
+        # Calculate balanced distribution of ranges across tasks
+        base_ranges_per_task = num_ranges // total_tasks
+        remainder = num_ranges % total_tasks
+
+        # Each of the first 'remainder' tasks gets one extra range
+        if sub_task_id < remainder:
+            num_ranges_for_task = base_ranges_per_task + 1
+            start_idx = sub_task_id * (base_ranges_per_task + 1)
+        else:
+            num_ranges_for_task = base_ranges_per_task
+            start_idx = (
+                remainder * (base_ranges_per_task + 1) +
+                (sub_task_id - remainder) * base_ranges_per_task
+            )
+        end_idx = start_idx + num_ranges_for_task - 1
+        return sorted_ranges[start_idx], sorted_ranges[end_idx]
 
     def _split_by_row_id(self, files: List[DataFileMeta]) -> 
List[List[DataFileMeta]]:
         """
diff --git a/paimon-python/pypaimon/tests/blob_table_test.py 
b/paimon-python/pypaimon/tests/blob_table_test.py
index 88d2626bb6..7670ec9447 100755
--- a/paimon-python/pypaimon/tests/blob_table_test.py
+++ b/paimon-python/pypaimon/tests/blob_table_test.py
@@ -2204,12 +2204,12 @@ class DataBlobWriterTest(unittest.TestCase):
         result = table_read.to_arrow(table_scan.plan().splits())
 
         # Verify the data
-        self.assertEqual(result.num_rows, 54, "Should have 54 rows")
+        self.assertEqual(result.num_rows, 80, "Should have 54 rows")
         self.assertEqual(result.num_columns, 4, "Should have 4 columns")
 
         # Verify blob data integrity
         blob_data = result.column('large_blob').to_pylist()
-        self.assertEqual(len(blob_data), 54, "Should have 54 blob records")
+        self.assertEqual(len(blob_data), 80, "Should have 54 blob records")
         # Verify each blob
         for i, blob in enumerate(blob_data):
             self.assertEqual(len(blob), len(large_blob_data), f"Blob {i + 1} 
should be {large_blob_size:,} bytes")
@@ -2264,21 +2264,22 @@ class DataBlobWriterTest(unittest.TestCase):
         actual_size = len(large_blob_data)
         print(f"Created blob data: {actual_size:,} bytes ({actual_size / (1024 
* 1024):.2f} MB)")
 
-        write_builder = table.new_batch_write_builder()
-        writer = write_builder.new_write()
         # Write 30 records
-        for record_id in range(30):
-            test_data = pa.Table.from_pydict({
-                'id': [record_id],  # Unique ID for each row
-                'metadata': [f'Large blob batch {record_id + 1}'],
-                'large_blob': [struct.pack('<I', record_id) + large_blob_data]
-            }, schema=pa_schema)
-            writer.write_arrow(test_data)
+        for i in range(3):
+            write_builder = table.new_batch_write_builder()
+            writer = write_builder.new_write()
+            for record_id in range(10):
+                test_data = pa.Table.from_pydict({
+                    'id': [record_id * i],  # Unique ID for each row
+                    'metadata': [f'Large blob batch {record_id + 1}'],
+                    'large_blob': [struct.pack('<I', record_id) + 
large_blob_data]
+                }, schema=pa_schema)
+                writer.write_arrow(test_data)
 
-        commit_messages = writer.prepare_commit()
-        commit = write_builder.new_commit()
-        commit.commit(commit_messages)
-        writer.close()
+            commit_messages = writer.prepare_commit()
+            commit = write_builder.new_commit()
+            commit.commit(commit_messages)
+            writer.close()
 
         # Read data back
         read_builder = table.new_read_builder()
@@ -2304,7 +2305,7 @@ class DataBlobWriterTest(unittest.TestCase):
         actual2 = table_read.to_arrow(splits2)
         splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits()
         actual3 = table_read.to_arrow(splits3)
-        actual = pa.concat_tables([actual1, actual2, actual3]).sort_by('id')
+        actual = pa.concat_tables([actual1, actual2, actual3])
 
         # Verify the data
         self.assertEqual(actual.num_rows, 30, "Should have 30 rows")
@@ -2337,22 +2338,25 @@ class DataBlobWriterTest(unittest.TestCase):
         repetitions = large_blob_size // pattern_size
         large_blob_data = blob_pattern * repetitions
 
-        num_row = 20000
-        write_builder = table.new_batch_write_builder()
-        writer = write_builder.new_write()
-        expected = pa.Table.from_pydict({
-            'id': [1] * num_row,
-            'batch_id': [11] * num_row,
-            'metadata': [f'Large blob batch {11}'] * num_row,
-            'large_blob': [i.to_bytes(2, byteorder='little') + large_blob_data 
for i in range(num_row)]
-        }, schema=pa_schema)
-        writer.write_arrow(expected)
+        for i in range(3):
+            num_row = 6666
+            if i == 0:
+                num_row += 1
+            write_builder = table.new_batch_write_builder()
+            writer = write_builder.new_write()
+            expected = pa.Table.from_pydict({
+                'id': [1] * num_row,
+                'batch_id': [11] * num_row,
+                'metadata': [f'Large blob batch {11}'] * num_row,
+                'large_blob': [i.to_bytes(2, byteorder='little') + 
large_blob_data for i in range(num_row)]
+            }, schema=pa_schema)
+            writer.write_arrow(expected)
 
-        # Commit all data at once
-        commit_messages = writer.prepare_commit()
-        commit = write_builder.new_commit()
-        commit.commit(commit_messages)
-        writer.close()
+            # Commit all data at once
+            commit_messages = writer.prepare_commit()
+            commit = write_builder.new_commit()
+            commit.commit(commit_messages)
+            writer.close()
 
         # Read data back
         read_builder = table.new_read_builder()
@@ -2364,7 +2368,7 @@ class DataBlobWriterTest(unittest.TestCase):
         self.assertEqual(6666, result.num_rows)
         self.assertEqual(4, result.num_columns)
 
-        self.assertEqual(expected.slice(13334, 6666), result)
+        self.assertEqual(expected, result)
         splits = read_builder.new_scan().plan().splits()
         expected = table_read.to_arrow(splits)
         splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits()
@@ -2490,13 +2494,14 @@ class DataBlobWriterTest(unittest.TestCase):
 
         # Read data back using table API
         read_builder = table.new_read_builder()
-        table_scan = read_builder.new_scan().with_shard(1, 2)
+        table_scan = read_builder.new_scan().with_shard(0, 2)
         table_read = read_builder.new_read()
         splits = table_scan.plan().splits()
         result = table_read.to_arrow(splits)
 
         # Verify the data was read back correctly
-        self.assertEqual(result.num_rows, 2, "Should have 2 rows")
+        # Just one file, so split 0 occupied the whole records
+        self.assertEqual(result.num_rows, 5, "Should have 2 rows")
         self.assertEqual(result.num_columns, 3, "Should have 3 columns")
 
     def test_blob_write_read_large_data_volume_rolling_with_shard(self):
diff --git a/paimon-python/pypaimon/tests/shard_table_updator_test.py 
b/paimon-python/pypaimon/tests/shard_table_updator_test.py
new file mode 100644
index 0000000000..641f545b47
--- /dev/null
+++ b/paimon-python/pypaimon/tests/shard_table_updator_test.py
@@ -0,0 +1,320 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import tempfile
+import unittest
+
+import pyarrow as pa
+
+from pypaimon import CatalogFactory, Schema
+
+
+class ShardTableUpdatorTest(unittest.TestCase):
+    """Tests for ShardTableUpdator partial column updates in data-evolution 
mode."""
+
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog = CatalogFactory.create({
+            'warehouse': cls.warehouse
+        })
+        cls.catalog.create_database('default', False)
+        cls.table_count = 0
+
+    def _create_unique_table_name(self, prefix='test'):
+        ShardTableUpdatorTest.table_count += 1
+        return f'default.{prefix}_{ShardTableUpdatorTest.table_count}'
+
+    def test_compute_column_d_equals_c_plus_b_minus_a(self):
+        """
+        Test: Create a table with columns a, b, c, d.
+        Write initial data for a, b, c.
+        Use ShardTableUpdator to compute d = c + b - a and fill in the d 
column.
+        """
+        # Step 1: Create table with a, b, c, d columns (all int32)
+        table_schema = pa.schema([
+            ('a', pa.int32()),
+            ('b', pa.int32()),
+            ('c', pa.int32()),
+            ('d', pa.int32()),
+        ])
+        schema = Schema.from_pyarrow_schema(
+            table_schema,
+            options={'row-tracking.enabled': 'true', 'data-evolution.enabled': 
'true'}
+        )
+        name = self._create_unique_table_name()
+        self.catalog.create_table(name, schema, False)
+        table = self.catalog.get_table(name)
+
+        # Step 2: Write initial data for a, b, c columns only
+        write_builder = table.new_batch_write_builder()
+        table_write = write_builder.new_write().with_write_type(['a', 'b', 
'c'])
+        table_commit = write_builder.new_commit()
+
+        init_data = pa.Table.from_pydict({
+            'a': [1, 2, 3, 4, 5],
+            'b': [10, 20, 30, 40, 50],
+            'c': [100, 200, 300, 400, 500],
+        }, schema=pa.schema([('a', pa.int32()), ('b', pa.int32()), ('c', 
pa.int32())]))
+
+        table_write.write_arrow(init_data)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        # Step 3: Use ShardTableUpdator to compute d = c + b - a
+        table_update = write_builder.new_update()
+        table_update.with_read_projection(['a', 'b', 'c'])
+        table_update.with_update_type(['d'])
+        
+        shard_updator = table_update.new_shard_updator(0, 1)
+
+        # Read data using arrow_reader
+        reader = shard_updator.arrow_reader()
+
+        for batch in iter(reader.read_next_batch, None):
+            # Compute d = c + b - a
+            a_values = batch.column('a').to_pylist()
+            b_values = batch.column('b').to_pylist()
+            c_values = batch.column('c').to_pylist()
+            
+            d_values = [c + b - a for a, b, c in zip(a_values, b_values, 
c_values)]
+            
+            # Create batch with d column
+            new_batch = pa.RecordBatch.from_pydict({
+                'd': d_values,
+            }, schema=pa.schema([('d', pa.int32())]))
+            
+            # Write d column
+            shard_updator.update_by_arrow_batch(new_batch)
+
+        # Prepare and commit
+        commit_messages = shard_updator.prepare_commit()
+        commit = write_builder.new_commit()
+        commit.commit(commit_messages)
+        commit.close()
+
+        # Step 4: Verify the result
+        read_builder = table.new_read_builder()
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        actual = table_read.to_arrow(table_scan.plan().splits())
+
+        # Expected values:
+        # Row 0: d = 100 + 10 - 1 = 109
+        # Row 1: d = 200 + 20 - 2 = 218
+        # Row 2: d = 300 + 30 - 3 = 327
+        # Row 3: d = 400 + 40 - 4 = 436
+        # Row 4: d = 500 + 50 - 5 = 545
+        expected = pa.Table.from_pydict({
+            'a': [1, 2, 3, 4, 5],
+            'b': [10, 20, 30, 40, 50],
+            'c': [100, 200, 300, 400, 500],
+            'd': [109, 218, 327, 436, 545],
+        }, schema=table_schema)
+
+        print("\n=== Actual Data ===")
+        print(actual.to_pandas())
+        print("\n=== Expected Data ===")
+        print(expected.to_pandas())
+
+        self.assertEqual(actual, expected)
+        print("\n✅ Test passed! Column d = c + b - a computed correctly!")
+
+    def test_compute_column_d_equals_c_plus_b_minus_a2(self):
+        """
+        Test: Create a table with columns a, b, c, d.
+        Write initial data for a, b, c.
+        Use ShardTableUpdator to compute d = c + b - a and fill in the d 
column.
+        """
+        # Step 1: Create table with a, b, c, d columns (all int32)
+        table_schema = pa.schema([
+            ('a', pa.int32()),
+            ('b', pa.int32()),
+            ('c', pa.int32()),
+            ('d', pa.int32()),
+        ])
+        schema = Schema.from_pyarrow_schema(
+            table_schema,
+            options={'row-tracking.enabled': 'true', 'data-evolution.enabled': 
'true'}
+        )
+        name = self._create_unique_table_name()
+        self.catalog.create_table(name, schema, False)
+        table = self.catalog.get_table(name)
+
+        # Step 2: Write initial data for a, b, c columns only
+        for i in range(1000):
+            write_builder = table.new_batch_write_builder()
+            table_write = write_builder.new_write().with_write_type(['a', 'b', 
'c'])
+            table_commit = write_builder.new_commit()
+
+            init_data = pa.Table.from_pydict({
+                'a': [1, 2, 3, 4, 5],
+                'b': [10, 20, 30, 40, 50],
+                'c': [100, 200, 300, 400, 500],
+            }, schema=pa.schema([('a', pa.int32()), ('b', pa.int32()), ('c', 
pa.int32())]))
+
+            table_write.write_arrow(init_data)
+            table_commit.commit(table_write.prepare_commit())
+            table_write.close()
+            table_commit.close()
+
+        # Step 3: Use ShardTableUpdator to compute d = c + b - a
+        table_update = write_builder.new_update()
+        table_update.with_read_projection(['a', 'b', 'c'])
+        table_update.with_update_type(['d'])
+
+        for i in range(10):
+            d_all_values = []
+            shard_updator = table_update.new_shard_updator(i, 10)
+
+            # Read data using arrow_reader
+            reader = shard_updator.arrow_reader()
+
+            for batch in iter(reader.read_next_batch, None):
+                # Compute d = c + b - a
+                a_values = batch.column('a').to_pylist()
+                b_values = batch.column('b').to_pylist()
+                c_values = batch.column('c').to_pylist()
+
+                d_values = [c + b - a for a, b, c in zip(a_values, b_values, 
c_values)]
+                d_all_values.extend(d_values)
+
+            # Concatenate all computed values and update once for this shard
+            new_batch = pa.RecordBatch.from_pydict(
+                {'d': d_all_values},
+                schema=pa.schema([('d', pa.int32())]),
+            )
+            shard_updator.update_by_arrow_batch(new_batch)
+
+            # Prepare and commit
+            commit_messages = shard_updator.prepare_commit()
+            commit = write_builder.new_commit()
+            commit.commit(commit_messages)
+            commit.close()
+
+        # Step 4: Verify the result
+        read_builder = table.new_read_builder()
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        actual = table_read.to_arrow(table_scan.plan().splits())
+
+        # Expected values:
+        # Row 0: d = 100 + 10 - 1 = 109
+        # Row 1: d = 200 + 20 - 2 = 218
+        # Row 2: d = 300 + 30 - 3 = 327
+        # Row 3: d = 400 + 40 - 4 = 436
+        # Row 4: d = 500 + 50 - 5 = 545
+        expected = pa.Table.from_pydict({
+            'a': [1, 2, 3, 4, 5] * 1000,
+            'b': [10, 20, 30, 40, 50] * 1000,
+            'c': [100, 200, 300, 400, 500] * 1000,
+            'd': [109, 218, 327, 436, 545] * 1000,
+        }, schema=table_schema)
+
+        print("\n=== Actual Data ===")
+        print(actual.to_pandas())
+        print("\n=== Expected Data ===")
+        print(expected.to_pandas())
+
+        self.assertEqual(actual, expected)
+        print("\n✅ Test passed! Column d = c + b - a computed correctly!")
+
+    def test_compute_column_with_existing_column(self):
+        table_schema = pa.schema([
+            ('a', pa.int32()),
+            ('b', pa.int32()),
+            ('c', pa.int32()),
+        ])
+        schema = Schema.from_pyarrow_schema(
+            table_schema,
+            options={'row-tracking.enabled': 'true', 'data-evolution.enabled': 
'true'}
+        )
+        name = self._create_unique_table_name()
+        self.catalog.create_table(name, schema, False)
+        table = self.catalog.get_table(name)
+
+        # Step 2: Write initial data for a, b, c columns only
+        for i in range(1000):
+            write_builder = table.new_batch_write_builder()
+            table_write = write_builder.new_write().with_write_type(['a', 'b', 
'c'])
+            table_commit = write_builder.new_commit()
+
+            init_data = pa.Table.from_pydict({
+                'a': [1, 2, 3, 4, 5],
+                'b': [10, 20, 30, 40, 50],
+                'c': [100, 200, 300, 400, 500],
+            }, schema=pa.schema([('a', pa.int32()), ('b', pa.int32()), ('c', 
pa.int32())]))
+
+            table_write.write_arrow(init_data)
+            table_commit.commit(table_write.prepare_commit())
+            table_write.close()
+            table_commit.close()
+
+        # Step 3: Use ShardTableUpdator to compute d = c + b - a
+        table_update = write_builder.new_update()
+        table_update.with_read_projection(['a', 'b'])
+        table_update.with_update_type(['c'])
+
+        for i in range(10):
+            shard_updator = table_update.new_shard_updator(i, 10)
+
+            # Read data using arrow_reader
+            reader = shard_updator.arrow_reader()
+
+            for batch in iter(reader.read_next_batch, None):
+                a_values = batch.column('a').to_pylist()
+                b_values = batch.column('b').to_pylist()
+
+                c_values = [b - a for a, b in zip(a_values, b_values)]
+
+                new_batch = pa.RecordBatch.from_pydict({
+                    'c': c_values,
+                }, schema=pa.schema([('c', pa.int32())]))
+
+                shard_updator.update_by_arrow_batch(new_batch)
+
+            # Prepare and commit
+            commit_messages = shard_updator.prepare_commit()
+            commit = write_builder.new_commit()
+            commit.commit(commit_messages)
+            commit.close()
+
+        # Step 4: Verify the result
+        read_builder = table.new_read_builder()
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        actual = table_read.to_arrow(table_scan.plan().splits())
+
+        expected = pa.Table.from_pydict({
+            'a': [1, 2, 3, 4, 5] * 1000,
+            'b': [10, 20, 30, 40, 50] * 1000,
+            'c': [9, 18, 27, 36, 45] * 1000,
+        }, schema=table_schema)
+
+        print("\n=== Actual Data ===")
+        print(actual.to_pandas())
+        print("\n=== Expected Data ===")
+        print(expected.to_pandas())
+
+        self.assertEqual(actual, expected)
+        print("\n✅ Test passed! Column d = c + b - a computed correctly!")
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/paimon-python/pypaimon/write/table_update.py 
b/paimon-python/pypaimon/write/table_update.py
index baaf3f18e3..8e3f91bde4 100644
--- a/paimon-python/pypaimon/write/table_update.py
+++ b/paimon-python/pypaimon/write/table_update.py
@@ -15,12 +15,20 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
################################################################################
-from typing import List
+from collections import defaultdict
+from typing import List, Optional, Tuple
 
+import pyarrow
 import pyarrow as pa
 
+from pypaimon.common.memory_size import MemorySize
+from pypaimon.globalindex import Range
+from pypaimon.manifest.schema.data_file_meta import DataFileMeta
+from pypaimon.read.split import DataSplit
 from pypaimon.write.commit_message import CommitMessage
 from pypaimon.write.table_update_by_row_id import TableUpdateByRowId
+from pypaimon.write.writer.data_writer import DataWriter
+from pypaimon.write.writer.append_only_data_writer import AppendOnlyDataWriter
 
 
 class TableUpdate:
@@ -30,6 +38,7 @@ class TableUpdate:
         self.table: FileStoreTable = table
         self.commit_user = commit_user
         self.update_cols = None
+        self.projection = None
 
     def with_update_type(self, update_cols: List[str]):
         for col in update_cols:
@@ -40,7 +49,151 @@ class TableUpdate:
         self.update_cols = update_cols
         return self
 
+    def with_read_projection(self, projection: List[str]):
+        self.projection = projection
+
+    def new_shard_updator(self, shard_num: int, total_shard_count: int):
+        """Create a shard updater for scan+rewrite style updates.
+
+        Args:
+            shard_num: Index of this shard/subtask.
+            total_shard_count: Total number of shards/subtasks.
+        """
+        return ShardTableUpdator(
+            self.table,
+            self.projection,
+            self.update_cols,
+            self.commit_user,
+            shard_num,
+            total_shard_count,
+        )
+
     def update_by_arrow_with_row_id(self, table: pa.Table) -> 
List[CommitMessage]:
         update_by_row_id = TableUpdateByRowId(self.table, self.commit_user)
         update_by_row_id.update_columns(table, self.update_cols)
         return update_by_row_id.commit_messages
+
+
+class ShardTableUpdator:
+
+    def __init__(
+            self,
+            table,
+            projection: Optional[List[str]],
+            write_cols: List[str],
+            commit_user,
+            shard_num: int,
+            total_shard_count: int,
+    ):
+        from pypaimon.table.file_store_table import FileStoreTable
+        self.table: FileStoreTable = table
+        self.projection = projection
+        self.write_cols = write_cols
+        self.commit_user = commit_user
+        self.total_shard_count = total_shard_count
+        self.shard_num = shard_num
+
+        self.write_pos = 0
+        self.writer: Optional[SingleWriter] = None
+        self.dict = defaultdict(list)
+
+        scanner = 
self.table.new_read_builder().new_scan().with_shard(shard_num, 
total_shard_count)
+        self.splits = scanner.plan().splits()
+
+        self.row_ranges: List[(Tuple, Range)] = []
+        for split in self.splits:
+            if not isinstance(split, DataSplit):
+                raise ValueError(f"Split {split} is not DataSplit.")
+            files = split.files
+            ranges = self.compute_from_files(files)
+            for row_range in ranges:
+                self.row_ranges.append((tuple(split.partition.values), 
row_range))
+
+    @staticmethod
+    def compute_from_files(files: List[DataFileMeta]) -> List[Range]:
+        ranges = []
+        for file in files:
+            ranges.append(Range(file.first_row_id, file.first_row_id + 
file.row_count - 1))
+
+        return Range.sort_and_merge_overlap(ranges, True, False)
+
+    def arrow_reader(self) -> pyarrow.ipc.RecordBatchReader:
+        read_builder = self.table.new_read_builder()
+        read_builder.with_projection(self.projection)
+        return read_builder.new_read().to_arrow_batch_reader(self.splits)
+
+    def prepare_commit(self) -> List[CommitMessage]:
+        commit_messages = []
+        for (partition, files) in self.dict.items():
+            commit_messages.append(CommitMessage(partition, 0, files))
+        return commit_messages
+
+    def update_by_arrow_batch(self, data: pa.RecordBatch):
+        self._init_writer()
+
+        capacity = self.writer.capacity()
+        if capacity <= 0:
+            raise RuntimeError("Writer has no remaining capacity.")
+
+        # Split the batch across writers.
+        first, rest = (data, None) if capacity >= data.num_rows else 
(data.slice(0, capacity), data.slice(capacity))
+
+        self.writer.write(first)
+        if self.writer.capacity() == 0:
+            self.dict[self.writer.partition()].append(self.writer.end())
+            self.writer = None
+
+        if rest is not None:
+            if self.writer is not None:
+                raise RuntimeError("Should not get here, rest and current 
writer exist in the same time.")
+            self.update_by_arrow_batch(rest)
+
+    def _init_writer(self):
+        if self.writer is None:
+            if self.write_pos >= len(self.row_ranges):
+                raise RuntimeError(
+                    "No more row ranges to write. "
+                    "Ensure you write exactly the same number of rows as read 
from this shard."
+                )
+            item = self.row_ranges[self.write_pos]
+            self.write_pos += 1
+            partition = item[0]
+            row_range = item[1]
+            writer = AppendOnlyDataWriter(self.table, partition, 0, 0, 
self.table.options, self.write_cols)
+            writer.target_file_size = 
MemorySize.of_mebi_bytes(999999999).get_bytes()
+            self.writer = SingleWriter(writer, partition, row_range.from_, 
row_range.to - row_range.from_ + 1)
+
+
+class SingleWriter:
+
+    def __init__(self, writer: DataWriter, partition, first_row_id: int, 
row_count: int):
+        self.writer: DataWriter = writer
+        self._partition = partition
+        self.first_row_id = first_row_id
+        self.row_count = row_count
+        self.written_records_count = 0
+
+    def capacity(self) -> int:
+        return self.row_count - self.written_records_count
+
+    def write(self, data: pa.RecordBatch):
+        if data.num_rows > self.capacity():
+            raise Exception("Data num size exceeds capacity.")
+        self.written_records_count += data.num_rows
+        self.writer.write(data)
+        return
+
+    def partition(self) -> Tuple:
+        return self._partition
+
+    def end(self) -> DataFileMeta:
+        if self.capacity() != 0:
+            raise Exception("There still capacity left in the writer.")
+        files = self.writer.prepare_commit()
+        if len(files) != 1:
+            raise Exception("Should have one file.")
+        file = files[0]
+        if file.row_count != self.row_count:
+            raise Exception("File row count mismatch.")
+        file = file.assign_first_row_id(self.first_row_id)
+        return file


Reply via email to