This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new b34d8dde Fix `Table.scan` to enable case sensitive argument (#1423)
b34d8dde is described below
commit b34d8dde5ca53b9dd9a823457dadd3b9e76abceb
Author: Jiakai Li <[email protected]>
AuthorDate: Mon Dec 16 21:34:01 2024 +1300
Fix `Table.scan` to enable case sensitive argument (#1423)
* fix-table-scan-enable-case-sensitivity
* Updates included:
- Add more readable integration test for case-sensitive and
case-insensitive `Table.scan`
- Remove less readable test
- Enable `case_sensitive` delete and overwrite
* Remove less readable test
* Add integration test `Table.delete` and `Table.overwrite`
* Fix typo
* Add test cases for default `Table.delete` case-sensitivity
* Update `case_sensitive` argument position
---
pyiceberg/table/__init__.py | 42 ++++++++----
pyiceberg/table/update/snapshot.py | 15 +++--
tests/integration/test_deletes.py | 134 ++++++++++++++++++++++++++++++++++++-
tests/integration/test_reads.py | 44 ++++++++++++
4 files changed, 214 insertions(+), 21 deletions(-)
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 3eb74eee..766ffba6 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -268,12 +268,10 @@ class Transaction:
return self
- def _scan(self, row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE)
-> DataScan:
+ def _scan(self, row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE,
case_sensitive: bool = True) -> DataScan:
"""Minimal data scan of the table with the current state of the
transaction."""
return DataScan(
- table_metadata=self.table_metadata,
- io=self._table.io,
- row_filter=row_filter,
+ table_metadata=self.table_metadata, io=self._table.io,
row_filter=row_filter, case_sensitive=case_sensitive
)
def upgrade_table_version(self, format_version: TableVersion) ->
Transaction:
@@ -422,6 +420,7 @@ class Transaction:
df: pa.Table,
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
+ case_sensitive: bool = True,
) -> None:
"""
Shorthand for adding a table overwrite with a PyArrow table to the
transaction.
@@ -436,6 +435,7 @@ class Transaction:
df: The Arrow dataframe that will be used to overwrite the table
overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
or a boolean expression in case of a partial
overwrite
+ case_sensitive: A bool determine if the provided
`overwrite_filter` is case-sensitive
snapshot_properties: Custom properties to be added to the snapshot
summary
"""
try:
@@ -459,7 +459,7 @@ class Transaction:
self.table_metadata.schema(), provided_schema=df.schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
- self.delete(delete_filter=overwrite_filter,
snapshot_properties=snapshot_properties)
+ self.delete(delete_filter=overwrite_filter,
case_sensitive=case_sensitive, snapshot_properties=snapshot_properties)
with
self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as
update_snapshot:
# skip writing data files if the dataframe is empty
@@ -470,11 +470,16 @@ class Transaction:
for data_file in data_files:
update_snapshot.append_data_file(data_file)
- def delete(self, delete_filter: Union[str, BooleanExpression],
snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
+ def delete(
+ self,
+ delete_filter: Union[str, BooleanExpression],
+ snapshot_properties: Dict[str, str] = EMPTY_DICT,
+ case_sensitive: bool = True,
+ ) -> None:
"""
Shorthand for deleting record from a table.
- An deletee may produce zero or more snapshots based on the operation:
+ A delete may produce zero or more snapshots based on the operation:
- DELETE: In case existing Parquet files can be dropped completely.
- REPLACE: In case existing Parquet files need to be rewritten
@@ -482,6 +487,7 @@ class Transaction:
Args:
delete_filter: A boolean expression to delete rows from a table
snapshot_properties: Custom properties to be added to the snapshot
summary
+ case_sensitive: A bool determine if the provided `delete_filter`
is case-sensitive
"""
from pyiceberg.io.pyarrow import (
ArrowScan,
@@ -499,14 +505,14 @@ class Transaction:
delete_filter = _parse_row_filter(delete_filter)
with
self.update_snapshot(snapshot_properties=snapshot_properties).delete() as
delete_snapshot:
- delete_snapshot.delete_by_predicate(delete_filter)
+ delete_snapshot.delete_by_predicate(delete_filter, case_sensitive)
# Check if there are any files that require an actual rewrite of a
data file
if delete_snapshot.rewrites_needed is True:
- bound_delete_filter = bind(self.table_metadata.schema(),
delete_filter, case_sensitive=True)
+ bound_delete_filter = bind(self.table_metadata.schema(),
delete_filter, case_sensitive)
preserve_row_filter =
_expression_to_complementary_pyarrow(bound_delete_filter)
- files = self._scan(row_filter=delete_filter).plan_files()
+ files = self._scan(row_filter=delete_filter,
case_sensitive=case_sensitive).plan_files()
commit_uuid = uuid.uuid4()
counter = itertools.count(0)
@@ -988,6 +994,7 @@ class Table:
df: pa.Table,
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
+ case_sensitive: bool = True,
) -> None:
"""
Shorthand for overwriting the table with a PyArrow table.
@@ -1003,12 +1010,18 @@ class Table:
overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
or a boolean expression in case of a partial
overwrite
snapshot_properties: Custom properties to be added to the snapshot
summary
+ case_sensitive: A bool determine if the provided
`overwrite_filter` is case-sensitive
"""
with self.transaction() as tx:
- tx.overwrite(df=df, overwrite_filter=overwrite_filter,
snapshot_properties=snapshot_properties)
+ tx.overwrite(
+ df=df, overwrite_filter=overwrite_filter,
case_sensitive=case_sensitive, snapshot_properties=snapshot_properties
+ )
def delete(
- self, delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT
+ self,
+ delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
+ snapshot_properties: Dict[str, str] = EMPTY_DICT,
+ case_sensitive: bool = True,
) -> None:
"""
Shorthand for deleting rows from the table.
@@ -1016,9 +1029,10 @@ class Table:
Args:
delete_filter: The predicate that used to remove rows
snapshot_properties: Custom properties to be added to the snapshot
summary
+ case_sensitive: A bool determine if the provided `delete_filter`
is case-sensitive
"""
with self.transaction() as tx:
- tx.delete(delete_filter=delete_filter,
snapshot_properties=snapshot_properties)
+ tx.delete(delete_filter=delete_filter,
case_sensitive=case_sensitive, snapshot_properties=snapshot_properties)
def add_files(
self, file_paths: List[str], snapshot_properties: Dict[str, str] =
EMPTY_DICT, check_duplicate_files: bool = True
@@ -1311,7 +1325,7 @@ def _match_deletes_to_data_file(data_entry:
ManifestEntry, positional_delete_ent
class DataScan(TableScan):
def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
- project = inclusive_projection(self.table_metadata.schema(),
self.table_metadata.specs()[spec_id])
+ project = inclusive_projection(self.table_metadata.schema(),
self.table_metadata.specs()[spec_id], self.case_sensitive)
return project(self.row_filter)
@cached_property
diff --git a/pyiceberg/table/update/snapshot.py
b/pyiceberg/table/update/snapshot.py
index 47e5fc55..c0d0056e 100644
--- a/pyiceberg/table/update/snapshot.py
+++ b/pyiceberg/table/update/snapshot.py
@@ -318,6 +318,7 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]):
"""
_predicate: BooleanExpression
+ _case_sensitive: bool
def __init__(
self,
@@ -329,6 +330,7 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]):
):
super().__init__(operation, transaction, io, commit_uuid,
snapshot_properties)
self._predicate = AlwaysFalse()
+ self._case_sensitive = True
def _commit(self) -> UpdatesAndRequirements:
# Only produce a commit when there is something to delete
@@ -340,7 +342,7 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]):
def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
schema = self._transaction.table_metadata.schema()
spec = self._transaction.table_metadata.specs()[spec_id]
- project = inclusive_projection(schema, spec)
+ project = inclusive_projection(schema, spec, self._case_sensitive)
return project(self._predicate)
@cached_property
@@ -350,10 +352,11 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]):
def _build_manifest_evaluator(self, spec_id: int) ->
Callable[[ManifestFile], bool]:
schema = self._transaction.table_metadata.schema()
spec = self._transaction.table_metadata.specs()[spec_id]
- return manifest_evaluator(spec, schema,
self.partition_filters[spec_id], case_sensitive=True)
+ return manifest_evaluator(spec, schema,
self.partition_filters[spec_id], self._case_sensitive)
- def delete_by_predicate(self, predicate: BooleanExpression) -> None:
+ def delete_by_predicate(self, predicate: BooleanExpression,
case_sensitive: bool = True) -> None:
self._predicate = Or(self._predicate, predicate)
+ self._case_sensitive = case_sensitive
@cached_property
def _compute_deletes(self) -> Tuple[List[ManifestFile],
List[ManifestEntry], bool]:
@@ -376,8 +379,10 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]):
)
manifest_evaluators: Dict[int, Callable[[ManifestFile], bool]] =
KeyDefaultDict(self._build_manifest_evaluator)
- strict_metrics_evaluator = _StrictMetricsEvaluator(schema,
self._predicate, case_sensitive=True).eval
- inclusive_metrics_evaluator = _InclusiveMetricsEvaluator(schema,
self._predicate, case_sensitive=True).eval
+ strict_metrics_evaluator = _StrictMetricsEvaluator(schema,
self._predicate, case_sensitive=self._case_sensitive).eval
+ inclusive_metrics_evaluator = _InclusiveMetricsEvaluator(
+ schema, self._predicate, case_sensitive=self._case_sensitive
+ ).eval
existing_manifests = []
total_deleted_entries = []
diff --git a/tests/integration/test_deletes.py
b/tests/integration/test_deletes.py
index 2cdf9916..affc480f 100644
--- a/tests/integration/test_deletes.py
+++ b/tests/integration/test_deletes.py
@@ -16,7 +16,7 @@
# under the License.
# pylint:disable=redefined-outer-name
from datetime import datetime
-from typing import List
+from typing import Generator, List
import pyarrow as pa
import pytest
@@ -28,9 +28,10 @@ from pyiceberg.expressions import AlwaysTrue, EqualTo
from pyiceberg.manifest import ManifestEntryStatus
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
+from pyiceberg.table import Table
from pyiceberg.table.snapshots import Operation, Summary
from pyiceberg.transforms import IdentityTransform
-from pyiceberg.types import FloatType, IntegerType, LongType, NestedField,
TimestampType
+from pyiceberg.types import FloatType, IntegerType, LongType, NestedField,
StringType, TimestampType
def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None:
@@ -38,6 +39,24 @@ def run_spark_commands(spark: SparkSession, sqls: List[str])
-> None:
spark.sql(sql)
[email protected]()
+def test_table(session_catalog: RestCatalog) -> Generator[Table, None, None]:
+ identifier = "default.__test_table"
+ arrow_table = pa.Table.from_arrays([pa.array([1, 2, 3, 4, 5]),
pa.array(["a", "b", "c", "d", "e"])], names=["idx", "value"])
+ test_table = session_catalog.create_table(
+ identifier,
+ schema=Schema(
+ NestedField(1, "idx", LongType()),
+ NestedField(2, "value", StringType()),
+ ),
+ )
+ test_table.append(arrow_table)
+
+ yield test_table
+
+ session_catalog.drop_table(identifier)
+
+
@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_partitioned_table_delete_full_file(spark: SparkSession,
session_catalog: RestCatalog, format_version: int) -> None:
@@ -770,3 +789,114 @@ def
test_delete_after_partition_evolution_from_partitioned(session_catalog: Rest
# Expect 8 records: 10 records - 2
assert len(tbl.scan().to_arrow()) == 8
+
+
[email protected]
+def test_delete_with_filter_case_sensitive_by_default(test_table: Table) ->
None:
+ record_to_delete = {"idx": 2, "value": "b"}
+ assert record_to_delete in test_table.scan().to_arrow().to_pylist()
+
+ with pytest.raises(ValueError) as e:
+ test_table.delete(f"Idx == {record_to_delete['idx']}")
+ assert "Could not find field with name Idx" in str(e.value)
+ assert record_to_delete in test_table.scan().to_arrow().to_pylist()
+
+ test_table.delete(f"idx == {record_to_delete['idx']}")
+ assert record_to_delete not in test_table.scan().to_arrow().to_pylist()
+
+
[email protected]
+def test_delete_with_filter_case_sensitive(test_table: Table) -> None:
+ record_to_delete = {"idx": 2, "value": "b"}
+ assert record_to_delete in test_table.scan().to_arrow().to_pylist()
+
+ with pytest.raises(ValueError) as e:
+ test_table.delete(f"Idx == {record_to_delete['idx']}",
case_sensitive=True)
+ assert "Could not find field with name Idx" in str(e.value)
+ assert record_to_delete in test_table.scan().to_arrow().to_pylist()
+
+ test_table.delete(f"idx == {record_to_delete['idx']}", case_sensitive=True)
+ assert record_to_delete not in test_table.scan().to_arrow().to_pylist()
+
+
[email protected]
+def test_delete_with_filter_case_insensitive(test_table: Table) -> None:
+ record_to_delete_1 = {"idx": 2, "value": "b"}
+ record_to_delete_2 = {"idx": 3, "value": "c"}
+ assert record_to_delete_1 in test_table.scan().to_arrow().to_pylist()
+ assert record_to_delete_2 in test_table.scan().to_arrow().to_pylist()
+
+ test_table.delete(f"Idx == {record_to_delete_1['idx']}",
case_sensitive=False)
+ assert record_to_delete_1 not in test_table.scan().to_arrow().to_pylist()
+
+ test_table.delete(f"idx == {record_to_delete_2['idx']}",
case_sensitive=False)
+ assert record_to_delete_2 not in test_table.scan().to_arrow().to_pylist()
+
+
[email protected]
+def test_overwrite_with_filter_case_sensitive_by_default(test_table: Table) ->
None:
+ record_to_overwrite = {"idx": 2, "value": "b"}
+ assert record_to_overwrite in test_table.scan().to_arrow().to_pylist()
+
+ new_record_to_insert = {"idx": 10, "value": "x"}
+ new_table = pa.Table.from_arrays(
+ [
+ pa.array([new_record_to_insert["idx"]]),
+ pa.array([new_record_to_insert["value"]]),
+ ],
+ names=["idx", "value"],
+ )
+
+ with pytest.raises(ValueError) as e:
+ test_table.overwrite(df=new_table, overwrite_filter=f"Idx ==
{record_to_overwrite['idx']}")
+ assert "Could not find field with name Idx" in str(e.value)
+ assert record_to_overwrite in test_table.scan().to_arrow().to_pylist()
+ assert new_record_to_insert not in test_table.scan().to_arrow().to_pylist()
+
+ test_table.overwrite(df=new_table, overwrite_filter=f"idx ==
{record_to_overwrite['idx']}")
+ assert record_to_overwrite not in test_table.scan().to_arrow().to_pylist()
+ assert new_record_to_insert in test_table.scan().to_arrow().to_pylist()
+
+
[email protected]
+def test_overwrite_with_filter_case_sensitive(test_table: Table) -> None:
+ record_to_overwrite = {"idx": 2, "value": "b"}
+ assert record_to_overwrite in test_table.scan().to_arrow().to_pylist()
+
+ new_record_to_insert = {"idx": 10, "value": "x"}
+ new_table = pa.Table.from_arrays(
+ [
+ pa.array([new_record_to_insert["idx"]]),
+ pa.array([new_record_to_insert["value"]]),
+ ],
+ names=["idx", "value"],
+ )
+
+ with pytest.raises(ValueError) as e:
+ test_table.overwrite(df=new_table, overwrite_filter=f"Idx ==
{record_to_overwrite['idx']}", case_sensitive=True)
+ assert "Could not find field with name Idx" in str(e.value)
+ assert record_to_overwrite in test_table.scan().to_arrow().to_pylist()
+ assert new_record_to_insert not in test_table.scan().to_arrow().to_pylist()
+
+ test_table.overwrite(df=new_table, overwrite_filter=f"idx ==
{record_to_overwrite['idx']}", case_sensitive=True)
+ assert record_to_overwrite not in test_table.scan().to_arrow().to_pylist()
+ assert new_record_to_insert in test_table.scan().to_arrow().to_pylist()
+
+
[email protected]
+def test_overwrite_with_filter_case_insensitive(test_table: Table) -> None:
+ record_to_overwrite = {"idx": 2, "value": "b"}
+ assert record_to_overwrite in test_table.scan().to_arrow().to_pylist()
+
+ new_record_to_insert = {"idx": 10, "value": "x"}
+ new_table = pa.Table.from_arrays(
+ [
+ pa.array([new_record_to_insert["idx"]]),
+ pa.array([new_record_to_insert["value"]]),
+ ],
+ names=["idx", "value"],
+ )
+
+ test_table.overwrite(df=new_table, overwrite_filter=f"Idx ==
{record_to_overwrite['idx']}", case_sensitive=False)
+ assert record_to_overwrite not in test_table.scan().to_arrow().to_pylist()
+ assert new_record_to_insert in test_table.scan().to_arrow().to_pylist()
diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py
index f8bc57bb..0279c219 100644
--- a/tests/integration/test_reads.py
+++ b/tests/integration/test_reads.py
@@ -621,6 +621,50 @@ def test_filter_on_new_column(catalog: Catalog) -> None:
assert arrow_table["b"].to_pylist() == [None]
[email protected]
[email protected]("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
+def test_filter_case_sensitive_by_default(catalog: Catalog) -> None:
+ test_table_add_column = catalog.load_table("default.test_table_add_column")
+ arrow_table = test_table_add_column.scan().to_arrow()
+ assert "2" in arrow_table["b"].to_pylist()
+
+ arrow_table = test_table_add_column.scan(row_filter="b == '2'").to_arrow()
+ assert arrow_table["b"].to_pylist() == ["2"]
+
+ with pytest.raises(ValueError) as e:
+ _ = test_table_add_column.scan(row_filter="B == '2'").to_arrow()
+ assert "Could not find field with name B" in str(e.value)
+
+
[email protected]
[email protected]("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
+def test_filter_case_sensitive(catalog: Catalog) -> None:
+ test_table_add_column = catalog.load_table("default.test_table_add_column")
+ arrow_table = test_table_add_column.scan().to_arrow()
+ assert "2" in arrow_table["b"].to_pylist()
+
+ arrow_table = test_table_add_column.scan(row_filter="b == '2'",
case_sensitive=True).to_arrow()
+ assert arrow_table["b"].to_pylist() == ["2"]
+
+ with pytest.raises(ValueError) as e:
+ _ = test_table_add_column.scan(row_filter="B == '2'",
case_sensitive=True).to_arrow()
+ assert "Could not find field with name B" in str(e.value)
+
+
[email protected]
[email protected]("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
+def test_filter_case_insensitive(catalog: Catalog) -> None:
+ test_table_add_column = catalog.load_table("default.test_table_add_column")
+ arrow_table = test_table_add_column.scan().to_arrow()
+ assert "2" in arrow_table["b"].to_pylist()
+
+ arrow_table = test_table_add_column.scan(row_filter="b == '2'",
case_sensitive=False).to_arrow()
+ assert arrow_table["b"].to_pylist() == ["2"]
+
+ arrow_table = test_table_add_column.scan(row_filter="B == '2'",
case_sensitive=False).to_arrow()
+ assert arrow_table["b"].to_pylist() == ["2"]
+
+
@pytest.mark.integration
@pytest.mark.parametrize("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
def test_upgrade_table_version(catalog: Catalog) -> None: