This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new c6a5df0e13 [python] Support mutual conversion between paimon DataField 
and pyarrow Field (#6000)
c6a5df0e13 is described below

commit c6a5df0e13b9e9b482bce1e5cc9367bb4f653603
Author: HeavenZH <[email protected]>
AuthorDate: Fri Aug 1 12:22:37 2025 +0800

    [python] Support mutual conversion between paimon DataField and pyarrow 
Field (#6000)
---
 paimon-python/pypaimon/schema/data_types.py       | 213 ++++++++++++----------
 paimon-python/pypaimon/schema/schema.py           |  24 ++-
 paimon-python/pypaimon/schema/table_schema.py     |  13 +-
 paimon-python/pypaimon/tests/schema_test.py       |  59 ++++++
 paimon-python/pypaimon/write/row_key_extractor.py |   1 -
 5 files changed, 191 insertions(+), 119 deletions(-)

diff --git a/paimon-python/pypaimon/schema/data_types.py 
b/paimon-python/pypaimon/schema/data_types.py
index c6cf516ad5..1d26ebca5b 100644
--- a/paimon-python/pypaimon/schema/data_types.py
+++ b/paimon-python/pypaimon/schema/data_types.py
@@ -15,7 +15,6 @@
 #  specific language governing permissions and limitations
 #  under the License.
 
-import json
 import re
 import threading
 from abc import ABC, abstractmethod
@@ -163,12 +162,12 @@ class DataField:
     default_value: Optional[str] = None
 
     def __init__(
-        self,
-        id: int,
-        name: str,
-        type: DataType,
-        description: Optional[str] = None,
-        default_value: Optional[str] = None,
+            self,
+            id: int,
+            name: str,
+            type: DataType,
+            description: Optional[str] = None,
+            default_value: Optional[str] = None,
     ):
         self.id = id
         self.name = name
@@ -195,43 +194,6 @@ class DataField:
 
         return result
 
-    def to_pyarrow_field(self):
-        data_type = self.type
-        if not isinstance(data_type, AtomicType):
-            raise ValueError(f"Unsupported data type: {data_type.__class__}")
-        type_name = data_type.type.upper()
-        if type_name == 'INT':
-            type_name = pyarrow.int32()
-        elif type_name == 'BIGINT':
-            type_name = pyarrow.int64()
-        elif type_name == 'FLOAT':
-            type_name = pyarrow.float32()
-        elif type_name == 'DOUBLE':
-            type_name = pyarrow.float64()
-        elif type_name == 'BOOLEAN':
-            type_name = pyarrow.bool_()
-        elif type_name == 'STRING':
-            type_name = pyarrow.string()
-        elif type_name == 'BINARY':
-            type_name = pyarrow.binary()
-        elif type_name == 'DATE':
-            type_name = pyarrow.date32()
-        elif type_name == 'TIMESTAMP':
-            type_name = pyarrow.timestamp('ms')
-        elif type_name.startswith('DECIMAL'):
-            match = re.match(r'DECIMAL\((\d+),\s*(\d+)\)', type_name)
-            if match:
-                precision, scale = map(int, match.groups())
-                type_name = pyarrow.decimal128(precision, scale)
-            else:
-                type_name = pyarrow.decimal128(38, 18)
-        else:
-            raise ValueError(f"Unsupported data type: {type_name}")
-        metadata = {}
-        if self.description:
-            metadata[b'description'] = self.description.encode('utf-8')
-        return pyarrow.field(self.name, type_name, 
nullable=data_type.nullable, metadata=metadata)
-
 
 @dataclass
 class RowType(DataType):
@@ -306,39 +268,9 @@ class DataTypeParser:
         except ValueError:
             raise Exception(f"Unknown type: {base_type}")
 
-    @staticmethod
-    def parse_atomic_type_pyarrow_field(field: pyarrow.Field) -> DataType:
-        type_name = str(field.type)
-        if type_name.startswith('int') or type_name.startswith('uint'):
-            type_name = 'INT'
-        elif type_name.startswith('float'):
-            type_name = 'FLOAT'
-        elif type_name.startswith('double'):
-            type_name = 'DOUBLE'
-        elif type_name.startswith('bool'):
-            type_name = 'BOOLEAN'
-        elif type_name.startswith('string'):
-            type_name = 'STRING'
-        elif type_name.startswith('binary'):
-            type_name = 'BINARY'
-        elif type_name.startswith('date'):
-            type_name = 'DATE'
-        elif type_name.startswith('timestamp'):
-            type_name = 'TIMESTAMP'
-        elif type_name.startswith('decimal'):
-            match = re.match(r'decimal\((\d+),\s*(\d+)\)', type_name)
-            if match:
-                precision, scale = map(int, match.groups())
-                type_name = f'DECIMAL({precision},{scale})'
-            else:
-                type_name = 'DECIMAL(38,18)'
-        else:
-            raise ValueError(f"Unknown type: {type_name}")
-        return AtomicType(type_name, field.nullable)
-
     @staticmethod
     def parse_data_type(
-        json_data: Union[Dict[str, Any], str], field_id: 
Optional[AtomicInteger] = None
+            json_data: Union[Dict[str, Any], str], field_id: 
Optional[AtomicInteger] = None
     ) -> DataType:
 
         if isinstance(json_data, str):
@@ -389,12 +321,12 @@ class DataTypeParser:
 
     @staticmethod
     def parse_data_field(
-        json_data: Dict[str, Any], field_id: Optional[AtomicInteger] = None
+            json_data: Dict[str, Any], field_id: Optional[AtomicInteger] = None
     ) -> DataField:
 
         if (
-            DataField.FIELD_ID in json_data
-            and json_data[DataField.FIELD_ID] is not None
+                DataField.FIELD_ID in json_data
+                and json_data[DataField.FIELD_ID] is not None
         ):
             if field_id is not None and field_id.get() != -1:
                 raise ValueError("Partial field id is not allowed.")
@@ -428,31 +360,112 @@ class DataTypeParser:
         )
 
 
-def parse_data_type_from_json(
-    json_str: str, field_id: Optional[AtomicInteger] = None
-) -> DataType:
-    json_data = json.loads(json_str)
-    return DataTypeParser.parse_data_type(json_data, field_id)
+class PyarrowFieldParse:
 
+    @staticmethod
+    def to_pyarrow_field(data_field: DataField) -> pyarrow.Field:
+        pa_field_type = PyarrowFieldParse.to_pyarrow_data_type(data_field.type)
+        metadata = {}
+        if data_field.description:
+            metadata[b'description'] = data_field.description.encode('utf-8')
+        return pyarrow.field(data_field.name, pa_field_type, 
nullable=data_field.type.nullable, metadata=metadata)
 
-def parse_data_field_from_json(
-    json_str: str, field_id: Optional[AtomicInteger] = None
-) -> DataField:
-    json_data = json.loads(json_str)
-    return DataTypeParser.parse_data_field(json_data, field_id)
-
+    @staticmethod
+    def to_pyarrow_data_type(data_type: DataType) -> pyarrow.DataType:
+        if isinstance(data_type, AtomicType):
+            type_name = data_type.type.upper()
+            if type_name == 'INT':
+                return pyarrow.int32()
+            elif type_name == 'BIGINT':
+                return pyarrow.int64()
+            elif type_name == 'FLOAT':
+                return pyarrow.float32()
+            elif type_name == 'DOUBLE':
+                return pyarrow.float64()
+            elif type_name == 'BOOLEAN':
+                return pyarrow.bool_()
+            elif type_name == 'STRING':
+                return pyarrow.string()
+            elif type_name == 'BINARY':
+                return pyarrow.binary()
+            elif type_name == 'DATE':
+                return pyarrow.date32()
+            elif type_name == 'TIMESTAMP':
+                return pyarrow.timestamp('ms')
+            elif type_name.startswith('DECIMAL'):
+                match = re.match(r'DECIMAL\((\d+),\s*(\d+)\)', type_name)
+                if match:
+                    precision, scale = map(int, match.groups())
+                    return pyarrow.decimal128(precision, scale)
+                else:
+                    return pyarrow.decimal128(38, 18)
+            else:
+                raise ValueError(f"Unsupported data type: {type_name}")
+        elif isinstance(data_type, ArrayType):
+            return 
pyarrow.list_(PyarrowFieldParse.to_pyarrow_data_type(data_type.element))
+        elif isinstance(data_type, MapType):
+            key_type = PyarrowFieldParse.to_pyarrow_data_type(data_type.key)
+            value_type = 
PyarrowFieldParse.to_pyarrow_data_type(data_type.value)
+            return pyarrow.map_(key_type, value_type)
+        else:
+            raise ValueError(f"Unsupported data type: {data_type}")
 
-def parse_data_fields_from_pyarrow_schema(pa_schema: pyarrow.Schema) -> 
list[DataField]:
-    fields = []
-    for i, pa_field in enumerate(pa_schema):
-        pa_field: pyarrow.Field
-        data_type = DataTypeParser.parse_atomic_type_pyarrow_field(pa_field)
-        data_field = DataField(
-            id=i,
+    @staticmethod
+    def from_pyarrow_field(field_idx: int, pa_field: pyarrow.Field) -> 
DataField:
+        data_type = PyarrowFieldParse.from_pyarrow_data_type(pa_field.type, 
pa_field.nullable)
+        description = pa_field.metadata.get(b'description', 
b'').decode('utf-8') \
+            if pa_field.metadata and b'description' in pa_field.metadata else 
None
+        return DataField(
+            id=field_idx,
             name=pa_field.name,
             type=data_type,
-            description=pa_field.metadata.get(b'description', b'').decode
-            ('utf-8') if pa_field.metadata and b'description' in 
pa_field.metadata else None
+            description=description
         )
-        fields.append(data_field)
-    return fields
+
+    @staticmethod
+    def from_pyarrow_data_type(pa_type: pyarrow.DataType, nullable: bool) -> 
DataType:
+        type_name = str(pa_type)
+        if type_name.startswith('int') or type_name.startswith('uint'):
+            type_name = 'INT'
+        elif type_name.startswith('float'):
+            type_name = 'FLOAT'
+        elif type_name.startswith('double'):
+            type_name = 'DOUBLE'
+        elif type_name.startswith('bool'):
+            type_name = 'BOOLEAN'
+        elif type_name.startswith('string'):
+            type_name = 'STRING'
+        elif type_name.startswith('binary'):
+            type_name = 'BINARY'
+        elif type_name.startswith('date'):
+            type_name = 'DATE'
+        elif type_name.startswith('timestamp'):
+            type_name = 'TIMESTAMP'
+        elif type_name.startswith('decimal'):
+            match = re.match(r'decimal\((\d+),\s*(\d+)\)', type_name)
+            if match:
+                precision, scale = map(int, match.groups())
+                type_name = f'DECIMAL({precision},{scale})'
+            else:
+                type_name = 'DECIMAL(38,18)'
+        elif type_name.startswith('list'):
+            pa_type: pyarrow.ListType
+            element_type = 
PyarrowFieldParse.from_pyarrow_data_type(pa_type.value_type, nullable)
+            return ArrayType(nullable, element_type)
+        elif type_name.startswith('map'):
+            pa_type: pyarrow.MapType
+            key_type = 
PyarrowFieldParse.from_pyarrow_data_type(pa_type.key_type, nullable)
+            value_type = 
PyarrowFieldParse.from_pyarrow_data_type(pa_type.item_type, nullable)
+            return MapType(nullable, key_type, value_type)
+        else:
+            raise ValueError(f"Unknown type: {type_name}")
+        return AtomicType(type_name)
+
+    @staticmethod
+    def parse_pyarrow_schema(pa_schema: pyarrow.Schema) -> List[DataField]:
+        fields = []
+        for i, pa_field in enumerate(pa_schema):
+            pa_field: pyarrow.Field
+            data_field = PyarrowFieldParse.from_pyarrow_field(i, pa_field)
+            fields.append(data_field)
+        return fields
diff --git a/paimon-python/pypaimon/schema/schema.py 
b/paimon-python/pypaimon/schema/schema.py
index 354a9f1d21..9bd8523a72 100644
--- a/paimon-python/pypaimon/schema/schema.py
+++ b/paimon-python/pypaimon/schema/schema.py
@@ -19,7 +19,7 @@ from dataclasses import dataclass
 from typing import Optional, List, Dict
 
 import pyarrow as pa
-from pypaimon.schema.data_types import DataField
+from pypaimon.schema.data_types import DataField, PyarrowFieldParse
 from pypaimon.common.rest_json import json_field
 
 
@@ -31,11 +31,23 @@ class Schema:
     FIELD_OPTIONS = "options"
     FIELD_COMMENT = "comment"
 
-    pa_schema: Optional[pa.Schema] = None
     fields: List[DataField] = json_field(FIELD_FIELDS, default_factory=list)
-    partition_keys: List[str] = json_field(
-        FIELD_PARTITION_KEYS, default_factory=list)
-    primary_keys: List[str] = json_field(
-        FIELD_PRIMARY_KEYS, default_factory=list)
+    partition_keys: List[str] = json_field(FIELD_PARTITION_KEYS, 
default_factory=list)
+    primary_keys: List[str] = json_field(FIELD_PRIMARY_KEYS, 
default_factory=list)
     options: Dict[str, str] = json_field(FIELD_OPTIONS, default_factory=dict)
     comment: Optional[str] = json_field(FIELD_COMMENT, default=None)
+
+    def __init__(self, fields: Optional[List[DataField]] = None, 
partition_keys: Optional[List[str]] = None,
+                 primary_keys: Optional[List[str]] = None,
+                 options: Optional[Dict[str, str]] = None, comment: 
Optional[str] = None):
+        self.fields = fields if fields is not None else []
+        self.partition_keys = partition_keys if partition_keys is not None 
else []
+        self.primary_keys = primary_keys if primary_keys is not None else []
+        self.options = options if options is not None else {}
+        self.comment = comment
+
+    @staticmethod
+    def build_from_pyarrow_schema(pa_schema: pa.Schema, partition_keys: 
Optional[List[str]] = None,
+                                  primary_keys: Optional[List[str]] = None, 
options: Optional[Dict[str, str]] = None,
+                                  comment: Optional[str] = None):
+        return Schema(PyarrowFieldParse.parse_pyarrow_schema(pa_schema), 
partition_keys, primary_keys, options, comment)
diff --git a/paimon-python/pypaimon/schema/table_schema.py 
b/paimon-python/pypaimon/schema/table_schema.py
index 736de902db..28084b2771 100644
--- a/paimon-python/pypaimon/schema/table_schema.py
+++ b/paimon-python/pypaimon/schema/table_schema.py
@@ -22,11 +22,9 @@ from dataclasses import dataclass
 from pathlib import Path
 from typing import List, Dict, Optional
 
-import pyarrow
 
 from pypaimon import Schema
 from pypaimon.common.rest_json import json_field
-from pypaimon.schema import data_types
 from pypaimon.common.core_options import CoreOptions
 from pypaimon.common.file_io import FileIO
 from pypaimon.schema.data_types import DataField
@@ -74,13 +72,6 @@ class TableSchema:
         self.time_millis = time_millis if time_millis is not None else 
int(time.time() * 1000)
 
     def to_schema(self) -> Schema:
-        try:
-            pa_fields = []
-            for field in self.fields:
-                pa_fields.append(field.to_pyarrow_field())
-            pyarrow.schema(pa_fields)
-        except Exception as e:
-            print(e)
         return Schema(
             fields=self.fields,
             partition_keys=self.partition_keys,
@@ -133,12 +124,10 @@ class TableSchema:
     @staticmethod
     def from_schema(schema_id: int, schema: Schema) -> "TableSchema":
         fields: List[DataField] = schema.fields
-        if not schema.fields:
-            fields = 
data_types.parse_data_fields_from_pyarrow_schema(schema.pa_schema)
         partition_keys: List[str] = schema.partition_keys
         primary_keys: List[str] = schema.primary_keys
         options: Dict[str, str] = schema.options
-        highest_field_id: int = None  # max(field.id for field in fields)
+        highest_field_id: int = max(field.id for field in fields)
 
         return TableSchema(
             TableSchema.CURRENT_VERSION,
diff --git a/paimon-python/pypaimon/tests/schema_test.py 
b/paimon-python/pypaimon/tests/schema_test.py
new file mode 100644
index 0000000000..4064069ebd
--- /dev/null
+++ b/paimon-python/pypaimon/tests/schema_test.py
@@ -0,0 +1,59 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import unittest
+
+import pyarrow
+
+from pypaimon import Schema
+from pypaimon.schema.table_schema import TableSchema
+from pypaimon.schema.data_types import AtomicType, ArrayType, MapType, 
DataField, PyarrowFieldParse
+
+
+class SchemaTestCase(unittest.TestCase):
+    def test_types(self):
+        data_fields = [
+            DataField(0, "name", AtomicType('INT'), 'desc  name'),
+            DataField(1, "arr", ArrayType(True, AtomicType('INT')), 'desc 
arr1'),
+            DataField(2, "map1",
+                      MapType(False, AtomicType('INT'), MapType(False, 
AtomicType('INT'), AtomicType('INT'))),
+                      'desc map1'),
+        ]
+        table_schema = TableSchema(TableSchema.CURRENT_VERSION, 
len(data_fields), data_fields,
+                                   max(field.id for field in data_fields),
+                                   [], [], {}, "")
+        pa_fields = []
+        for field in table_schema.fields:
+            pa_field = PyarrowFieldParse.to_pyarrow_field(field)
+            pa_fields.append(pa_field)
+        schema = Schema.build_from_pyarrow_schema(
+            pa_schema=pyarrow.schema(pa_fields),
+            partition_keys=table_schema.partition_keys,
+            primary_keys=table_schema.primary_keys,
+            options=table_schema.options,
+            comment=table_schema.comment
+        )
+        table_schema2 = TableSchema.from_schema(len(data_fields), schema)
+        # print("table_schema2:", table_schema2)
+        l1 = []
+        for field in table_schema.fields:
+            l1.append(field.to_dict())
+        l2 = []
+        for field in table_schema2.fields:
+            l2.append(field.to_dict())
+        self.assertEqual(l1, l2)
diff --git a/paimon-python/pypaimon/write/row_key_extractor.py 
b/paimon-python/pypaimon/write/row_key_extractor.py
index cda3ad07ba..13a755c53c 100644
--- a/paimon-python/pypaimon/write/row_key_extractor.py
+++ b/paimon-python/pypaimon/write/row_key_extractor.py
@@ -58,7 +58,6 @@ class RowKeyExtractor(ABC):
     @abstractmethod
     def _extract_buckets_batch(self, table: pa.RecordBatch) -> List[int]:
         """Extract bucket numbers for all rows. Must be implemented by 
subclasses."""
-        pass
 
 
 class FixedBucketRowKeyExtractor(RowKeyExtractor):

Reply via email to