This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-python.git


The following commit(s) were added to refs/heads/main by this push:
     new f09dc58  Refactor ReadBuilder#with_projection to accept field names 
for better using (#27)
f09dc58 is described below

commit f09dc5887fdcab90b7bcac361e54414fab089fc9
Author: yuzelin <[email protected]>
AuthorDate: Mon Nov 25 20:40:20 2024 +0800

    Refactor ReadBuilder#with_projection to accept field names for better using 
(#27)
---
 paimon_python_api/read_builder.py               |  2 +-
 paimon_python_java/pypaimon.py                  | 43 ++++++++---------
 paimon_python_java/tests/test_write_and_read.py | 62 +++++++++++++++++++++++++
 paimon_python_java/util/java_utils.py           |  9 ++++
 4 files changed, 92 insertions(+), 24 deletions(-)

diff --git a/paimon_python_api/read_builder.py 
b/paimon_python_api/read_builder.py
index ad5e6d6..a031a05 100644
--- a/paimon_python_api/read_builder.py
+++ b/paimon_python_api/read_builder.py
@@ -32,7 +32,7 @@ class ReadBuilder(ABC):
         """
 
     @abstractmethod
-    def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
+    def with_projection(self, projection: List[str]) -> 'ReadBuilder':
         """Push nested projection."""
 
     @abstractmethod
diff --git a/paimon_python_java/pypaimon.py b/paimon_python_java/pypaimon.py
index 16c7a69..b884fa4 100644
--- a/paimon_python_java/pypaimon.py
+++ b/paimon_python_java/pypaimon.py
@@ -61,37 +61,36 @@ class Table(table.Table):
     def __init__(self, j_table, catalog_options: dict):
         self._j_table = j_table
         self._catalog_options = catalog_options
-        # init arrow schema
-        schema_bytes = 
get_gateway().jvm.SchemaUtil.getArrowSchema(j_table.rowType())
-        schema_reader = 
pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
-        self._arrow_schema = schema_reader.schema
-        schema_reader.close()
 
     def new_read_builder(self) -> 'ReadBuilder':
         j_read_builder = 
get_gateway().jvm.InvocationUtil.getReadBuilder(self._j_table)
-        return ReadBuilder(
-            j_read_builder, self._j_table.rowType(), self._catalog_options, 
self._arrow_schema)
+        return ReadBuilder(j_read_builder, self._j_table.rowType(), 
self._catalog_options)
 
     def new_batch_write_builder(self) -> 'BatchWriteBuilder':
         java_utils.check_batch_write(self._j_table)
         j_batch_write_builder = 
get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table)
-        return BatchWriteBuilder(j_batch_write_builder, 
self._j_table.rowType(), self._arrow_schema)
+        return BatchWriteBuilder(j_batch_write_builder)
 
 
 class ReadBuilder(read_builder.ReadBuilder):
 
-    def __init__(self, j_read_builder, j_row_type, catalog_options: dict, 
arrow_schema: pa.Schema):
+    def __init__(self, j_read_builder, j_row_type, catalog_options: dict):
         self._j_read_builder = j_read_builder
         self._j_row_type = j_row_type
         self._catalog_options = catalog_options
-        self._arrow_schema = arrow_schema
 
     def with_filter(self, predicate: 'Predicate'):
         self._j_read_builder.withFilter(predicate.to_j_predicate())
         return self
 
-    def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
-        self._j_read_builder.withProjection(projection)
+    def with_projection(self, projection: List[str]) -> 'ReadBuilder':
+        field_names = list(map(lambda field: field.name(), 
self._j_row_type.getFields()))
+        int_projection = list(map(lambda p: field_names.index(p), projection))
+        gateway = get_gateway()
+        int_projection_arr = gateway.new_array(gateway.jvm.int, 
len(projection))
+        for i in range(len(projection)):
+            int_projection_arr[i] = int_projection[i]
+        self._j_read_builder.withProjection(int_projection_arr)
         return self
 
     def with_limit(self, limit: int) -> 'ReadBuilder':
@@ -104,7 +103,7 @@ class ReadBuilder(read_builder.ReadBuilder):
 
     def new_read(self) -> 'TableRead':
         j_table_read = self._j_read_builder.newRead().executeFilter()
-        return TableRead(j_table_read, self._j_row_type, 
self._catalog_options, self._arrow_schema)
+        return TableRead(j_table_read, self._j_read_builder.readType(), 
self._catalog_options)
 
     def new_predicate_builder(self) -> 'PredicateBuilder':
         return PredicateBuilder(self._j_row_type)
@@ -141,12 +140,12 @@ class Split(split.Split):
 
 class TableRead(table_read.TableRead):
 
-    def __init__(self, j_table_read, j_row_type, catalog_options, 
arrow_schema):
+    def __init__(self, j_table_read, j_read_type, catalog_options):
         self._j_table_read = j_table_read
-        self._j_row_type = j_row_type
+        self._j_read_type = j_read_type
         self._catalog_options = catalog_options
         self._j_bytes_reader = None
-        self._arrow_schema = arrow_schema
+        self._arrow_schema = java_utils.to_arrow_schema(j_read_type)
 
     def to_arrow(self, splits):
         record_batch_reader = self.to_arrow_batch_reader(splits)
@@ -174,7 +173,7 @@ class TableRead(table_read.TableRead):
             if max_workers <= 0:
                 raise ValueError("max_workers must be greater than 0")
             self._j_bytes_reader = 
get_gateway().jvm.InvocationUtil.createParallelBytesReader(
-                self._j_table_read, self._j_row_type, max_workers)
+                self._j_table_read, self._j_read_type, max_workers)
 
     def _batch_generator(self) -> Iterator[pa.RecordBatch]:
         while True:
@@ -188,10 +187,8 @@ class TableRead(table_read.TableRead):
 
 class BatchWriteBuilder(write_builder.BatchWriteBuilder):
 
-    def __init__(self, j_batch_write_builder, j_row_type, arrow_schema: 
pa.Schema):
+    def __init__(self, j_batch_write_builder):
         self._j_batch_write_builder = j_batch_write_builder
-        self._j_row_type = j_row_type
-        self._arrow_schema = arrow_schema
 
     def overwrite(self, static_partition: Optional[dict] = None) -> 
'BatchWriteBuilder':
         if static_partition is None:
@@ -201,7 +198,7 @@ class BatchWriteBuilder(write_builder.BatchWriteBuilder):
 
     def new_write(self) -> 'BatchTableWrite':
         j_batch_table_write = self._j_batch_write_builder.newWrite()
-        return BatchTableWrite(j_batch_table_write, self._j_row_type, 
self._arrow_schema)
+        return BatchTableWrite(j_batch_table_write, 
self._j_batch_write_builder.rowType())
 
     def new_commit(self) -> 'BatchTableCommit':
         j_batch_table_commit = self._j_batch_write_builder.newCommit()
@@ -210,11 +207,11 @@ class BatchWriteBuilder(write_builder.BatchWriteBuilder):
 
 class BatchTableWrite(table_write.BatchTableWrite):
 
-    def __init__(self, j_batch_table_write, j_row_type, arrow_schema: 
pa.Schema):
+    def __init__(self, j_batch_table_write, j_row_type):
         self._j_batch_table_write = j_batch_table_write
         self._j_bytes_writer = 
get_gateway().jvm.InvocationUtil.createBytesWriter(
             j_batch_table_write, j_row_type)
-        self._arrow_schema = arrow_schema
+        self._arrow_schema = java_utils.to_arrow_schema(j_row_type)
 
     def write_arrow(self, table):
         for record_batch in table.to_reader():
diff --git a/paimon_python_java/tests/test_write_and_read.py 
b/paimon_python_java/tests/test_write_and_read.py
index b468e9f..337b9f5 100644
--- a/paimon_python_java/tests/test_write_and_read.py
+++ b/paimon_python_java/tests/test_write_and_read.py
@@ -445,3 +445,65 @@ class TableWriteReadTest(unittest.TestCase):
         df['f0'] = df['f0'].astype('int32')
         pd.testing.assert_frame_equal(
             actual_df.reset_index(drop=True), df.reset_index(drop=True))
+
+    def testProjection(self):
+        pa_schema = pa.schema([
+            ('f0', pa.int64()),
+            ('f1', pa.string()),
+            ('f2', pa.bool_()),
+            ('f3', pa.string())
+        ])
+        schema = Schema(pa_schema)
+        self.catalog.create_table('default.test_projection', schema, False)
+        table = self.catalog.get_table('default.test_projection')
+
+        # prepare data
+        data = {
+            'f0': [1, 2, 3],
+            'f1': ['a', 'b', 'c'],
+            'f2': [True, True, False],
+            'f3': ['A', 'B', 'C']
+        }
+        df = pd.DataFrame(data)
+
+        # write and commit data
+        write_builder = table.new_batch_write_builder()
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+
+        table_write.write_pandas(df)
+        commit_messages = table_write.prepare_commit()
+        table_commit.commit(commit_messages)
+
+        table_write.close()
+        table_commit.close()
+
+        # case 1: read empty
+        read_builder = table.new_read_builder().with_projection([])
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        result1 = table_read.to_pandas(splits)
+        self.assertTrue(result1.empty)
+
+        # case 2: read fully
+        read_builder = table.new_read_builder().with_projection(['f0', 'f1', 
'f2', 'f3'])
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        result2 = table_read.to_pandas(splits)
+        pd.testing.assert_frame_equal(
+            result2.reset_index(drop=True), df.reset_index(drop=True))
+
+        # case 3: read partially
+        read_builder = table.new_read_builder().with_projection(['f3', 'f2'])
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        result3 = table_read.to_pandas(splits)
+        expected_df = pd.DataFrame({
+            'f3': ['A', 'B', 'C'],
+            'f2': [True, True, False]
+        })
+        pd.testing.assert_frame_equal(
+            result3.reset_index(drop=True), expected_df.reset_index(drop=True))
diff --git a/paimon_python_java/util/java_utils.py 
b/paimon_python_java/util/java_utils.py
index 8c4f276..ce0404a 100644
--- a/paimon_python_java/util/java_utils.py
+++ b/paimon_python_java/util/java_utils.py
@@ -91,3 +91,12 @@ def _to_j_type(name, pa_type):
         return jvm.DataTypes.STRING()
     else:
         raise ValueError(f'Found unsupported data type {str(pa_type)} for 
field {name}.')
+
+
+def to_arrow_schema(j_row_type):
+    # init arrow schema
+    schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_row_type)
+    schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
+    arrow_schema = schema_reader.schema
+    schema_reader.close()
+    return arrow_schema

Reply via email to