This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new c6a5df0e13 [python] Support mutual conversion between paimon DataField
and pyarrow Field (#6000)
c6a5df0e13 is described below
commit c6a5df0e13b9e9b482bce1e5cc9367bb4f653603
Author: HeavenZH <[email protected]>
AuthorDate: Fri Aug 1 12:22:37 2025 +0800
[python] Support mutual conversion between paimon DataField and pyarrow
Field (#6000)
---
paimon-python/pypaimon/schema/data_types.py | 213 ++++++++++++----------
paimon-python/pypaimon/schema/schema.py | 24 ++-
paimon-python/pypaimon/schema/table_schema.py | 13 +-
paimon-python/pypaimon/tests/schema_test.py | 59 ++++++
paimon-python/pypaimon/write/row_key_extractor.py | 1 -
5 files changed, 191 insertions(+), 119 deletions(-)
diff --git a/paimon-python/pypaimon/schema/data_types.py
b/paimon-python/pypaimon/schema/data_types.py
index c6cf516ad5..1d26ebca5b 100644
--- a/paimon-python/pypaimon/schema/data_types.py
+++ b/paimon-python/pypaimon/schema/data_types.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-import json
import re
import threading
from abc import ABC, abstractmethod
@@ -163,12 +162,12 @@ class DataField:
default_value: Optional[str] = None
def __init__(
- self,
- id: int,
- name: str,
- type: DataType,
- description: Optional[str] = None,
- default_value: Optional[str] = None,
+ self,
+ id: int,
+ name: str,
+ type: DataType,
+ description: Optional[str] = None,
+ default_value: Optional[str] = None,
):
self.id = id
self.name = name
@@ -195,43 +194,6 @@ class DataField:
return result
- def to_pyarrow_field(self):
- data_type = self.type
- if not isinstance(data_type, AtomicType):
- raise ValueError(f"Unsupported data type: {data_type.__class__}")
- type_name = data_type.type.upper()
- if type_name == 'INT':
- type_name = pyarrow.int32()
- elif type_name == 'BIGINT':
- type_name = pyarrow.int64()
- elif type_name == 'FLOAT':
- type_name = pyarrow.float32()
- elif type_name == 'DOUBLE':
- type_name = pyarrow.float64()
- elif type_name == 'BOOLEAN':
- type_name = pyarrow.bool_()
- elif type_name == 'STRING':
- type_name = pyarrow.string()
- elif type_name == 'BINARY':
- type_name = pyarrow.binary()
- elif type_name == 'DATE':
- type_name = pyarrow.date32()
- elif type_name == 'TIMESTAMP':
- type_name = pyarrow.timestamp('ms')
- elif type_name.startswith('DECIMAL'):
- match = re.match(r'DECIMAL\((\d+),\s*(\d+)\)', type_name)
- if match:
- precision, scale = map(int, match.groups())
- type_name = pyarrow.decimal128(precision, scale)
- else:
- type_name = pyarrow.decimal128(38, 18)
- else:
- raise ValueError(f"Unsupported data type: {type_name}")
- metadata = {}
- if self.description:
- metadata[b'description'] = self.description.encode('utf-8')
- return pyarrow.field(self.name, type_name,
nullable=data_type.nullable, metadata=metadata)
-
@dataclass
class RowType(DataType):
@@ -306,39 +268,9 @@ class DataTypeParser:
except ValueError:
raise Exception(f"Unknown type: {base_type}")
- @staticmethod
- def parse_atomic_type_pyarrow_field(field: pyarrow.Field) -> DataType:
- type_name = str(field.type)
- if type_name.startswith('int') or type_name.startswith('uint'):
- type_name = 'INT'
- elif type_name.startswith('float'):
- type_name = 'FLOAT'
- elif type_name.startswith('double'):
- type_name = 'DOUBLE'
- elif type_name.startswith('bool'):
- type_name = 'BOOLEAN'
- elif type_name.startswith('string'):
- type_name = 'STRING'
- elif type_name.startswith('binary'):
- type_name = 'BINARY'
- elif type_name.startswith('date'):
- type_name = 'DATE'
- elif type_name.startswith('timestamp'):
- type_name = 'TIMESTAMP'
- elif type_name.startswith('decimal'):
- match = re.match(r'decimal\((\d+),\s*(\d+)\)', type_name)
- if match:
- precision, scale = map(int, match.groups())
- type_name = f'DECIMAL({precision},{scale})'
- else:
- type_name = 'DECIMAL(38,18)'
- else:
- raise ValueError(f"Unknown type: {type_name}")
- return AtomicType(type_name, field.nullable)
-
@staticmethod
def parse_data_type(
- json_data: Union[Dict[str, Any], str], field_id:
Optional[AtomicInteger] = None
+ json_data: Union[Dict[str, Any], str], field_id:
Optional[AtomicInteger] = None
) -> DataType:
if isinstance(json_data, str):
@@ -389,12 +321,12 @@ class DataTypeParser:
@staticmethod
def parse_data_field(
- json_data: Dict[str, Any], field_id: Optional[AtomicInteger] = None
+ json_data: Dict[str, Any], field_id: Optional[AtomicInteger] = None
) -> DataField:
if (
- DataField.FIELD_ID in json_data
- and json_data[DataField.FIELD_ID] is not None
+ DataField.FIELD_ID in json_data
+ and json_data[DataField.FIELD_ID] is not None
):
if field_id is not None and field_id.get() != -1:
raise ValueError("Partial field id is not allowed.")
@@ -428,31 +360,112 @@ class DataTypeParser:
)
-def parse_data_type_from_json(
- json_str: str, field_id: Optional[AtomicInteger] = None
-) -> DataType:
- json_data = json.loads(json_str)
- return DataTypeParser.parse_data_type(json_data, field_id)
+class PyarrowFieldParse:
+ @staticmethod
+ def to_pyarrow_field(data_field: DataField) -> pyarrow.Field:
+ pa_field_type = PyarrowFieldParse.to_pyarrow_data_type(data_field.type)
+ metadata = {}
+ if data_field.description:
+ metadata[b'description'] = data_field.description.encode('utf-8')
+ return pyarrow.field(data_field.name, pa_field_type,
nullable=data_field.type.nullable, metadata=metadata)
-def parse_data_field_from_json(
- json_str: str, field_id: Optional[AtomicInteger] = None
-) -> DataField:
- json_data = json.loads(json_str)
- return DataTypeParser.parse_data_field(json_data, field_id)
-
+ @staticmethod
+ def to_pyarrow_data_type(data_type: DataType) -> pyarrow.DataType:
+ if isinstance(data_type, AtomicType):
+ type_name = data_type.type.upper()
+ if type_name == 'INT':
+ return pyarrow.int32()
+ elif type_name == 'BIGINT':
+ return pyarrow.int64()
+ elif type_name == 'FLOAT':
+ return pyarrow.float32()
+ elif type_name == 'DOUBLE':
+ return pyarrow.float64()
+ elif type_name == 'BOOLEAN':
+ return pyarrow.bool_()
+ elif type_name == 'STRING':
+ return pyarrow.string()
+ elif type_name == 'BINARY':
+ return pyarrow.binary()
+ elif type_name == 'DATE':
+ return pyarrow.date32()
+ elif type_name == 'TIMESTAMP':
+ return pyarrow.timestamp('ms')
+ elif type_name.startswith('DECIMAL'):
+ match = re.match(r'DECIMAL\((\d+),\s*(\d+)\)', type_name)
+ if match:
+ precision, scale = map(int, match.groups())
+ return pyarrow.decimal128(precision, scale)
+ else:
+ return pyarrow.decimal128(38, 18)
+ else:
+ raise ValueError(f"Unsupported data type: {type_name}")
+ elif isinstance(data_type, ArrayType):
+ return
pyarrow.list_(PyarrowFieldParse.to_pyarrow_data_type(data_type.element))
+ elif isinstance(data_type, MapType):
+ key_type = PyarrowFieldParse.to_pyarrow_data_type(data_type.key)
+ value_type =
PyarrowFieldParse.to_pyarrow_data_type(data_type.value)
+ return pyarrow.map_(key_type, value_type)
+ else:
+ raise ValueError(f"Unsupported data type: {data_type}")
-def parse_data_fields_from_pyarrow_schema(pa_schema: pyarrow.Schema) ->
list[DataField]:
- fields = []
- for i, pa_field in enumerate(pa_schema):
- pa_field: pyarrow.Field
- data_type = DataTypeParser.parse_atomic_type_pyarrow_field(pa_field)
- data_field = DataField(
- id=i,
+ @staticmethod
+ def from_pyarrow_field(field_idx: int, pa_field: pyarrow.Field) ->
DataField:
+ data_type = PyarrowFieldParse.from_pyarrow_data_type(pa_field.type,
pa_field.nullable)
+ description = pa_field.metadata.get(b'description',
b'').decode('utf-8') \
+ if pa_field.metadata and b'description' in pa_field.metadata else
None
+ return DataField(
+ id=field_idx,
name=pa_field.name,
type=data_type,
- description=pa_field.metadata.get(b'description', b'').decode
- ('utf-8') if pa_field.metadata and b'description' in
pa_field.metadata else None
+ description=description
)
- fields.append(data_field)
- return fields
+
+ @staticmethod
+ def from_pyarrow_data_type(pa_type: pyarrow.DataType, nullable: bool) ->
DataType:
+ type_name = str(pa_type)
+ if type_name.startswith('int') or type_name.startswith('uint'):
+ type_name = 'INT'
+ elif type_name.startswith('float'):
+ type_name = 'FLOAT'
+ elif type_name.startswith('double'):
+ type_name = 'DOUBLE'
+ elif type_name.startswith('bool'):
+ type_name = 'BOOLEAN'
+ elif type_name.startswith('string'):
+ type_name = 'STRING'
+ elif type_name.startswith('binary'):
+ type_name = 'BINARY'
+ elif type_name.startswith('date'):
+ type_name = 'DATE'
+ elif type_name.startswith('timestamp'):
+ type_name = 'TIMESTAMP'
+ elif type_name.startswith('decimal'):
+ match = re.match(r'decimal\((\d+),\s*(\d+)\)', type_name)
+ if match:
+ precision, scale = map(int, match.groups())
+ type_name = f'DECIMAL({precision},{scale})'
+ else:
+ type_name = 'DECIMAL(38,18)'
+ elif type_name.startswith('list'):
+ pa_type: pyarrow.ListType
+ element_type =
PyarrowFieldParse.from_pyarrow_data_type(pa_type.value_type, nullable)
+ return ArrayType(nullable, element_type)
+ elif type_name.startswith('map'):
+ pa_type: pyarrow.MapType
+ key_type =
PyarrowFieldParse.from_pyarrow_data_type(pa_type.key_type, nullable)
+ value_type =
PyarrowFieldParse.from_pyarrow_data_type(pa_type.item_type, nullable)
+ return MapType(nullable, key_type, value_type)
+ else:
+ raise ValueError(f"Unknown type: {type_name}")
+ return AtomicType(type_name)
+
+ @staticmethod
+ def parse_pyarrow_schema(pa_schema: pyarrow.Schema) -> List[DataField]:
+ fields = []
+ for i, pa_field in enumerate(pa_schema):
+ pa_field: pyarrow.Field
+ data_field = PyarrowFieldParse.from_pyarrow_field(i, pa_field)
+ fields.append(data_field)
+ return fields
diff --git a/paimon-python/pypaimon/schema/schema.py
b/paimon-python/pypaimon/schema/schema.py
index 354a9f1d21..9bd8523a72 100644
--- a/paimon-python/pypaimon/schema/schema.py
+++ b/paimon-python/pypaimon/schema/schema.py
@@ -19,7 +19,7 @@ from dataclasses import dataclass
from typing import Optional, List, Dict
import pyarrow as pa
-from pypaimon.schema.data_types import DataField
+from pypaimon.schema.data_types import DataField, PyarrowFieldParse
from pypaimon.common.rest_json import json_field
@@ -31,11 +31,23 @@ class Schema:
FIELD_OPTIONS = "options"
FIELD_COMMENT = "comment"
- pa_schema: Optional[pa.Schema] = None
fields: List[DataField] = json_field(FIELD_FIELDS, default_factory=list)
- partition_keys: List[str] = json_field(
- FIELD_PARTITION_KEYS, default_factory=list)
- primary_keys: List[str] = json_field(
- FIELD_PRIMARY_KEYS, default_factory=list)
+ partition_keys: List[str] = json_field(FIELD_PARTITION_KEYS,
default_factory=list)
+ primary_keys: List[str] = json_field(FIELD_PRIMARY_KEYS,
default_factory=list)
options: Dict[str, str] = json_field(FIELD_OPTIONS, default_factory=dict)
comment: Optional[str] = json_field(FIELD_COMMENT, default=None)
+
+ def __init__(self, fields: Optional[List[DataField]] = None,
partition_keys: Optional[List[str]] = None,
+ primary_keys: Optional[List[str]] = None,
+ options: Optional[Dict[str, str]] = None, comment:
Optional[str] = None):
+ self.fields = fields if fields is not None else []
+ self.partition_keys = partition_keys if partition_keys is not None
else []
+ self.primary_keys = primary_keys if primary_keys is not None else []
+ self.options = options if options is not None else {}
+ self.comment = comment
+
+ @staticmethod
+ def build_from_pyarrow_schema(pa_schema: pa.Schema, partition_keys:
Optional[List[str]] = None,
+ primary_keys: Optional[List[str]] = None,
options: Optional[Dict[str, str]] = None,
+ comment: Optional[str] = None):
+ return Schema(PyarrowFieldParse.parse_pyarrow_schema(pa_schema),
partition_keys, primary_keys, options, comment)
diff --git a/paimon-python/pypaimon/schema/table_schema.py
b/paimon-python/pypaimon/schema/table_schema.py
index 736de902db..28084b2771 100644
--- a/paimon-python/pypaimon/schema/table_schema.py
+++ b/paimon-python/pypaimon/schema/table_schema.py
@@ -22,11 +22,9 @@ from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict, Optional
-import pyarrow
from pypaimon import Schema
from pypaimon.common.rest_json import json_field
-from pypaimon.schema import data_types
from pypaimon.common.core_options import CoreOptions
from pypaimon.common.file_io import FileIO
from pypaimon.schema.data_types import DataField
@@ -74,13 +72,6 @@ class TableSchema:
self.time_millis = time_millis if time_millis is not None else
int(time.time() * 1000)
def to_schema(self) -> Schema:
- try:
- pa_fields = []
- for field in self.fields:
- pa_fields.append(field.to_pyarrow_field())
- pyarrow.schema(pa_fields)
- except Exception as e:
- print(e)
return Schema(
fields=self.fields,
partition_keys=self.partition_keys,
@@ -133,12 +124,10 @@ class TableSchema:
@staticmethod
def from_schema(schema_id: int, schema: Schema) -> "TableSchema":
fields: List[DataField] = schema.fields
- if not schema.fields:
- fields =
data_types.parse_data_fields_from_pyarrow_schema(schema.pa_schema)
partition_keys: List[str] = schema.partition_keys
primary_keys: List[str] = schema.primary_keys
options: Dict[str, str] = schema.options
- highest_field_id: int = None # max(field.id for field in fields)
+ highest_field_id: int = max(field.id for field in fields)
return TableSchema(
TableSchema.CURRENT_VERSION,
diff --git a/paimon-python/pypaimon/tests/schema_test.py
b/paimon-python/pypaimon/tests/schema_test.py
new file mode 100644
index 0000000000..4064069ebd
--- /dev/null
+++ b/paimon-python/pypaimon/tests/schema_test.py
@@ -0,0 +1,59 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import unittest
+
+import pyarrow
+
+from pypaimon import Schema
+from pypaimon.schema.table_schema import TableSchema
+from pypaimon.schema.data_types import AtomicType, ArrayType, MapType,
DataField, PyarrowFieldParse
+
+
+class SchemaTestCase(unittest.TestCase):
+ def test_types(self):
+ data_fields = [
+ DataField(0, "name", AtomicType('INT'), 'desc name'),
+ DataField(1, "arr", ArrayType(True, AtomicType('INT')), 'desc
arr1'),
+ DataField(2, "map1",
+ MapType(False, AtomicType('INT'), MapType(False,
AtomicType('INT'), AtomicType('INT'))),
+ 'desc map1'),
+ ]
+ table_schema = TableSchema(TableSchema.CURRENT_VERSION,
len(data_fields), data_fields,
+ max(field.id for field in data_fields),
+ [], [], {}, "")
+ pa_fields = []
+ for field in table_schema.fields:
+ pa_field = PyarrowFieldParse.to_pyarrow_field(field)
+ pa_fields.append(pa_field)
+ schema = Schema.build_from_pyarrow_schema(
+ pa_schema=pyarrow.schema(pa_fields),
+ partition_keys=table_schema.partition_keys,
+ primary_keys=table_schema.primary_keys,
+ options=table_schema.options,
+ comment=table_schema.comment
+ )
+ table_schema2 = TableSchema.from_schema(len(data_fields), schema)
+ # print("table_schema2:", table_schema2)
+ l1 = []
+ for field in table_schema.fields:
+ l1.append(field.to_dict())
+ l2 = []
+ for field in table_schema2.fields:
+ l2.append(field.to_dict())
+ self.assertEqual(l1, l2)
diff --git a/paimon-python/pypaimon/write/row_key_extractor.py
b/paimon-python/pypaimon/write/row_key_extractor.py
index cda3ad07ba..13a755c53c 100644
--- a/paimon-python/pypaimon/write/row_key_extractor.py
+++ b/paimon-python/pypaimon/write/row_key_extractor.py
@@ -58,7 +58,6 @@ class RowKeyExtractor(ABC):
@abstractmethod
def _extract_buckets_batch(self, table: pa.RecordBatch) -> List[int]:
"""Extract bucket numbers for all rows. Must be implemented by
subclasses."""
- pass
class FixedBucketRowKeyExtractor(RowKeyExtractor):