This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new c3487e1331 [python] Introduce TableScan.with_shard (#6068)
c3487e1331 is described below
commit c3487e1331540f1a3ba0923e8cca99687e48c6af
Author: umi <[email protected]>
AuthorDate: Wed Aug 13 16:13:47 2025 +0800
[python] Introduce TableScan.with_shard (#6068)
---
.../pypaimon/manifest/manifest_file_manager.py | 5 +-
paimon-python/pypaimon/read/table_scan.py | 22 +-
paimon-python/pypaimon/schema/table_schema.py | 17 ++
paimon-python/pypaimon/table/file_store_table.py | 13 +-
.../pypaimon/tests/rest_catalog_base_test.py | 229 +++++++++++++++++++++
...talog_test.py => rest_table_read_write_test.py} | 208 +------------------
paimon-python/pypaimon/tests/rest_table_test.py | 147 +++++++++++++
paimon-python/pypaimon/write/row_key_extractor.py | 29 ++-
8 files changed, 453 insertions(+), 217 deletions(-)
diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py
b/paimon-python/pypaimon/manifest/manifest_file_manager.py
index 686d365910..8f7c8e325c 100644
--- a/paimon-python/pypaimon/manifest/manifest_file_manager.py
+++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-
import uuid
from io import BytesIO
from typing import List
@@ -42,7 +41,7 @@ class ManifestFileManager:
self.primary_key_fields =
self.table.table_schema.get_primary_key_fields()
self.trimmed_primary_key_fields =
self.table.table_schema.get_trimmed_primary_key_fields()
- def read(self, manifest_file_name: str) -> List[ManifestEntry]:
+ def read(self, manifest_file_name: str, shard_filter=None) ->
List[ManifestEntry]:
manifest_file_path = self.manifest_path / manifest_file_name
entries = []
@@ -74,6 +73,8 @@ class ManifestFileManager:
total_buckets=record['_TOTAL_BUCKETS'],
file=file_meta
)
+ if shard_filter and not shard_filter(entry):
+ continue
entries.append(entry)
return entries
diff --git a/paimon-python/pypaimon/read/table_scan.py
b/paimon-python/pypaimon/read/table_scan.py
index 83e1b9cf6c..fa4ade3a5f 100644
--- a/paimon-python/pypaimon/read/table_scan.py
+++ b/paimon-python/pypaimon/read/table_scan.py
@@ -29,6 +29,7 @@ from pypaimon.read.plan import Plan
from pypaimon.read.split import Split
from pypaimon.schema.data_types import DataField
from pypaimon.snapshot.snapshot_manager import SnapshotManager
+from pypaimon.write.row_key_extractor import FixedBucketRowKeyExtractor
class TableScan:
@@ -40,7 +41,6 @@ class TableScan:
self.table: FileStoreTable = table
self.predicate = predicate
- self.predicate = predicate
self.limit = limit
self.read_type = read_type
@@ -52,6 +52,9 @@ class TableScan:
self.target_split_size = 128 * 1024 * 1024
self.open_file_cost = 4 * 1024 * 1024
+ self.idx_of_this_subtask = None
+ self.number_of_para_subtasks = None
+
def plan(self) -> Plan:
latest_snapshot = self.snapshot_manager.get_latest_snapshot()
if not latest_snapshot:
@@ -60,7 +63,9 @@ class TableScan:
file_entries = []
for manifest_file_path in manifest_files:
- manifest_entries =
self.manifest_file_manager.read(manifest_file_path)
+ manifest_entries =
self.manifest_file_manager.read(manifest_file_path,
+ (lambda row:
self._shard_filter(row))
+ if
self.idx_of_this_subtask is not None else None)
for entry in manifest_entries:
if entry.kind == 0:
file_entries.append(entry)
@@ -83,6 +88,19 @@ class TableScan:
return Plan(splits)
+ def with_shard(self, idx_of_this_subtask, number_of_para_subtasks) ->
'TableScan':
+ self.idx_of_this_subtask = idx_of_this_subtask
+ self.number_of_para_subtasks = number_of_para_subtasks
+ return self
+
+ def _shard_filter(self, entry: Optional[ManifestEntry]) -> bool:
+ if self.table.is_primary_key_table:
+ bucket = entry.bucket
+ return bucket % self.number_of_para_subtasks ==
self.idx_of_this_subtask
+ else:
+ file = entry.file.file_name
+ return FixedBucketRowKeyExtractor.hash(file) %
self.number_of_para_subtasks == self.idx_of_this_subtask
+
def _apply_push_down_limit(self, splits: List[Split]) -> List[Split]:
if self.limit is None:
return splits
diff --git a/paimon-python/pypaimon/schema/table_schema.py
b/paimon-python/pypaimon/schema/table_schema.py
index da51f168d2..80b787eb6e 100644
--- a/paimon-python/pypaimon/schema/table_schema.py
+++ b/paimon-python/pypaimon/schema/table_schema.py
@@ -69,6 +69,17 @@ class TableSchema:
self.options = options or {}
self.comment = comment
self.time_millis = time_millis if time_millis is not None else
int(time.time() * 1000)
+ self.get_trimmed_primary_key_fields()
+
+ from typing import List
+
+ def cross_partition_update(self) -> bool:
+ if not self.primary_keys or not self.partition_keys:
+ return False
+
+ # Check if primary keys contain all partition keys
+ # Return True if they don't contain all (cross-partition update)
+ return not all(pk in self.primary_keys for pk in self.partition_keys)
def to_schema(self) -> Schema:
return Schema(
@@ -184,5 +195,11 @@ class TableSchema:
if not self.primary_keys or not self.partition_keys:
return self.get_primary_key_fields()
adjusted = [pk for pk in self.primary_keys if pk not in
self.partition_keys]
+ # Validate that filtered list is not empty
+ if not adjusted:
+ raise ValueError(
+ f"Primary key constraint {self.primary_keys} "
+ f"should not be same with partition fields
{self.partition_keys}, "
+ "this will result in only one record in a partition")
field_map = {field.name: field for field in self.fields}
return [field_map[name] for name in adjusted if name in field_map]
diff --git a/paimon-python/pypaimon/table/file_store_table.py
b/paimon-python/pypaimon/table/file_store_table.py
index 0d2e542789..1ceb007fae 100644
--- a/paimon-python/pypaimon/table/file_store_table.py
+++ b/paimon-python/pypaimon/table/file_store_table.py
@@ -27,7 +27,8 @@ from pypaimon.schema.table_schema import TableSchema
from pypaimon.table.bucket_mode import BucketMode
from pypaimon.table.table import Table
from pypaimon.write.batch_write_builder import BatchWriteBuilder
-from pypaimon.write.row_key_extractor import (FixedBucketRowKeyExtractor,
+from pypaimon.write.row_key_extractor import (DynamicBucketRowKeyExtractor,
+ FixedBucketRowKeyExtractor,
RowKeyExtractor,
UnawareBucketRowKeyExtractor)
@@ -47,13 +48,15 @@ class FileStoreTable(Table):
self.schema_manager = SchemaManager(file_io, table_path)
self.is_primary_key_table = bool(self.primary_keys)
+ self.cross_partition_update =
self.table_schema.cross_partition_update()
def bucket_mode(self) -> BucketMode:
if self.is_primary_key_table:
- if self.primary_keys == self.partition_keys:
- return BucketMode.CROSS_PARTITION
if self.options.get(CoreOptions.BUCKET, -1) == -1:
- return BucketMode.HASH_DYNAMIC
+ if self.cross_partition_update:
+ return BucketMode.CROSS_PARTITION
+ else:
+ return BucketMode.HASH_DYNAMIC
else:
return BucketMode.HASH_FIXED
else:
@@ -75,6 +78,6 @@ class FileStoreTable(Table):
elif bucket_mode == BucketMode.BUCKET_UNAWARE:
return UnawareBucketRowKeyExtractor(self.table_schema)
elif bucket_mode == BucketMode.HASH_DYNAMIC or bucket_mode ==
BucketMode.CROSS_PARTITION:
- raise ValueError(f"Unsupported bucket mode {bucket_mode} yet")
+ return DynamicBucketRowKeyExtractor(self.table_schema)
else:
raise ValueError(f"Unsupported bucket mode: {bucket_mode}")
diff --git a/paimon-python/pypaimon/tests/rest_catalog_base_test.py
b/paimon-python/pypaimon/tests/rest_catalog_base_test.py
new file mode 100644
index 0000000000..773d94abed
--- /dev/null
+++ b/paimon-python/pypaimon/tests/rest_catalog_base_test.py
@@ -0,0 +1,229 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import glob
+import logging
+import os
+import shutil
+import tempfile
+import unittest
+import uuid
+
+import pyarrow as pa
+
+from pypaimon.api import ConfigResponse, Identifier
+from pypaimon.api.auth import BearTokenAuthProvider
+from pypaimon.api.options import Options
+from pypaimon.catalog.catalog_context import CatalogContext
+from pypaimon.catalog.catalog_factory import CatalogFactory
+from pypaimon.catalog.rest.rest_catalog import RESTCatalog
+from pypaimon.catalog.table_metadata import TableMetadata
+from pypaimon.schema.data_types import (ArrayType, AtomicType, DataField,
+ MapType)
+from pypaimon.schema.schema import Schema
+from pypaimon.schema.table_schema import TableSchema
+from pypaimon.tests.rest_server import RESTCatalogServer
+
+
+class RESTCatalogBaseTest(unittest.TestCase):
+ def setUp(self):
+ self.temp_dir = tempfile.mkdtemp(prefix="unittest_")
+ self.warehouse = os.path.join(self.temp_dir, 'warehouse')
+
+ self.config = ConfigResponse(defaults={"prefix": "mock-test"})
+ self.token = str(uuid.uuid4())
+ self.server = RESTCatalogServer(
+ data_path=self.temp_dir,
+ auth_provider=BearTokenAuthProvider(self.token),
+ config=self.config,
+ warehouse="warehouse"
+ )
+ self.server.start()
+ print(f"\nServer started at: {self.server.get_url()}")
+
+ self.options = {
+ 'metastore': 'rest',
+ 'uri': f"http://localhost:{self.server.port}",
+ 'warehouse': "warehouse",
+ 'dlf.region': 'cn-hangzhou',
+ "token.provider": "bear",
+ 'token': self.token,
+ }
+ self.rest_catalog = CatalogFactory.create(self.options)
+ self.rest_catalog.create_database("default", False)
+
+ self.pa_schema = pa.schema([
+ ('user_id', pa.int64()),
+ ('item_id', pa.int64()),
+ ('behavior', pa.string()),
+ ('dt', pa.string()),
+ ('long-dt', pa.string())
+ ])
+ self.raw_data = {
+ 'user_id': [1, 2, 3, 4, 5, 6, 7, 8],
+ 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008],
+ 'behavior': ['a', 'b', 'c', None, 'e', 'f', 'g', 'h'],
+ 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p2'],
+ 'long-dt': ['2024-10-10', '2024-10-10', '2024-10-10',
'2024-01-01', '2024-10-10', '2025-01-23',
+ 'abcdefghijklmnopk', '2025-08-08']
+ }
+ self.expected = pa.Table.from_pydict(self.raw_data,
schema=self.pa_schema)
+
+ schema = Schema.from_pyarrow_schema(self.pa_schema)
+ self.rest_catalog.create_table('default.test_reader_iterator', schema,
False)
+ self.table =
self.rest_catalog.get_table('default.test_reader_iterator')
+ write_builder = self.table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ table_write.write_arrow(self.expected)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ def tearDown(self):
+ # Shutdown server
+ self.server.shutdown()
+ print("Server stopped")
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
+
+ def test_rest_catalog(self):
+ """Example usage of RESTCatalogServer"""
+ # Setup logging
+ logging.basicConfig(level=logging.INFO)
+
+ # Create config
+ config = ConfigResponse(defaults={"prefix": "mock-test"})
+ token = str(uuid.uuid4())
+ # Create server
+ server = RESTCatalogServer(
+ data_path="/tmp/test_warehouse",
+ auth_provider=BearTokenAuthProvider(token),
+ config=config,
+ warehouse="test_warehouse"
+ )
+ try:
+ # Start server
+ server.start()
+ print(f"Server started at: {server.get_url()}")
+ test_databases = {
+ "default": server.mock_database("default", {"env": "test"}),
+ "test_db1": server.mock_database("test_db1", {"env": "test"}),
+ "test_db2": server.mock_database("test_db2", {"env": "test"}),
+ "prod_db": server.mock_database("prod_db", {"env": "prod"})
+ }
+ data_fields = [
+ DataField(0, "name", AtomicType('INT'), 'desc name'),
+ DataField(1, "arr11", ArrayType(True, AtomicType('INT')),
'desc arr11'),
+ DataField(2, "map11", MapType(False, AtomicType('INT'),
+ MapType(False,
AtomicType('INT'), AtomicType('INT'))),
+ 'desc arr11'),
+ ]
+ schema = TableSchema(TableSchema.CURRENT_VERSION,
len(data_fields), data_fields, len(data_fields),
+ [], [], {}, "")
+ test_tables = {
+ "default.user": TableMetadata(uuid=str(uuid.uuid4()),
is_external=True, schema=schema),
+ }
+ server.table_metadata_store.update(test_tables)
+ server.database_store.update(test_databases)
+ options = {
+ 'uri': f"http://localhost:{server.port}",
+ 'warehouse': 'test_warehouse',
+ 'dlf.region': 'cn-hangzhou',
+ "token.provider": "bear",
+ 'token': token
+ }
+ rest_catalog =
RESTCatalog(CatalogContext.create_from_options(Options(options)))
+ self.assertSetEqual(set(rest_catalog.list_databases()),
{*test_databases})
+ self.assertEqual(rest_catalog.get_database('default').name,
test_databases.get('default').name)
+ table =
rest_catalog.get_table(Identifier.from_string('default.user'))
+ self.assertEqual(table.identifier.get_full_name(), 'default.user')
+ finally:
+ # Shutdown server
+ server.shutdown()
+ print("Server stopped")
+
+ def test_write(self):
+ # Setup logging
+ logging.basicConfig(level=logging.INFO)
+
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
+ self.rest_catalog.create_table("default.test_table", schema, False)
+ table = self.rest_catalog.get_table("default.test_table")
+
+ data = {
+ 'user_id': [1, 2, 3, 4],
+ 'item_id': [1001, 1002, 1003, 1004],
+ 'behavior': ['a', 'b', 'c', None],
+ 'dt': ['p1', 'p1', 'p2', 'p1'],
+ 'long-dt': ['2024-10-10', '2024-10-10', '2024-10-10', '2024-01-01']
+ }
+ expect = pa.Table.from_pydict(data, schema=self.pa_schema)
+
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ table_write.write_arrow(expect)
+ commit_messages = table_write.prepare_commit()
+ table_commit.commit(commit_messages)
+ table_write.close()
+ table_commit.close()
+
+ self.assertTrue(os.path.exists(self.warehouse +
"/default/test_table/snapshot/LATEST"))
+ self.assertTrue(os.path.exists(self.warehouse +
"/default/test_table/snapshot/snapshot-1"))
+ self.assertTrue(os.path.exists(self.warehouse +
"/default/test_table/manifest"))
+ self.assertTrue(os.path.exists(self.warehouse +
"/default/test_table/dt=p1"))
+ self.assertEqual(len(glob.glob(self.warehouse +
"/default/test_table/manifest/*.avro")), 2)
+
+ def _write_test_table(self, table):
+ write_builder = table.new_batch_write_builder()
+
+ # first write
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data1 = {
+ 'user_id': [1, 2, 3, 4],
+ 'item_id': [1001, 1002, 1003, 1004],
+ 'behavior': ['a', 'b', 'c', None],
+ 'dt': ['p1', 'p1', 'p2', 'p1'],
+ 'long-dt': ['2024-10-10', '2024-10-10', '2024-10-10',
'2024-01-01'],
+ }
+ pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # second write
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data2 = {
+ 'user_id': [5, 6, 7, 8],
+ 'item_id': [1005, 1006, 1007, 1008],
+ 'behavior': ['e', 'f', 'g', 'h'],
+ 'dt': ['p2', 'p1', 'p2', 'p2'],
+ 'long-dt': ['2024-10-10', '2025-01-23', 'abcdefghijklmnopk',
'2025-08-08'],
+ }
+ pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ def _read_test_table(self, read_builder):
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ return table_read.to_arrow(splits)
diff --git a/paimon-python/pypaimon/tests/rest_catalog_test.py
b/paimon-python/pypaimon/tests/rest_table_read_write_test.py
similarity index 59%
rename from paimon-python/pypaimon/tests/rest_catalog_test.py
rename to paimon-python/pypaimon/tests/rest_table_read_write_test.py
index 9d3429a225..fd1cd65b07 100644
--- a/paimon-python/pypaimon/tests/rest_catalog_test.py
+++ b/paimon-python/pypaimon/tests/rest_table_read_write_test.py
@@ -15,179 +15,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
-import glob
-import logging
-import os
-import shutil
-import tempfile
-import unittest
-import uuid
import pandas as pd
import pyarrow as pa
-from pypaimon.api import ConfigResponse, Identifier
-from pypaimon.api.auth import BearTokenAuthProvider
-from pypaimon.api.options import Options
-from pypaimon.catalog.catalog_context import CatalogContext
-from pypaimon.catalog.catalog_factory import CatalogFactory
-from pypaimon.catalog.rest.rest_catalog import RESTCatalog
-from pypaimon.catalog.table_metadata import TableMetadata
-from pypaimon.schema.data_types import (ArrayType, AtomicType, DataField,
- MapType)
from pypaimon.schema.schema import Schema
-from pypaimon.schema.table_schema import TableSchema
-from pypaimon.tests.rest_server import RESTCatalogServer
-
-
-class RESTCatalogTest(unittest.TestCase):
- def setUp(self):
- self.temp_dir = tempfile.mkdtemp(prefix="unittest_")
- self.warehouse = os.path.join(self.temp_dir, 'warehouse')
-
- self.config = ConfigResponse(defaults={"prefix": "mock-test"})
- self.token = str(uuid.uuid4())
- self.server = RESTCatalogServer(
- data_path=self.temp_dir,
- auth_provider=BearTokenAuthProvider(self.token),
- config=self.config,
- warehouse="warehouse"
- )
- self.server.start()
- print(f"\nServer started at: {self.server.get_url()}")
-
- self.options = {
- 'metastore': 'rest',
- 'uri': f"http://localhost:{self.server.port}",
- 'warehouse': "warehouse",
- 'dlf.region': 'cn-hangzhou',
- "token.provider": "bear",
- 'token': self.token,
- }
- self.rest_catalog = CatalogFactory.create(self.options)
- self.rest_catalog.create_database("default", False)
-
- self.pa_schema = pa.schema([
- ('user_id', pa.int64()),
- ('item_id', pa.int64()),
- ('behavior', pa.string()),
- ('dt', pa.string()),
- ('long-dt', pa.string())
- ])
- self.raw_data = {
- 'user_id': [1, 2, 3, 4, 5, 6, 7, 8],
- 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008],
- 'behavior': ['a', 'b', 'c', None, 'e', 'f', 'g', 'h'],
- 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p2'],
- 'long-dt': ['2024-10-10', '2024-10-10', '2024-10-10',
'2024-01-01', '2024-10-10', '2025-01-23',
- 'abcdefghijklmnopk', '2025-08-08']
- }
- self.expected = pa.Table.from_pydict(self.raw_data,
schema=self.pa_schema)
+from pypaimon.tests.rest_catalog_base_test import RESTCatalogBaseTest
- schema = Schema.from_pyarrow_schema(self.pa_schema)
- self.rest_catalog.create_table('default.test_reader_iterator', schema,
False)
- self.table =
self.rest_catalog.get_table('default.test_reader_iterator')
- write_builder = self.table.new_batch_write_builder()
- table_write = write_builder.new_write()
- table_commit = write_builder.new_commit()
- table_write.write_arrow(self.expected)
- table_commit.commit(table_write.prepare_commit())
- table_write.close()
- table_commit.close()
-
- def tearDown(self):
- # Shutdown server
- self.server.shutdown()
- print("Server stopped")
- shutil.rmtree(self.temp_dir, ignore_errors=True)
-
- def test_rest_catalog(self):
- """Example usage of RESTCatalogServer"""
- # Setup logging
- logging.basicConfig(level=logging.INFO)
-
- # Create config
- config = ConfigResponse(defaults={"prefix": "mock-test"})
- token = str(uuid.uuid4())
- # Create server
- server = RESTCatalogServer(
- data_path="/tmp/test_warehouse",
- auth_provider=BearTokenAuthProvider(token),
- config=config,
- warehouse="test_warehouse"
- )
- try:
- # Start server
- server.start()
- print(f"Server started at: {server.get_url()}")
- test_databases = {
- "default": server.mock_database("default", {"env": "test"}),
- "test_db1": server.mock_database("test_db1", {"env": "test"}),
- "test_db2": server.mock_database("test_db2", {"env": "test"}),
- "prod_db": server.mock_database("prod_db", {"env": "prod"})
- }
- data_fields = [
- DataField(0, "name", AtomicType('INT'), 'desc name'),
- DataField(1, "arr11", ArrayType(True, AtomicType('INT')),
'desc arr11'),
- DataField(2, "map11", MapType(False, AtomicType('INT'),
- MapType(False,
AtomicType('INT'), AtomicType('INT'))),
- 'desc arr11'),
- ]
- schema = TableSchema(TableSchema.CURRENT_VERSION,
len(data_fields), data_fields, len(data_fields),
- [], [], {}, "")
- test_tables = {
- "default.user": TableMetadata(uuid=str(uuid.uuid4()),
is_external=True, schema=schema),
- }
- server.table_metadata_store.update(test_tables)
- server.database_store.update(test_databases)
- options = {
- 'uri': f"http://localhost:{server.port}",
- 'warehouse': 'test_warehouse',
- 'dlf.region': 'cn-hangzhou',
- "token.provider": "bear",
- 'token': token
- }
- rest_catalog =
RESTCatalog(CatalogContext.create_from_options(Options(options)))
- self.assertSetEqual(set(rest_catalog.list_databases()),
{*test_databases})
- self.assertEqual(rest_catalog.get_database('default').name,
test_databases.get('default').name)
- table =
rest_catalog.get_table(Identifier.from_string('default.user'))
- self.assertEqual(table.identifier.get_full_name(), 'default.user')
- finally:
- # Shutdown server
- server.shutdown()
- print("Server stopped")
-
- def test_write(self):
- # Setup logging
- logging.basicConfig(level=logging.INFO)
-
- schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
- self.rest_catalog.create_table("default.test_table", schema, False)
- table = self.rest_catalog.get_table("default.test_table")
- data = {
- 'user_id': [1, 2, 3, 4],
- 'item_id': [1001, 1002, 1003, 1004],
- 'behavior': ['a', 'b', 'c', None],
- 'dt': ['p1', 'p1', 'p2', 'p1'],
- 'long-dt': ['2024-10-10', '2024-10-10', '2024-10-10', '2024-01-01']
- }
- expect = pa.Table.from_pydict(data, schema=self.pa_schema)
-
- write_builder = table.new_batch_write_builder()
- table_write = write_builder.new_write()
- table_commit = write_builder.new_commit()
- table_write.write_arrow(expect)
- commit_messages = table_write.prepare_commit()
- table_commit.commit(commit_messages)
- table_write.close()
- table_commit.close()
-
- self.assertTrue(os.path.exists(self.warehouse +
"/default/test_table/snapshot/LATEST"))
- self.assertTrue(os.path.exists(self.warehouse +
"/default/test_table/snapshot/snapshot-1"))
- self.assertTrue(os.path.exists(self.warehouse +
"/default/test_table/manifest"))
- self.assertTrue(os.path.exists(self.warehouse +
"/default/test_table/dt=p1"))
- self.assertEqual(len(glob.glob(self.warehouse +
"/default/test_table/manifest/*.avro")), 2)
+class RESTTableReadWriteTest(RESTCatalogBaseTest):
def testParquetAppendOnlyReader(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
@@ -437,43 +273,3 @@ class RESTCatalogTest(unittest.TestCase):
actual = duckdb_con.query("SELECT * FROM duckdb_table").fetchdf()
expect = pd.DataFrame(self.raw_data)
pd.testing.assert_frame_equal(actual.reset_index(drop=True),
expect.reset_index(drop=True))
-
- def _write_test_table(self, table):
- write_builder = table.new_batch_write_builder()
-
- # first write
- table_write = write_builder.new_write()
- table_commit = write_builder.new_commit()
- data1 = {
- 'user_id': [1, 2, 3, 4],
- 'item_id': [1001, 1002, 1003, 1004],
- 'behavior': ['a', 'b', 'c', None],
- 'dt': ['p1', 'p1', 'p2', 'p1'],
- 'long-dt': ['2024-10-10', '2024-10-10', '2024-10-10',
'2024-01-01'],
- }
- pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
- table_write.write_arrow(pa_table)
- table_commit.commit(table_write.prepare_commit())
- table_write.close()
- table_commit.close()
-
- # second write
- table_write = write_builder.new_write()
- table_commit = write_builder.new_commit()
- data2 = {
- 'user_id': [5, 6, 7, 8],
- 'item_id': [1005, 1006, 1007, 1008],
- 'behavior': ['e', 'f', 'g', 'h'],
- 'dt': ['p2', 'p1', 'p2', 'p2'],
- 'long-dt': ['2024-10-10', '2025-01-23', 'abcdefghijklmnopk',
'2025-08-08'],
- }
- pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
- table_write.write_arrow(pa_table)
- table_commit.commit(table_write.prepare_commit())
- table_write.close()
- table_commit.close()
-
- def _read_test_table(self, read_builder):
- table_read = read_builder.new_read()
- splits = read_builder.new_scan().plan().splits()
- return table_read.to_arrow(splits)
diff --git a/paimon-python/pypaimon/tests/rest_table_test.py
b/paimon-python/pypaimon/tests/rest_table_test.py
new file mode 100644
index 0000000000..a1feed8d50
--- /dev/null
+++ b/paimon-python/pypaimon/tests/rest_table_test.py
@@ -0,0 +1,147 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import pyarrow as pa
+
+from pypaimon.schema.schema import Schema
+from pypaimon.tests.rest_catalog_base_test import RESTCatalogBaseTest
+from pypaimon.write.row_key_extractor import (DynamicBucketRowKeyExtractor,
+ FixedBucketRowKeyExtractor,
+ UnawareBucketRowKeyExtractor)
+
+
+class RESTTableTest(RESTCatalogBaseTest):
+ def setUp(self):
+ super().setUp()
+ self.pa_schema = pa.schema([
+ ('user_id', pa.int64()),
+ ('item_id', pa.int64()),
+ ('behavior', pa.string()),
+ ('dt', pa.string()),
+ ])
+ self.data = {
+ 'user_id': [2, 4, 6, 8, 10],
+ 'item_id': [1001, 1002, 1003, 1004, 1005],
+ 'behavior': ['a', 'b', 'c', 'd', 'e'],
+ 'dt': ['2000-10-10', '2025-08-10', '2025-08-11', '2025-08-12',
'2025-08-13']
+ }
+ self.expected = pa.Table.from_pydict(self.data, schema=self.pa_schema)
+
+ def test_with_shard_ao_unaware_bucket(self):
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['user_id'])
+ self.rest_catalog.create_table('default.test_with_shard', schema,
False)
+ table = self.rest_catalog.get_table('default.test_with_shard')
+
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ self.assertIsInstance(table_write.row_key_extractor,
UnawareBucketRowKeyExtractor)
+
+ pa_table = pa.Table.from_pydict(self.data, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ splits = []
+ read_builder = table.new_read_builder()
+ splits.extend(read_builder.new_scan().with_shard(0, 3).plan().splits())
+ splits.extend(read_builder.new_scan().with_shard(1, 3).plan().splits())
+ splits.extend(read_builder.new_scan().with_shard(2, 3).plan().splits())
+
+ table_read = read_builder.new_read()
+ actual = table_read.to_arrow(splits)
+
+ self.assertEqual(actual.sort_by('user_id'), self.expected)
+
+ def test_with_shard_ao_fixed_bucket(self):
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['user_id'],
+ options={'bucket': '5',
'bucket-key': 'item_id'})
+ self.rest_catalog.create_table('default.test_with_shard', schema,
False)
+ table = self.rest_catalog.get_table('default.test_with_shard')
+
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ self.assertIsInstance(table_write.row_key_extractor,
FixedBucketRowKeyExtractor)
+
+ pa_table = pa.Table.from_pydict(self.data, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ splits = []
+ read_builder = table.new_read_builder()
+ splits.extend(read_builder.new_scan().with_shard(0, 3).plan().splits())
+ splits.extend(read_builder.new_scan().with_shard(1, 3).plan().splits())
+ splits.extend(read_builder.new_scan().with_shard(2, 3).plan().splits())
+
+ table_read = read_builder.new_read()
+ actual = table_read.to_arrow(splits)
+ self.assertEqual(actual.sort_by("user_id"), self.expected)
+
+ def test_with_shard_pk_dynamic_bucket(self):
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['user_id'], primary_keys=['user_id', 'dt'])
+ self.rest_catalog.create_table('default.test_with_shard', schema,
False)
+ table = self.rest_catalog.get_table('default.test_with_shard')
+
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ self.assertIsInstance(table_write.row_key_extractor,
DynamicBucketRowKeyExtractor)
+
+ pa_table = pa.Table.from_pydict(self.data, schema=self.pa_schema)
+
+ with self.assertRaises(ValueError) as context:
+ table_write.write_arrow(pa_table)
+
+ self.assertEqual(str(context.exception), "Can't extract bucket from
row in dynamic bucket mode")
+
+ def test_with_shard_pk_fixed_bucket(self):
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['user_id'], primary_keys=['user_id', 'dt'],
+ options={'bucket': '5'})
+ self.rest_catalog.create_table('default.test_with_shard', schema,
False)
+ table = self.rest_catalog.get_table('default.test_with_shard')
+
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ self.assertIsInstance(table_write.row_key_extractor,
FixedBucketRowKeyExtractor)
+
+ pa_table = pa.Table.from_pydict(self.data, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ splits = []
+ read_builder = table.new_read_builder()
+ splits.extend(read_builder.new_scan().with_shard(0, 3).plan().splits())
+ splits.extend(read_builder.new_scan().with_shard(1, 3).plan().splits())
+ splits.extend(read_builder.new_scan().with_shard(2, 3).plan().splits())
+
+ table_read = read_builder.new_read()
+ actual = table_read.to_arrow(splits)
+ data_expected = {
+ 'user_id': [4, 6, 2, 10, 8],
+ 'item_id': [1002, 1003, 1001, 1005, 1004],
+ 'behavior': ['b', 'c', 'a', 'e', 'd'],
+ 'dt': ['2025-08-10', '2025-08-11', '2000-10-10', '2025-08-13',
'2025-08-12']
+ }
+ expected = pa.Table.from_pydict(data_expected, schema=self.pa_schema)
+ self.assertEqual(actual, expected)
diff --git a/paimon-python/pypaimon/write/row_key_extractor.py
b/paimon-python/pypaimon/write/row_key_extractor.py
index 8d847f30b5..801cc5fff4 100644
--- a/paimon-python/pypaimon/write/row_key_extractor.py
+++ b/paimon-python/pypaimon/write/row_key_extractor.py
@@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-
+import hashlib
+import json
from abc import ABC, abstractmethod
from typing import List, Tuple
@@ -84,9 +85,14 @@ class FixedBucketRowKeyExtractor(RowKeyExtractor):
hashes = []
for row_idx in range(data.num_rows):
row_values = tuple(col[row_idx].as_py() for col in columns)
- hashes.append(hash(row_values))
+ hashes.append(self.hash(row_values))
return [abs(hash_val) % self.num_buckets for hash_val in hashes]
+ @staticmethod
+ def hash(data) -> int:
+ data_json = json.dumps(data)
+ return int(hashlib.md5(data_json.encode()).hexdigest(), 16)
+
class UnawareBucketRowKeyExtractor(RowKeyExtractor):
"""Extractor for unaware bucket mode (bucket = -1, no primary keys)."""
@@ -100,3 +106,22 @@ class UnawareBucketRowKeyExtractor(RowKeyExtractor):
def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]:
return [0] * data.num_rows
+
+
+class DynamicBucketRowKeyExtractor(RowKeyExtractor):
+ """
+ Row key extractor for dynamic bucket mode
+ Ensures bucket configuration is set to -1 and prevents bucket extraction
+ """
+
+ def __init__(self, table_schema: 'TableSchema'):
+ super().__init__(table_schema)
+ num_buckets = table_schema.options.get(CoreOptions.BUCKET, -1)
+
+ if num_buckets != -1:
+ raise ValueError(
+ f"Only 'bucket' = '-1' is allowed for
'DynamicBucketRowKeyExtractor', but found: {num_buckets}"
+ )
+
+ def _extract_buckets_batch(self, data: pa.RecordBatch) -> int:
+ raise ValueError("Can't extract bucket from row in dynamic bucket
mode")