This is an automated email from the ASF dual-hosted git repository.
yuzelin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-python.git
The following commit(s) were added to refs/heads/main by this push:
new 0eb79a8 #44 Make Split and Predicate Serializable (#45)
0eb79a8 is described below
commit 0eb79a8b38125534c0a3b08531a7161bcd45ad3a
Author: ChengHui Chen <[email protected]>
AuthorDate: Thu Mar 20 13:30:30 2025 +0800
#44 Make Split and Predicate Serializable (#45)
---
pypaimon/py4j/java_implementation.py | 53 ++++++++++++++++++++++--------------
pypaimon/py4j/util/java_utils.py | 17 ++++++++++++
2 files changed, 49 insertions(+), 21 deletions(-)
diff --git a/pypaimon/py4j/java_implementation.py
b/pypaimon/py4j/java_implementation.py
index 9f378b7..9a13037 100644
--- a/pypaimon/py4j/java_implementation.py
+++ b/pypaimon/py4j/java_implementation.py
@@ -23,6 +23,7 @@ import pyarrow as pa
from pypaimon.py4j.java_gateway import get_gateway
from pypaimon.py4j.util import java_utils, constants
+from pypaimon.py4j.util.java_utils import serialize_java_object,
deserialize_java_object
from pypaimon.api import \
(catalog, table, read_builder, table_scan, split, row_type,
table_read, write_builder, table_write, commit_message,
@@ -145,33 +146,41 @@ class Plan(table_scan.Plan):
self._j_splits = j_splits
def splits(self) -> List['Split']:
- return list(map(lambda s: Split(s), self._j_splits))
+ return list(map(lambda s: self._build_single_split(s), self._j_splits))
+
+ def _build_single_split(self, j_split) -> 'Split':
+ j_split_bytes = serialize_java_object(j_split)
+ row_count = j_split.rowCount()
+ files_optional = j_split.convertToRawFiles()
+ if not files_optional.isPresent():
+ file_size = 0
+ file_paths = []
+ else:
+ files = files_optional.get()
+ file_size = sum(file.length() for file in files)
+ file_paths = [file.path() for file in files]
+ return Split(j_split_bytes, row_count, file_size, file_paths)
class Split(split.Split):
- def __init__(self, j_split):
- self._j_split = j_split
+ def __init__(self, j_split_bytes, row_count: int, file_size: int,
file_paths: List[str]):
+ self._j_split_bytes = j_split_bytes
+ self._row_count = row_count
+ self._file_size = file_size
+ self._file_paths = file_paths
def to_j_split(self):
- return self._j_split
+ return deserialize_java_object(self._j_split_bytes)
def row_count(self) -> int:
- return self._j_split.rowCount()
+ return self._row_count
def file_size(self) -> int:
- files_optional = self._j_split.convertToRawFiles()
- if not files_optional.isPresent():
- return 0
- files = files_optional.get()
- return sum(file.length() for file in files)
+ return self._file_size
def file_paths(self) -> List[str]:
- files_optional = self._j_split.convertToRawFiles()
- if not files_optional.isPresent():
- return []
- files = files_optional.get()
- return [file.path() for file in files]
+ return self._file_paths
class TableRead(table_read.TableRead):
@@ -317,11 +326,11 @@ class BatchTableCommit(table_commit.BatchTableCommit):
class Predicate(predicate.Predicate):
- def __init__(self, j_predicate):
- self._j_predicate = j_predicate
+ def __init__(self, j_predicate_bytes):
+ self._j_predicate_bytes = j_predicate_bytes
def to_j_predicate(self):
- return self._j_predicate
+ return deserialize_java_object(self._j_predicate_bytes)
class PredicateBuilder(predicate.PredicateBuilder):
@@ -350,7 +359,7 @@ class PredicateBuilder(predicate.PredicateBuilder):
index,
literals
)
- return Predicate(j_predicate)
+ return Predicate(serialize_java_object(j_predicate))
def equal(self, field: str, literal: Any) -> Predicate:
return self._build('equal', field, [literal])
@@ -397,8 +406,10 @@ class PredicateBuilder(predicate.PredicateBuilder):
def and_predicates(self, predicates: List[Predicate]) -> Predicate:
predicates = list(map(lambda p: p.to_j_predicate(), predicates))
- return
Predicate(get_gateway().jvm.PredicationUtil.buildAnd(predicates))
+ j_predicate = get_gateway().jvm.PredicationUtil.buildAnd(predicates)
+ return Predicate(serialize_java_object(j_predicate))
def or_predicates(self, predicates: List[Predicate]) -> Predicate:
predicates = list(map(lambda p: p.to_j_predicate(), predicates))
- return Predicate(get_gateway().jvm.PredicationUtil.buildOr(predicates))
+ j_predicate = get_gateway().jvm.PredicationUtil.buildOr(predicates)
+ return Predicate(serialize_java_object(j_predicate))
diff --git a/pypaimon/py4j/util/java_utils.py b/pypaimon/py4j/util/java_utils.py
index 0beb527..2a2aac9 100644
--- a/pypaimon/py4j/util/java_utils.py
+++ b/pypaimon/py4j/util/java_utils.py
@@ -100,3 +100,20 @@ def to_arrow_schema(j_row_type):
arrow_schema = schema_reader.schema
schema_reader.close()
return arrow_schema
+
+
+def serialize_java_object(java_obj) -> bytes:
+ gateway = get_gateway()
+ util = gateway.jvm.org.apache.paimon.utils.InstantiationUtil
+ try:
+ java_bytes = util.serializeObject(java_obj)
+ return bytes(java_bytes)
+ except Exception as e:
+ raise RuntimeError(f"Java serialization failed: {e}")
+
+
+def deserialize_java_object(bytes_data):
+ gateway = get_gateway()
+ cl = get_gateway().jvm.Thread.currentThread().getContextClassLoader()
+ util = gateway.jvm.org.apache.paimon.utils.InstantiationUtil
+ return util.deserializeObject(bytes_data, cl)