This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 6fe570b16d [python] introduce update_columns in python api. (#6861)
6fe570b16d is described below
commit 6fe570b16d603b74629989e37a6cff2cee084f81
Author: zhoulii <[email protected]>
AuthorDate: Tue Dec 23 20:23:23 2025 +0800
[python] introduce update_columns in python api. (#6861)
---
docs/content/program-api/python-api.md | 152 +++++++++
.../pypaimon/tests/partial_columns_write_test.py | 376 +++++++++++++++++++++
.../pypaimon/write/partial_column_write.py | 214 ++++++++++++
paimon-python/pypaimon/write/table_write.py | 31 +-
paimon-python/pypaimon/write/write_builder.py | 13 +-
5 files changed, 780 insertions(+), 6 deletions(-)
diff --git a/docs/content/program-api/python-api.md
b/docs/content/program-api/python-api.md
index c2c6a2fd30..12c15c8f6d 100644
--- a/docs/content/program-api/python-api.md
+++ b/docs/content/program-api/python-api.md
@@ -213,6 +213,158 @@ write_builder =
table.new_batch_write_builder().overwrite()
write_builder = table.new_batch_write_builder().overwrite({'dt': '2024-01-01'})
```
+### Write partial columns
+
+when enable data-evolution, you can write partial columns to table:
+
+```python
+simple_pa_schema = pa.schema([
+ ('f0', pa.int8()),
+ ('f1', pa.int16()),
+])
+schema = Schema.from_pyarrow_schema(simple_pa_schema,
+ options={'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true'})
+catalog.create_table('default.test_row_tracking', schema, False)
+table = catalog.get_table('default.test_row_tracking')
+
+# write all columns
+write_builder = table.new_batch_write_builder()
+table_write = write_builder.new_write()
+table_commit = write_builder.new_commit()
+expect_data = pa.Table.from_pydict({
+ 'f0': [-1, 2],
+ 'f1': [-1001, 1002]
+}, schema=simple_pa_schema)
+table_write.write_arrow(expect_data)
+table_commit.commit(table_write.prepare_commit())
+table_write.close()
+table_commit.close()
+
+# write partial columns
+table_write = write_builder.new_write().with_write_type(['f0'])
+table_commit = write_builder.new_commit()
+data2 = pa.Table.from_pydict({
+ 'f0': [3, 4],
+}, schema=pa.schema([
+ ('f0', pa.int8()),
+]))
+table_write.write_arrow(data2)
+cmts = table_write.prepare_commit()
+
+# assign first row id
+cmts[0].new_files[0].first_row_id = 0
+table_commit.commit(cmts)
+table_write.close()
+table_commit.close()
+```
+
+Paimon data-evolution table use `first_row_id` to split files, when write
partial columns,
+you should split data into multiple parts by rows, and assign `first_row_id`
for each file before commit
+, or it may cause some fatal error during table reads.
+
+For example, in the following code, `write-1` will generate a file with
`first_row_id=0` which contains 2 rows,
+and `write-2` will generate a file with `first_row_id=2` which also contains 2
rows. Then, if we update column `f0` and
+do not split data into multiple parts by rows, the generated file would have
`first_row_id=0` and contains 4 rows, when reading
+this table, it will cause a fatal error.
+
+```python
+table = catalog.get_table('default.test_row_tracking')
+
+# write-1
+write_builder = table.new_batch_write_builder()
+table_write = write_builder.new_write()
+table_commit = write_builder.new_commit()
+expect_data = pa.Table.from_pydict({
+ 'f0': [-1, 2],
+ 'f1': [-1001, 1002]
+}, schema=simple_pa_schema)
+table_write.write_arrow(expect_data)
+table_commit.commit(table_write.prepare_commit())
+table_write.close()
+table_commit.close()
+
+# write-2
+table_write = write_builder.new_write()
+table_commit = write_builder.new_commit()
+expect_data = pa.Table.from_pydict({
+ 'f0': [3, 4],
+ 'f1': [1003, 1004]
+}, schema=simple_pa_schema)
+table_write.write_arrow(expect_data)
+table_commit.commit(table_write.prepare_commit())
+table_write.close()
+table_commit.close()
+
+# write partial columns
+table_write = write_builder.new_write().with_write_type(['f0'])
+table_commit = write_builder.new_commit()
+data2 = pa.Table.from_pydict({
+ 'f0': [5, 6, 7, 8],
+}, schema=pa.schema([
+ ('f0', pa.int8()),
+]))
+table_write.write_arrow(data2)
+cmts = table_write.prepare_commit()
+cmts[0].new_files[0].first_row_id = 0
+table_commit.commit(cmts)
+table_write.close()
+table_commit.close()
+
+read_builder = table.new_read_builder()
+table_scan = read_builder.new_scan()
+table_read = read_builder.new_read()
+
+# a fatal error will be thrown
+actual_data = table_read.to_arrow(table_scan.plan().splits())
+```
+
+### Update columns
+
+Handle file `first_row_id` manually is inconvenient and error-prone. If you
don't want to do this, you can enable `update_columns_by_row_id`
+when create `WriteBuilder` and set write type for `TableWrite`, then you can
write partial columns without handling file `first_row_id`.
+The input data should include the `_ROW_ID` column, writing operation will
automatically sort and match each `_ROW_ID` to
+its corresponding `first_row_id`, then groups rows with the same
`first_row_id` and writes them to a separate file.
+
+```python
+table = catalog.get_table('default.test_row_tracking')
+
+# write-1
+# same as above
+
+# write-2
+# same as above
+
+# update partial columns
+write_builder = table.new_batch_write_builder().update_columns_by_row_id()
+table_write = write_builder.new_write().with_write_type(['f0'])
+table_commit = write_builder.new_commit()
+data2 = pa.Table.from_pydict({
+ '_ROW_ID': [0, 1, 2, 3],
+ 'f0': [5, 6, 7, 8],
+}, schema=pa.schema([
+ ('_ROW_ID', pa.int64()),
+ ('f0', pa.int8()),
+]))
+table_write.write_arrow(data2)
+cmts = table_write.prepare_commit()
+table_commit.commit(cmts)
+table_write.close()
+table_commit.close()
+
+read_builder = table.new_read_builder()
+table_scan = read_builder.new_scan()
+table_read = read_builder.new_read()
+actual_data = table_read.to_arrow(table_scan.plan().splits())
+expect_data = pa.Table.from_pydict({
+ 'f0': [5, 6, 7, 8],
+ 'f1': [-1001, 1002, 1003, 1004]
+}, schema=pa.schema([
+ ('f0', pa.int8()),
+ ('f1', pa.int16()),
+]))
+self.assertEqual(actual_data, expect_data)
+```
+
## Batch Read
### Predicate pushdown
diff --git a/paimon-python/pypaimon/tests/partial_columns_write_test.py
b/paimon-python/pypaimon/tests/partial_columns_write_test.py
new file mode 100644
index 0000000000..6f3b0b75f6
--- /dev/null
+++ b/paimon-python/pypaimon/tests/partial_columns_write_test.py
@@ -0,0 +1,376 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import shutil
+import tempfile
+import unittest
+
+import pyarrow as pa
+
+from pypaimon import CatalogFactory, Schema
+
+
+class PartialColumnsWriteTest(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog = CatalogFactory.create({
+ 'warehouse': cls.warehouse
+ })
+ cls.catalog.create_database('default', True)
+
+ # Define table schema for testing
+ cls.pa_schema = pa.schema([
+ ('id', pa.int32()),
+ ('name', pa.string()),
+ ('age', pa.int32()),
+ ('city', pa.string()),
+ ])
+
+ # Define options for data evolution
+ cls.table_options = {
+ 'row-tracking.enabled': 'true',
+ 'data-evolution.enabled': 'true'
+ }
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ def _create_table(self):
+ """Helper method to create a table with initial data."""
+ # Generate unique table name for each test
+ import uuid
+ table_name = f'test_data_evolution_{uuid.uuid4().hex[:8]}'
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
options=self.table_options)
+ self.catalog.create_table(f'default.{table_name}', schema, False)
+ table = self.catalog.get_table(f'default.{table_name}')
+
+ # Write batch-1
+ write_builder = table.new_batch_write_builder()
+
+ initial_data = pa.Table.from_pydict({
+ 'id': [1, 2],
+ 'name': ['Alice', 'Bob'],
+ 'age': [25, 30],
+ 'city': ['NYC', 'LA']
+ }, schema=self.pa_schema)
+
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ table_write.write_arrow(initial_data)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # Write batch-2
+ following_data = pa.Table.from_pydict({
+ 'id': [3, 4, 5],
+ 'name': ['Charlie', 'David', 'Eve'],
+ 'age': [35, 40, 45],
+ 'city': ['Chicago', 'Houston', 'Phoenix']
+ }, schema=self.pa_schema)
+
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ table_write.write_arrow(following_data)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ return table
+
+ def test_update_existing_column(self):
+ """Test updating an existing column using data evolution."""
+ # Create table with initial data
+ table = self._create_table()
+
+ # Create data evolution writer using BatchTableWrite
+ write_builder = table.new_batch_write_builder()
+ batch_write = write_builder.new_write()
+
+ # Prepare update data (sorted by row_id)
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [1, 0, 2, 3, 4],
+ 'age': [31, 26, 36, 39, 42]
+ })
+
+ # Update the age column
+ write_builder =
table.new_batch_write_builder().update_columns_by_row_id()
+ batch_write = write_builder.new_write().with_write_type(['age'])
+ batch_write.write_arrow(update_data)
+ commit_messages = batch_write.prepare_commit()
+
+ # Commit the changes
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+ batch_write.close()
+
+ # Verify the updated data
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ # Check that ages were updated for rows 0-2
+ ages = result['age'].to_pylist()
+ expected_ages = [26, 31, 36, 39, 42]
+ self.assertEqual(ages, expected_ages)
+
+ def test_update_multiple_columns(self):
+ """Test updating multiple columns at once."""
+ # Create table with initial data
+ table = self._create_table()
+
+ # Create data evolution writer using BatchTableWrite
+ write_builder = table.new_batch_write_builder()
+ batch_write = write_builder.new_write()
+
+ # Prepare update data (sorted by row_id)
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [1, 0, 2, 3, 4],
+ 'age': [31, 26, 36, 39, 42],
+ 'city': ['Los Angeles', 'New York', 'Chicago', 'Phoenix',
'Houston']
+ })
+
+ # Update multiple columns
+ write_builder =
table.new_batch_write_builder().update_columns_by_row_id()
+ batch_write = write_builder.new_write().with_write_type(['age',
'city'])
+ batch_write.write_arrow(update_data)
+ commit_messages = batch_write.prepare_commit()
+
+ # Commit the changes
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+ batch_write.close()
+
+ # Verify the updated data
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ # Check that both age and city were updated for rows 0-2
+ ages = result['age'].to_pylist()
+ cities = result['city'].to_pylist()
+
+ expected_ages = [26, 31, 36, 39, 42]
+ expected_cities = ['New York', 'Los Angeles', 'Chicago', 'Phoenix',
'Houston']
+
+ self.assertEqual(ages, expected_ages)
+ self.assertEqual(cities, expected_cities)
+
+ def test_nonexistent_column(self):
+ """Test that updating a non-existent column raises an error."""
+ table = self._create_table()
+
+ # Create data evolution writer using BatchTableWrite
+ write_builder = table.new_batch_write_builder()
+ batch_write = write_builder.new_write()
+
+ # Try to update a non-existent column
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [0, 1, 2, 3, 4],
+ 'nonexistent_column': [100, 200, 300, 400, 500]
+ })
+
+ # Should raise ValueError
+ with self.assertRaises(ValueError) as context:
+ write_builder =
table.new_batch_write_builder().update_columns_by_row_id()
+ batch_write =
write_builder.new_write().with_write_type(['nonexistent_column'])
+ batch_write.write_arrow(update_data)
+
+ self.assertIn('not in table schema', str(context.exception))
+ batch_write.close()
+
+ def test_missing_row_id_column(self):
+ """Test that missing row_id column raises an error."""
+ table = self._create_table()
+
+ # Create data evolution writer using BatchTableWrite
+ write_builder = table.new_batch_write_builder()
+ batch_write = write_builder.new_write()
+
+ # Prepare update data without row_id column
+ update_data = pa.Table.from_pydict({
+ 'age': [26, 27, 28, 29, 30]
+ })
+
+ # Should raise ValueError
+ with self.assertRaises(ValueError) as context:
+ write_builder =
table.new_batch_write_builder().update_columns_by_row_id()
+ batch_write = write_builder.new_write().with_write_type(['age'])
+ batch_write.write_arrow(update_data)
+
+ self.assertIn("Input data must contain _ROW_ID column",
str(context.exception))
+ batch_write.close()
+
+ def test_partitioned_table_update(self):
+ """Test updating columns in a partitioned table."""
+ # Create partitioned table
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['city'], options=self.table_options)
+ self.catalog.create_table('default.test_partitioned_evolution',
schema, False)
+ table = self.catalog.get_table('default.test_partitioned_evolution')
+
+ # Write initial data
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+
+ initial_data = pa.Table.from_pydict({
+ 'id': [1, 2, 3, 4, 5],
+ 'name': ['Alice', 'Bob', 'Charlie', 'David', 'Eve'],
+ 'age': [25, 30, 35, 40, 45],
+ 'city': ['NYC', 'NYC', 'LA', 'LA', 'Chicago']
+ }, schema=self.pa_schema)
+
+ table_write.write_arrow(initial_data)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # Create data evolution writer using BatchTableWrite
+ write_builder =
table.new_batch_write_builder().update_columns_by_row_id()
+ batch_write = write_builder.new_write().with_write_type(['age'])
+
+ # Update ages
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [1, 0, 2, 3, 4],
+ 'age': [31, 26, 36, 41, 46]
+ })
+
+ batch_write.write_arrow(update_data)
+ commit_messages = batch_write.prepare_commit()
+
+ # Commit the changes
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+ batch_write.close()
+
+ # Verify the updated data
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ # Check ages were updated
+ ages = result['age'].to_pylist()
+ expected_ages = [26, 31, 36, 41, 46]
+ self.assertEqual(ages, expected_ages)
+
+ def test_multiple_calls(self):
+ """Test multiple calls to update_columns, each updating a single
column."""
+ # Create table with initial data
+ table = self._create_table()
+
+ # First update: Update age column
+ write_builder =
table.new_batch_write_builder().update_columns_by_row_id()
+ batch_write = write_builder.new_write().with_write_type(['age'])
+
+ update_age_data = pa.Table.from_pydict({
+ '_ROW_ID': [1, 0, 2, 3, 4],
+ 'age': [31, 26, 36, 41, 46]
+ })
+
+ batch_write.write_arrow(update_age_data)
+ commit_messages = batch_write.prepare_commit()
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Second update: Update city column
+ update_city_data = pa.Table.from_pydict({
+ '_ROW_ID': [1, 0, 2, 3, 4],
+ 'city': ['Los Angeles', 'New York', 'Chicago', 'Phoenix',
'Houston']
+ })
+ write_builder =
table.new_batch_write_builder().update_columns_by_row_id()
+ batch_write = write_builder.new_write().with_write_type(['city'])
+ batch_write.write_arrow(update_city_data)
+ commit_messages = batch_write.prepare_commit()
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Close the batch write
+ batch_write.close()
+
+ # Verify both columns were updated correctly
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ ages = result['age'].to_pylist()
+ cities = result['city'].to_pylist()
+
+ expected_ages = [26, 31, 36, 41, 46]
+ expected_cities = ['New York', 'Los Angeles', 'Chicago', 'Phoenix',
'Houston']
+
+ self.assertEqual(ages, expected_ages, "Age column was not updated
correctly")
+ self.assertEqual(cities, expected_cities, "City column was not updated
correctly")
+
+ def test_wrong_total_row_count(self):
+ """Test that wrong total row count raises an error."""
+ # Create table with initial data
+ table = self._create_table()
+
+ # Create data evolution writer using BatchTableWrite
+ write_builder =
table.new_batch_write_builder().update_columns_by_row_id()
+ batch_write = write_builder.new_write().with_write_type(['age'])
+
+ # Prepare update data with wrong row count (only 3 rows instead of 5)
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [0, 1, 2],
+ 'age': [26, 31, 36]
+ })
+
+ # Should raise ValueError for total row count mismatch
+ with self.assertRaises(ValueError) as context:
+ batch_write.write_arrow(update_data)
+
+ self.assertIn("does not match table total row count",
str(context.exception))
+ batch_write.close()
+
+ def test_wrong_first_row_id_row_count(self):
+ """Test that wrong row count for a first_row_id raises an error."""
+ # Create table with initial data
+ table = self._create_table()
+
+ # Create data evolution writer using BatchTableWrite
+ write_builder =
table.new_batch_write_builder().update_columns_by_row_id()
+ batch_write = write_builder.new_write().with_write_type(['age'])
+
+ # Prepare update data with duplicate row_id (violates monotonically
increasing)
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [0, 1, 1, 4, 5],
+ 'age': [26, 31, 36, 37, 38]
+ })
+
+ # Should raise ValueError for row ID validation
+ with self.assertRaises(ValueError) as context:
+ batch_write.write_arrow(update_data)
+
+ self.assertIn("Row IDs are not monotonically increasing",
str(context.exception))
+ batch_write.close()
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paimon-python/pypaimon/write/partial_column_write.py
b/paimon-python/pypaimon/write/partial_column_write.py
new file mode 100644
index 0000000000..c60464f306
--- /dev/null
+++ b/paimon-python/pypaimon/write/partial_column_write.py
@@ -0,0 +1,214 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import bisect
+from typing import Dict, List, Optional
+
+import pyarrow as pa
+import pyarrow.compute as pc
+
+from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER
+from pypaimon.table.row.generic_row import GenericRow
+from pypaimon.table.special_fields import SpecialFields
+from pypaimon.write.file_store_write import FileStoreWrite
+
+
+class PartialColumnWrite:
+ """
+ Table write for partial column updates (data evolution).
+
+ This writer is designed for adding/updating specific columns in existing
tables.
+ Input data should contain _ROW_ID column.
+ """
+
+ FIRST_ROW_ID_COLUMN = '_FIRST_ROW_ID'
+
+ def __init__(self, table, commit_user: str):
+ from pypaimon.table.file_store_table import FileStoreTable
+
+ self.table: FileStoreTable = table
+ self.commit_user = commit_user
+
+ # Load existing first_row_ids and build partition map
+ (self.first_row_ids,
+ self.first_row_id_to_partition_map,
+ self.first_row_id_to_row_count_map,
+ self.total_row_count) = self._load_existing_files_info()
+
+ # Collect commit messages
+ self.commit_messages = []
+
+ def _load_existing_files_info(self):
+ """Load existing first_row_ids and build partition map for efficient
lookup."""
+ first_row_ids = []
+ first_row_id_to_partition_map: Dict[int, GenericRow] = {}
+ first_row_id_to_row_count_map: Dict[int, int] = {}
+
+ read_builder = self.table.new_read_builder()
+ scan = read_builder.new_scan()
+ splits = scan.plan().splits()
+
+ for split in splits:
+ for file in split.files:
+ if file.first_row_id is not None and not
file.file_name.endswith('.blob'):
+ first_row_id = file.first_row_id
+ first_row_ids.append(first_row_id)
+ first_row_id_to_partition_map[first_row_id] =
split.partition
+ first_row_id_to_row_count_map[first_row_id] =
file.row_count
+
+ total_row_count = sum(first_row_id_to_row_count_map.values())
+
+ return sorted(list(set(first_row_ids))
+ ), first_row_id_to_partition_map,
first_row_id_to_row_count_map, total_row_count
+
+ def update_columns(self, data: pa.Table, column_names: List[str]) -> List:
+ """
+ Add or update columns in the table.
+
+ Args:
+ data: Input data containing row_id and columns to update
+ column_names: Names of columns to update (excluding row_id)
+
+ Returns:
+ List of commit messages
+ """
+
+ # Validate column_names is not empty
+ if not column_names:
+ raise ValueError("column_names cannot be empty")
+
+ # Validate input data has row_id column
+ if SpecialFields.ROW_ID.name not in data.column_names:
+ raise ValueError(f"Input data must contain
{SpecialFields.ROW_ID.name} column")
+
+ # Validate all update columns exist in the schema
+ for col_name in column_names:
+ if col_name not in self.table.field_names:
+ raise ValueError(f"Column {col_name} not found in table
schema")
+
+ # Validate data row count matches total row count
+ if data.num_rows != self.total_row_count:
+ raise ValueError(
+ f"Input data row count ({data.num_rows}) does not match table
total row count ({self.total_row_count})")
+
+ # Sort data by _ROW_ID column
+ sorted_data = data.sort_by([(SpecialFields.ROW_ID.name, "ascending")])
+
+ # Calculate first_row_id for each row
+ data_with_first_row_id = self._calculate_first_row_id(sorted_data)
+
+ # Group by first_row_id and write each group
+ self._write_by_first_row_id(data_with_first_row_id, column_names)
+
+ return self.commit_messages
+
+ def _calculate_first_row_id(self, data: pa.Table) -> pa.Table:
+ """Calculate _first_row_id for each row based on _ROW_ID."""
+ row_ids = data[SpecialFields.ROW_ID.name].to_pylist()
+
+ # Validate that row_ids are monotonically increasing starting from 0
+ expected_row_ids = list(range(len(row_ids)))
+ if row_ids != expected_row_ids:
+ raise ValueError(f"Row IDs are not monotonically increasing
starting from 0. "
+ f"Expected: {expected_row_ids}")
+
+ # Calculate first_row_id for each row_id
+ first_row_id_values = []
+ for row_id in row_ids:
+ first_row_id = self._floor_binary_search(self.first_row_ids,
row_id)
+ first_row_id_values.append(first_row_id)
+
+ # Add first_row_id column to the table
+ first_row_id_array = pa.array(first_row_id_values, type=pa.int64())
+ return data.append_column(self.FIRST_ROW_ID_COLUMN, first_row_id_array)
+
+ def _floor_binary_search(self, sorted_seq: List[int], value: int) -> int:
+ """Binary search to find the floor value in sorted sequence."""
+ if not sorted_seq:
+ raise ValueError("The input sorted sequence is empty.")
+
+ idx = bisect.bisect_right(sorted_seq, value) - 1
+ if idx < 0:
+ raise ValueError(f"Value {value} is less than the first element in
the sorted sequence.")
+
+ return sorted_seq[idx]
+
+ def _write_by_first_row_id(self, data: pa.Table, column_names: List[str]):
+ """Write data grouped by first_row_id."""
+ # Extract unique first_row_id values
+ first_row_id_array = data[self.FIRST_ROW_ID_COLUMN]
+ unique_first_row_ids = pc.unique(first_row_id_array).to_pylist()
+
+ for first_row_id in unique_first_row_ids:
+ # Filter rows for this first_row_id
+ mask = pc.equal(first_row_id_array, first_row_id)
+ group_data = data.filter(mask)
+
+ # Get partition for this first_row_id
+ partition = self._find_partition_by_first_row_id(first_row_id)
+
+ if partition is None:
+ raise ValueError(f"No existing file found for first_row_id
{first_row_id}")
+
+ # Write this group
+ self._write_group(partition, first_row_id, group_data,
column_names)
+
+ def _find_partition_by_first_row_id(self, first_row_id: int) ->
Optional[GenericRow]:
+ """Find the partition for a given first_row_id using pre-built
partition map."""
+ return self.first_row_id_to_partition_map.get(first_row_id)
+
+ def _write_group(self, partition: GenericRow, first_row_id: int,
+ data: pa.Table, column_names: List[str]):
+ """Write a group of data with the same first_row_id."""
+
+ # Validate data row count matches the first_row_id's row count
+ expected_row_count =
self.first_row_id_to_row_count_map.get(first_row_id, 0)
+ if data.num_rows != expected_row_count:
+ raise ValueError(
+ f"Data row count ({data.num_rows}) does not match expected row
count ({expected_row_count}) "
+ f"for first_row_id {first_row_id}")
+
+ # Create a file store write for this partition
+ file_store_write = FileStoreWrite(self.table, self.commit_user)
+
+ # Set write columns to only update specific columns
+ # Note: _ROW_ID is metadata column, not part of schema
+ write_cols = column_names
+ file_store_write.write_cols = write_cols
+
+ # Convert partition to tuple for hashing
+ partition_tuple = tuple(partition.values)
+
+ # Write data - convert Table to RecordBatch
+ data_to_write = data.select(write_cols)
+ for batch in data_to_write.to_batches():
+ file_store_write.write(partition_tuple, 0, batch)
+
+ # Prepare commit and assign first_row_id
+ commit_messages =
file_store_write.prepare_commit(BATCH_COMMIT_IDENTIFIER)
+
+ # Assign first_row_id to the new files
+ for msg in commit_messages:
+ for file in msg.new_files:
+ # Assign the same first_row_id as the original file
+ file.first_row_id = first_row_id
+ file.write_cols = write_cols
+
+ self.commit_messages.extend(commit_messages)
+
+ # Close the writer
+ file_store_write.close()
diff --git a/paimon-python/pypaimon/write/table_write.py
b/paimon-python/pypaimon/write/table_write.py
index 0ac73356a3..8bc2c023ee 100644
--- a/paimon-python/pypaimon/write/table_write.py
+++ b/paimon-python/pypaimon/write/table_write.py
@@ -16,7 +16,7 @@
# limitations under the License.
################################################################################
from collections import defaultdict
-from typing import List
+from typing import List, Optional
import pyarrow as pa
@@ -24,6 +24,7 @@ from pypaimon.schema.data_types import PyarrowFieldParser
from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER
from pypaimon.write.commit_message import CommitMessage
from pypaimon.write.file_store_write import FileStoreWrite
+from pypaimon.write.partial_column_write import PartialColumnWrite
class TableWrite:
@@ -80,15 +81,39 @@ class TableWrite:
class BatchTableWrite(TableWrite):
- def __init__(self, table, commit_user):
+ def __init__(self, table, commit_user, update_columns_by_row_id=False):
super().__init__(table, commit_user)
self.batch_committed = False
+ self._partial_column_write: Optional[PartialColumnWrite] = None
+ if update_columns_by_row_id:
+ self._partial_column_write = PartialColumnWrite(self.table,
self.commit_user)
+
+ def write_arrow(self, table: pa.Table):
+ if self._partial_column_write is not None:
+ return self._partial_column_write.update_columns(table,
self.file_store_write.write_cols)
+ super().write_arrow(table)
+
+ def write_arrow_batch(self, data: pa.RecordBatch):
+ if self._partial_column_write is not None:
+ table = pa.Table.from_batches([data])
+ return self._partial_column_write.update_columns(table,
self.file_store_write.write_cols)
+ super().write_arrow_batch(data)
+
+ def write_pandas(self, dataframe):
+ if self._partial_column_write is not None:
+ table = pa.Table.from_pandas(dataframe)
+ return self._partial_column_write.update_columns(table,
self.file_store_write.write_cols)
+ super().write_pandas(dataframe)
def prepare_commit(self) -> List[CommitMessage]:
if self.batch_committed:
raise RuntimeError("BatchTableWrite only supports one-time
committing.")
self.batch_committed = True
- return self.file_store_write.prepare_commit(BATCH_COMMIT_IDENTIFIER)
+
+ if self._partial_column_write is not None:
+ return self._partial_column_write.commit_messages
+ else:
+ return
self.file_store_write.prepare_commit(BATCH_COMMIT_IDENTIFIER)
class StreamTableWrite(TableWrite):
diff --git a/paimon-python/pypaimon/write/write_builder.py
b/paimon-python/pypaimon/write/write_builder.py
index 7b96a8c2a2..c83c11e746 100644
--- a/paimon-python/pypaimon/write/write_builder.py
+++ b/paimon-python/pypaimon/write/write_builder.py
@@ -20,8 +20,10 @@ import uuid
from abc import ABC
from typing import Optional
-from pypaimon.write.table_commit import BatchTableCommit, StreamTableCommit,
TableCommit
-from pypaimon.write.table_write import BatchTableWrite, StreamTableWrite,
TableWrite
+from pypaimon.write.table_commit import (BatchTableCommit, StreamTableCommit,
+ TableCommit)
+from pypaimon.write.table_write import (BatchTableWrite, StreamTableWrite,
+ TableWrite)
class WriteBuilder(ABC):
@@ -31,6 +33,7 @@ class WriteBuilder(ABC):
self.table: FileStoreTable = table
self.commit_user = self._create_commit_user()
self.static_partition = None
+ self._update_columns_by_row_id = False
def overwrite(self, static_partition: Optional[dict] = None):
self.static_partition = static_partition if static_partition is not
None else {}
@@ -49,11 +52,15 @@ class WriteBuilder(ABC):
else:
return str(uuid.uuid4())
+ def update_columns_by_row_id(self):
+ self._update_columns_by_row_id = True
+ return self
+
class BatchWriteBuilder(WriteBuilder):
def new_write(self) -> BatchTableWrite:
- return BatchTableWrite(self.table, self.commit_user)
+ return BatchTableWrite(self.table, self.commit_user,
self._update_columns_by_row_id)
def new_commit(self) -> BatchTableCommit:
commit = BatchTableCommit(self.table, self.commit_user,
self.static_partition)