This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 77f130bde0 [Python] Predicate Push Down for Scan / Read (#6166)
77f130bde0 is described below
commit 77f130bde03381e3b26b86e4f0e95cff5c1dc622
Author: ChengHui Chen <[email protected]>
AuthorDate: Fri Aug 29 16:34:24 2025 +0800
[Python] Predicate Push Down for Scan / Read (#6166)
---
paimon-python/pypaimon/common/predicate.py | 103 ++++++++++++-
.../pypaimon/manifest/manifest_file_manager.py | 32 ++--
.../pypaimon/manifest/manifest_list_manager.py | 12 +-
.../pypaimon/manifest/schema/simple_stats.py | 6 +-
paimon-python/pypaimon/read/push_down_utils.py | 72 +++++++++
.../pypaimon/read/reader/format_avro_reader.py | 24 +--
.../pypaimon/read/reader/format_pyarrow_reader.py | 50 +------
paimon-python/pypaimon/read/split_read.py | 12 +-
paimon-python/pypaimon/read/table_read.py | 35 ++++-
paimon-python/pypaimon/read/table_scan.py | 161 +++++++--------------
.../pypaimon/tests/predicate_push_down_test.py | 151 +++++++++++++++++++
paimon-python/pypaimon/write/file_store_commit.py | 6 +-
paimon-python/pypaimon/write/writer/data_writer.py | 34 ++---
13 files changed, 471 insertions(+), 227 deletions(-)
diff --git a/paimon-python/pypaimon/common/predicate.py
b/paimon-python/pypaimon/common/predicate.py
index ee13aca99b..ba56713032 100644
--- a/paimon-python/pypaimon/common/predicate.py
+++ b/paimon-python/pypaimon/common/predicate.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from dataclasses import dataclass
from functools import reduce
-from typing import Any, List, Optional
+from typing import Any, Dict, List, Optional
import pyarrow
from pyarrow import compute as pyarrow_compute
@@ -82,6 +82,107 @@ class Predicate:
else:
raise ValueError("Unsupported predicate method:
{}".format(self.method))
+ def test_by_value(self, value: Any) -> bool:
+ if self.method == 'and':
+ return all(p.test_by_value(value) for p in self.literals)
+ if self.method == 'or':
+ t = any(p.test_by_value(value) for p in self.literals)
+ return t
+
+ if self.method == 'equal':
+ return value == self.literals[0]
+ if self.method == 'notEqual':
+ return value != self.literals[0]
+ if self.method == 'lessThan':
+ return value < self.literals[0]
+ if self.method == 'lessOrEqual':
+ return value <= self.literals[0]
+ if self.method == 'greaterThan':
+ return value > self.literals[0]
+ if self.method == 'greaterOrEqual':
+ return value >= self.literals[0]
+ if self.method == 'isNull':
+ return value is None
+ if self.method == 'isNotNull':
+ return value is not None
+ if self.method == 'startsWith':
+ if not isinstance(value, str):
+ return False
+ return value.startswith(self.literals[0])
+ if self.method == 'endsWith':
+ if not isinstance(value, str):
+ return False
+ return value.endswith(self.literals[0])
+ if self.method == 'contains':
+ if not isinstance(value, str):
+ return False
+ return self.literals[0] in value
+ if self.method == 'in':
+ return value in self.literals
+ if self.method == 'notIn':
+ return value not in self.literals
+ if self.method == 'between':
+ return self.literals[0] <= value <= self.literals[1]
+
+ raise ValueError("Unsupported predicate method:
{}".format(self.method))
+
+ def test_by_stats(self, stat: Dict) -> bool:
+ if self.method == 'and':
+ return all(p.test_by_stats(stat) for p in self.literals)
+ if self.method == 'or':
+ t = any(p.test_by_stats(stat) for p in self.literals)
+ return t
+
+ null_count = stat["null_counts"][self.field]
+ row_count = stat["row_count"]
+
+ if self.method == 'isNull':
+ return null_count is not None and null_count > 0
+ if self.method == 'isNotNull':
+ return null_count is None or row_count is None or null_count <
row_count
+
+ min_value = stat["min_values"][self.field]
+ max_value = stat["max_values"][self.field]
+
+ if min_value is None or max_value is None or (null_count is not None
and null_count == row_count):
+ return False
+
+ if self.method == 'equal':
+ return min_value <= self.literals[0] <= max_value
+ if self.method == 'notEqual':
+ return not (min_value == self.literals[0] == max_value)
+ if self.method == 'lessThan':
+ return self.literals[0] > min_value
+ if self.method == 'lessOrEqual':
+ return self.literals[0] >= min_value
+ if self.method == 'greaterThan':
+ return self.literals[0] < max_value
+ if self.method == 'greaterOrEqual':
+ return self.literals[0] <= max_value
+ if self.method == 'startsWith':
+ if not isinstance(min_value, str) or not isinstance(max_value,
str):
+ raise RuntimeError("startsWith predicate on non-str field")
+ return ((min_value.startswith(self.literals[0]) or min_value <
self.literals[0])
+ and (max_value.startswith(self.literals[0]) or max_value >
self.literals[0]))
+ if self.method == 'endsWith':
+ return True
+ if self.method == 'contains':
+ return True
+ if self.method == 'in':
+ for literal in self.literals:
+ if min_value <= literal <= max_value:
+ return True
+ return False
+ if self.method == 'notIn':
+ for literal in self.literals:
+ if min_value == literal == max_value:
+ return False
+ return True
+ if self.method == 'between':
+ return self.literals[0] <= max_value and self.literals[1] >=
min_value
+ else:
+ raise ValueError("Unsupported predicate method:
{}".format(self.method))
+
def to_arrow(self) -> pyarrow_compute.Expression | bool:
if self.method == 'equal':
return pyarrow_dataset.field(self.field) == self.literals[0]
diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py
b/paimon-python/pypaimon/manifest/manifest_file_manager.py
index 7c46b368d2..7c97f7b0ca 100644
--- a/paimon-python/pypaimon/manifest/manifest_file_manager.py
+++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py
@@ -55,19 +55,19 @@ class ManifestFileManager:
file_dict = dict(record['_FILE'])
key_dict = dict(file_dict['_KEY_STATS'])
key_stats = SimpleStats(
-
min_value=BinaryRowDeserializer.from_bytes(key_dict['_MIN_VALUES'],
-
self.trimmed_primary_key_fields),
-
max_value=BinaryRowDeserializer.from_bytes(key_dict['_MAX_VALUES'],
-
self.trimmed_primary_key_fields),
- null_count=key_dict['_NULL_COUNTS'],
+
min_values=BinaryRowDeserializer.from_bytes(key_dict['_MIN_VALUES'],
+
self.trimmed_primary_key_fields),
+
max_values=BinaryRowDeserializer.from_bytes(key_dict['_MAX_VALUES'],
+
self.trimmed_primary_key_fields),
+ null_counts=key_dict['_NULL_COUNTS'],
)
value_dict = dict(file_dict['_VALUE_STATS'])
value_stats = SimpleStats(
-
min_value=BinaryRowDeserializer.from_bytes(value_dict['_MIN_VALUES'],
-
self.table.table_schema.fields),
-
max_value=BinaryRowDeserializer.from_bytes(value_dict['_MAX_VALUES'],
-
self.table.table_schema.fields),
- null_count=value_dict['_NULL_COUNTS'],
+
min_values=BinaryRowDeserializer.from_bytes(value_dict['_MIN_VALUES'],
+
self.table.table_schema.fields),
+
max_values=BinaryRowDeserializer.from_bytes(value_dict['_MAX_VALUES'],
+
self.table.table_schema.fields),
+ null_counts=value_dict['_NULL_COUNTS'],
)
file_meta = DataFileMeta(
file_name=file_dict['_FILE_NAME'],
@@ -118,14 +118,14 @@ class ManifestFileManager:
"_MIN_KEY": BinaryRowSerializer.to_bytes(file.min_key),
"_MAX_KEY": BinaryRowSerializer.to_bytes(file.max_key),
"_KEY_STATS": {
- "_MIN_VALUES":
BinaryRowSerializer.to_bytes(file.key_stats.min_value),
- "_MAX_VALUES":
BinaryRowSerializer.to_bytes(file.key_stats.max_value),
- "_NULL_COUNTS": file.key_stats.null_count,
+ "_MIN_VALUES":
BinaryRowSerializer.to_bytes(file.key_stats.min_values),
+ "_MAX_VALUES":
BinaryRowSerializer.to_bytes(file.key_stats.max_values),
+ "_NULL_COUNTS": file.key_stats.null_counts,
},
"_VALUE_STATS": {
- "_MIN_VALUES":
BinaryRowSerializer.to_bytes(file.value_stats.min_value),
- "_MAX_VALUES":
BinaryRowSerializer.to_bytes(file.value_stats.max_value),
- "_NULL_COUNTS": file.value_stats.null_count,
+ "_MIN_VALUES":
BinaryRowSerializer.to_bytes(file.value_stats.min_values),
+ "_MAX_VALUES":
BinaryRowSerializer.to_bytes(file.value_stats.max_values),
+ "_NULL_COUNTS": file.value_stats.null_counts,
},
"_MIN_SEQUENCE_NUMBER": file.min_sequence_number,
"_MAX_SEQUENCE_NUMBER": file.max_sequence_number,
diff --git a/paimon-python/pypaimon/manifest/manifest_list_manager.py
b/paimon-python/pypaimon/manifest/manifest_list_manager.py
index 65fd2b21ac..dc9d5db44d 100644
--- a/paimon-python/pypaimon/manifest/manifest_list_manager.py
+++ b/paimon-python/pypaimon/manifest/manifest_list_manager.py
@@ -58,15 +58,15 @@ class ManifestListManager:
for record in reader:
stats_dict = dict(record['_PARTITION_STATS'])
partition_stats = SimpleStats(
- min_value=BinaryRowDeserializer.from_bytes(
+ min_values=BinaryRowDeserializer.from_bytes(
stats_dict['_MIN_VALUES'],
self.table.table_schema.get_partition_key_fields()
),
- max_value=BinaryRowDeserializer.from_bytes(
+ max_values=BinaryRowDeserializer.from_bytes(
stats_dict['_MAX_VALUES'],
self.table.table_schema.get_partition_key_fields()
),
- null_count=stats_dict['_NULL_COUNTS'],
+ null_counts=stats_dict['_NULL_COUNTS'],
)
manifest_file_meta = ManifestFileMeta(
file_name=record['_FILE_NAME'],
@@ -90,9 +90,9 @@ class ManifestListManager:
"_NUM_ADDED_FILES": meta.num_added_files,
"_NUM_DELETED_FILES": meta.num_deleted_files,
"_PARTITION_STATS": {
- "_MIN_VALUES":
BinaryRowSerializer.to_bytes(meta.partition_stats.min_value),
- "_MAX_VALUES":
BinaryRowSerializer.to_bytes(meta.partition_stats.max_value),
- "_NULL_COUNTS": meta.partition_stats.null_count,
+ "_MIN_VALUES":
BinaryRowSerializer.to_bytes(meta.partition_stats.min_values),
+ "_MAX_VALUES":
BinaryRowSerializer.to_bytes(meta.partition_stats.max_values),
+ "_NULL_COUNTS": meta.partition_stats.null_counts,
},
"_SCHEMA_ID": meta.schema_id,
}
diff --git a/paimon-python/pypaimon/manifest/schema/simple_stats.py
b/paimon-python/pypaimon/manifest/schema/simple_stats.py
index 4a73d3eee4..55b2163e76 100644
--- a/paimon-python/pypaimon/manifest/schema/simple_stats.py
+++ b/paimon-python/pypaimon/manifest/schema/simple_stats.py
@@ -24,9 +24,9 @@ from pypaimon.table.row.binary_row import BinaryRow
@dataclass
class SimpleStats:
- min_value: BinaryRow
- max_value: BinaryRow
- null_count: Optional[List[int]]
+ min_values: BinaryRow
+ max_values: BinaryRow
+ null_counts: Optional[List[int]]
SIMPLE_STATS_SCHEMA = {
diff --git a/paimon-python/pypaimon/read/push_down_utils.py
b/paimon-python/pypaimon/read/push_down_utils.py
new file mode 100644
index 0000000000..31e66973c6
--- /dev/null
+++ b/paimon-python/pypaimon/read/push_down_utils.py
@@ -0,0 +1,72 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from typing import Dict, List, Set
+
+from pypaimon.common.predicate import Predicate
+
+
+def extract_predicate_to_list(result: list, input_predicate: 'Predicate',
keys: List[str]):
+ if not input_predicate or not keys:
+ return
+
+ if input_predicate.method == 'and':
+ for sub_predicate in input_predicate.literals:
+ extract_predicate_to_list(result, sub_predicate, keys)
+ return
+ elif input_predicate.method == 'or':
+ # condition: involved keys all belong to primary keys
+ involved_fields = _get_all_fields(input_predicate)
+ if involved_fields and involved_fields.issubset(keys):
+ result.append(input_predicate)
+ return
+
+ if input_predicate.field in keys:
+ result.append(input_predicate)
+
+
+def _get_all_fields(predicate: 'Predicate') -> Set[str]:
+ if predicate.field is not None:
+ return {predicate.field}
+ involved_fields = set()
+ if predicate.literals:
+ for sub_predicate in predicate.literals:
+ involved_fields.update(_get_all_fields(sub_predicate))
+ return involved_fields
+
+
+def extract_predicate_to_dict(result: Dict, input_predicate: 'Predicate',
keys: List[str]):
+ if not input_predicate or not keys:
+ return
+
+ if input_predicate.method == 'and':
+ for sub_predicate in input_predicate.literals:
+ extract_predicate_to_dict(result, sub_predicate, keys)
+ return
+ elif input_predicate.method == 'or':
+ # ensure no recursive and/or
+ if not input_predicate.literals or any(p.field is None for p in
input_predicate.literals):
+ return
+ # condition: only one key for 'or', and the key belongs to keys
+ involved_fields = {p.field for p in input_predicate.literals}
+ if len(involved_fields) == 1 and (field := involved_fields.pop()) in
keys:
+ result[field].append(input_predicate)
+ return
+
+ if input_predicate.field in keys:
+ result[input_predicate.field].append(input_predicate)
diff --git a/paimon-python/pypaimon/read/reader/format_avro_reader.py
b/paimon-python/pypaimon/read/reader/format_avro_reader.py
index 83e90606e7..4ce7c04ed4 100644
--- a/paimon-python/pypaimon/read/reader/format_avro_reader.py
+++ b/paimon-python/pypaimon/read/reader/format_avro_reader.py
@@ -20,11 +20,11 @@ from typing import List, Optional
import fastavro
import pyarrow as pa
+import pyarrow.compute as pc
import pyarrow.dataset as ds
from pyarrow import RecordBatch
from pypaimon.common.file_io import FileIO
-from pypaimon.common.predicate import Predicate
from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
from pypaimon.schema.data_types import DataField, PyarrowFieldParser
@@ -35,26 +35,18 @@ class FormatAvroReader(RecordBatchReader):
provided predicate and projection, and converts Avro records to
RecordBatch format.
"""
- def __init__(self, file_io: FileIO, file_path: str, primary_keys:
List[str],
- fields: List[str], full_fields: List[DataField], predicate:
Predicate, batch_size: int = 4096):
+ def __init__(self, file_io: FileIO, file_path: str, read_fields:
List[str], full_fields: List[DataField],
+ push_down_predicate: pc.Expression | bool, batch_size: int =
4096):
self._file = file_io.filesystem.open_input_file(file_path)
self._avro_reader = fastavro.reader(self._file)
self._batch_size = batch_size
- self._primary_keys = primary_keys
+ self._push_down_predicate = push_down_predicate
- self._fields = fields
+ self._fields = read_fields
full_fields_map = {field.name: field for field in full_fields}
- projected_data_fields = [full_fields_map[name] for name in fields]
+ projected_data_fields = [full_fields_map[name] for name in read_fields]
self._schema =
PyarrowFieldParser.from_paimon_schema(projected_data_fields)
- if primary_keys:
- # TODO: utilize predicate to improve performance
- predicate = None
- if predicate is not None:
- self._predicate = predicate.to_arrow()
- else:
- self._predicate = None
-
def read_arrow_batch(self) -> Optional[RecordBatch]:
pydict_data = {name: [] for name in self._fields}
records_in_batch = 0
@@ -68,12 +60,12 @@ class FormatAvroReader(RecordBatchReader):
if records_in_batch == 0:
return None
- if self._predicate is None:
+ if self._push_down_predicate is None:
return pa.RecordBatch.from_pydict(pydict_data, self._schema)
else:
pa_batch = pa.Table.from_pydict(pydict_data, self._schema)
dataset = ds.InMemoryDataset(pa_batch)
- scanner = dataset.scanner(filter=self._predicate)
+ scanner = dataset.scanner(filter=self._push_down_predicate)
combine_chunks = scanner.to_table().combine_chunks()
if combine_chunks.num_rows > 0:
return combine_chunks.to_batches()[0]
diff --git a/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py
b/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py
index b01f8113b7..ecef589391 100644
--- a/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py
+++ b/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py
@@ -18,11 +18,11 @@
from typing import List, Optional
+import pyarrow.compute as pc
import pyarrow.dataset as ds
from pyarrow import RecordBatch
from pypaimon.common.file_io import FileIO
-from pypaimon.common.predicate import Predicate
from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
@@ -32,19 +32,12 @@ class FormatPyArrowReader(RecordBatchReader):
and filters it based on the provided predicate and projection.
"""
- def __init__(self, file_io: FileIO, file_format: str, file_path: str,
primary_keys: List[str],
- fields: List[str], predicate: Predicate, batch_size: int =
4096):
-
- if primary_keys:
- # TODO: utilize predicate to improve performance
- predicate = None
- if predicate is not None:
- predicate = predicate.to_arrow()
-
+ def __init__(self, file_io: FileIO, file_format: str, file_path: str,
read_fields: List[str],
+ push_down_predicate: pc.Expression | bool, batch_size: int =
4096):
self.dataset = ds.dataset(file_path, format=file_format,
filesystem=file_io.filesystem)
self.reader = self.dataset.scanner(
- columns=fields,
- filter=predicate,
+ columns=read_fields,
+ filter=push_down_predicate,
batch_size=batch_size
).to_reader()
@@ -58,36 +51,3 @@ class FormatPyArrowReader(RecordBatchReader):
if self.reader is not None:
self.reader.close()
self.reader = None
-
-
-def _filter_predicate_by_primary_keys(predicate: Predicate, primary_keys):
- """
- Filter out predicates that are not related to primary key fields.
- """
- if predicate is None or primary_keys is None:
- return predicate
-
- if predicate.method in ['and', 'or']:
- filtered_literals = []
- for literal in predicate.literals:
- filtered = _filter_predicate_by_primary_keys(literal, primary_keys)
- if filtered is not None:
- filtered_literals.append(filtered)
-
- if not filtered_literals:
- return None
-
- if len(filtered_literals) == 1:
- return filtered_literals[0]
-
- return Predicate(
- method=predicate.method,
- index=predicate.index,
- field=predicate.field,
- literals=filtered_literals
- )
-
- if predicate.field in primary_keys:
- return predicate
- else:
- return None
diff --git a/paimon-python/pypaimon/read/split_read.py
b/paimon-python/pypaimon/read/split_read.py
index 99f8a4da21..1fe0a89d0e 100644
--- a/paimon-python/pypaimon/read/split_read.py
+++ b/paimon-python/pypaimon/read/split_read.py
@@ -49,11 +49,13 @@ NULL_FIELD_INDEX = -1
class SplitRead(ABC):
"""Abstract base class for split reading operations."""
- def __init__(self, table, predicate: Optional[Predicate], read_type:
List[DataField], split: Split):
+ def __init__(self, table, predicate: Optional[Predicate],
push_down_predicate,
+ read_type: List[DataField], split: Split):
from pypaimon.table.file_store_table import FileStoreTable
self.table: FileStoreTable = table
self.predicate = predicate
+ self.push_down_predicate = push_down_predicate
self.split = split
self.value_arity = len(read_type)
@@ -72,11 +74,11 @@ class SplitRead(ABC):
format_reader: RecordBatchReader
if file_format == "avro":
- format_reader = FormatAvroReader(self.table.file_io, file_path,
self.table.primary_keys,
-
self._get_final_read_data_fields(), self.read_fields, self.predicate)
+ format_reader = FormatAvroReader(self.table.file_io, file_path,
self._get_final_read_data_fields(),
+ self.read_fields,
self.push_down_predicate)
elif file_format == "parquet" or file_format == "orc":
- format_reader = FormatPyArrowReader(self.table.file_io,
file_format, file_path, self.table.primary_keys,
-
self._get_final_read_data_fields(), self.predicate)
+ format_reader = FormatPyArrowReader(self.table.file_io,
file_format, file_path,
+
self._get_final_read_data_fields(), self.push_down_predicate)
else:
raise ValueError(f"Unexpected file format: {file_format}")
diff --git a/paimon-python/pypaimon/read/table_read.py
b/paimon-python/pypaimon/read/table_read.py
index b8a28c19d1..621549d832 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -19,8 +19,11 @@ from typing import Iterator, List, Optional
import pandas
import pyarrow
+import pyarrow.compute as pc
from pypaimon.common.predicate import Predicate
+from pypaimon.common.predicate_builder import PredicateBuilder
+from pypaimon.read.push_down_utils import extract_predicate_to_list
from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
from pypaimon.read.split import Split
from pypaimon.read.split_read import (MergeFileSplitRead, RawFileSplitRead,
@@ -37,6 +40,7 @@ class TableRead:
self.table: FileStoreTable = table
self.predicate = predicate
+ self.push_down_predicate = self._push_down_predicate()
self.read_type = read_type
def to_iterator(self, splits: List[Split]) -> Iterator:
@@ -78,12 +82,12 @@ class TableRead:
row_tuple_chunk.append(row.row_tuple[row.offset:
row.offset + row.arity])
if len(row_tuple_chunk) >= chunk_size:
- batch =
convert_rows_to_arrow_batch(row_tuple_chunk, schema)
+ batch =
self.convert_rows_to_arrow_batch(row_tuple_chunk, schema)
yield batch
row_tuple_chunk = []
if row_tuple_chunk:
- batch = convert_rows_to_arrow_batch(row_tuple_chunk,
schema)
+ batch =
self.convert_rows_to_arrow_batch(row_tuple_chunk, schema)
yield batch
finally:
reader.close()
@@ -105,11 +109,27 @@ class TableRead:
return ray.data.from_arrow(self.to_arrow(splits))
+ def _push_down_predicate(self) -> pc.Expression | bool:
+ if self.predicate is None:
+ return None
+ elif self.table.is_primary_key_table:
+ result = []
+ extract_predicate_to_list(result, self.predicate,
self.table.primary_keys)
+ if result:
+ # the field index is unused for arrow field
+ pk_predicates =
(PredicateBuilder(self.table.fields).and_predicates(result)).to_arrow()
+ return pk_predicates
+ else:
+ return None
+ else:
+ return self.predicate.to_arrow()
+
def _create_split_read(self, split: Split) -> SplitRead:
if self.table.is_primary_key_table and not split.raw_convertible:
return MergeFileSplitRead(
table=self.table,
predicate=self.predicate,
+ push_down_predicate=self.push_down_predicate,
read_type=self.read_type,
split=split
)
@@ -117,12 +137,13 @@ class TableRead:
return RawFileSplitRead(
table=self.table,
predicate=self.predicate,
+ push_down_predicate=self.push_down_predicate,
read_type=self.read_type,
split=split
)
-
-def convert_rows_to_arrow_batch(row_tuples: List[tuple], schema:
pyarrow.Schema) -> pyarrow.RecordBatch:
- columns_data = zip(*row_tuples)
- pydict = {name: list(column) for name, column in zip(schema.names,
columns_data)}
- return pyarrow.RecordBatch.from_pydict(pydict, schema=schema)
+ @staticmethod
+ def convert_rows_to_arrow_batch(row_tuples: List[tuple], schema:
pyarrow.Schema) -> pyarrow.RecordBatch:
+ columns_data = zip(*row_tuples)
+ pydict = {name: list(column) for name, column in zip(schema.names,
columns_data)}
+ return pyarrow.RecordBatch.from_pydict(pydict, schema=schema)
diff --git a/paimon-python/pypaimon/read/table_scan.py
b/paimon-python/pypaimon/read/table_scan.py
index d89ddd0bcb..1c2c4f33dc 100644
--- a/paimon-python/pypaimon/read/table_scan.py
+++ b/paimon-python/pypaimon/read/table_scan.py
@@ -20,12 +20,15 @@ from collections import defaultdict
from typing import Callable, List, Optional
from pypaimon.common.predicate import Predicate
+from pypaimon.common.predicate_builder import PredicateBuilder
from pypaimon.manifest.manifest_file_manager import ManifestFileManager
from pypaimon.manifest.manifest_list_manager import ManifestListManager
from pypaimon.manifest.schema.data_file_meta import DataFileMeta
from pypaimon.manifest.schema.manifest_entry import ManifestEntry
from pypaimon.read.interval_partition import IntervalPartition, SortedRun
from pypaimon.read.plan import Plan
+from pypaimon.read.push_down_utils import (extract_predicate_to_dict,
+ extract_predicate_to_list)
from pypaimon.read.split import Split
from pypaimon.schema.data_types import DataField
from pypaimon.snapshot.snapshot_manager import SnapshotManager
@@ -49,15 +52,23 @@ class TableScan:
self.manifest_list_manager = ManifestListManager(table)
self.manifest_file_manager = ManifestFileManager(table)
- self.partition_conditions = self._extract_partition_conditions()
+ pk_conditions = []
+ trimmed_pk = [field.name for field in
self.table.table_schema.get_trimmed_primary_key_fields()]
+ extract_predicate_to_list(pk_conditions, self.predicate, trimmed_pk)
+ self.primary_key_predicate =
PredicateBuilder(self.table.fields).and_predicates(pk_conditions)
+
+ partition_conditions = defaultdict(list)
+ extract_predicate_to_dict(partition_conditions, self.predicate,
self.table.partition_keys)
+ self.partition_key_predicate = partition_conditions
+
self.target_split_size = 128 * 1024 * 1024
self.open_file_cost = 4 * 1024 * 1024
self.idx_of_this_subtask = None
self.number_of_para_subtasks = None
- self.only_read_real_buckets = True if self.table.options.get('bucket',
- -1) ==
BucketMode.POSTPONE_BUCKET.value else False
+ self.only_read_real_buckets = True \
+ if (self.table.options.get('bucket', -1) ==
BucketMode.POSTPONE_BUCKET.value) else False
def plan(self) -> Plan:
latest_snapshot = self.snapshot_manager.get_latest_snapshot()
@@ -129,7 +140,7 @@ class TableScan:
filtered_files = []
for file_entry in file_entries:
- if self.partition_conditions and not
self._filter_by_partition(file_entry):
+ if self.partition_key_predicate and not
self._filter_by_partition(file_entry):
continue
if not self._filter_by_stats(file_entry):
continue
@@ -138,98 +149,31 @@ class TableScan:
return filtered_files
def _filter_by_partition(self, file_entry: ManifestEntry) -> bool:
- # TODO: refactor with a better solution
partition_dict = file_entry.partition.to_dict()
- for field_name, condition in self.partition_conditions.items():
+ for field_name, conditions in self.partition_key_predicate.items():
partition_value = partition_dict[field_name]
- if condition['op'] == '=':
- if str(partition_value) != str(condition['value']):
- return False
- elif condition['op'] == 'in':
- if str(partition_value) not in [str(v) for v in
condition['values']]:
- return False
- elif condition['op'] == 'notIn':
- if str(partition_value) in [str(v) for v in
condition['values']]:
- return False
- elif condition['op'] == '>':
- if partition_value <= condition['values']:
- return False
- elif condition['op'] == '>=':
- if partition_value < condition['values']:
- return False
- elif condition['op'] == '<':
- if partition_value >= condition['values']:
- return False
- elif condition['op'] == '<=':
- if partition_value > condition['values']:
+ for predicate in conditions:
+ if not predicate.test_by_value(partition_value):
return False
return True
def _filter_by_stats(self, file_entry: ManifestEntry) -> bool:
- # TODO: real support for filtering by stat
- return True
-
- def _extract_partition_conditions(self) -> dict:
- if not self.predicate or not self.table.partition_keys:
- return {}
-
- conditions = {}
- self._extract_conditions_from_predicate(self.predicate, conditions,
self.table.partition_keys)
- return conditions
-
- def _extract_conditions_from_predicate(self, predicate: 'Predicate',
conditions: dict,
- partition_keys: List[str]):
- if predicate.method == 'and':
- for sub_predicate in predicate.literals:
- self._extract_conditions_from_predicate(sub_predicate,
conditions, partition_keys)
- return
- elif predicate.method == 'or':
- all_partition_conditions = True
- for sub_predicate in predicate.literals:
- if sub_predicate.field not in partition_keys:
- all_partition_conditions = False
- break
- if all_partition_conditions:
- for sub_predicate in predicate.literals:
- self._extract_conditions_from_predicate(sub_predicate,
conditions, partition_keys)
- return
-
- if predicate.field in partition_keys:
- if predicate.method == 'equal':
- conditions[predicate.field] = {
- 'op': '=',
- 'value': predicate.literals[0] if predicate.literals else
None
- }
- elif predicate.method == 'in':
- conditions[predicate.field] = {
- 'op': 'in',
- 'values': predicate.literals if predicate.literals else []
- }
- elif predicate.method == 'notIn':
- conditions[predicate.field] = {
- 'op': 'notIn',
- 'values': predicate.literals if predicate.literals else []
- }
- elif predicate.method == 'greaterThan':
- conditions[predicate.field] = {
- 'op': '>',
- 'value': predicate.literals[0] if predicate.literals else
None
- }
- elif predicate.method == 'greaterOrEqual':
- conditions[predicate.field] = {
- 'op': '>=',
- 'value': predicate.literals[0] if predicate.literals else
None
- }
- elif predicate.method == 'lessThan':
- conditions[predicate.field] = {
- 'op': '<',
- 'value': predicate.literals[0] if predicate.literals else
None
- }
- elif predicate.method == 'lessOrEqual':
- conditions[predicate.field] = {
- 'op': '<=',
- 'value': predicate.literals[0] if predicate.literals else
None
- }
+ if file_entry.kind != 0:
+ return False
+ if self.table.is_primary_key_table:
+ predicate = self.primary_key_predicate
+ stats = file_entry.file.key_stats
+ else:
+ predicate = self.predicate
+ stats = file_entry.file.value_stats
+ return predicate.test_by_stats({
+ "min_values": stats.min_values.to_dict(),
+ "max_values": stats.max_values.to_dict(),
+ "null_counts": {
+ stats.min_values.fields[i].name: stats.null_counts[i] for i in
range(len(stats.min_values.fields))
+ },
+ "row_count": file_entry.file.row_count,
+ })
def _create_append_only_splits(self, file_entries: List[ManifestEntry]) ->
List['Split']:
if not file_entries:
@@ -240,7 +184,7 @@ class TableScan:
def weight_func(f: DataFileMeta) -> int:
return max(f.file_size, self.open_file_cost)
- packed_files: List[List[DataFileMeta]] = _pack_for_ordered(data_files,
weight_func, self.target_split_size)
+ packed_files: List[List[DataFileMeta]] =
self._pack_for_ordered(data_files, weight_func, self.target_split_size)
return self._build_split_from_pack(packed_files, file_entries, False)
def _create_primary_key_splits(self, file_entries: List[ManifestEntry]) ->
List['Split']:
@@ -257,7 +201,8 @@ class TableScan:
def weight_func(fl: List[DataFileMeta]) -> int:
return max(sum(f.file_size for f in fl), self.open_file_cost)
- packed_files: List[List[List[DataFileMeta]]] =
_pack_for_ordered(sections, weight_func, self.target_split_size)
+ packed_files: List[List[List[DataFileMeta]]] =
self._pack_for_ordered(sections, weight_func,
+
self.target_split_size)
flatten_packed_files: List[List[DataFileMeta]] = [
[file for sub_pack in pack for file in sub_pack]
for pack in packed_files
@@ -295,23 +240,23 @@ class TableScan:
splits.append(split)
return splits
+ @staticmethod
+ def _pack_for_ordered(items: List, weight_func: Callable, target_weight:
int) -> List[List]:
+ packed = []
+ bin_items = []
+ bin_weight = 0
-def _pack_for_ordered(items: List, weight_func: Callable, target_weight: int)
-> List[List]:
- packed = []
- bin_items = []
- bin_weight = 0
+ for item in items:
+ weight = weight_func(item)
+ if bin_weight + weight > target_weight and len(bin_items) > 0:
+ packed.append(bin_items)
+ bin_items.clear()
+ bin_weight = 0
- for item in items:
- weight = weight_func(item)
- if bin_weight + weight > target_weight and len(bin_items) > 0:
- packed.append(bin_items)
- bin_items.clear()
- bin_weight = 0
-
- bin_weight += weight
- bin_items.append(item)
+ bin_weight += weight
+ bin_items.append(item)
- if len(bin_items) > 0:
- packed.append(bin_items)
+ if len(bin_items) > 0:
+ packed.append(bin_items)
- return packed
+ return packed
diff --git a/paimon-python/pypaimon/tests/predicate_push_down_test.py
b/paimon-python/pypaimon/tests/predicate_push_down_test.py
new file mode 100644
index 0000000000..b5b403f674
--- /dev/null
+++ b/paimon-python/pypaimon/tests/predicate_push_down_test.py
@@ -0,0 +1,151 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+import shutil
+import tempfile
+import unittest
+
+import pyarrow as pa
+
+from pypaimon.catalog.catalog_factory import CatalogFactory
+from pypaimon.common.predicate_builder import PredicateBuilder
+from pypaimon.schema.schema import Schema
+
+
+class PredicatePushDownTest(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog = CatalogFactory.create({
+ 'warehouse': cls.warehouse
+ })
+ cls.catalog.create_database('default', False)
+
+ cls.pa_schema = pa.schema([
+ pa.field('key1', pa.int32(), nullable=False),
+ pa.field('key2', pa.string(), nullable=False),
+ ('behavior', pa.string()),
+ pa.field('dt1', pa.string(), nullable=False),
+ pa.field('dt2', pa.int32(), nullable=False)
+ ])
+ cls.expected = pa.Table.from_pydict({
+ 'key1': [1, 2, 3, 4, 5, 7, 8],
+ 'key2': ['h', 'g', 'f', 'e', 'd', 'b', 'a'],
+ 'behavior': ['a', 'b-new', 'c', None, 'e', 'g', 'h'],
+ 'dt1': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'],
+ 'dt2': [2, 2, 1, 2, 2, 1, 2],
+ }, schema=cls.pa_schema)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ def testPkReaderWithFilter(self):
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
+ partition_keys=['dt1', 'dt2'],
+ primary_keys=['key1', 'key2'],
+ options={'bucket': '1'})
+ self.catalog.create_table('default.test_pk_filter', schema, False)
+ table = self.catalog.get_table('default.test_pk_filter')
+
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data1 = {
+ 'key1': [1, 2, 3, 4],
+ 'key2': ['h', 'g', 'f', 'e'],
+ 'behavior': ['a', 'b', 'c', None],
+ 'dt1': ['p1', 'p1', 'p2', 'p1'],
+ 'dt2': [2, 2, 1, 2],
+ }
+ pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ data1 = {
+ 'key1': [5, 2, 7, 8],
+ 'key2': ['d', 'g', 'b', 'a'],
+ 'behavior': ['e', 'b-new', 'g', 'h'],
+ 'dt1': ['p2', 'p1', 'p1', 'p2'],
+ 'dt2': [2, 2, 1, 2]
+ }
+ pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # test filter by partition
+ predicate_builder: PredicateBuilder =
table.new_read_builder().new_predicate_builder()
+ p1 = predicate_builder.startswith('dt1', "p1")
+ p2 = predicate_builder.is_in('dt1', ["p2"])
+ p3 = predicate_builder.or_predicates([p1, p2])
+ p4 = predicate_builder.equal('dt2', 2)
+ g1 = predicate_builder.and_predicates([p3, p4])
+ # (dt1 startswith 'p1' or dt1 is_in ["p2"]) and dt2 == 2
+ read_builder = table.new_read_builder().with_filter(g1)
+ splits = read_builder.new_scan().plan().splits()
+ self.assertEqual(len(splits), 2)
+ self.assertEqual(splits[0].partition.to_dict()["dt2"], 2)
+ self.assertEqual(splits[1].partition.to_dict()["dt2"], 2)
+
+ # test filter by stats
+ predicate_builder: PredicateBuilder =
table.new_read_builder().new_predicate_builder()
+ p1 = predicate_builder.equal('key1', 7)
+ p2 = predicate_builder.is_in('key2', ["e", "f"])
+ p3 = predicate_builder.or_predicates([p1, p2])
+ p4 = predicate_builder.greater_than('key1', 3)
+ g1 = predicate_builder.and_predicates([p3, p4])
+ # (key1 == 7 or key2 is_in ["e", "f"]) and key1 > 3
+ read_builder = table.new_read_builder().with_filter(g1)
+ splits = read_builder.new_scan().plan().splits()
+ # initial splits meta:
+ # p1, 2 -> 2g, 2g; 1e, 4h
+ # p2, 1 -> 3f, 3f
+ # p2, 2 -> 5a, 8d
+ # p1, 1 -> 7b, 7b
+ self.assertEqual(len(splits), 3)
+ # expect to filter out `p1, 2 -> 2g, 2g` and `p2, 1 -> 3f, 3f`
+ count = 0
+ for split in splits:
+ if split.partition.values == ["p1", 2]:
+ count += 1
+ self.assertEqual(len(split.files), 1)
+ min_values = split.files[0].value_stats.min_values.to_dict()
+ max_values = split.files[0].value_stats.max_values.to_dict()
+ self.assertTrue(min_values["key1"] == 1 and min_values["key2"]
== "e"
+ and max_values["key1"] == 4 and
max_values["key2"] == "h")
+ elif split.partition.values == ["p2", 2]:
+ count += 1
+ min_values = split.files[0].value_stats.min_values.to_dict()
+ max_values = split.files[0].value_stats.max_values.to_dict()
+ self.assertTrue(min_values["key1"] == 5 and min_values["key2"]
== "a"
+ and max_values["key1"] == 8 and
max_values["key2"] == "d")
+ elif split.partition.values == ["p1", 1]:
+ count += 1
+ min_values = split.files[0].value_stats.min_values.to_dict()
+ max_values = split.files[0].value_stats.max_values.to_dict()
+ self.assertTrue(min_values["key1"] == max_values["key1"] == 7
+ and max_values["key2"] == max_values["key2"]
== "b")
+ self.assertEqual(count, 3)
diff --git a/paimon-python/pypaimon/write/file_store_commit.py
b/paimon-python/pypaimon/write/file_store_commit.py
index e7bf7ba534..efe8207ebe 100644
--- a/paimon-python/pypaimon/write/file_store_commit.py
+++ b/paimon-python/pypaimon/write/file_store_commit.py
@@ -81,15 +81,15 @@ class FileStoreCommit:
num_added_files=sum(len(msg.new_files) for msg in commit_messages),
num_deleted_files=0,
partition_stats=SimpleStats(
- min_value=BinaryRow(
+ min_values=BinaryRow(
values=partition_min_stats,
fields=self.table.table_schema.get_partition_key_fields(),
),
- max_value=BinaryRow(
+ max_values=BinaryRow(
values=partition_max_stats,
fields=self.table.table_schema.get_partition_key_fields(),
),
- null_count=partition_null_counts,
+ null_counts=partition_null_counts,
),
schema_id=self.table.table_schema.id,
)
diff --git a/paimon-python/pypaimon/write/writer/data_writer.py
b/paimon-python/pypaimon/write/writer/data_writer.py
index 5d9641718c..bc4553847c 100644
--- a/paimon-python/pypaimon/write/writer/data_writer.py
+++ b/paimon-python/pypaimon/write/writer/data_writer.py
@@ -19,7 +19,7 @@ import uuid
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
-from typing import List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple
import pyarrow as pa
import pyarrow.compute as pc
@@ -128,13 +128,13 @@ class DataWriter(ABC):
for field in self.table.table_schema.fields
}
all_fields = self.table.table_schema.fields
- min_value_stats = [column_stats[field.name]['min_value'] for field in
all_fields]
- max_value_stats = [column_stats[field.name]['max_value'] for field in
all_fields]
- value_null_counts = [column_stats[field.name]['null_count'] for field
in all_fields]
+ min_value_stats = [column_stats[field.name]['min_values'] for field in
all_fields]
+ max_value_stats = [column_stats[field.name]['max_values'] for field in
all_fields]
+ value_null_counts = [column_stats[field.name]['null_counts'] for field
in all_fields]
key_fields = self.trimmed_primary_key_fields
- min_key_stats = [column_stats[field.name]['min_value'] for field in
key_fields]
- max_key_stats = [column_stats[field.name]['max_value'] for field in
key_fields]
- key_null_counts = [column_stats[field.name]['null_count'] for field in
key_fields]
+ min_key_stats = [column_stats[field.name]['min_values'] for field in
key_fields]
+ max_key_stats = [column_stats[field.name]['max_values'] for field in
key_fields]
+ key_null_counts = [column_stats[field.name]['null_counts'] for field
in key_fields]
if not all(count == 0 for count in key_null_counts):
raise RuntimeError("Primary key should not be null")
@@ -203,21 +203,21 @@ class DataWriter(ABC):
return best_split
@staticmethod
- def _get_column_stats(record_batch: pa.RecordBatch, column_name: str) ->
dict:
+ def _get_column_stats(record_batch: pa.RecordBatch, column_name: str) ->
Dict:
column_array = record_batch.column(column_name)
if column_array.null_count == len(column_array):
return {
- "min_value": None,
- "max_value": None,
- "null_count": column_array.null_count,
+ "min_values": None,
+ "max_values": None,
+ "null_counts": column_array.null_count,
}
- min_value = pc.min(column_array).as_py()
- max_value = pc.max(column_array).as_py()
- null_count = column_array.null_count
+ min_values = pc.min(column_array).as_py()
+ max_values = pc.max(column_array).as_py()
+ null_counts = column_array.null_count
return {
- "min_value": min_value,
- "max_value": max_value,
- "null_count": null_count,
+ "min_values": min_values,
+ "max_values": max_values,
+ "null_counts": null_counts,
}