This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 8b660f8  Fix TruncateTransform for falsey values (#276)
8b660f8 is described below

commit 8b660f8fe93596bab6b1c34a6c969078b220419e
Author: Sung Yun <[email protected]>
AuthorDate: Thu Jan 18 00:51:24 2024 -0500

    Fix TruncateTransform for falsey values (#276)
---
 pyiceberg/transforms.py  | 2 +-
 tests/test_transforms.py | 5 +++--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py
index 2c91f1c..4386547 100644
--- a/pyiceberg/transforms.py
+++ b/pyiceberg/transforms.py
@@ -613,7 +613,7 @@ class TruncateTransform(Transform[S, S]):
         else:
             raise ValueError(f"Cannot truncate for type: {source}")
 
-        return lambda v: truncate_func(v) if v else None
+        return lambda v: truncate_func(v) if v is not None else None
 
     def satisfies_order_of(self, other: Transform[S, T]) -> bool:
         if self == other:
diff --git a/tests/test_transforms.py b/tests/test_transforms.py
index d7a1478..d4910cc 100644
--- a/tests/test_transforms.py
+++ b/tests/test_transforms.py
@@ -362,7 +362,7 @@ def test_identity_method(type_var: PrimitiveType) -> None:
 @pytest.mark.parametrize("type_var", [IntegerType(), LongType()])
 @pytest.mark.parametrize(
     "input_var,expected",
-    [(1, 0), (5, 0), (9, 0), (10, 10), (11, 10), (-1, -10), (-10, -10), (-12, 
-20)],
+    [(1, 0), (5, 0), (9, 0), (10, 10), (11, 10), (-1, -10), (-10, -10), (-12, 
-20), (0, 0)],
 )
 def test_truncate_integer(type_var: PrimitiveType, input_var: int, expected: 
int) -> None:
     trunc = TruncateTransform(10)  # type: ignore
@@ -377,6 +377,7 @@ def test_truncate_integer(type_var: PrimitiveType, 
input_var: int, expected: int
         (Decimal("12.29"), Decimal("12.20")),
         (Decimal("0.05"), Decimal("0.00")),
         (Decimal("-0.05"), Decimal("-0.10")),
+        (Decimal("0.0"), Decimal("0.0")),
     ],
 )
 def test_truncate_decimal(input_var: Decimal, expected: Decimal) -> None:
@@ -384,7 +385,7 @@ def test_truncate_decimal(input_var: Decimal, expected: 
Decimal) -> None:
     assert trunc.transform(DecimalType(9, 2))(input_var) == expected
 
 
[email protected]("input_var,expected", [("abcdefg", "abcde"), ("abc", 
"abc")])
[email protected]("input_var,expected", [("abcdefg", "abcde"), ("abc", 
"abc"), ("", "")])
 def test_truncate_string(input_var: str, expected: str) -> None:
     trunc = TruncateTransform(5)  # type: ignore
     assert trunc.transform(StringType())(input_var) == expected

Reply via email to