This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 63acd3174f Python: Move Transforms to Pydantic (#5170)
63acd3174f is described below

commit 63acd3174f5d4dd213ddfb3e9e66225313fc6489
Author: Fokko Driesprong <[email protected]>
AuthorDate: Fri Jul 1 17:21:15 2022 +0200

    Python: Move Transforms to Pydantic (#5170)
---
 python/pyiceberg/transforms.py  | 145 +++++++++++++++++++++-------------------
 python/tests/test_transforms.py | 132 ++++++++++++++++++++++++++++++++++++
 2 files changed, 210 insertions(+), 67 deletions(-)

diff --git a/python/pyiceberg/transforms.py b/python/pyiceberg/transforms.py
index d90736d613..ca337ed584 100644
--- a/python/pyiceberg/transforms.py
+++ b/python/pyiceberg/transforms.py
@@ -23,12 +23,14 @@ from functools import singledispatch
 from typing import (
     Any,
     Generic,
+    Literal,
     Optional,
     TypeVar,
 )
 from uuid import UUID
 
 import mmh3  # type: ignore
+from pydantic import Field, PositiveInt, PrivateAttr
 
 from pyiceberg.types import (
     BinaryType,
@@ -46,32 +48,21 @@ from pyiceberg.types import (
 )
 from pyiceberg.utils import datetime
 from pyiceberg.utils.decimal import decimal_to_bytes, truncate_decimal
+from pyiceberg.utils.iceberg_base_model import IcebergBaseModel
 from pyiceberg.utils.singleton import Singleton
 
 S = TypeVar("S")
 T = TypeVar("T")
 
 
-class Transform(ABC, Generic[S, T]):
+class Transform(IcebergBaseModel, ABC, Generic[S, T]):
     """Transform base class for concrete transforms.
 
     A base class to transform values and project predicates on partition 
values.
     This class is not used directly. Instead, use one of module method to 
create the child classes.
-
-    Args:
-        transform_string (str): name of the transform type
-        repr_string (str): string representation of a transform instance
     """
 
-    def __init__(self, transform_string: str, repr_string: str):
-        self._transform_string = transform_string
-        self._repr_string = repr_string
-
-    def __repr__(self):
-        return self._repr_string
-
-    def __str__(self):
-        return self._transform_string
+    __root__: str = Field()
 
     def __call__(self, value: Optional[S]) -> Optional[T]:
         return self.apply(value)
@@ -100,7 +91,10 @@ class Transform(ABC, Generic[S, T]):
 
     @property
     def dedup_name(self) -> str:
-        return self._transform_string
+        return self.__str__()
+
+    def __str__(self) -> str:
+        return self.__root__
 
 
 class BaseBucketTransform(Transform[S, int]):
@@ -115,11 +109,12 @@ class BaseBucketTransform(Transform[S, int]):
       num_buckets (int): The number of buckets.
     """
 
-    def __init__(self, source_type: IcebergType, num_buckets: int):
-        super().__init__(
-            f"bucket[{num_buckets}]",
-            f"transforms.bucket(source_type={repr(source_type)}, 
num_buckets={num_buckets})",
-        )
+    _source_type: IcebergType = PrivateAttr()
+    _num_buckets: PositiveInt = PrivateAttr()
+
+    def __init__(self, source_type: IcebergType, num_buckets: int, **data: 
Any):
+        super().__init__(__root__=f"bucket[{num_buckets}]", **data)
+        self._source_type = source_type
         self._num_buckets = num_buckets
 
     @property
@@ -139,6 +134,9 @@ class BaseBucketTransform(Transform[S, int]):
     def can_transform(self, source: IcebergType) -> bool:
         pass
 
+    def __repr__(self) -> str:
+        return f"transforms.bucket(source_type={repr(self._source_type)}, 
num_buckets={self._num_buckets})"
+
 
 class BucketNumberTransform(BaseBucketTransform):
     """Transforms a value of IntegerType, LongType, DateType, TimeType, 
TimestampType, or TimestamptzType
@@ -177,14 +175,11 @@ class BucketStringTransform(BaseBucketTransform):
     """Transforms a value of StringType into a bucket partition value.
 
     Example:
-        >>> transform = BucketStringTransform(100)
+        >>> transform = BucketStringTransform(StringType(), 100)
         >>> transform.apply("iceberg")
         89
     """
 
-    def __init__(self, num_buckets: int):
-        super().__init__(StringType(), num_buckets)
-
     def can_transform(self, source: IcebergType) -> bool:
         return isinstance(source, StringType)
 
@@ -212,14 +207,11 @@ class BucketUUIDTransform(BaseBucketTransform):
     """Transforms a value of UUIDType into a bucket partition value.
 
     Example:
-        >>> transform = BucketUUIDTransform(100)
+        >>> transform = BucketUUIDTransform(UUIDType(), 100)
         >>> transform.apply(UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"))
         40
     """
 
-    def __init__(self, num_buckets: int):
-        super().__init__(UUIDType(), num_buckets)
-
     def can_transform(self, source: IcebergType) -> bool:
         return isinstance(source, UUIDType)
 
@@ -247,12 +239,12 @@ class IdentityTransform(Transform[S, S]):
         'hello-world'
     """
 
-    def __init__(self, source_type: IcebergType):
-        super().__init__(
-            "identity",
-            f"transforms.identity(source_type={repr(source_type)})",
-        )
-        self._type = source_type
+    __root__: Literal["identity"] = Field(default="identity")
+    _source_type: IcebergType = PrivateAttr()
+
+    def __init__(self, source_type: IcebergType, **data: Any):
+        super().__init__(**data)
+        self._source_type = source_type
 
     def apply(self, value: Optional[S]) -> Optional[S]:
         return value
@@ -272,40 +264,38 @@ class IdentityTransform(Transform[S, S]):
         return other.preserves_order
 
     def to_human_string(self, value: Optional[S]) -> str:
-        return _human_string(value, self._type) if value is not None else 
"null"
+        return _human_string(value, self._source_type) if value is not None 
else "null"
+
+    def __str__(self) -> str:
+        return "identity"
+
+    def __repr__(self) -> str:
+        return f"transforms.identity(source_type={repr(self._source_type)})"
 
 
 class TruncateTransform(Transform[S, S]):
     """A transform for truncating a value to a specified width.
     Args:
       source_type (Type): An Iceberg Type of IntegerType, LongType, 
StringType, BinaryType or DecimalType
-      width (int): The truncate width
+      width (int): The truncate width, should be positive
     Raises:
       ValueError: If a type is provided that is incompatible with a Truncate 
transform
     """
 
-    def __init__(self, source_type: IcebergType, width: int):
-        assert width > 0, f"width ({width}) should be greater than 0"
-        super().__init__(
-            f"truncate[{width}]",
-            f"transforms.truncate(source_type={repr(source_type)}, 
width={width})",
-        )
-        self._type = source_type
-        self._width = width
+    __root__: str = Field()
+    _source_type: IcebergType = PrivateAttr()
+    _width: PositiveInt = PrivateAttr()
 
-    @property
-    def width(self) -> int:
-        return self._width
-
-    @property
-    def type(self) -> IcebergType:
-        return self._type
+    def __init__(self, source_type: IcebergType, width: int, **data: Any):
+        super().__init__(__root__=f"truncate[{width}]", **data)
+        self._source_type = source_type
+        self._width = width
 
     def apply(self, value: Optional[S]) -> Optional[S]:
         return _truncate_value(value, self._width) if value is not None else 
None
 
     def can_transform(self, source: IcebergType) -> bool:
-        return self._type == source
+        return self._source_type == source
 
     def result_type(self, source: IcebergType) -> IcebergType:
         return source
@@ -314,11 +304,23 @@ class TruncateTransform(Transform[S, S]):
     def preserves_order(self) -> bool:
         return True
 
+    @property
+    def source_type(self) -> IcebergType:
+        return self._source_type
+
+    @property
+    def width(self) -> int:
+        return self._width
+
     def satisfies_order_of(self, other: Transform) -> bool:
         if self == other:
             return True
-        elif isinstance(self._type, StringType) and isinstance(other, 
TruncateTransform) and isinstance(other.type, StringType):
-            return self._width >= other.width
+        elif (
+            isinstance(self.source_type, StringType)
+            and isinstance(other, TruncateTransform)
+            and isinstance(other.source_type, StringType)
+        ):
+            return self.width >= other.width
 
         return False
 
@@ -330,6 +332,9 @@ class TruncateTransform(Transform[S, S]):
         else:
             return str(value)
 
+    def __repr__(self) -> str:
+        return f"transforms.truncate(source_type={repr(self._source_type)}, 
width={self._width})"
+
 
 @singledispatch
 def _human_string(value: Any, _type: IcebergType) -> str:
@@ -403,35 +408,38 @@ def _(value: Decimal, _width: int) -> Decimal:
 class UnknownTransform(Transform):
     """A transform that represents when an unknown transform is provided
     Args:
-      source_type (Type): An Iceberg `Type`
+      source_type (IcebergType): An Iceberg `Type`
       transform (str): A string name of a transform
     Raises:
       AttributeError: If the apply method is called.
     """
 
-    def __init__(self, source_type: IcebergType, transform: str):
-        super().__init__(
-            transform,
-            f"transforms.UnknownTransform(source_type={repr(source_type)}, 
transform={repr(transform)})",
-        )
-        self._type = source_type
+    __root__: Literal["unknown"] = Field(default="unknown")
+    _source_type: IcebergType = PrivateAttr()
+    _transform: str = PrivateAttr()
+
+    def __init__(self, source_type: IcebergType, transform: str, **data: Any):
+        super().__init__(**data)
+        self._source_type = source_type
         self._transform = transform
 
     def apply(self, value: Optional[S]):
         raise AttributeError(f"Cannot apply unsupported transform: {self}")
 
     def can_transform(self, source: IcebergType) -> bool:
-        return self._type == source
+        return self._source_type == source
 
     def result_type(self, source: IcebergType) -> IcebergType:
         return StringType()
 
+    def __repr__(self) -> str:
+        return 
f"transforms.UnknownTransform(source_type={repr(self._source_type)}, 
transform={repr(self._transform)})"
+
 
 class VoidTransform(Transform, Singleton):
     """A transform that always returns None"""
 
-    def __init__(self):
-        super().__init__("void", "transforms.always_null()")
+    __root__ = "void"
 
     def apply(self, value: Optional[S]) -> None:
         return None
@@ -445,6 +453,9 @@ class VoidTransform(Transform, Singleton):
     def to_human_string(self, value: Optional[S]) -> str:
         return "null"
 
+    def __repr__(self) -> str:
+        return "transforms.always_null()"
+
 
 def bucket(source_type: IcebergType, num_buckets: int) -> BaseBucketTransform:
     if type(source_type) in {IntegerType, LongType, DateType, TimeType, 
TimestampType, TimestamptzType}:
@@ -452,13 +463,13 @@ def bucket(source_type: IcebergType, num_buckets: int) -> 
BaseBucketTransform:
     elif isinstance(source_type, DecimalType):
         return BucketDecimalTransform(source_type, num_buckets)
     elif isinstance(source_type, StringType):
-        return BucketStringTransform(num_buckets)
+        return BucketStringTransform(source_type, num_buckets)
     elif isinstance(source_type, BinaryType):
         return BucketBytesTransform(source_type, num_buckets)
     elif isinstance(source_type, FixedType):
         return BucketBytesTransform(source_type, num_buckets)
     elif isinstance(source_type, UUIDType):
-        return BucketUUIDTransform(num_buckets)
+        return BucketUUIDTransform(source_type, num_buckets)
     else:
         raise ValueError(f"Cannot bucket by type: {source_type}")
 
@@ -471,5 +482,5 @@ def truncate(source_type: IcebergType, width: int) -> 
TruncateTransform:
     return TruncateTransform(source_type, width)
 
 
-def always_null() -> Transform:
+def always_null() -> VoidTransform:
     return VoidTransform()
diff --git a/python/tests/test_transforms.py b/python/tests/test_transforms.py
index f7a772212a..ca2f441bce 100644
--- a/python/tests/test_transforms.py
+++ b/python/tests/test_transforms.py
@@ -23,6 +23,17 @@ import mmh3 as mmh3
 import pytest
 
 from pyiceberg import transforms
+from pyiceberg.transforms import (
+    BucketBytesTransform,
+    BucketDecimalTransform,
+    BucketNumberTransform,
+    BucketStringTransform,
+    BucketUUIDTransform,
+    IdentityTransform,
+    TruncateTransform,
+    UnknownTransform,
+    VoidTransform,
+)
 from pyiceberg.types import (
     BinaryType,
     BooleanType,
@@ -256,3 +267,124 @@ def test_void_transform():
     assert not void_transform.satisfies_order_of(transforms.bucket(DateType(), 
100))
     assert void_transform.to_human_string("test") == "null"
     assert void_transform.dedup_name == "void"
+
+
+def test_bucket_number_transform_json():
+    assert BucketNumberTransform(source_type=IntegerType(), 
num_buckets=22).json() == '"bucket[22]"'
+
+
+def test_bucket_number_transform_str():
+    assert str(BucketNumberTransform(source_type=IntegerType(), 
num_buckets=22)) == "bucket[22]"
+
+
+def test_bucket_number_transform_repr():
+    assert (
+        repr(BucketNumberTransform(source_type=IntegerType(), num_buckets=22))
+        == "transforms.bucket(source_type=IntegerType(), num_buckets=22)"
+    )
+
+
+def test_bucket_decimal_transform_json():
+    assert BucketDecimalTransform(source_type=DecimalType(19, 25), 
num_buckets=22).json() == '"bucket[22]"'
+
+
+def test_bucket_decimal_transform_str():
+    assert str(BucketDecimalTransform(source_type=DecimalType(19, 25), 
num_buckets=22)) == "bucket[22]"
+
+
+def test_bucket_decimal_transform_repr():
+    assert (
+        repr(BucketDecimalTransform(source_type=DecimalType(19, 25), 
num_buckets=22))
+        == "transforms.bucket(source_type=DecimalType(precision=19, scale=25), 
num_buckets=22)"
+    )
+
+
+def test_bucket_string_transform_json():
+    assert BucketStringTransform(StringType(), num_buckets=22).json() == 
'"bucket[22]"'
+
+
+def test_bucket_string_transform_str():
+    assert str(BucketStringTransform(StringType(), num_buckets=22)) == 
"bucket[22]"
+
+
+def test_bucket_string_transform_repr():
+    assert (
+        repr(BucketStringTransform(StringType(), num_buckets=22)) == 
"transforms.bucket(source_type=StringType(), num_buckets=22)"
+    )
+
+
+def test_bucket_bytes_transform_json():
+    assert BucketBytesTransform(BinaryType(), num_buckets=22).json() == 
'"bucket[22]"'
+
+
+def test_bucket_bytes_transform_str():
+    assert str(BucketBytesTransform(BinaryType(), num_buckets=22)) == 
"bucket[22]"
+
+
+def test_bucket_bytes_transform_repr():
+    assert (
+        repr(BucketBytesTransform(BinaryType(), num_buckets=22)) == 
"transforms.bucket(source_type=BinaryType(), num_buckets=22)"
+    )
+
+
+def test_bucket_uuid_transform_json():
+    assert BucketUUIDTransform(UUIDType(), num_buckets=22).json() == 
'"bucket[22]"'
+
+
+def test_bucket_uuid_transform_str():
+    assert str(BucketUUIDTransform(UUIDType(), num_buckets=22)) == "bucket[22]"
+
+
+def test_bucket_uuid_transform_repr():
+    assert repr(BucketUUIDTransform(UUIDType(), num_buckets=22)) == 
"transforms.bucket(source_type=UUIDType(), num_buckets=22)"
+
+
+def test_identity_transform_json():
+    assert IdentityTransform(StringType()).json() == '"identity"'
+
+
+def test_identity_transform_str():
+    assert str(IdentityTransform(StringType())) == "identity"
+
+
+def test_identity_transform_repr():
+    assert repr(IdentityTransform(StringType())) == 
"transforms.identity(source_type=StringType())"
+
+
+def test_truncate_transform_json():
+    assert TruncateTransform(StringType(), 22).json() == '"truncate[22]"'
+
+
+def test_truncate_transform_str():
+    assert str(TruncateTransform(StringType(), 22)) == "truncate[22]"
+
+
+def test_truncate_transform_repr():
+    assert repr(TruncateTransform(StringType(), 22)) == 
"transforms.truncate(source_type=StringType(), width=22)"
+
+
+def test_unknown_transform_json():
+    assert UnknownTransform(StringType(), "unknown").json() == '"unknown"'
+
+
+def test_unknown_transform_str():
+    assert str(UnknownTransform(StringType(), "unknown")) == "unknown"
+
+
+def test_unknown_transform_repr():
+    assert (
+        repr(UnknownTransform(StringType(), "unknown"))
+        == "transforms.UnknownTransform(source_type=StringType(), 
transform='unknown')"
+    )
+
+
+def test_void_transform_json():
+    assert VoidTransform().json() == '"void"'
+
+
+def test_void_transform_str():
+    assert str(VoidTransform()) == "void"
+
+
+def test_void_transform_repr():
+    assert repr(VoidTransform()) == "transforms.always_null()"

Reply via email to