This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new aa663c1e8e [Python] Support schema evolution read for changing column
position (#6458)
aa663c1e8e is described below
commit aa663c1e8ec9bb307e78a29b3d630ceb408a52f0
Author: umi <[email protected]>
AuthorDate: Thu Oct 23 14:53:19 2025 +0800
[Python] Support schema evolution read for changing column position (#6458)
---
paimon-python/pypaimon/common/predicate.py | 21 +-
.../pypaimon/manifest/manifest_file_manager.py | 42 ++-
.../pypaimon/manifest/simple_stats_evolutions.py | 5 +-
.../pypaimon/read/scanner/full_starting_scanner.py | 9 +-
paimon-python/pypaimon/read/split_read.py | 53 +++-
paimon-python/pypaimon/read/table_read.py | 36 ++-
paimon-python/pypaimon/tests/pvfs_test.py | 3 +-
.../pypaimon/tests/reader_append_only_test.py | 2 +-
paimon-python/pypaimon/tests/rest/rest_server.py | 7 +-
.../pypaimon/tests/schema_evolution_read_test.py | 328 +++++++++++++++++++++
10 files changed, 454 insertions(+), 52 deletions(-)
diff --git a/paimon-python/pypaimon/common/predicate.py
b/paimon-python/pypaimon/common/predicate.py
index 9ae2cdfce3..89c82c9de2 100644
--- a/paimon-python/pypaimon/common/predicate.py
+++ b/paimon-python/pypaimon/common/predicate.py
@@ -94,10 +94,10 @@ class Predicate:
def to_arrow(self) -> Any:
if self.method == 'and':
- return reduce(lambda x, y: x & y,
+ return reduce(lambda x, y: (x[0] & y[0], x[1] | y[1]),
[p.to_arrow() for p in self.literals])
if self.method == 'or':
- return reduce(lambda x, y: x | y,
+ return reduce(lambda x, y: (x[0] | y[0], x[1] | y[1]),
[p.to_arrow() for p in self.literals])
if self.method == 'startsWith':
@@ -108,10 +108,11 @@ class Predicate:
# Ensure the field is cast to string type
string_field = field_ref.cast(pyarrow.string())
result = pyarrow_compute.starts_with(string_field, pattern)
- return result
+ return result, {self.field}
except Exception:
# Fallback to True
- return pyarrow_dataset.field(self.field).is_valid() |
pyarrow_dataset.field(self.field).is_null()
+ return (pyarrow_dataset.field(self.field).is_valid() |
pyarrow_dataset.field(self.field).is_null(),
+ {self.field})
if self.method == 'endsWith':
pattern = self.literals[0]
# For PyArrow compatibility
@@ -120,10 +121,11 @@ class Predicate:
# Ensure the field is cast to string type
string_field = field_ref.cast(pyarrow.string())
result = pyarrow_compute.ends_with(string_field, pattern)
- return result
+ return result, {self.field}
except Exception:
# Fallback to True
- return pyarrow_dataset.field(self.field).is_valid() |
pyarrow_dataset.field(self.field).is_null()
+ return (pyarrow_dataset.field(self.field).is_valid() |
pyarrow_dataset.field(self.field).is_null(),
+ {self.field})
if self.method == 'contains':
pattern = self.literals[0]
# For PyArrow compatibility
@@ -132,15 +134,16 @@ class Predicate:
# Ensure the field is cast to string type
string_field = field_ref.cast(pyarrow.string())
result = pyarrow_compute.match_substring(string_field, pattern)
- return result
+ return result, {self.field}
except Exception:
# Fallback to True
- return pyarrow_dataset.field(self.field).is_valid() |
pyarrow_dataset.field(self.field).is_null()
+ return (pyarrow_dataset.field(self.field).is_valid() |
pyarrow_dataset.field(self.field).is_null(),
+ {self.field})
field = pyarrow_dataset.field(self.field)
tester = Predicate.testers.get(self.method)
if tester:
- return tester.test_by_arrow(field, self.literals)
+ return tester.test_by_arrow(field, self.literals), {self.field}
raise ValueError("Unsupported predicate method:
{}".format(self.method))
diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py
b/paimon-python/pypaimon/manifest/manifest_file_manager.py
index b635f9e49c..927dad4674 100644
--- a/paimon-python/pypaimon/manifest/manifest_file_manager.py
+++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py
@@ -26,6 +26,7 @@ from pypaimon.manifest.schema.manifest_entry import
(MANIFEST_ENTRY_SCHEMA,
ManifestEntry)
from pypaimon.manifest.schema.manifest_file_meta import ManifestFileMeta
from pypaimon.manifest.schema.simple_stats import SimpleStats
+from pypaimon.schema.table_schema import TableSchema
from pypaimon.table.row.generic_row import (GenericRowDeserializer,
GenericRowSerializer)
from pypaimon.table.row.binary_row import BinaryRow
@@ -43,6 +44,7 @@ class ManifestFileManager:
self.partition_keys_fields = self.table.partition_keys_fields
self.primary_keys_fields = self.table.primary_keys_fields
self.trimmed_primary_keys_fields =
self.table.trimmed_primary_keys_fields
+ self.schema_cache = {}
def read_entries_parallel(self, manifest_files: List[ManifestFileMeta],
manifest_entry_filter=None,
drop_stats=True, max_workers=8) ->
List[ManifestEntry]:
@@ -86,17 +88,9 @@ class ManifestFileManager:
null_counts=key_dict['_NULL_COUNTS'],
)
+ schema_fields = self._get_schema(file_dict['_SCHEMA_ID']).fields
+ fields = self._get_value_stats_fields(file_dict, schema_fields)
value_dict = dict(file_dict['_VALUE_STATS'])
- if file_dict['_VALUE_STATS_COLS'] is None:
- if file_dict['_WRITE_COLS'] is None:
- fields = self.table.table_schema.fields
- else:
- read_fields = file_dict['_WRITE_COLS']
- fields = [self.table.field_dict[col] for col in
read_fields]
- elif not file_dict['_VALUE_STATS_COLS']:
- fields = []
- else:
- fields = [self.table.field_dict[col] for col in
file_dict['_VALUE_STATS_COLS']]
value_stats = SimpleStats(
min_values=BinaryRow(value_dict['_MIN_VALUES'], fields),
max_values=BinaryRow(value_dict['_MAX_VALUES'], fields),
@@ -121,8 +115,8 @@ class ManifestFileManager:
file_source=file_dict['_FILE_SOURCE'],
value_stats_cols=file_dict.get('_VALUE_STATS_COLS'),
external_path=file_dict.get('_EXTERNAL_PATH'),
- first_row_id=file_dict['_FIRST_ROW_ID'],
- write_cols=file_dict['_WRITE_COLS'],
+ first_row_id=file_dict['_FIRST_ROW_ID'] if '_FIRST_ROW_ID' in
file_dict else None,
+ write_cols=file_dict['_WRITE_COLS'] if '_WRITE_COLS' in
file_dict else None,
)
entry = ManifestEntry(
kind=record['_KIND'],
@@ -138,6 +132,30 @@ class ManifestFileManager:
entries.append(entry)
return entries
+ def _get_value_stats_fields(self, file_dict: dict, schema_fields: list) ->
List:
+ if file_dict['_VALUE_STATS_COLS'] is None:
+ if '_WRITE_COLS' in file_dict:
+ if file_dict['_WRITE_COLS'] is None:
+ fields = schema_fields
+ else:
+ read_fields = file_dict['_WRITE_COLS']
+ fields = [self.table.field_dict[col] for col in
read_fields]
+ else:
+ fields = schema_fields
+ elif not file_dict['_VALUE_STATS_COLS']:
+ fields = []
+ else:
+ fields = [self.table.field_dict[col] for col in
file_dict['_VALUE_STATS_COLS']]
+ return fields
+
+ def _get_schema(self, schema_id: int) -> TableSchema:
+ if schema_id not in self.schema_cache:
+ schema = self.table.schema_manager.read_schema(schema_id)
+ if schema is None:
+ raise ValueError(f"Schema {schema_id} not found")
+ self.schema_cache[schema_id] = schema
+ return self.schema_cache[schema_id]
+
def write(self, file_name, entries: List[ManifestEntry]):
avro_records = []
for entry in entries:
diff --git a/paimon-python/pypaimon/manifest/simple_stats_evolutions.py
b/paimon-python/pypaimon/manifest/simple_stats_evolutions.py
index 0b99acab21..df417d595b 100644
--- a/paimon-python/pypaimon/manifest/simple_stats_evolutions.py
+++ b/paimon-python/pypaimon/manifest/simple_stats_evolutions.py
@@ -28,8 +28,7 @@ class SimpleStatsEvolutions:
def __init__(self, schema_fields: Callable[[int], List[DataField]],
table_schema_id: int):
self.schema_fields = schema_fields
self.table_schema_id = table_schema_id
- self.table_data_fields = schema_fields(table_schema_id)
- self.table_fields = None
+ self.table_fields = schema_fields(table_schema_id)
self.evolutions: Dict[int, SimpleStatsEvolution] = {}
def get_or_create(self, data_schema_id: int) -> SimpleStatsEvolution:
@@ -40,8 +39,6 @@ class SimpleStatsEvolutions:
if self.table_schema_id == data_schema_id:
evolution =
SimpleStatsEvolution(self.schema_fields(data_schema_id), None, None)
else:
- if self.table_fields is None:
- self.table_fields = self.table_data_fields
data_fields = self.schema_fields(data_schema_id)
index_cast_mapping =
self._create_index_cast_mapping(self.table_fields, data_fields)
diff --git a/paimon-python/pypaimon/read/scanner/full_starting_scanner.py
b/paimon-python/pypaimon/read/scanner/full_starting_scanner.py
index cacf3ce343..44223b761a 100644
--- a/paimon-python/pypaimon/read/scanner/full_starting_scanner.py
+++ b/paimon-python/pypaimon/read/scanner/full_starting_scanner.py
@@ -64,13 +64,12 @@ class FullStartingScanner(StartingScanner):
self.table.options.get('bucket', -1)) ==
BucketMode.POSTPONE_BUCKET.value else False
self.data_evolution =
self.table.options.get(CoreOptions.DATA_EVOLUTION_ENABLED, 'false').lower() ==
'true'
- self._schema_cache = {}
-
def schema_fields_func(schema_id: int):
- if schema_id not in self._schema_cache:
+ if schema_id not in self.manifest_file_manager.schema_cache:
schema = self.table.schema_manager.read_schema(schema_id)
- self._schema_cache[schema_id] = schema
- return self._schema_cache[schema_id].fields if
self._schema_cache[schema_id] else []
+ self.manifest_file_manager.schema_cache[schema_id] = schema
+ return self.manifest_file_manager.schema_cache[schema_id].fields
if self.manifest_file_manager.schema_cache[
+ schema_id] else []
self.simple_stats_evolutions = SimpleStatsEvolutions(
schema_fields_func,
diff --git a/paimon-python/pypaimon/read/split_read.py
b/paimon-python/pypaimon/read/split_read.py
index 000e272e39..5c75cf6506 100644
--- a/paimon-python/pypaimon/read/split_read.py
+++ b/paimon-python/pypaimon/read/split_read.py
@@ -19,7 +19,7 @@
import os
from abc import ABC, abstractmethod
from functools import partial
-from typing import List, Optional, Tuple, Any
+from typing import List, Optional, Tuple, Any, Dict
from pypaimon.common.core_options import CoreOptions
from pypaimon.common.predicate import Predicate
@@ -46,6 +46,7 @@ from pypaimon.read.reader.key_value_wrap_reader import
KeyValueWrapReader
from pypaimon.read.reader.sort_merge_reader import SortMergeReaderWithMinHeap
from pypaimon.read.split import Split
from pypaimon.schema.data_types import AtomicType, DataField
+from pypaimon.schema.table_schema import TableSchema
KEY_PREFIX = "_KEY_"
KEY_FIELD_ID_START = 1000000
@@ -55,12 +56,14 @@ NULL_FIELD_INDEX = -1
class SplitRead(ABC):
"""Abstract base class for split reading operations."""
- def __init__(self, table, predicate: Optional[Predicate], read_type:
List[DataField], split: Split):
+ def __init__(self, table, predicate: Optional[Predicate], read_type:
List[DataField], split: Split,
+ schema_fields_cache: Dict):
from pypaimon.table.file_store_table import FileStoreTable
self.table: FileStoreTable = table
self.predicate = predicate
- self.push_down_predicate = self._push_down_predicate()
+ predicate_tuple = self._push_down_predicate()
+ self.push_down_predicate, self.predicate_fields = predicate_tuple if
predicate_tuple else (None, None)
self.split = split
self.value_arity = len(read_type)
@@ -68,6 +71,7 @@ class SplitRead(ABC):
self.read_fields = read_type
if isinstance(self, MergeFileSplitRead):
self.read_fields = self._create_key_value_fields(read_type)
+ self.schema_fields_cache = schema_fields_cache
def _push_down_predicate(self) -> Any:
if self.predicate is None:
@@ -84,21 +88,26 @@ class SplitRead(ABC):
def create_reader(self) -> RecordReader:
"""Create a record reader for the given split."""
- def file_reader_supplier(self, file_path: str, for_merge_read: bool,
read_fields: List[str]):
+ def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool,
read_fields: List[str]):
+ read_file_fields, file_filter = self._get_schema(file.schema_id,
read_fields)
+ if not file_filter:
+ return None
+
+ file_path = file.file_path
_, extension = os.path.splitext(file_path)
file_format = extension[1:]
format_reader: RecordBatchReader
if file_format == CoreOptions.FILE_FORMAT_AVRO:
- format_reader = FormatAvroReader(self.table.file_io, file_path,
read_fields,
+ format_reader = FormatAvroReader(self.table.file_io, file_path,
read_file_fields,
self.read_fields,
self.push_down_predicate)
elif file_format == CoreOptions.FILE_FORMAT_BLOB:
blob_as_descriptor =
CoreOptions.get_blob_as_descriptor(self.table.options)
- format_reader = FormatBlobReader(self.table.file_io, file_path,
read_fields,
+ format_reader = FormatBlobReader(self.table.file_io, file_path,
read_file_fields,
self.read_fields,
self.push_down_predicate, blob_as_descriptor)
elif file_format == CoreOptions.FILE_FORMAT_PARQUET or file_format ==
CoreOptions.FILE_FORMAT_ORC:
format_reader = FormatPyArrowReader(self.table.file_io,
file_format, file_path,
- read_fields,
self.push_down_predicate)
+ read_file_fields,
self.push_down_predicate)
else:
raise ValueError(f"Unexpected file format: {file_format}")
@@ -111,6 +120,24 @@ class SplitRead(ABC):
return DataFileBatchReader(format_reader, index_mapping,
partition_info, None,
self.table.table_schema.fields)
+ def _get_schema(self, schema_id: int, read_fields) -> TableSchema:
+ if schema_id not in self.schema_fields_cache[0]:
+ schema = self.table.schema_manager.read_schema(schema_id)
+ if schema is None:
+ raise ValueError(f"Schema {schema_id} not found")
+ self.schema_fields_cache[0][schema_id] = schema
+ schema = self.schema_fields_cache[0][schema_id]
+ fields_key = (schema_id, tuple(read_fields))
+ if fields_key not in self.schema_fields_cache[1]:
+ schema_field_names = set(field.name for field in schema.fields)
+ if self.table.is_primary_key_table:
+ schema_field_names.add('_SEQUENCE_NUMBER')
+ schema_field_names.add('_VALUE_KIND')
+ self.schema_fields_cache[1][fields_key] = (
+ [read_field for read_field in read_fields if read_field in
schema_field_names],
+ False if self.predicate_fields and self.predicate_fields -
schema_field_names else True)
+ return self.schema_fields_cache[1][fields_key]
+
@abstractmethod
def _get_all_data_fields(self):
"""Get all data fields"""
@@ -263,10 +290,10 @@ class RawFileSplitRead(SplitRead):
def create_reader(self) -> RecordReader:
data_readers = []
- for file_path in self.split.file_paths:
+ for file in self.split.files:
supplier = partial(
self.file_reader_supplier,
- file_path=file_path,
+ file=file,
for_merge_read=False,
read_fields=self._get_final_read_data_fields(),
)
@@ -289,10 +316,10 @@ class RawFileSplitRead(SplitRead):
class MergeFileSplitRead(SplitRead):
- def kv_reader_supplier(self, file_path):
+ def kv_reader_supplier(self, file):
reader_supplier = partial(
self.file_reader_supplier,
- file_path=file_path,
+ file=file,
for_merge_read=True,
read_fields=self._get_final_read_data_fields()
)
@@ -303,7 +330,7 @@ class MergeFileSplitRead(SplitRead):
for sorter_run in section:
data_readers = []
for file in sorter_run.files:
- supplier = partial(self.kv_reader_supplier, file.file_path)
+ supplier = partial(self.kv_reader_supplier, file)
data_readers.append(supplier)
readers.append(ConcatRecordReader(data_readers))
return SortMergeReaderWithMinHeap(readers, self.table.table_schema)
@@ -468,7 +495,7 @@ class DataEvolutionSplitRead(SplitRead):
def _create_file_reader(self, file: DataFileMeta, read_fields: [str]) ->
RecordReader:
"""Create a file reader for a single file."""
- return self.file_reader_supplier(file_path=file.file_path,
for_merge_read=False, read_fields=read_fields)
+ return self.file_reader_supplier(file=file, for_merge_read=False,
read_fields=read_fields)
def _split_field_bunches(self, need_merge_files: List[DataFileMeta]) ->
List[FieldBunch]:
"""Split files into field bunches."""
diff --git a/paimon-python/pypaimon/read/table_read.py
b/paimon-python/pypaimon/read/table_read.py
index 31545e4ea4..6cd544e745 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -39,6 +39,7 @@ class TableRead:
self.table: FileStoreTable = table
self.predicate = predicate
self.read_type = read_type
+ self.schema_fields_cache = ({}, {})
def to_iterator(self, splits: List[Split]) -> Iterator:
def _record_generator():
@@ -57,10 +58,32 @@ class TableRead:
batch_iterator = self._arrow_batch_generator(splits, schema)
return pyarrow.ipc.RecordBatchReader.from_batches(schema,
batch_iterator)
+ def _pad_batch_to_schema(self, batch: pyarrow.RecordBatch, target_schema):
+ columns = []
+ num_rows = batch.num_rows
+
+ for field in target_schema:
+ if field.name in batch.column_names:
+ col = batch.column(field.name)
+ else:
+ col = pyarrow.nulls(num_rows, type=field.type)
+ columns.append(col)
+
+ return pyarrow.RecordBatch.from_arrays(columns, schema=target_schema)
+
def to_arrow(self, splits: List[Split]) -> Optional[pyarrow.Table]:
batch_reader = self.to_arrow_batch_reader(splits)
- arrow_table = batch_reader.read_all()
- return arrow_table
+
+ schema = PyarrowFieldParser.from_paimon_schema(self.read_type)
+ table_list = []
+ for batch in iter(batch_reader.read_next_batch, None):
+ table_list.append(batch) if schema == batch.schema \
+ else table_list.append(self._pad_batch_to_schema(batch,
schema))
+
+ if not table_list:
+ return pyarrow.Table.from_arrays([pyarrow.array([],
type=field.type) for field in schema], schema=schema)
+ else:
+ return pyarrow.Table.from_batches(table_list)
def _arrow_batch_generator(self, splits: List[Split], schema:
pyarrow.Schema) -> Iterator[pyarrow.RecordBatch]:
chunk_size = 65536
@@ -112,21 +135,24 @@ class TableRead:
table=self.table,
predicate=self.predicate,
read_type=self.read_type,
- split=split
+ split=split,
+ schema_fields_cache=self.schema_fields_cache
)
elif self.table.options.get(CoreOptions.DATA_EVOLUTION_ENABLED,
'false').lower() == 'true':
return DataEvolutionSplitRead(
table=self.table,
predicate=self.predicate,
read_type=self.read_type,
- split=split
+ split=split,
+ schema_fields_cache=self.schema_fields_cache
)
else:
return RawFileSplitRead(
table=self.table,
predicate=self.predicate,
read_type=self.read_type,
- split=split
+ split=split,
+ schema_fields_cache=self.schema_fields_cache
)
@staticmethod
diff --git a/paimon-python/pypaimon/tests/pvfs_test.py
b/paimon-python/pypaimon/tests/pvfs_test.py
index 29ef979f9e..7bebceb96e 100644
--- a/paimon-python/pypaimon/tests/pvfs_test.py
+++ b/paimon-python/pypaimon/tests/pvfs_test.py
@@ -151,7 +151,8 @@ class PVFSTest(unittest.TestCase):
self.assertEqual(table_virtual_path,
self.pvfs.info(table_virtual_path).get('name'))
self.assertEqual(True, self.pvfs.exists(database_virtual_path))
user_dirs =
self.pvfs.ls(f"pvfs://{self.catalog}/{self.database}/{self.table}",
detail=False)
- self.assertSetEqual(set(user_dirs),
{f'pvfs://{self.catalog}/{self.database}/{self.table}/{data_file_name}'})
+ self.assertSetEqual(set(user_dirs),
{f'pvfs://{self.catalog}/{self.database}/{self.table}/{data_file_name}',
+
f'pvfs://{self.catalog}/{self.database}/{self.table}/schema'})
data_file_name = 'data.txt'
data_file_path =
f'pvfs://{self.catalog}/{self.database}/{self.table}/{data_file_name}'
diff --git a/paimon-python/pypaimon/tests/reader_append_only_test.py
b/paimon-python/pypaimon/tests/reader_append_only_test.py
index 0367ab409c..3a99196854 100644
--- a/paimon-python/pypaimon/tests/reader_append_only_test.py
+++ b/paimon-python/pypaimon/tests/reader_append_only_test.py
@@ -38,7 +38,7 @@ class AoReaderTest(unittest.TestCase):
cls.catalog = CatalogFactory.create({
'warehouse': cls.warehouse
})
- cls.catalog.create_database('default', False)
+ cls.catalog.create_database('default', True)
cls.pa_schema = pa.schema([
('user_id', pa.int32()),
diff --git a/paimon-python/pypaimon/tests/rest/rest_server.py
b/paimon-python/pypaimon/tests/rest/rest_server.py
index 8f7cf23944..d2908e59be 100644
--- a/paimon-python/pypaimon/tests/rest/rest_server.py
+++ b/paimon-python/pypaimon/tests/rest/rest_server.py
@@ -428,12 +428,15 @@ class RESTCatalogServer:
if create_table.identifier.get_full_name() in
self.table_metadata_store:
raise TableAlreadyExistException(create_table.identifier)
table_metadata = self._create_table_metadata(
- create_table.identifier, 1, create_table.schema,
str(uuid.uuid4()), False
+ create_table.identifier, 0, create_table.schema,
str(uuid.uuid4()), False
)
self.table_metadata_store.update({create_table.identifier.get_full_name():
table_metadata})
- table_dir = Path(self.data_path) / self.warehouse /
database_name / create_table.identifier.object_name
+ table_dir = Path(
+ self.data_path) / self.warehouse / database_name /
create_table.identifier.object_name / 'schema'
if not table_dir.exists():
table_dir.mkdir(parents=True)
+ with open(table_dir / "schema-0", "w") as f:
+ f.write(JSON.to_json(table_metadata.schema, indent=2))
return self._mock_response("", 200)
return self._mock_response(ErrorResponse(None, None, "Method Not
Allowed", 405), 405)
diff --git a/paimon-python/pypaimon/tests/schema_evolution_read_test.py
b/paimon-python/pypaimon/tests/schema_evolution_read_test.py
new file mode 100644
index 0000000000..dde1f2c15f
--- /dev/null
+++ b/paimon-python/pypaimon/tests/schema_evolution_read_test.py
@@ -0,0 +1,328 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+import shutil
+import tempfile
+import unittest
+
+import pyarrow as pa
+
+from pypaimon import CatalogFactory, Schema
+
+from pypaimon.schema.schema_manager import SchemaManager
+from pypaimon.schema.table_schema import TableSchema
+
+
+class SchemaEvolutionReadTest(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog = CatalogFactory.create({
+ 'warehouse': cls.warehouse
+ })
+ cls.catalog.create_database('default', False)
+
+ cls.pa_schema = pa.schema([
+ ('user_id', pa.int64()),
+ ('item_id', pa.int64()),
+ ('behavior', pa.string()),
+ ('dt', pa.string())
+ ])
+ cls.raw_data = {
+ 'user_id': [1, 2, 3, 4, 5],
+ 'item_id': [1001, 1002, 1003, 1004, 1005],
+ 'behavior': ['a', 'b', 'c', None, 'e'],
+ 'dt': ['p1', 'p1', 'p1', 'p1', 'p2'],
+ }
+ cls.expected = pa.Table.from_pydict(cls.raw_data, schema=cls.pa_schema)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ def test_schema_evolution(self):
+ # schema 0
+ pa_schema = pa.schema([
+ ('user_id', pa.int64()),
+ ('item_id', pa.int64()),
+ ('dt', pa.string())
+ ])
+ schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt'])
+ self.catalog.create_table('default.test_sample', schema, False)
+ table1 = self.catalog.get_table('default.test_sample')
+ write_builder = table1.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data1 = {
+ 'user_id': [1, 2, 3, 4],
+ 'item_id': [1001, 1002, 1003, 1004],
+ 'dt': ['p1', 'p1', 'p2', 'p1'],
+ }
+ pa_table = pa.Table.from_pydict(data1, schema=pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # schema 1 add behavior column
+ pa_schema = pa.schema([
+ ('user_id', pa.int64()),
+ ('item_id', pa.int64()),
+ ('dt', pa.string()),
+ ('behavior', pa.string())
+ ])
+ schema2 = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt'])
+ self.catalog.create_table('default.test_schema_evolution', schema2,
False)
+ table2 = self.catalog.get_table('default.test_schema_evolution')
+ table2.table_schema.id = 1
+ write_builder = table2.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data2 = {
+ 'user_id': [5, 6, 7, 8],
+ 'item_id': [1005, 1006, 1007, 1008],
+ 'dt': ['p2', 'p1', 'p2', 'p2'],
+ 'behavior': ['e', 'f', 'g', 'h'],
+ }
+ pa_table = pa.Table.from_pydict(data2, schema=pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # write schema-0 and schema-1 to table2
+ schema_manager = SchemaManager(table2.file_io, table2.table_path)
+ schema_manager.commit(TableSchema.from_schema(schema_id=0,
schema=schema))
+ schema_manager.commit(TableSchema.from_schema(schema_id=1,
schema=schema2))
+
+ splits = self._scan_table(table1.new_read_builder())
+ read_builder = table2.new_read_builder()
+ splits2 = self._scan_table(read_builder)
+ splits.extend(splits2)
+
+ table_read = read_builder.new_read()
+ actual = table_read.to_arrow(splits)
+ expected = pa.Table.from_pydict({
+ 'user_id': [1, 2, 4, 3, 5, 7, 8, 6],
+ 'item_id': [1001, 1002, 1004, 1003, 1005, 1007, 1008, 1006],
+ 'dt': ["p1", "p1", "p1", "p2", "p2", "p2", "p2", "p1"],
+ 'behavior': [None, None, None, None, "e", "g", "h", "f"],
+ }, schema=pa_schema)
+ self.assertEqual(expected, actual)
+
+ def test_schema_evolution_with_read_filter(self):
+ # schema 0
+ pa_schema = pa.schema([
+ ('user_id', pa.int64()),
+ ('item_id', pa.int64()),
+ ('dt', pa.string())
+ ])
+ schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt'])
+ self.catalog.create_table('default.test_schema_evolution_with_filter',
schema, False)
+ table1 =
self.catalog.get_table('default.test_schema_evolution_with_filter')
+ write_builder = table1.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data1 = {
+ 'user_id': [1, 2, 3, 4],
+ 'item_id': [1001, 1002, 1003, 1004],
+ 'dt': ['p1', 'p1', 'p2', 'p1'],
+ }
+ pa_table = pa.Table.from_pydict(data1, schema=pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # schema 1 add behavior column
+ pa_schema = pa.schema([
+ ('user_id', pa.int64()),
+ ('item_id', pa.int64()),
+ ('dt', pa.string()),
+ ('behavior', pa.string())
+ ])
+ schema2 = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt'])
+
self.catalog.create_table('default.test_schema_evolution_with_filter2',
schema2, False)
+ table2 =
self.catalog.get_table('default.test_schema_evolution_with_filter2')
+ table2.table_schema.id = 1
+ write_builder = table2.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data2 = {
+ 'user_id': [5, 6, 7, 8],
+ 'item_id': [1005, 1006, 1007, 1008],
+ 'dt': ['p2', 'p1', 'p2', 'p2'],
+ 'behavior': ['e', 'f', 'g', 'h'],
+ }
+ pa_table = pa.Table.from_pydict(data2, schema=pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # write schema-0 and schema-1 to table2
+ schema_manager = SchemaManager(table2.file_io, table2.table_path)
+ schema_manager.commit(TableSchema.from_schema(schema_id=0,
schema=schema))
+ schema_manager.commit(TableSchema.from_schema(schema_id=1,
schema=schema2))
+ # behavior filter
+ splits = self._scan_table(table1.new_read_builder())
+
+ read_builder = table2.new_read_builder()
+ predicate_builder = read_builder.new_predicate_builder()
+ predicate = predicate_builder.not_equal('behavior', "g")
+ splits2 = self._scan_table(read_builder.with_filter(predicate))
+ for split in splits2:
+ for file in split.files:
+ file.schema_id = 1
+ splits.extend(splits2)
+
+ table_read = read_builder.new_read()
+ actual = table_read.to_arrow(splits)
+ expected = pa.Table.from_pydict({
+ 'user_id': [5, 8, 6],
+ 'item_id': [1005, 1008, 1006],
+ 'dt': ["p2", "p2", "p1"],
+ 'behavior': ["e", "h", "f"],
+ }, schema=pa_schema)
+ self.assertEqual(expected, actual)
+ # user_id filter
+ splits = self._scan_table(table1.new_read_builder())
+
+ read_builder = table2.new_read_builder()
+ predicate_builder = read_builder.new_predicate_builder()
+ predicate = predicate_builder.less_than('user_id', 6)
+ splits2 = self._scan_table(read_builder.with_filter(predicate))
+ self.assertEqual(1, len(splits2))
+ for split in splits2:
+ for file in split.files:
+ file.schema_id = 1
+ splits.extend(splits2)
+
+ table_read = read_builder.new_read()
+ actual = table_read.to_arrow(splits)
+ expected = pa.Table.from_pydict({
+ 'user_id': [1, 2, 4, 3, 5],
+ 'item_id': [1001, 1002, 1004, 1003, 1005],
+ 'dt': ["p1", "p1", "p1", "p2", "p2"],
+ 'behavior': [None, None, None, None, "e"],
+ }, schema=pa_schema)
+ self.assertEqual(expected, actual)
+
+ def test_schema_evolution_with_scan_filter(self):
+ # schema 0
+ pa_schema = pa.schema([
+ ('user_id', pa.int64()),
+ ('item_id', pa.int64()),
+ ('dt', pa.string())
+ ])
+ schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt'])
+ self.catalog.create_table('default.test_schema_evolution1', schema,
False)
+ table1 = self.catalog.get_table('default.test_schema_evolution1')
+ write_builder = table1.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data1 = {
+ 'user_id': [1, 2, 3, 4],
+ 'item_id': [1001, 1002, 1003, 1004],
+ 'dt': ['p1', 'p1', 'p2', 'p1'],
+ }
+ pa_table = pa.Table.from_pydict(data1, schema=pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # schema 1 add behavior column
+ pa_schema = pa.schema([
+ ('user_id', pa.int64()),
+ ('item_id', pa.int64()),
+ ('behavior', pa.string()),
+ ('dt', pa.string())
+ ])
+ schema2 = Schema.from_pyarrow_schema(pa_schema, partition_keys=['dt'])
+ self.catalog.create_table('default.test_schema_evolution2', schema2,
False)
+ table2 = self.catalog.get_table('default.test_schema_evolution2')
+ table2.table_schema.id = 1
+ write_builder = table2.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data2 = {
+ 'user_id': [5, 6, 7, 8],
+ 'item_id': [1005, 1006, 1007, 1008],
+ 'behavior': ['e', 'f', 'g', 'h'],
+ 'dt': ['p2', 'p1', 'p2', 'p2'],
+ }
+ pa_table = pa.Table.from_pydict(data2, schema=pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # write schema-0 and schema-1 to table2
+ schema_manager = SchemaManager(table2.file_io, table2.table_path)
+ schema_manager.commit(TableSchema.from_schema(schema_id=0,
schema=schema))
+ schema_manager.commit(TableSchema.from_schema(schema_id=1,
schema=schema2))
+ # scan filter for schema evolution
+ latest_snapshot =
table1.new_read_builder().new_scan().starting_scanner.snapshot_manager.get_latest_snapshot()
+ table2.table_path = table1.table_path
+ new_read_buidler = table2.new_read_builder()
+ predicate_builder = new_read_buidler.new_predicate_builder()
+ predicate = predicate_builder.less_than('user_id', 3)
+ new_scan = new_read_buidler.with_filter(predicate).new_scan()
+ manifest_files =
new_scan.starting_scanner.manifest_list_manager.read_all(latest_snapshot)
+ entries =
new_scan.starting_scanner.read_manifest_entries(manifest_files)
+ self.assertEqual(1, len(entries)) # verify scan filter success for
schema evolution
+
+ def _write_test_table(self, table):
+ write_builder = table.new_batch_write_builder()
+
+ # first write
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data1 = {
+ 'user_id': [1, 2, 3, 4],
+ 'item_id': [1001, 1002, 1003, 1004],
+ 'behavior': ['a', 'b', 'c', None],
+ 'dt': ['p1', 'p1', 'p2', 'p1'],
+ }
+ pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # second write
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data2 = {
+ 'user_id': [5, 6, 7, 8],
+ 'item_id': [1005, 1006, 1007, 1008],
+ 'behavior': ['e', 'f', 'g', 'h'],
+ 'dt': ['p2', 'p1', 'p2', 'p2'],
+ }
+ pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ def _scan_table(self, read_builder):
+ splits = read_builder.new_scan().plan().splits()
+ return splits