This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new b34d8dde Fix `Table.scan` to enable case sensitive argument (#1423)
b34d8dde is described below

commit b34d8dde5ca53b9dd9a823457dadd3b9e76abceb
Author: Jiakai Li <[email protected]>
AuthorDate: Mon Dec 16 21:34:01 2024 +1300

    Fix `Table.scan` to enable case sensitive argument (#1423)
    
    * fix-table-scan-enable-case-sensitivity
    
    * Updates included:
    - Add more readable integration test for case-sensitive and 
case-insensitive `Table.scan`
    - Remove less readable test
    - Enable `case_sensitive` delete and overwrite
    
    * Remove less readable test
    
    * Add integration test `Table.delete` and `Table.overwrite`
    
    * Fix typo
    
    * Add test cases for default `Table.delete` case-sensitivity
    
    * Update `case_sensitive` argument position
---
 pyiceberg/table/__init__.py        |  42 ++++++++----
 pyiceberg/table/update/snapshot.py |  15 +++--
 tests/integration/test_deletes.py  | 134 ++++++++++++++++++++++++++++++++++++-
 tests/integration/test_reads.py    |  44 ++++++++++++
 4 files changed, 214 insertions(+), 21 deletions(-)

diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 3eb74eee..766ffba6 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -268,12 +268,10 @@ class Transaction:
 
         return self
 
-    def _scan(self, row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE) 
-> DataScan:
+    def _scan(self, row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, 
case_sensitive: bool = True) -> DataScan:
         """Minimal data scan of the table with the current state of the 
transaction."""
         return DataScan(
-            table_metadata=self.table_metadata,
-            io=self._table.io,
-            row_filter=row_filter,
+            table_metadata=self.table_metadata, io=self._table.io, 
row_filter=row_filter, case_sensitive=case_sensitive
         )
 
     def upgrade_table_version(self, format_version: TableVersion) -> 
Transaction:
@@ -422,6 +420,7 @@ class Transaction:
         df: pa.Table,
         overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
         snapshot_properties: Dict[str, str] = EMPTY_DICT,
+        case_sensitive: bool = True,
     ) -> None:
         """
         Shorthand for adding a table overwrite with a PyArrow table to the 
transaction.
@@ -436,6 +435,7 @@ class Transaction:
             df: The Arrow dataframe that will be used to overwrite the table
             overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
                               or a boolean expression in case of a partial 
overwrite
+            case_sensitive: A bool determine if the provided 
`overwrite_filter` is case-sensitive
             snapshot_properties: Custom properties to be added to the snapshot 
summary
         """
         try:
@@ -459,7 +459,7 @@ class Transaction:
             self.table_metadata.schema(), provided_schema=df.schema, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
         )
 
-        self.delete(delete_filter=overwrite_filter, 
snapshot_properties=snapshot_properties)
+        self.delete(delete_filter=overwrite_filter, 
case_sensitive=case_sensitive, snapshot_properties=snapshot_properties)
 
         with 
self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as 
update_snapshot:
             # skip writing data files if the dataframe is empty
@@ -470,11 +470,16 @@ class Transaction:
                 for data_file in data_files:
                     update_snapshot.append_data_file(data_file)
 
-    def delete(self, delete_filter: Union[str, BooleanExpression], 
snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
+    def delete(
+        self,
+        delete_filter: Union[str, BooleanExpression],
+        snapshot_properties: Dict[str, str] = EMPTY_DICT,
+        case_sensitive: bool = True,
+    ) -> None:
         """
         Shorthand for deleting record from a table.
 
-        An deletee may produce zero or more snapshots based on the operation:
+        A delete may produce zero or more snapshots based on the operation:
 
             - DELETE: In case existing Parquet files can be dropped completely.
             - REPLACE: In case existing Parquet files need to be rewritten
@@ -482,6 +487,7 @@ class Transaction:
         Args:
             delete_filter: A boolean expression to delete rows from a table
             snapshot_properties: Custom properties to be added to the snapshot 
summary
+            case_sensitive: A bool determine if the provided `delete_filter` 
is case-sensitive
         """
         from pyiceberg.io.pyarrow import (
             ArrowScan,
@@ -499,14 +505,14 @@ class Transaction:
             delete_filter = _parse_row_filter(delete_filter)
 
         with 
self.update_snapshot(snapshot_properties=snapshot_properties).delete() as 
delete_snapshot:
-            delete_snapshot.delete_by_predicate(delete_filter)
+            delete_snapshot.delete_by_predicate(delete_filter, case_sensitive)
 
         # Check if there are any files that require an actual rewrite of a 
data file
         if delete_snapshot.rewrites_needed is True:
-            bound_delete_filter = bind(self.table_metadata.schema(), 
delete_filter, case_sensitive=True)
+            bound_delete_filter = bind(self.table_metadata.schema(), 
delete_filter, case_sensitive)
             preserve_row_filter = 
_expression_to_complementary_pyarrow(bound_delete_filter)
 
-            files = self._scan(row_filter=delete_filter).plan_files()
+            files = self._scan(row_filter=delete_filter, 
case_sensitive=case_sensitive).plan_files()
 
             commit_uuid = uuid.uuid4()
             counter = itertools.count(0)
@@ -988,6 +994,7 @@ class Table:
         df: pa.Table,
         overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
         snapshot_properties: Dict[str, str] = EMPTY_DICT,
+        case_sensitive: bool = True,
     ) -> None:
         """
         Shorthand for overwriting the table with a PyArrow table.
@@ -1003,12 +1010,18 @@ class Table:
             overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
                               or a boolean expression in case of a partial 
overwrite
             snapshot_properties: Custom properties to be added to the snapshot 
summary
+            case_sensitive: A bool determine if the provided 
`overwrite_filter` is case-sensitive
         """
         with self.transaction() as tx:
-            tx.overwrite(df=df, overwrite_filter=overwrite_filter, 
snapshot_properties=snapshot_properties)
+            tx.overwrite(
+                df=df, overwrite_filter=overwrite_filter, 
case_sensitive=case_sensitive, snapshot_properties=snapshot_properties
+            )
 
     def delete(
-        self, delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, 
snapshot_properties: Dict[str, str] = EMPTY_DICT
+        self,
+        delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
+        snapshot_properties: Dict[str, str] = EMPTY_DICT,
+        case_sensitive: bool = True,
     ) -> None:
         """
         Shorthand for deleting rows from the table.
@@ -1016,9 +1029,10 @@ class Table:
         Args:
             delete_filter: The predicate that used to remove rows
             snapshot_properties: Custom properties to be added to the snapshot 
summary
+            case_sensitive: A bool determine if the provided `delete_filter` 
is case-sensitive
         """
         with self.transaction() as tx:
-            tx.delete(delete_filter=delete_filter, 
snapshot_properties=snapshot_properties)
+            tx.delete(delete_filter=delete_filter, 
case_sensitive=case_sensitive, snapshot_properties=snapshot_properties)
 
     def add_files(
         self, file_paths: List[str], snapshot_properties: Dict[str, str] = 
EMPTY_DICT, check_duplicate_files: bool = True
@@ -1311,7 +1325,7 @@ def _match_deletes_to_data_file(data_entry: 
ManifestEntry, positional_delete_ent
 
 class DataScan(TableScan):
     def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
-        project = inclusive_projection(self.table_metadata.schema(), 
self.table_metadata.specs()[spec_id])
+        project = inclusive_projection(self.table_metadata.schema(), 
self.table_metadata.specs()[spec_id], self.case_sensitive)
         return project(self.row_filter)
 
     @cached_property
diff --git a/pyiceberg/table/update/snapshot.py 
b/pyiceberg/table/update/snapshot.py
index 47e5fc55..c0d0056e 100644
--- a/pyiceberg/table/update/snapshot.py
+++ b/pyiceberg/table/update/snapshot.py
@@ -318,6 +318,7 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]):
     """
 
     _predicate: BooleanExpression
+    _case_sensitive: bool
 
     def __init__(
         self,
@@ -329,6 +330,7 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]):
     ):
         super().__init__(operation, transaction, io, commit_uuid, 
snapshot_properties)
         self._predicate = AlwaysFalse()
+        self._case_sensitive = True
 
     def _commit(self) -> UpdatesAndRequirements:
         # Only produce a commit when there is something to delete
@@ -340,7 +342,7 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]):
     def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
         schema = self._transaction.table_metadata.schema()
         spec = self._transaction.table_metadata.specs()[spec_id]
-        project = inclusive_projection(schema, spec)
+        project = inclusive_projection(schema, spec, self._case_sensitive)
         return project(self._predicate)
 
     @cached_property
@@ -350,10 +352,11 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]):
     def _build_manifest_evaluator(self, spec_id: int) -> 
Callable[[ManifestFile], bool]:
         schema = self._transaction.table_metadata.schema()
         spec = self._transaction.table_metadata.specs()[spec_id]
-        return manifest_evaluator(spec, schema, 
self.partition_filters[spec_id], case_sensitive=True)
+        return manifest_evaluator(spec, schema, 
self.partition_filters[spec_id], self._case_sensitive)
 
-    def delete_by_predicate(self, predicate: BooleanExpression) -> None:
+    def delete_by_predicate(self, predicate: BooleanExpression, 
case_sensitive: bool = True) -> None:
         self._predicate = Or(self._predicate, predicate)
+        self._case_sensitive = case_sensitive
 
     @cached_property
     def _compute_deletes(self) -> Tuple[List[ManifestFile], 
List[ManifestEntry], bool]:
@@ -376,8 +379,10 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]):
             )
 
         manifest_evaluators: Dict[int, Callable[[ManifestFile], bool]] = 
KeyDefaultDict(self._build_manifest_evaluator)
-        strict_metrics_evaluator = _StrictMetricsEvaluator(schema, 
self._predicate, case_sensitive=True).eval
-        inclusive_metrics_evaluator = _InclusiveMetricsEvaluator(schema, 
self._predicate, case_sensitive=True).eval
+        strict_metrics_evaluator = _StrictMetricsEvaluator(schema, 
self._predicate, case_sensitive=self._case_sensitive).eval
+        inclusive_metrics_evaluator = _InclusiveMetricsEvaluator(
+            schema, self._predicate, case_sensitive=self._case_sensitive
+        ).eval
 
         existing_manifests = []
         total_deleted_entries = []
diff --git a/tests/integration/test_deletes.py 
b/tests/integration/test_deletes.py
index 2cdf9916..affc480f 100644
--- a/tests/integration/test_deletes.py
+++ b/tests/integration/test_deletes.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint:disable=redefined-outer-name
 from datetime import datetime
-from typing import List
+from typing import Generator, List
 
 import pyarrow as pa
 import pytest
@@ -28,9 +28,10 @@ from pyiceberg.expressions import AlwaysTrue, EqualTo
 from pyiceberg.manifest import ManifestEntryStatus
 from pyiceberg.partitioning import PartitionField, PartitionSpec
 from pyiceberg.schema import Schema
+from pyiceberg.table import Table
 from pyiceberg.table.snapshots import Operation, Summary
 from pyiceberg.transforms import IdentityTransform
-from pyiceberg.types import FloatType, IntegerType, LongType, NestedField, 
TimestampType
+from pyiceberg.types import FloatType, IntegerType, LongType, NestedField, 
StringType, TimestampType
 
 
 def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None:
@@ -38,6 +39,24 @@ def run_spark_commands(spark: SparkSession, sqls: List[str]) 
-> None:
         spark.sql(sql)
 
 
[email protected]()
+def test_table(session_catalog: RestCatalog) -> Generator[Table, None, None]:
+    identifier = "default.__test_table"
+    arrow_table = pa.Table.from_arrays([pa.array([1, 2, 3, 4, 5]), 
pa.array(["a", "b", "c", "d", "e"])], names=["idx", "value"])
+    test_table = session_catalog.create_table(
+        identifier,
+        schema=Schema(
+            NestedField(1, "idx", LongType()),
+            NestedField(2, "value", StringType()),
+        ),
+    )
+    test_table.append(arrow_table)
+
+    yield test_table
+
+    session_catalog.drop_table(identifier)
+
+
 @pytest.mark.integration
 @pytest.mark.parametrize("format_version", [1, 2])
 def test_partitioned_table_delete_full_file(spark: SparkSession, 
session_catalog: RestCatalog, format_version: int) -> None:
@@ -770,3 +789,114 @@ def 
test_delete_after_partition_evolution_from_partitioned(session_catalog: Rest
 
     # Expect 8 records: 10 records - 2
     assert len(tbl.scan().to_arrow()) == 8
+
+
[email protected]
+def test_delete_with_filter_case_sensitive_by_default(test_table: Table) -> 
None:
+    record_to_delete = {"idx": 2, "value": "b"}
+    assert record_to_delete in test_table.scan().to_arrow().to_pylist()
+
+    with pytest.raises(ValueError) as e:
+        test_table.delete(f"Idx == {record_to_delete['idx']}")
+    assert "Could not find field with name Idx" in str(e.value)
+    assert record_to_delete in test_table.scan().to_arrow().to_pylist()
+
+    test_table.delete(f"idx == {record_to_delete['idx']}")
+    assert record_to_delete not in test_table.scan().to_arrow().to_pylist()
+
+
[email protected]
+def test_delete_with_filter_case_sensitive(test_table: Table) -> None:
+    record_to_delete = {"idx": 2, "value": "b"}
+    assert record_to_delete in test_table.scan().to_arrow().to_pylist()
+
+    with pytest.raises(ValueError) as e:
+        test_table.delete(f"Idx == {record_to_delete['idx']}", 
case_sensitive=True)
+    assert "Could not find field with name Idx" in str(e.value)
+    assert record_to_delete in test_table.scan().to_arrow().to_pylist()
+
+    test_table.delete(f"idx == {record_to_delete['idx']}", case_sensitive=True)
+    assert record_to_delete not in test_table.scan().to_arrow().to_pylist()
+
+
[email protected]
+def test_delete_with_filter_case_insensitive(test_table: Table) -> None:
+    record_to_delete_1 = {"idx": 2, "value": "b"}
+    record_to_delete_2 = {"idx": 3, "value": "c"}
+    assert record_to_delete_1 in test_table.scan().to_arrow().to_pylist()
+    assert record_to_delete_2 in test_table.scan().to_arrow().to_pylist()
+
+    test_table.delete(f"Idx == {record_to_delete_1['idx']}", 
case_sensitive=False)
+    assert record_to_delete_1 not in test_table.scan().to_arrow().to_pylist()
+
+    test_table.delete(f"idx == {record_to_delete_2['idx']}", 
case_sensitive=False)
+    assert record_to_delete_2 not in test_table.scan().to_arrow().to_pylist()
+
+
[email protected]
+def test_overwrite_with_filter_case_sensitive_by_default(test_table: Table) -> 
None:
+    record_to_overwrite = {"idx": 2, "value": "b"}
+    assert record_to_overwrite in test_table.scan().to_arrow().to_pylist()
+
+    new_record_to_insert = {"idx": 10, "value": "x"}
+    new_table = pa.Table.from_arrays(
+        [
+            pa.array([new_record_to_insert["idx"]]),
+            pa.array([new_record_to_insert["value"]]),
+        ],
+        names=["idx", "value"],
+    )
+
+    with pytest.raises(ValueError) as e:
+        test_table.overwrite(df=new_table, overwrite_filter=f"Idx == 
{record_to_overwrite['idx']}")
+    assert "Could not find field with name Idx" in str(e.value)
+    assert record_to_overwrite in test_table.scan().to_arrow().to_pylist()
+    assert new_record_to_insert not in test_table.scan().to_arrow().to_pylist()
+
+    test_table.overwrite(df=new_table, overwrite_filter=f"idx == 
{record_to_overwrite['idx']}")
+    assert record_to_overwrite not in test_table.scan().to_arrow().to_pylist()
+    assert new_record_to_insert in test_table.scan().to_arrow().to_pylist()
+
+
[email protected]
+def test_overwrite_with_filter_case_sensitive(test_table: Table) -> None:
+    record_to_overwrite = {"idx": 2, "value": "b"}
+    assert record_to_overwrite in test_table.scan().to_arrow().to_pylist()
+
+    new_record_to_insert = {"idx": 10, "value": "x"}
+    new_table = pa.Table.from_arrays(
+        [
+            pa.array([new_record_to_insert["idx"]]),
+            pa.array([new_record_to_insert["value"]]),
+        ],
+        names=["idx", "value"],
+    )
+
+    with pytest.raises(ValueError) as e:
+        test_table.overwrite(df=new_table, overwrite_filter=f"Idx == 
{record_to_overwrite['idx']}", case_sensitive=True)
+    assert "Could not find field with name Idx" in str(e.value)
+    assert record_to_overwrite in test_table.scan().to_arrow().to_pylist()
+    assert new_record_to_insert not in test_table.scan().to_arrow().to_pylist()
+
+    test_table.overwrite(df=new_table, overwrite_filter=f"idx == 
{record_to_overwrite['idx']}", case_sensitive=True)
+    assert record_to_overwrite not in test_table.scan().to_arrow().to_pylist()
+    assert new_record_to_insert in test_table.scan().to_arrow().to_pylist()
+
+
[email protected]
+def test_overwrite_with_filter_case_insensitive(test_table: Table) -> None:
+    record_to_overwrite = {"idx": 2, "value": "b"}
+    assert record_to_overwrite in test_table.scan().to_arrow().to_pylist()
+
+    new_record_to_insert = {"idx": 10, "value": "x"}
+    new_table = pa.Table.from_arrays(
+        [
+            pa.array([new_record_to_insert["idx"]]),
+            pa.array([new_record_to_insert["value"]]),
+        ],
+        names=["idx", "value"],
+    )
+
+    test_table.overwrite(df=new_table, overwrite_filter=f"Idx == 
{record_to_overwrite['idx']}", case_sensitive=False)
+    assert record_to_overwrite not in test_table.scan().to_arrow().to_pylist()
+    assert new_record_to_insert in test_table.scan().to_arrow().to_pylist()
diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py
index f8bc57bb..0279c219 100644
--- a/tests/integration/test_reads.py
+++ b/tests/integration/test_reads.py
@@ -621,6 +621,50 @@ def test_filter_on_new_column(catalog: Catalog) -> None:
     assert arrow_table["b"].to_pylist() == [None]
 
 
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_filter_case_sensitive_by_default(catalog: Catalog) -> None:
+    test_table_add_column = catalog.load_table("default.test_table_add_column")
+    arrow_table = test_table_add_column.scan().to_arrow()
+    assert "2" in arrow_table["b"].to_pylist()
+
+    arrow_table = test_table_add_column.scan(row_filter="b == '2'").to_arrow()
+    assert arrow_table["b"].to_pylist() == ["2"]
+
+    with pytest.raises(ValueError) as e:
+        _ = test_table_add_column.scan(row_filter="B == '2'").to_arrow()
+    assert "Could not find field with name B" in str(e.value)
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_filter_case_sensitive(catalog: Catalog) -> None:
+    test_table_add_column = catalog.load_table("default.test_table_add_column")
+    arrow_table = test_table_add_column.scan().to_arrow()
+    assert "2" in arrow_table["b"].to_pylist()
+
+    arrow_table = test_table_add_column.scan(row_filter="b == '2'", 
case_sensitive=True).to_arrow()
+    assert arrow_table["b"].to_pylist() == ["2"]
+
+    with pytest.raises(ValueError) as e:
+        _ = test_table_add_column.scan(row_filter="B == '2'", 
case_sensitive=True).to_arrow()
+    assert "Could not find field with name B" in str(e.value)
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_filter_case_insensitive(catalog: Catalog) -> None:
+    test_table_add_column = catalog.load_table("default.test_table_add_column")
+    arrow_table = test_table_add_column.scan().to_arrow()
+    assert "2" in arrow_table["b"].to_pylist()
+
+    arrow_table = test_table_add_column.scan(row_filter="b == '2'", 
case_sensitive=False).to_arrow()
+    assert arrow_table["b"].to_pylist() == ["2"]
+
+    arrow_table = test_table_add_column.scan(row_filter="B == '2'", 
case_sensitive=False).to_arrow()
+    assert arrow_table["b"].to_pylist() == ["2"]
+
+
 @pytest.mark.integration
 @pytest.mark.parametrize("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
 def test_upgrade_table_version(catalog: Catalog) -> None:

Reply via email to