This is an automated email from the ASF dual-hosted git repository.

yuzelin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 0eb79a8  #44 Make Split and Predicate Serializable (#45)
0eb79a8 is described below

commit 0eb79a8b38125534c0a3b08531a7161bcd45ad3a
Author: ChengHui Chen <[email protected]>
AuthorDate: Thu Mar 20 13:30:30 2025 +0800

    #44 Make Split and Predicate Serializable (#45)
---
 pypaimon/py4j/java_implementation.py | 53 ++++++++++++++++++++++--------------
 pypaimon/py4j/util/java_utils.py     | 17 ++++++++++++
 2 files changed, 49 insertions(+), 21 deletions(-)

diff --git a/pypaimon/py4j/java_implementation.py 
b/pypaimon/py4j/java_implementation.py
index 9f378b7..9a13037 100644
--- a/pypaimon/py4j/java_implementation.py
+++ b/pypaimon/py4j/java_implementation.py
@@ -23,6 +23,7 @@ import pyarrow as pa
 
 from pypaimon.py4j.java_gateway import get_gateway
 from pypaimon.py4j.util import java_utils, constants
+from pypaimon.py4j.util.java_utils import serialize_java_object, 
deserialize_java_object
 from pypaimon.api import \
     (catalog, table, read_builder, table_scan, split, row_type,
      table_read, write_builder, table_write, commit_message,
@@ -145,33 +146,41 @@ class Plan(table_scan.Plan):
         self._j_splits = j_splits
 
     def splits(self) -> List['Split']:
-        return list(map(lambda s: Split(s), self._j_splits))
+        return list(map(lambda s: self._build_single_split(s), self._j_splits))
+
+    def _build_single_split(self, j_split) -> 'Split':
+        j_split_bytes = serialize_java_object(j_split)
+        row_count = j_split.rowCount()
+        files_optional = j_split.convertToRawFiles()
+        if not files_optional.isPresent():
+            file_size = 0
+            file_paths = []
+        else:
+            files = files_optional.get()
+            file_size = sum(file.length() for file in files)
+            file_paths = [file.path() for file in files]
+        return Split(j_split_bytes, row_count, file_size, file_paths)
 
 
 class Split(split.Split):
 
-    def __init__(self, j_split):
-        self._j_split = j_split
+    def __init__(self, j_split_bytes, row_count: int, file_size: int, 
file_paths: List[str]):
+        self._j_split_bytes = j_split_bytes
+        self._row_count = row_count
+        self._file_size = file_size
+        self._file_paths = file_paths
 
     def to_j_split(self):
-        return self._j_split
+        return deserialize_java_object(self._j_split_bytes)
 
     def row_count(self) -> int:
-        return self._j_split.rowCount()
+        return self._row_count
 
     def file_size(self) -> int:
-        files_optional = self._j_split.convertToRawFiles()
-        if not files_optional.isPresent():
-            return 0
-        files = files_optional.get()
-        return sum(file.length() for file in files)
+        return self._file_size
 
     def file_paths(self) -> List[str]:
-        files_optional = self._j_split.convertToRawFiles()
-        if not files_optional.isPresent():
-            return []
-        files = files_optional.get()
-        return [file.path() for file in files]
+        return self._file_paths
 
 
 class TableRead(table_read.TableRead):
@@ -317,11 +326,11 @@ class BatchTableCommit(table_commit.BatchTableCommit):
 
 class Predicate(predicate.Predicate):
 
-    def __init__(self, j_predicate):
-        self._j_predicate = j_predicate
+    def __init__(self, j_predicate_bytes):
+        self._j_predicate_bytes = j_predicate_bytes
 
     def to_j_predicate(self):
-        return self._j_predicate
+        return deserialize_java_object(self._j_predicate_bytes)
 
 
 class PredicateBuilder(predicate.PredicateBuilder):
@@ -350,7 +359,7 @@ class PredicateBuilder(predicate.PredicateBuilder):
             index,
             literals
         )
-        return Predicate(j_predicate)
+        return Predicate(serialize_java_object(j_predicate))
 
     def equal(self, field: str, literal: Any) -> Predicate:
         return self._build('equal', field, [literal])
@@ -397,8 +406,10 @@ class PredicateBuilder(predicate.PredicateBuilder):
 
     def and_predicates(self, predicates: List[Predicate]) -> Predicate:
         predicates = list(map(lambda p: p.to_j_predicate(), predicates))
-        return 
Predicate(get_gateway().jvm.PredicationUtil.buildAnd(predicates))
+        j_predicate = get_gateway().jvm.PredicationUtil.buildAnd(predicates)
+        return Predicate(serialize_java_object(j_predicate))
 
     def or_predicates(self, predicates: List[Predicate]) -> Predicate:
         predicates = list(map(lambda p: p.to_j_predicate(), predicates))
-        return Predicate(get_gateway().jvm.PredicationUtil.buildOr(predicates))
+        j_predicate = get_gateway().jvm.PredicationUtil.buildOr(predicates)
+        return Predicate(serialize_java_object(j_predicate))
diff --git a/pypaimon/py4j/util/java_utils.py b/pypaimon/py4j/util/java_utils.py
index 0beb527..2a2aac9 100644
--- a/pypaimon/py4j/util/java_utils.py
+++ b/pypaimon/py4j/util/java_utils.py
@@ -100,3 +100,20 @@ def to_arrow_schema(j_row_type):
     arrow_schema = schema_reader.schema
     schema_reader.close()
     return arrow_schema
+
+
+def serialize_java_object(java_obj) -> bytes:
+    gateway = get_gateway()
+    util = gateway.jvm.org.apache.paimon.utils.InstantiationUtil
+    try:
+        java_bytes = util.serializeObject(java_obj)
+        return bytes(java_bytes)
+    except Exception as e:
+        raise RuntimeError(f"Java serialization failed: {e}")
+
+
+def deserialize_java_object(bytes_data):
+    gateway = get_gateway()
+    cl = get_gateway().jvm.Thread.currentThread().getContextClassLoader()
+    util = gateway.jvm.org.apache.paimon.utils.InstantiationUtil
+    return util.deserializeObject(bytes_data, cl)

Reply via email to