This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new bfdd9b2ad7 Python: Implement project in Transform implementations
(#6128)
bfdd9b2ad7 is described below
commit bfdd9b2ad72e9d54fcec6cc2fe4a8466b35d2809
Author: Fokko Driesprong <[email protected]>
AuthorDate: Tue Nov 22 00:37:27 2022 +0100
Python: Implement project in Transform implementations (#6128)
---
python/pyiceberg/expressions/__init__.py | 68 ++++++
python/pyiceberg/expressions/literals.py | 29 +++
python/pyiceberg/transforms.py | 214 ++++++++++++++--
python/tests/expressions/test_literals.py | 18 ++
python/tests/test_transforms.py | 389 +++++++++++++++++++++++++++++-
5 files changed, 701 insertions(+), 17 deletions(-)
diff --git a/python/pyiceberg/expressions/__init__.py
b/python/pyiceberg/expressions/__init__.py
index f39e15006b..5d25d33915 100644
--- a/python/pyiceberg/expressions/__init__.py
+++ b/python/pyiceberg/expressions/__init__.py
@@ -309,6 +309,11 @@ class BoundPredicate(Generic[L], Bound, BooleanExpression,
ABC):
return self.term == other.term
return False
+ @property
+ @abstractmethod
+ def as_unbound(self) -> Type[UnboundPredicate[Any]]:
+ ...
+
class UnboundPredicate(Generic[L], Unbound[BooleanExpression],
BooleanExpression, ABC):
term: UnboundTerm[Any]
@@ -347,6 +352,11 @@ class BoundUnaryPredicate(BoundPredicate[L], ABC):
def __repr__(self) -> str:
return f"{str(self.__class__.__name__)}(term={repr(self.term)})"
+ @property
+ @abstractmethod
+ def as_unbound(self) -> Type[UnaryPredicate]:
+ ...
+
class BoundIsNull(BoundUnaryPredicate[L]):
def __new__(cls, term: BoundTerm[L]): # pylint: disable=W0221
@@ -357,6 +367,10 @@ class BoundIsNull(BoundUnaryPredicate[L]):
def __invert__(self) -> BoundNotNull[L]:
return BoundNotNull(self.term)
+ @property
+ def as_unbound(self) -> Type[IsNull]:
+ return IsNull
+
class BoundNotNull(BoundUnaryPredicate[L]):
def __new__(cls, term: BoundTerm[L]): # pylint: disable=W0221
@@ -367,6 +381,10 @@ class BoundNotNull(BoundUnaryPredicate[L]):
def __invert__(self) -> BoundIsNull[L]:
return BoundIsNull(self.term)
+ @property
+ def as_unbound(self) -> Type[NotNull]:
+ return NotNull
+
class IsNull(UnaryPredicate):
def __invert__(self) -> NotNull:
@@ -396,6 +414,10 @@ class BoundIsNaN(BoundUnaryPredicate[L]):
def __invert__(self) -> BoundNotNaN[L]:
return BoundNotNaN(self.term)
+ @property
+ def as_unbound(self) -> Type[IsNaN]:
+ return IsNaN
+
class BoundNotNaN(BoundUnaryPredicate[L]):
def __new__(cls, term: BoundTerm[L]): # pylint: disable=W0221
@@ -407,6 +429,10 @@ class BoundNotNaN(BoundUnaryPredicate[L]):
def __invert__(self) -> BoundIsNaN[L]:
return BoundIsNaN(self.term)
+ @property
+ def as_unbound(self) -> Type[NotNaN]:
+ return NotNaN
+
class IsNaN(UnaryPredicate):
def __invert__(self) -> NotNaN:
@@ -477,6 +503,11 @@ class BoundSetPredicate(BoundPredicate[L], ABC):
def __eq__(self, other: Any) -> bool:
return self.term == other.term and self.literals == other.literals if
isinstance(other, BoundSetPredicate) else False
+ @property
+ @abstractmethod
+ def as_unbound(self) -> Type[SetPredicate[L]]:
+ ...
+
class BoundIn(BoundSetPredicate[L]):
def __new__(cls, term: BoundTerm[L], literals: Set[Literal[L]]): #
pylint: disable=W0221
@@ -494,6 +525,10 @@ class BoundIn(BoundSetPredicate[L]):
def __eq__(self, other: Any) -> bool:
return self.term == other.term and self.literals == other.literals if
isinstance(other, BoundIn) else False
+ @property
+ def as_unbound(self) -> Type[In[L]]:
+ return In
+
class BoundNotIn(BoundSetPredicate[L]):
def __new__( # pylint: disable=W0221
@@ -512,6 +547,10 @@ class BoundNotIn(BoundSetPredicate[L]):
def __invert__(self) -> BoundIn[L]:
return BoundIn(self.term, self.literals)
+ @property
+ def as_unbound(self) -> Type[NotIn[L]]:
+ return NotIn
+
class In(SetPredicate[L]):
def __new__(
@@ -601,36 +640,65 @@ class BoundLiteralPredicate(BoundPredicate[L], ABC):
def __repr__(self) -> str:
return f"{str(self.__class__.__name__)}(term={repr(self.term)},
literal={repr(self.literal)})"
+ @property
+ @abstractmethod
+ def as_unbound(self) -> Type[LiteralPredicate[L]]:
+ ...
+
class BoundEqualTo(BoundLiteralPredicate[L]):
def __invert__(self) -> BoundNotEqualTo[L]:
return BoundNotEqualTo[L](self.term, self.literal)
+ @property
+ def as_unbound(self) -> Type[EqualTo[L]]:
+ return EqualTo
+
class BoundNotEqualTo(BoundLiteralPredicate[L]):
def __invert__(self) -> BoundEqualTo[L]:
return BoundEqualTo[L](self.term, self.literal)
+ @property
+ def as_unbound(self) -> Type[NotEqualTo[L]]:
+ return NotEqualTo
+
class BoundGreaterThanOrEqual(BoundLiteralPredicate[L]):
def __invert__(self) -> BoundLessThan[L]:
return BoundLessThan[L](self.term, self.literal)
+ @property
+ def as_unbound(self) -> Type[GreaterThanOrEqual[L]]:
+ return GreaterThanOrEqual[L]
+
class BoundGreaterThan(BoundLiteralPredicate[L]):
def __invert__(self) -> BoundLessThanOrEqual[L]:
return BoundLessThanOrEqual(self.term, self.literal)
+ @property
+ def as_unbound(self) -> Type[GreaterThan[L]]:
+ return GreaterThan[L]
+
class BoundLessThan(BoundLiteralPredicate[L]):
def __invert__(self) -> BoundGreaterThanOrEqual[L]:
return BoundGreaterThanOrEqual[L](self.term, self.literal)
+ @property
+ def as_unbound(self) -> Type[LessThan[L]]:
+ return LessThan[L]
+
class BoundLessThanOrEqual(BoundLiteralPredicate[L]):
def __invert__(self) -> BoundGreaterThan[L]:
return BoundGreaterThan[L](self.term, self.literal)
+ @property
+ def as_unbound(self) -> Type[LessThanOrEqual[L]]:
+ return LessThanOrEqual[L]
+
class EqualTo(LiteralPredicate[L]):
def __invert__(self) -> NotEqualTo[L]:
diff --git a/python/pyiceberg/expressions/literals.py
b/python/pyiceberg/expressions/literals.py
index 44ab9a15e6..c59c6bcf8d 100644
--- a/python/pyiceberg/expressions/literals.py
+++ b/python/pyiceberg/expressions/literals.py
@@ -53,6 +53,7 @@ from pyiceberg.utils.datetime import (
timestamp_to_micros,
timestamptz_to_micros,
)
+from pyiceberg.utils.decimal import decimal_to_unscaled, unscaled_to_decimal
from pyiceberg.utils.singleton import Singleton
@@ -210,6 +211,12 @@ class LongLiteral(Literal[int]):
def to(self, type_var: IcebergType) -> Literal: # type: ignore
raise TypeError(f"Cannot convert LongLiteral into {type_var}")
+ def increment(self) -> Literal[int]:
+ return LongLiteral(self.value + 1)
+
+ def decrement(self) -> Literal[int]:
+ return LongLiteral(self.value - 1)
+
@to.register(LongType)
def _(self, _: LongType) -> Literal[int]:
return self
@@ -319,6 +326,12 @@ class DateLiteral(Literal[int]):
def __init__(self, value: int):
super().__init__(value, int)
+ def increment(self) -> Literal[int]:
+ return DateLiteral(self.value + 1)
+
+ def decrement(self) -> Literal[int]:
+ return DateLiteral(self.value - 1)
+
@singledispatchmethod
def to(self, type_var: IcebergType) -> Literal: # type: ignore
raise TypeError(f"Cannot convert DateLiteral into {type_var}")
@@ -345,6 +358,12 @@ class TimestampLiteral(Literal[int]):
def __init__(self, value: int):
super().__init__(value, int)
+ def increment(self) -> Literal[int]:
+ return TimestampLiteral(self.value + 1)
+
+ def decrement(self) -> Literal[int]:
+ return TimestampLiteral(self.value - 1)
+
@singledispatchmethod
def to(self, type_var: IcebergType) -> Literal: # type: ignore
raise TypeError(f"Cannot convert TimestampLiteral into {type_var}")
@@ -362,6 +381,16 @@ class DecimalLiteral(Literal[Decimal]):
def __init__(self, value: Decimal):
super().__init__(value, Decimal)
+ def increment(self) -> Literal[Decimal]:
+ original_scale = abs(self.value.as_tuple().exponent)
+ unscaled = decimal_to_unscaled(self.value)
+ return DecimalLiteral(unscaled_to_decimal(unscaled + 1,
original_scale))
+
+ def decrement(self) -> Literal[Decimal]:
+ original_scale = abs(self.value.as_tuple().exponent)
+ unscaled = decimal_to_unscaled(self.value)
+ return DecimalLiteral(unscaled_to_decimal(unscaled - 1,
original_scale))
+
@singledispatchmethod
def to(self, type_var: IcebergType) -> Literal: # type: ignore
raise TypeError(f"Cannot convert DecimalLiteral into {type_var}")
diff --git a/python/pyiceberg/transforms.py b/python/pyiceberg/transforms.py
index 14d76fd8cb..79f901913c 100644
--- a/python/pyiceberg/transforms.py
+++ b/python/pyiceberg/transforms.py
@@ -20,18 +20,41 @@ import struct
from abc import ABC, abstractmethod
from enum import IntEnum
from functools import singledispatch
-from typing import (
- Any,
- Callable,
- Generic,
- Literal,
- Optional,
- TypeVar,
-)
+from typing import Any, Callable, Generic
+from typing import Literal as LiteralType
+from typing import Optional, TypeVar
import mmh3
from pydantic import Field, PositiveInt, PrivateAttr
+from pyiceberg.expressions import (
+ BoundEqualTo,
+ BoundGreaterThan,
+ BoundGreaterThanOrEqual,
+ BoundIn,
+ BoundLessThan,
+ BoundLessThanOrEqual,
+ BoundLiteralPredicate,
+ BoundNotIn,
+ BoundPredicate,
+ BoundSetPredicate,
+ BoundTerm,
+ BoundUnaryPredicate,
+ EqualTo,
+ GreaterThanOrEqual,
+ LessThanOrEqual,
+ Reference,
+ UnboundPredicate,
+)
+from pyiceberg.expressions.literals import (
+ DateLiteral,
+ DecimalLiteral,
+ Literal,
+ LongLiteral,
+ TimestampLiteral,
+ literal,
+)
+from pyiceberg.typedef import L
from pyiceberg.types import (
BinaryType,
DateType,
@@ -68,6 +91,11 @@ BUCKET_PARSER = ParseNumberFromBrackets(BUCKET)
TRUNCATE_PARSER = ParseNumberFromBrackets(TRUNCATE)
+def _transform_literal(func: Callable[[L], L], lit: Literal[L]) -> Literal[L]:
+ """Small helper to upwrap the value from the literal, and wrap it again"""
+ return literal(func(lit.value))
+
+
class Transform(IcebergBaseModel, ABC, Generic[S, T]):
"""Transform base class for concrete transforms.
@@ -121,6 +149,10 @@ class Transform(IcebergBaseModel, ABC, Generic[S, T]):
def result_type(self, source: IcebergType) -> IcebergType:
...
+ @abstractmethod
+ def project(self, name: str, pred: BoundPredicate[L]) ->
Optional[UnboundPredicate[Any]]:
+ ...
+
@property
def preserves_order(self) -> bool:
return False
@@ -173,6 +205,23 @@ class BucketTransform(Transform[S, int]):
def result_type(self, source: IcebergType) -> IcebergType:
return IntegerType()
+ def project(self, name: str, pred: BoundPredicate[L]) ->
Optional[UnboundPredicate[Any]]:
+ transformer = self.transform(pred.term.ref().field.field_type)
+
+ if isinstance(pred.term, BoundTransform):
+ return _project_transform_predicate(self, name, pred)
+ elif isinstance(pred, BoundUnaryPredicate):
+ return pred.as_unbound(Reference(name))
+ elif isinstance(pred, BoundEqualTo):
+ return pred.as_unbound(Reference(name),
_transform_literal(transformer, pred.literal))
+ elif isinstance(pred, BoundIn): # NotIn can't be projected
+ return pred.as_unbound(Reference(name),
{_transform_literal(transformer, literal) for literal in pred.literals})
+ else:
+ # - Comparison predicates can't be projected, notEq can't be
projected
+ # - Small ranges can be projected:
+ # For example, (x > 0) and (x < 3) can be turned into in({1, 2})
and projected.
+ return None
+
def can_transform(self, source: IcebergType) -> bool:
return type(source) in {
IntegerType,
@@ -246,9 +295,26 @@ class TimeTransform(Transform[S, int], Singleton):
def satisfies_order_of(self, other: Transform[S, T]) -> bool:
return self.granularity <= other.granularity if hasattr(other,
"granularity") else False
- def result_type(self, source: IcebergType) -> IcebergType:
+ def result_type(self, source: IcebergType) -> IntegerType:
return IntegerType()
+ @abstractmethod
+ def transform(self, source: IcebergType) -> Callable[[Optional[Any]],
Optional[int]]:
+ ...
+
+ def project(self, name: str, pred: BoundPredicate[L]) ->
Optional[UnboundPredicate[Any]]:
+ transformer = self.transform(pred.term.ref().field.field_type)
+ if isinstance(pred.term, BoundTransform):
+ return _project_transform_predicate(self, name, pred)
+ elif isinstance(pred, BoundUnaryPredicate):
+ return pred.as_unbound(Reference(name))
+ elif isinstance(pred, BoundLiteralPredicate):
+ return _truncate_number(name, pred, transformer)
+ elif isinstance(pred, BoundIn): # NotIn can't be projected
+ return _set_apply_transform(name, pred, transformer)
+ else:
+ return None
+
@property
def dedup_name(self) -> str:
return "time"
@@ -267,7 +333,7 @@ class YearTransform(TimeTransform[S]):
47
"""
- __root__: Literal["year"] = Field(default="year")
+ __root__: LiteralType["year"] = Field(default="year") # noqa: F821
def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[int]]:
source_type = type(source)
@@ -313,7 +379,7 @@ class MonthTransform(TimeTransform[S]):
575
"""
- __root__: Literal["month"] = Field(default="month")
+ __root__: LiteralType["month"] = Field(default="month") # noqa: F821
def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[int]]:
source_type = type(source)
@@ -359,7 +425,7 @@ class DayTransform(TimeTransform[S]):
17501
"""
- __root__: Literal["day"] = Field(default="day")
+ __root__: LiteralType["day"] = Field(default="day") # noqa: F821
def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[int]]:
source_type = type(source)
@@ -408,7 +474,7 @@ class HourTransform(TimeTransform[S]):
420042
"""
- __root__: Literal["hour"] = Field(default="hour")
+ __root__: LiteralType["hour"] = Field(default="hour") # noqa: F821
def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[int]]:
if type(source) in {TimestampType, TimestamptzType}:
@@ -452,7 +518,7 @@ class IdentityTransform(Transform[S, S]):
'hello-world'
"""
- __root__: Literal["identity"] = Field(default="identity")
+ __root__: LiteralType["identity"] = Field(default="identity") # noqa: F821
def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[S]]:
return lambda v: v
@@ -463,6 +529,18 @@ class IdentityTransform(Transform[S, S]):
def result_type(self, source: IcebergType) -> IcebergType:
return source
+ def project(self, name: str, pred: BoundPredicate[L]) ->
Optional[UnboundPredicate[Any]]:
+ if isinstance(pred.term, BoundTransform):
+ return _project_transform_predicate(self, name, pred)
+ elif isinstance(pred, BoundUnaryPredicate):
+ return pred.as_unbound(Reference(name))
+ elif isinstance(pred, BoundEqualTo):
+ return pred.as_unbound(Reference(name), pred.literal)
+ elif isinstance(pred, (BoundIn, BoundNotIn)):
+ return pred.as_unbound(Reference(name), pred.literals)
+ else:
+ raise ValueError(f"Could not project: {self}")
+
@property
def preserves_order(self) -> bool:
return True
@@ -511,6 +589,29 @@ class TruncateTransform(Transform[S, S]):
def source_type(self) -> IcebergType:
return self._source_type
+ def project(self, name: str, pred: BoundPredicate[L]) ->
Optional[UnboundPredicate[Any]]:
+ field_type = pred.term.ref().field.field_type
+
+ if isinstance(pred.term, BoundTransform):
+ return _project_transform_predicate(self, name, pred)
+
+ # Implement startswith and notstartswith for string (and probably
binary)
+ # https://github.com/apache/iceberg/issues/6112
+
+ if isinstance(pred, BoundUnaryPredicate):
+ return pred.as_unbound(Reference(name))
+ elif isinstance(field_type, (IntegerType, LongType, DecimalType)):
+ if isinstance(pred, BoundLiteralPredicate):
+ return _truncate_number(name, pred, self.transform(field_type))
+ elif isinstance(pred, BoundIn):
+ return _set_apply_transform(name, pred,
self.transform(field_type))
+ elif isinstance(field_type, (BinaryType, StringType)):
+ if isinstance(pred, BoundLiteralPredicate):
+ return _truncate_array(name, pred, self.transform(field_type))
+ elif isinstance(pred, BoundIn):
+ return _set_apply_transform(name, pred,
self.transform(field_type))
+ return None
+
@property
def width(self) -> int:
return self._width
@@ -610,7 +711,7 @@ class UnknownTransform(Transform[S, T]):
AttributeError: If the apply method is called.
"""
- __root__: Literal["unknown"] = Field(default="unknown")
+ __root__: LiteralType["unknown"] = Field(default="unknown") # noqa: F821
_transform: str = PrivateAttr()
def __init__(self, transform: str, **data: Any):
@@ -623,9 +724,12 @@ class UnknownTransform(Transform[S, T]):
def can_transform(self, source: IcebergType) -> bool:
return False
- def result_type(self, source: IcebergType) -> IcebergType:
+ def result_type(self, source: IcebergType) -> StringType:
return StringType()
+ def project(self, name: str, pred: BoundPredicate[L]) ->
Optional[UnboundPredicate[Any]]:
+ return None
+
def __repr__(self) -> str:
return f"UnknownTransform(transform={repr(self._transform)})"
@@ -644,8 +748,86 @@ class VoidTransform(Transform[S, None], Singleton):
def result_type(self, source: IcebergType) -> IcebergType:
return source
+ def project(self, name: str, pred: BoundPredicate[L]) ->
Optional[UnboundPredicate[Any]]:
+ return None
+
def to_human_string(self, _: IcebergType, value: Optional[S]) -> str:
return "null"
def __repr__(self) -> str:
return "VoidTransform()"
+
+
+def _truncate_number(
+ name: str, pred: BoundLiteralPredicate[L], transform:
Callable[[Optional[L]], Optional[L]]
+) -> Optional[UnboundPredicate[Any]]:
+ boundary = pred.literal
+
+ if not isinstance(boundary, (LongLiteral, DecimalLiteral, DateLiteral,
TimestampLiteral)):
+ raise ValueError(f"Expected a numeric literal, got: {type(boundary)}")
+
+ if isinstance(pred, BoundLessThan):
+ return LessThanOrEqual(Reference(name), _transform_literal(transform,
boundary.decrement())) # type: ignore
+ elif isinstance(pred, BoundLessThanOrEqual):
+ return LessThanOrEqual(Reference(name), _transform_literal(transform,
boundary))
+ elif isinstance(pred, BoundGreaterThan):
+ return GreaterThanOrEqual(Reference(name),
_transform_literal(transform, boundary.increment())) # type: ignore
+ elif isinstance(pred, BoundGreaterThanOrEqual):
+ return GreaterThanOrEqual(Reference(name),
_transform_literal(transform, boundary))
+ elif isinstance(pred, BoundEqualTo):
+ return EqualTo(Reference(name), _transform_literal(transform,
boundary))
+ else:
+ return None
+
+
+def _truncate_array(
+ name: str, pred: BoundLiteralPredicate[L], transform:
Callable[[Optional[L]], Optional[L]]
+) -> Optional[UnboundPredicate[Any]]:
+ boundary = pred.literal
+
+ if type(pred) in {BoundLessThan, BoundLessThanOrEqual}:
+ return LessThanOrEqual(Reference(name), _transform_literal(transform,
boundary))
+ elif type(pred) in {BoundGreaterThan, BoundGreaterThanOrEqual}:
+ return GreaterThanOrEqual(Reference(name),
_transform_literal(transform, boundary))
+ if isinstance(pred, BoundEqualTo):
+ return EqualTo(Reference(name), _transform_literal(transform,
boundary))
+ else:
+ return None
+
+
+def _project_transform_predicate(
+ transform: Transform[Any, Any], partition_name: str, pred:
BoundPredicate[L]
+) -> Optional[UnboundPredicate[Any]]:
+ term = pred.term
+ if isinstance(term, BoundTransform) and transform == term.transform:
+ return _remove_transform(partition_name, pred)
+ return None
+
+
+def _remove_transform(partition_name: str, pred: BoundPredicate[L]):
+ if isinstance(pred, BoundUnaryPredicate):
+ return pred.as_unbound(Reference(partition_name))
+ elif isinstance(pred, BoundLiteralPredicate):
+ return pred.as_unbound(Reference(partition_name), pred.literal)
+ elif isinstance(pred, (BoundIn, BoundNotIn)):
+ return pred.as_unbound(Reference(partition_name), pred.literals)
+ else:
+ raise ValueError(f"Cannot replace transform in unknown predicate:
{pred}")
+
+
+def _set_apply_transform(name: str, pred: BoundSetPredicate[L], transform:
Callable[[L], L]) -> UnboundPredicate[Any]:
+ literals = pred.literals
+ if isinstance(pred, BoundSetPredicate):
+ return pred.as_unbound(Reference(name), {_transform_literal(transform,
literal) for literal in literals})
+ else:
+ raise ValueError(f"Unknown BoundSetPredicate: {pred}")
+
+
+class BoundTransform(BoundTerm[L]):
+ """A transform expression"""
+
+ transform: Transform[L, Any]
+
+ def __init__(self, term: BoundTerm[L], transform: Transform[L, Any]):
+ self.term: BoundTerm[L] = term
+ self.transform = transform
diff --git a/python/tests/expressions/test_literals.py
b/python/tests/expressions/test_literals.py
index d203195930..efcacc4574 100644
--- a/python/tests/expressions/test_literals.py
+++ b/python/tests/expressions/test_literals.py
@@ -805,6 +805,24 @@ def test_string_to_decimal_type_invalid_value():
assert "Could not convert 18.15 into a decimal(10, 0), scales differ 0 <>
2" in str(e.value)
+def test_decimal_literal_increment():
+ dec = DecimalLiteral(Decimal("10.123"))
+ # Twice to check that we don't mutate the value
+ assert dec.increment() == DecimalLiteral(Decimal("10.124"))
+ assert dec.increment() == DecimalLiteral(Decimal("10.124"))
+ # To check that the scale is still the same
+ assert dec.increment().value.as_tuple() == Decimal("10.124").as_tuple()
+
+
+def test_decimal_literal_dencrement():
+ dec = DecimalLiteral(Decimal("10.123"))
+ # Twice to check that we don't mutate the value
+ assert dec.decrement() == DecimalLiteral(Decimal("10.122"))
+ assert dec.decrement() == DecimalLiteral(Decimal("10.122"))
+ # To check that the scale is still the same
+ assert dec.decrement().value.as_tuple() == Decimal("10.122").as_tuple()
+
+
# __ __ ___
# | \/ |_ _| _ \_ _
# | |\/| | || | _/ || |
diff --git a/python/tests/test_transforms.py b/python/tests/test_transforms.py
index f8bc42e2dc..9395db13c2 100644
--- a/python/tests/test_transforms.py
+++ b/python/tests/test_transforms.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=eval-used,protected-access
+# pylint: disable=eval-used,protected-access,redefined-outer-name
from datetime import date
from decimal import Decimal
from typing import Any
@@ -24,6 +24,31 @@ import mmh3 as mmh3
import pytest
from pyiceberg import transforms
+from pyiceberg.expressions import (
+ BoundEqualTo,
+ BoundGreaterThan,
+ BoundGreaterThanOrEqual,
+ BoundIn,
+ BoundLessThan,
+ BoundLessThanOrEqual,
+ BoundNotIn,
+ BoundNotNull,
+ BoundReference,
+ EqualTo,
+ GreaterThanOrEqual,
+ In,
+ LessThanOrEqual,
+ NotIn,
+ NotNull,
+ Reference,
+)
+from pyiceberg.expressions.literals import (
+ DateLiteral,
+ DecimalLiteral,
+ TimestampLiteral,
+ literal,
+)
+from pyiceberg.schema import Accessor
from pyiceberg.transforms import (
BucketTransform,
DayTransform,
@@ -46,6 +71,7 @@ from pyiceberg.types import (
FloatType,
IntegerType,
LongType,
+ NestedField,
StringType,
TimestampType,
TimestamptzType,
@@ -507,3 +533,364 @@ def test_datetime_transform_str(transform, transform_str):
)
def test_datetime_transform_repr(transform, transform_repr):
assert repr(transform) == transform_repr
+
+
[email protected]
+def bound_reference_str() -> BoundReference[str]:
+ return BoundReference(field=NestedField(1, "field", StringType(),
required=False), accessor=Accessor(position=0, inner=None))
+
+
[email protected]
+def bound_reference_date() -> BoundReference[int]:
+ return BoundReference(field=NestedField(1, "field", DateType(),
required=False), accessor=Accessor(position=0, inner=None))
+
+
[email protected]
+def bound_reference_timestamp() -> BoundReference[int]:
+ return BoundReference(
+ field=NestedField(1, "field", TimestampType(), required=False),
accessor=Accessor(position=0, inner=None)
+ )
+
+
[email protected]
+def bound_reference_decimal() -> BoundReference[Decimal]:
+ return BoundReference(
+ field=NestedField(1, "field", DecimalType(8, 2), required=False),
accessor=Accessor(position=0, inner=None)
+ )
+
+
[email protected]
+def bound_reference_long() -> BoundReference[int]:
+ return BoundReference(
+ field=NestedField(1, "field", DecimalType(8, 2), required=False),
accessor=Accessor(position=0, inner=None)
+ )
+
+
+def test_projection_bucket_unary(bound_reference_str: BoundReference[str]) ->
None:
+ assert BucketTransform(2).project("name",
BoundNotNull(term=bound_reference_str)) == NotNull(term=Reference(name="name"))
+
+
+def test_projection_bucket_literal(bound_reference_str: BoundReference[str])
-> None:
+ assert BucketTransform(2).project("name",
BoundEqualTo(term=bound_reference_str, literal=literal("data"))) == EqualTo(
+ term="name", literal=1
+ )
+
+
+def test_projection_bucket_set_same_bucket(bound_reference_str:
BoundReference[str]) -> None:
+ assert BucketTransform(2).project(
+ "name", BoundIn(term=bound_reference_str, literals={literal("hello"),
literal("world")})
+ ) == EqualTo(term="name", literal=1)
+
+
+def test_projection_bucket_set_in(bound_reference_str: BoundReference[str]) ->
None:
+ assert BucketTransform(3).project(
+ "name", BoundIn(term=bound_reference_str, literals={literal("hello"),
literal("world")})
+ ) == In(term="name", literals={1, 2})
+
+
+def test_projection_bucket_set_not_in(bound_reference_str:
BoundReference[str]) -> None:
+ assert (
+ BucketTransform(3).project("name",
BoundNotIn(term=bound_reference_str, literals={literal("hello"),
literal("world")}))
+ is None
+ )
+
+
+def test_projection_year_unary(bound_reference_date: BoundReference[int]) ->
None:
+ assert YearTransform().project("name",
BoundNotNull(term=bound_reference_date)) == NotNull(term="name")
+
+
+def test_projection_year_literal(bound_reference_date: BoundReference[int]) ->
None:
+ assert YearTransform().project("name",
BoundEqualTo(term=bound_reference_date, literal=DateLiteral(1925))) == EqualTo(
+ term="name", literal=5
+ )
+
+
+def test_projection_year_set_same_year(bound_reference_date:
BoundReference[int]) -> None:
+ assert YearTransform().project(
+ "name", BoundIn(term=bound_reference_date,
literals={DateLiteral(1925), DateLiteral(1926)})
+ ) == EqualTo(term="name", literal=5)
+
+
+def test_projection_year_set_in(bound_reference_date: BoundReference[int]) ->
None:
+ assert YearTransform().project(
+ "name", BoundIn(term=bound_reference_date,
literals={DateLiteral(1925), DateLiteral(2925)})
+ ) == In(term="name", literals={8, 5})
+
+
+def test_projection_year_set_not_in(bound_reference_date: BoundReference[int])
-> None:
+ assert (
+ YearTransform().project("name", BoundNotIn(term=bound_reference_date,
literals={DateLiteral(1925), DateLiteral(2925)}))
+ is None
+ )
+
+
+def test_projection_month_unary(bound_reference_date: BoundReference[int]) ->
None:
+ assert MonthTransform().project("name",
BoundNotNull(term=bound_reference_date)) == NotNull(term="name")
+
+
+def test_projection_month_literal(bound_reference_date: BoundReference[int])
-> None:
+ assert MonthTransform().project("name",
BoundEqualTo(term=bound_reference_date, literal=DateLiteral(1925))) == EqualTo(
+ term="name", literal=63
+ )
+
+
+def test_projection_month_set_same_month(bound_reference_date:
BoundReference[int]) -> None:
+ assert MonthTransform().project(
+ "name", BoundIn(term=bound_reference_date,
literals={DateLiteral(1925), DateLiteral(1926)})
+ ) == EqualTo(term="name", literal=63)
+
+
+def test_projection_month_set_in(bound_reference_date: BoundReference[int]) ->
None:
+ assert MonthTransform().project(
+ "name", BoundIn(term=bound_reference_date,
literals={DateLiteral(1925), DateLiteral(2925)})
+ ) == In(term="name", literals={96, 63})
+
+
+def test_projection_day_month_not_in(bound_reference_date:
BoundReference[int]) -> None:
+ assert (
+ MonthTransform().project("name", BoundNotIn(term=bound_reference_date,
literals={DateLiteral(1925), DateLiteral(2925)}))
+ is None
+ )
+
+
+def test_projection_day_unary(bound_reference_timestamp) -> None:
+ assert DayTransform().project("name",
BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name")
+
+
+def test_projection_day_literal(bound_reference_timestamp) -> None:
+ assert DayTransform().project(
+ "name", BoundEqualTo(term=bound_reference_timestamp,
literal=TimestampLiteral(1667696874000))
+ ) == EqualTo(term="name", literal=19)
+
+
+def test_projection_day_set_same_day(bound_reference_timestamp) -> None:
+ assert DayTransform().project(
+ "name",
+ BoundIn(term=bound_reference_timestamp,
literals={TimestampLiteral(1667696874001), TimestampLiteral(1667696874000)}),
+ ) == EqualTo(term="name", literal=19)
+
+
+def test_projection_day_set_in(bound_reference_timestamp) -> None:
+ assert DayTransform().project(
+ "name",
+ BoundIn(term=bound_reference_timestamp,
literals={TimestampLiteral(1667696874001), TimestampLiteral(1567696874000)}),
+ ) == In(term="name", literals={18, 19})
+
+
+def test_projection_day_set_not_in(bound_reference_timestamp) -> None:
+ assert (
+ DayTransform().project(
+ "name",
+ BoundNotIn(term=bound_reference_timestamp,
literals={TimestampLiteral(1567696874), TimestampLiteral(1667696874)}),
+ )
+ is None
+ )
+
+
+def test_projection_day_human(bound_reference_date: BoundReference[int]) ->
None:
+ date_literal = DateLiteral(17532)
+ assert DayTransform().project("dt",
BoundEqualTo(term=bound_reference_date, literal=date_literal)) == EqualTo(
+ term="dt", literal=17532
+ ) # == 2018, 1, 1
+
+ assert DayTransform().project("dt",
BoundLessThanOrEqual(term=bound_reference_date, literal=date_literal)) ==
LessThanOrEqual(
+ term="dt", literal=17532
+ ) # <= 2018, 1, 1
+
+ assert DayTransform().project("dt",
BoundLessThan(term=bound_reference_date, literal=date_literal)) ==
LessThanOrEqual(
+ term="dt", literal=17531
+ ) # <= 2017, 12, 31
+
+ assert DayTransform().project(
+ "dt", BoundGreaterThanOrEqual(term=bound_reference_date,
literal=date_literal)
+ ) == GreaterThanOrEqual(
+ term="dt", literal=17532
+ ) # >= 2018, 1, 1
+
+ assert DayTransform().project("dt",
BoundGreaterThan(term=bound_reference_date, literal=date_literal)) ==
GreaterThanOrEqual(
+ term="dt", literal=17533
+ ) # >= 2018, 1, 2
+
+
+def test_projection_hour_unary(bound_reference_timestamp) -> None:
+ assert HourTransform().project("name",
BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name")
+
+
+TIMESTAMP_EXAMPLE = 1667696874000000 # Sun Nov 06 2022 01:07:54
+HOUR_IN_MICROSECONDS = 60 * 60 * 1000 * 1000
+
+
+def test_projection_hour_literal(bound_reference_timestamp) -> None:
+ assert HourTransform().project(
+ "name", BoundEqualTo(term=bound_reference_timestamp,
literal=TimestampLiteral(TIMESTAMP_EXAMPLE))
+ ) == EqualTo(term="name", literal=463249)
+
+
+def test_projection_hour_set_same_hour(bound_reference_timestamp) -> None:
+ assert HourTransform().project(
+ "name",
+ BoundIn(
+ term=bound_reference_timestamp,
+ literals={TimestampLiteral(TIMESTAMP_EXAMPLE + 1),
TimestampLiteral(TIMESTAMP_EXAMPLE)},
+ ),
+ ) == EqualTo(term="name", literal=463249)
+
+
+def test_projection_hour_set_in(bound_reference_timestamp) -> None:
+ assert HourTransform().project(
+ "name",
+ BoundIn(
+ term=bound_reference_timestamp,
+ literals={TimestampLiteral(TIMESTAMP_EXAMPLE +
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)},
+ ),
+ ) == In(term="name", literals={463249, 463250})
+
+
+def test_projection_hour_set_not_in(bound_reference_timestamp) -> None:
+ assert (
+ HourTransform().project(
+ "name",
+ BoundNotIn(
+ term=bound_reference_timestamp,
+ literals={TimestampLiteral(TIMESTAMP_EXAMPLE +
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)},
+ ),
+ )
+ is None
+ )
+
+
+def test_projection_identity_unary(bound_reference_timestamp) -> None:
+ assert IdentityTransform().project("name",
BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name")
+
+
+def test_projection_identity_literal(bound_reference_timestamp) -> None:
+ assert IdentityTransform().project(
+ "name", BoundEqualTo(term=bound_reference_timestamp,
literal=TimestampLiteral(TIMESTAMP_EXAMPLE))
+ ) == EqualTo(
+ term="name", literal=TimestampLiteral(TIMESTAMP_EXAMPLE) # type:
ignore
+ )
+
+
+def test_projection_identity_set_in(bound_reference_timestamp) -> None:
+ assert IdentityTransform().project(
+ "name",
+ BoundIn(
+ term=bound_reference_timestamp,
+ literals={TimestampLiteral(TIMESTAMP_EXAMPLE +
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)},
+ ),
+ ) == In(
+ term="name", literals={TimestampLiteral(TIMESTAMP_EXAMPLE +
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)} # type: ignore
+ )
+
+
+def test_projection_identity_set_not_in(bound_reference_timestamp) -> None:
+ assert IdentityTransform().project(
+ "name",
+ BoundNotIn(
+ term=bound_reference_timestamp,
+ literals={TimestampLiteral(TIMESTAMP_EXAMPLE +
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)},
+ ),
+ ) == NotIn(
+ term="name", literals={TimestampLiteral(TIMESTAMP_EXAMPLE +
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)} # type: ignore
+ )
+
+
+def test_projection_truncate_string_unary(bound_reference_str:
BoundReference[str]) -> None:
+ assert TruncateTransform(2).project("name",
BoundNotNull(term=bound_reference_str)) == NotNull(term="name")
+
+
+def test_projection_truncate_string_literal_eq(bound_reference_str:
BoundReference[str]) -> None:
+ assert TruncateTransform(2).project("name",
BoundEqualTo(term=bound_reference_str, literal=literal("data"))) == EqualTo(
+ term="name", literal=literal("da")
+ )
+
+
+def test_projection_truncate_string_literal_gt(bound_reference_str:
BoundReference[str]) -> None:
+ assert TruncateTransform(2).project("name",
BoundGreaterThan(term=bound_reference_str, literal=literal("data"))) == EqualTo(
+ term="name", literal=literal("da")
+ )
+
+
+def test_projection_truncate_string_literal_gte(bound_reference_str:
BoundReference[str]) -> None:
+ assert TruncateTransform(2).project(
+ "name", BoundGreaterThanOrEqual(term=bound_reference_str,
literal=literal("data"))
+ ) == EqualTo(term="name", literal=literal("da"))
+
+
+def test_projection_truncate_string_set_same_result(bound_reference_str:
BoundReference[str]) -> None:
+ assert TruncateTransform(2).project(
+ "name", BoundIn(term=bound_reference_str, literals={literal("hello"),
literal("helloworld")})
+ ) == EqualTo(term="name", literal=literal("he"))
+
+
+def test_projection_truncate_string_set_in(bound_reference_str:
BoundReference[str]) -> None:
+ assert TruncateTransform(3).project(
+ "name", BoundIn(term=bound_reference_str, literals={literal("hello"),
literal("world")})
+ ) == In(term="name", literals={literal("hel"), literal("wor")})
+
+
+def test_projection_truncate_string_set_not_in(bound_reference_str:
BoundReference[str]) -> None:
+ assert (
+ TruncateTransform(3).project("name",
BoundNotIn(term=bound_reference_str, literals={literal("hello"),
literal("world")}))
+ is None
+ )
+
+
+def test_projection_truncate_decimal_literal_eq(bound_reference_decimal:
BoundReference[Decimal]) -> None:
+ assert TruncateTransform(2).project(
+ "name", BoundEqualTo(term=bound_reference_decimal,
literal=DecimalLiteral(Decimal(19.25)))
+ ) == EqualTo(term="name", literal=Decimal("19.24"))
+
+
+def test_projection_truncate_decimal_literal_gt(bound_reference_decimal:
BoundReference[Decimal]) -> None:
+ assert TruncateTransform(2).project(
+ "name", BoundGreaterThan(term=bound_reference_decimal,
literal=DecimalLiteral(Decimal(19.25)))
+ ) == GreaterThanOrEqual(term="name", literal=Decimal("19.26"))
+
+
+def test_projection_truncate_decimal_literal_gte(bound_reference_decimal:
BoundReference[Decimal]) -> None:
+ assert TruncateTransform(2).project(
+ "name", BoundGreaterThanOrEqual(term=bound_reference_decimal,
literal=DecimalLiteral(Decimal(19.25)))
+ ) == GreaterThanOrEqual(term="name", literal=Decimal("19.24"))
+
+
+def test_projection_truncate_decimal_in(bound_reference_decimal:
BoundReference[Decimal]) -> None:
+ assert TruncateTransform(2).project(
+ "name", BoundIn(term=bound_reference_decimal,
literals={literal(Decimal(19.25)), literal(Decimal(18.15))})
+ ) == In(
+ term="name",
+ literals={
+ Decimal("19.24"),
+ Decimal("18.14999999999999857891452847979962825775146484374"),
+ },
+ )
+
+
+def test_projection_truncate_long_literal_eq(bound_reference_decimal:
BoundReference[Decimal]) -> None:
+ assert TruncateTransform(2).project(
+ "name", BoundEqualTo(term=bound_reference_decimal,
literal=DecimalLiteral(Decimal(19.25)))
+ ) == EqualTo(term="name", literal=Decimal("19.24"))
+
+
+def test_projection_truncate_long_literal_gt(bound_reference_decimal:
BoundReference[Decimal]) -> None:
+ assert TruncateTransform(2).project(
+ "name", BoundGreaterThan(term=bound_reference_decimal,
literal=DecimalLiteral(Decimal(19.25)))
+ ) == GreaterThanOrEqual(term="name", literal=Decimal("19.26"))
+
+
+def test_projection_truncate_long_literal_gte(bound_reference_decimal:
BoundReference[Decimal]) -> None:
+ assert TruncateTransform(2).project(
+ "name", BoundGreaterThanOrEqual(term=bound_reference_decimal,
literal=DecimalLiteral(Decimal(19.25)))
+ ) == GreaterThanOrEqual(term="name", literal=Decimal("19.24"))
+
+
+def test_projection_truncate_long_in(bound_reference_decimal:
BoundReference[Decimal]) -> None:
+ assert TruncateTransform(2).project(
+ "name", BoundIn(term=bound_reference_decimal,
literals={DecimalLiteral(Decimal(19.25)), DecimalLiteral(Decimal(18.15))})
+ ) == In(
+ term="name",
+ literals={
+ Decimal("19.24"),
+ Decimal("18.14999999999999857891452847979962825775146484374"),
+ },
+ )