This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 6603a81fe4 Python: Add sort order fields (#5124)
6603a81fe4 is described below
commit 6603a81fe495d0a1c2b60d7155908b15746f5076
Author: Fokko Driesprong <[email protected]>
AuthorDate: Tue Jul 12 01:21:31 2022 +0200
Python: Add sort order fields (#5124)
---
python/pyiceberg/schema.py | 6 +-
python/pyiceberg/table/metadata.py | 28 +--
python/pyiceberg/table/sorting.py | 111 +++++++++++
python/pyiceberg/transforms.py | 320 +++++++++++++-------------------
python/pyiceberg/types.py | 13 +-
python/pyiceberg/utils/parsing.py | 35 ++++
python/tests/table/test_metadata.py | 8 +-
python/tests/table/test_partitioning.py | 11 +-
python/tests/table/test_sorting.py | 70 +++++++
python/tests/test_transforms.py | 210 +++++++--------------
10 files changed, 453 insertions(+), 359 deletions(-)
diff --git a/python/pyiceberg/schema.py b/python/pyiceberg/schema.py
index 37372a0d5d..db2aaae35b 100644
--- a/python/pyiceberg/schema.py
+++ b/python/pyiceberg/schema.py
@@ -146,7 +146,11 @@ class Schema(IcebergBaseModel):
field_id = self._name_to_id.get(name_or_id)
else:
field_id = self._lazy_name_to_id_lower.get(name_or_id.lower())
- return self._lazy_id_to_field.get(field_id) # type: ignore
+
+ if not field_id:
+ raise ValueError(f"Could not find field with name or id
{name_or_id}, case_sensitive={case_sensitive}")
+
+ return self._lazy_id_to_field.get(field_id)
def find_type(self, name_or_id: Union[str, int], case_sensitive: bool =
True) -> IcebergType:
"""Find a field type using a field name or field ID
diff --git a/python/pyiceberg/table/metadata.py
b/python/pyiceberg/table/metadata.py
index fef4110baa..cc6bf33e8e 100644
--- a/python/pyiceberg/table/metadata.py
+++ b/python/pyiceberg/table/metadata.py
@@ -30,12 +30,12 @@ from pydantic import Field, root_validator
from pyiceberg.exceptions import ValidationError
from pyiceberg.schema import Schema
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
+from pyiceberg.table.sorting import UNSORTED_SORT_ORDER,
UNSORTED_SORT_ORDER_ID, SortOrder
from pyiceberg.utils.iceberg_base_model import IcebergBaseModel
_INITIAL_SEQUENCE_NUMBER = 0
INITIAL_SPEC_ID = 0
DEFAULT_SCHEMA_ID = 0
-DEFAULT_SORT_ORDER_UNSORTED = 0
def check_schemas(values: Dict[str, Any]) -> Dict[str, Any]:
@@ -62,14 +62,15 @@ def check_partition_specs(values: Dict[str, Any]) ->
Dict[str, Any]:
def check_sort_orders(values: Dict[str, Any]) -> Dict[str, Any]:
"""Validator to check if the default_sort_order_id is present in
sort-orders"""
- default_sort_order_id = values["default_sort_order_id"]
+ default_sort_order_id: int = values["default_sort_order_id"]
- if default_sort_order_id != DEFAULT_SORT_ORDER_UNSORTED:
- for sort in values["sort_orders"]:
- if sort["order-id"] == default_sort_order_id:
+ if default_sort_order_id != UNSORTED_SORT_ORDER_ID:
+ sort_orders: List[SortOrder] = values["sort_orders"]
+ for sort_order in sort_orders:
+ if sort_order.order_id == default_sort_order_id:
return values
- raise ValidationError(f"default-sort-order-id {default_sort_order_id}
can't be found")
+ raise ValidationError(f"default-sort-order-id {default_sort_order_id}
can't be found in {sort_orders}")
return values
@@ -77,6 +78,9 @@ class TableMetadataCommonFields(IcebergBaseModel):
"""Metadata for an Iceberg table as specified in the Apache Iceberg
spec (https://iceberg.apache.org/spec/#iceberg-table-spec)"""
+ def current_schema(self) -> Schema:
+ return next(schema for schema in self.schemas if schema.schema_id ==
self.current_schema_id)
+
@root_validator(pre=True)
def cleanup_snapshot_id(cls, data: Dict[str, Any]):
if data.get("current-snapshot-id") == -1:
@@ -159,10 +163,10 @@ class TableMetadataCommonFields(IcebergBaseModel):
remove oldest metadata log entries and keep a fixed-size log of the most
recent entries after a commit."""
- sort_orders: List[Dict[str, Any]] = Field(alias="sort-orders",
default_factory=list)
+ sort_orders: List[SortOrder] = Field(alias="sort-orders",
default_factory=list)
"""A list of sort orders, stored as full sort order objects."""
- default_sort_order_id: int = Field(alias="default-sort-order-id",
default=DEFAULT_SORT_ORDER_UNSORTED)
+ default_sort_order_id: int = Field(alias="default-sort-order-id",
default=UNSORTED_SORT_ORDER_ID)
"""Default sort order id of the table. Note that this could be used by
writers, but is not used when reading because reads use the specs stored
in manifest files."""
@@ -267,12 +271,10 @@ class TableMetadataV1(TableMetadataCommonFields,
IcebergBaseModel):
Returns:
The TableMetadata with the sort_orders set, if not provided
"""
- # This is going to be much nicer as soon as sort-order is an actual
pydantic object
- # Probably we'll just create a UNSORTED_ORDER constant then
- if not data.get("sort_orders"):
- data["sort_orders"] = [{"order_id": 0, "fields": []}]
+ if sort_orders := data.get("sort_orders"):
+ check_sort_orders(sort_orders)
else:
- check_sort_orders(data["sort_orders"])
+ data["sort_orders"] = [UNSORTED_SORT_ORDER]
return data
def to_v2(self) -> "TableMetadataV2":
diff --git a/python/pyiceberg/table/sorting.py
b/python/pyiceberg/table/sorting.py
new file mode 100644
index 0000000000..e2a72fd24c
--- /dev/null
+++ b/python/pyiceberg/table/sorting.py
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=keyword-arg-before-vararg
+from enum import Enum
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Union,
+)
+
+from pydantic import Field, root_validator
+
+from pyiceberg.transforms import Transform
+from pyiceberg.types import IcebergType
+from pyiceberg.utils.iceberg_base_model import IcebergBaseModel
+
+
+class SortDirection(Enum):
+ ASC = "asc"
+ DESC = "desc"
+
+
+class NullOrder(Enum):
+ NULLS_FIRST = "nulls-first"
+ NULLS_LAST = "nulls-last"
+
+
+class SortField(IcebergBaseModel):
+ """Sort order field
+
+ Args:
+ source_id (int): Source column id from the table’s schema
+ transform (str): Transform that is used to produce values to be sorted
on from the source column.
+ This is the same transform as described in partition
transforms.
+ direction (SortDirection): Sort direction, that can only be either asc
or desc
+ null_order (NullOrder): Null order that describes the order of null
values when sorted. Can only be either nulls-first or nulls-last
+ """
+
+ def __init__(
+ self,
+ source_id: Optional[int] = None,
+ transform: Optional[Union[Transform, Callable[[IcebergType],
Transform]]] = None,
+ direction: Optional[SortDirection] = None,
+ null_order: Optional[NullOrder] = None,
+ **data: Any,
+ ):
+ if source_id is not None:
+ data["source-id"] = source_id
+ if transform is not None:
+ data["transform"] = transform
+ if direction is not None:
+ data["direction"] = direction
+ if null_order is not None:
+ data["null-order"] = null_order
+ super().__init__(**data)
+
+ @root_validator(pre=True)
+ def set_null_order(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ values["direction"] = values["direction"] if values.get("direction")
else SortDirection.ASC
+ if not values.get("null-order"):
+ values["null-order"] = NullOrder.NULLS_FIRST if
values["direction"] == SortDirection.ASC else NullOrder.NULLS_LAST
+ return values
+
+ source_id: int = Field(alias="source-id")
+ transform: Transform = Field()
+ direction: SortDirection = Field()
+ null_order: NullOrder = Field(alias="null-order")
+
+
+class SortOrder(IcebergBaseModel):
+ """Describes how the data is sorted within the table
+
+ Users can sort their data within partitions by columns to gain performance.
+
+ The order of the sort fields within the list defines the order in which
the sort is applied to the data.
+
+ Args:
+ order_id (int): The id of the sort-order. To keep track of historical
sorting
+ fields (List[SortField]): The fields how the table is sorted
+ """
+
+ def __init__(self, order_id: Optional[int] = None, *fields: SortField,
**data: Any):
+ if order_id is not None:
+ data["order-id"] = order_id
+ if fields:
+ data["fields"] = fields
+ super().__init__(**data)
+
+ order_id: Optional[int] = Field(alias="order-id")
+ fields: List[SortField] = Field(default_factory=list)
+
+
+UNSORTED_SORT_ORDER_ID = 0
+UNSORTED_SORT_ORDER = SortOrder(order_id=UNSORTED_SORT_ORDER_ID)
diff --git a/python/pyiceberg/transforms.py b/python/pyiceberg/transforms.py
index ca337ed584..824afc3419 100644
--- a/python/pyiceberg/transforms.py
+++ b/python/pyiceberg/transforms.py
@@ -18,18 +18,17 @@
import base64
import struct
from abc import ABC, abstractmethod
-from decimal import Decimal
from functools import singledispatch
from typing import (
Any,
+ Callable,
Generic,
Literal,
Optional,
TypeVar,
)
-from uuid import UUID
-import mmh3 # type: ignore
+import mmh3
from pydantic import Field, PositiveInt, PrivateAttr
from pyiceberg.types import (
@@ -49,11 +48,20 @@ from pyiceberg.types import (
from pyiceberg.utils import datetime
from pyiceberg.utils.decimal import decimal_to_bytes, truncate_decimal
from pyiceberg.utils.iceberg_base_model import IcebergBaseModel
+from pyiceberg.utils.parsing import ParseNumberFromBrackets
from pyiceberg.utils.singleton import Singleton
S = TypeVar("S")
T = TypeVar("T")
+IDENTITY = "identity"
+VOID = "void"
+BUCKET = "bucket"
+TRUNCATE = "truncate"
+
+BUCKET_PARSER = ParseNumberFromBrackets(BUCKET)
+TRUNCATE_PARSER = ParseNumberFromBrackets(TRUNCATE)
+
class Transform(IcebergBaseModel, ABC, Generic[S, T]):
"""Transform base class for concrete transforms.
@@ -64,11 +72,32 @@ class Transform(IcebergBaseModel, ABC, Generic[S, T]):
__root__: str = Field()
- def __call__(self, value: Optional[S]) -> Optional[T]:
- return self.apply(value)
+ @classmethod
+ def __get_validators__(cls):
+ # one or more validators may be yielded which will be called in the
+ # order to validate the input, each validator will receive as an input
+ # the value returned from the previous validator
+ yield cls.validate
+
+ @classmethod
+ def validate(cls, v: Any):
+ # When Pydantic is unable to determine the subtype
+ # In this case we'll help pydantic a bit by parsing the transform type
ourselves
+ if isinstance(v, str):
+ if v == IDENTITY:
+ return IdentityTransform()
+ elif v == VOID:
+ return VoidTransform()
+ elif v.startswith(BUCKET):
+ return BucketTransform(num_buckets=BUCKET_PARSER.match(v))
+ elif v.startswith(TRUNCATE):
+ return TruncateTransform(width=BUCKET_PARSER.match(v))
+ else:
+ return UnknownTransform(transform=v)
+ return v
@abstractmethod
- def apply(self, value: Optional[S]) -> Optional[T]:
+ def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[T]]:
...
@abstractmethod
@@ -86,7 +115,7 @@ class Transform(IcebergBaseModel, ABC, Generic[S, T]):
def satisfies_order_of(self, other) -> bool:
return self == other
- def to_human_string(self, value: Optional[S]) -> str:
+ def to_human_string(self, _: IcebergType, value: Optional[S]) -> str:
return str(value) if value is not None else "null"
@property
@@ -96,25 +125,27 @@ class Transform(IcebergBaseModel, ABC, Generic[S, T]):
def __str__(self) -> str:
return self.__root__
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, Transform):
+ return self.__root__ == other.__root__
+ return False
-class BaseBucketTransform(Transform[S, int]):
+
+class BucketTransform(Transform[S, int]):
"""Base Transform class to transform a value into a bucket partition value
Transforms are parameterized by a number of buckets. Bucket partition
transforms use a 32-bit
hash of the source value to produce a positive value by mod the bucket
number.
Args:
- source_type (Type): An Iceberg Type of IntegerType, LongType,
DecimalType, DateType, TimeType,
- TimestampType, TimestamptzType, StringType, BinaryType, FixedType,
UUIDType.
num_buckets (int): The number of buckets.
"""
_source_type: IcebergType = PrivateAttr()
_num_buckets: PositiveInt = PrivateAttr()
- def __init__(self, source_type: IcebergType, num_buckets: int, **data:
Any):
+ def __init__(self, num_buckets: int, **data: Any):
super().__init__(__root__=f"bucket[{num_buckets}]", **data)
- self._source_type = source_type
self._num_buckets = num_buckets
@property
@@ -130,99 +161,58 @@ class BaseBucketTransform(Transform[S, int]):
def result_type(self, source: IcebergType) -> IcebergType:
return IntegerType()
- @abstractmethod
- def can_transform(self, source: IcebergType) -> bool:
- pass
-
- def __repr__(self) -> str:
- return f"transforms.bucket(source_type={repr(self._source_type)},
num_buckets={self._num_buckets})"
-
-
-class BucketNumberTransform(BaseBucketTransform):
- """Transforms a value of IntegerType, LongType, DateType, TimeType,
TimestampType, or TimestamptzType
- into a bucket partition value
-
- Example:
- >>> transform = BucketNumberTransform(LongType(), 100)
- >>> transform.apply(81068000000)
- 59
- """
-
def can_transform(self, source: IcebergType) -> bool:
- return type(source) in {IntegerType, DateType, LongType, TimeType,
TimestampType, TimestamptzType}
-
- def hash(self, value) -> int:
- return mmh3.hash(struct.pack("<q", value))
-
-
-class BucketDecimalTransform(BaseBucketTransform):
- """Transforms a value of DecimalType into a bucket partition value.
-
- Example:
- >>> transform = BucketDecimalTransform(DecimalType(9, 2), 100)
- >>> transform.apply(Decimal("14.20"))
- 59
- """
-
- def can_transform(self, source: IcebergType) -> bool:
- return isinstance(source, DecimalType)
-
- def hash(self, value: Decimal) -> int:
- return mmh3.hash(decimal_to_bytes(value))
-
-
-class BucketStringTransform(BaseBucketTransform):
- """Transforms a value of StringType into a bucket partition value.
-
- Example:
- >>> transform = BucketStringTransform(StringType(), 100)
- >>> transform.apply("iceberg")
- 89
- """
-
- def can_transform(self, source: IcebergType) -> bool:
- return isinstance(source, StringType)
-
- def hash(self, value: str) -> int:
- return mmh3.hash(value)
-
-
-class BucketBytesTransform(BaseBucketTransform):
- """Transforms a value of FixedType or BinaryType into a bucket partition
value.
-
- Example:
- >>> transform = BucketBytesTransform(BinaryType(), 100)
- >>> transform.apply(b"\\x00\\x01\\x02\\x03")
- 41
- """
-
- def can_transform(self, source: IcebergType) -> bool:
- return type(source) in {FixedType, BinaryType}
-
- def hash(self, value: bytes) -> int:
- return mmh3.hash(value)
+ return type(source) in {
+ IntegerType,
+ DateType,
+ LongType,
+ TimeType,
+ TimestampType,
+ TimestamptzType,
+ DecimalType,
+ StringType,
+ FixedType,
+ BinaryType,
+ UUIDType,
+ }
+
+ def transform(self, source: IcebergType, bucket: bool = True) ->
Callable[[Optional[Any]], Optional[int]]:
+ source_type = type(source)
+ if source_type in {IntegerType, LongType, DateType, TimeType,
TimestampType, TimestamptzType}:
+
+ def hash_func(v):
+ return mmh3.hash(struct.pack("<q", v))
+
+ elif source_type == DecimalType:
+
+ def hash_func(v):
+ return mmh3.hash(decimal_to_bytes(v))
+
+ elif source_type in {StringType, FixedType, BinaryType}:
+
+ def hash_func(v):
+ return mmh3.hash(v)
+
+ elif source_type == UUIDType:
+
+ def hash_func(v):
+ return mmh3.hash(
+ struct.pack(
+ ">QQ",
+ (v.int >> 64) & 0xFFFFFFFFFFFFFFFF,
+ v.int & 0xFFFFFFFFFFFFFFFF,
+ )
+ )
+ else:
+ raise ValueError(f"Unknown type {source}")
-class BucketUUIDTransform(BaseBucketTransform):
- """Transforms a value of UUIDType into a bucket partition value.
-
- Example:
- >>> transform = BucketUUIDTransform(UUIDType(), 100)
- >>> transform.apply(UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"))
- 40
- """
-
- def can_transform(self, source: IcebergType) -> bool:
- return isinstance(source, UUIDType)
+ if bucket:
+ return lambda v: (hash_func(v) & IntegerType.max) %
self._num_buckets if v else None
+ return hash_func
- def hash(self, value: UUID) -> int:
- return mmh3.hash(
- struct.pack(
- ">QQ",
- (value.int >> 64) & 0xFFFFFFFFFFFFFFFF,
- value.int & 0xFFFFFFFFFFFFFFFF,
- )
- )
+ def __repr__(self) -> str:
+ return f"BucketTransform(num_buckets={self._num_buckets})"
def _base64encode(buffer: bytes) -> str:
@@ -234,20 +224,16 @@ class IdentityTransform(Transform[S, S]):
"""Transforms a value into itself.
Example:
- >>> transform = IdentityTransform(StringType())
- >>> transform.apply('hello-world')
+ >>> transform = IdentityTransform()
+ >>> transform.transform(StringType())('hello-world')
'hello-world'
"""
__root__: Literal["identity"] = Field(default="identity")
_source_type: IcebergType = PrivateAttr()
- def __init__(self, source_type: IcebergType, **data: Any):
- super().__init__(**data)
- self._source_type = source_type
-
- def apply(self, value: Optional[S]) -> Optional[S]:
- return value
+ def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[S]]:
+ return lambda v: v
def can_transform(self, source: IcebergType) -> bool:
return source.is_primitive
@@ -263,20 +249,19 @@ class IdentityTransform(Transform[S, S]):
"""ordering by value is the same as long as the other preserves
order"""
return other.preserves_order
- def to_human_string(self, value: Optional[S]) -> str:
- return _human_string(value, self._source_type) if value is not None
else "null"
+ def to_human_string(self, source_type: IcebergType, value: Optional[S]) ->
str:
+ return _human_string(value, source_type) if value is not None else
"null"
def __str__(self) -> str:
return "identity"
def __repr__(self) -> str:
- return f"transforms.identity(source_type={repr(self._source_type)})"
+ return "IdentityTransform()"
class TruncateTransform(Transform[S, S]):
"""A transform for truncating a value to a specified width.
Args:
- source_type (Type): An Iceberg Type of IntegerType, LongType,
StringType, BinaryType or DecimalType
width (int): The truncate width, should be positive
Raises:
ValueError: If a type is provided that is incompatible with a Truncate
transform
@@ -286,16 +271,12 @@ class TruncateTransform(Transform[S, S]):
_source_type: IcebergType = PrivateAttr()
_width: PositiveInt = PrivateAttr()
- def __init__(self, source_type: IcebergType, width: int, **data: Any):
+ def __init__(self, width: int, **data: Any):
super().__init__(__root__=f"truncate[{width}]", **data)
- self._source_type = source_type
self._width = width
- def apply(self, value: Optional[S]) -> Optional[S]:
- return _truncate_value(value, self._width) if value is not None else
None
-
def can_transform(self, source: IcebergType) -> bool:
- return self._source_type == source
+ return type(source) in {IntegerType, LongType, StringType, BinaryType,
DecimalType}
def result_type(self, source: IcebergType) -> IcebergType:
return source
@@ -312,6 +293,28 @@ class TruncateTransform(Transform[S, S]):
def width(self) -> int:
return self._width
+ def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[S]]:
+ source_type = type(source)
+ if source_type in {IntegerType, LongType}:
+
+ def truncate_func(v):
+ return v - v % self._width
+
+ elif source_type in {StringType, BinaryType}:
+
+ def truncate_func(v):
+ return v[0 : min(self._width, len(v))]
+
+ elif source_type == DecimalType:
+
+ def truncate_func(v):
+ return truncate_decimal(v, self._width)
+
+ else:
+ raise ValueError(f"Cannot truncate for type: {source}")
+
+ return lambda v: truncate_func(v) if v else None
+
def satisfies_order_of(self, other: Transform) -> bool:
if self == other:
return True
@@ -324,7 +327,7 @@ class TruncateTransform(Transform[S, S]):
return False
- def to_human_string(self, value: Optional[S]) -> str:
+ def to_human_string(self, _: IcebergType, value: Optional[S]) -> str:
if value is None:
return "null"
elif isinstance(value, bytes):
@@ -333,7 +336,7 @@ class TruncateTransform(Transform[S, S]):
return str(value)
def __repr__(self) -> str:
- return f"transforms.truncate(source_type={repr(self._source_type)},
width={self._width})"
+ return f"TruncateTransform(width={self._width})"
@singledispatch
@@ -376,35 +379,6 @@ def _(_type: IcebergType, value: int) -> str:
return datetime.to_human_timestamptz(value)
-@singledispatch
-def _truncate_value(value: Any, _width: int) -> S:
- raise ValueError(f"Cannot truncate value: {value}")
-
-
-@_truncate_value.register(int)
-def _(value: int, _width: int) -> int:
- """Truncate a given int value into a given width if feasible."""
- return value - value % _width
-
-
-@_truncate_value.register(str)
-def _(value: str, _width: int) -> str:
- """Truncate a given string to a given width."""
- return value[0 : min(_width, len(value))]
-
-
-@_truncate_value.register(bytes)
-def _(value: bytes, _width: int) -> bytes:
- """Truncate a given binary bytes into a given width."""
- return value[0 : min(_width, len(value))]
-
-
-@_truncate_value.register(Decimal)
-def _(value: Decimal, _width: int) -> Decimal:
- """Truncate a given decimal value into a given width."""
- return truncate_decimal(value, _width)
-
-
class UnknownTransform(Transform):
"""A transform that represents when an unknown transform is provided
Args:
@@ -418,22 +392,21 @@ class UnknownTransform(Transform):
_source_type: IcebergType = PrivateAttr()
_transform: str = PrivateAttr()
- def __init__(self, source_type: IcebergType, transform: str, **data: Any):
+ def __init__(self, transform: str, **data: Any):
super().__init__(**data)
- self._source_type = source_type
self._transform = transform
- def apply(self, value: Optional[S]):
+ def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[T]]:
raise AttributeError(f"Cannot apply unsupported transform: {self}")
def can_transform(self, source: IcebergType) -> bool:
- return self._source_type == source
+ return False
def result_type(self, source: IcebergType) -> IcebergType:
return StringType()
def __repr__(self) -> str:
- return
f"transforms.UnknownTransform(source_type={repr(self._source_type)},
transform={repr(self._transform)})"
+ return f"UnknownTransform(transform={repr(self._transform)})"
class VoidTransform(Transform, Singleton):
@@ -441,8 +414,8 @@ class VoidTransform(Transform, Singleton):
__root__ = "void"
- def apply(self, value: Optional[S]) -> None:
- return None
+ def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[T]]:
+ return lambda v: None
def can_transform(self, _: IcebergType) -> bool:
return True
@@ -450,37 +423,8 @@ class VoidTransform(Transform, Singleton):
def result_type(self, source: IcebergType) -> IcebergType:
return source
- def to_human_string(self, value: Optional[S]) -> str:
+ def to_human_string(self, _: IcebergType, value: Optional[S]) -> str:
return "null"
def __repr__(self) -> str:
- return "transforms.always_null()"
-
-
-def bucket(source_type: IcebergType, num_buckets: int) -> BaseBucketTransform:
- if type(source_type) in {IntegerType, LongType, DateType, TimeType,
TimestampType, TimestamptzType}:
- return BucketNumberTransform(source_type, num_buckets)
- elif isinstance(source_type, DecimalType):
- return BucketDecimalTransform(source_type, num_buckets)
- elif isinstance(source_type, StringType):
- return BucketStringTransform(source_type, num_buckets)
- elif isinstance(source_type, BinaryType):
- return BucketBytesTransform(source_type, num_buckets)
- elif isinstance(source_type, FixedType):
- return BucketBytesTransform(source_type, num_buckets)
- elif isinstance(source_type, UUIDType):
- return BucketUUIDTransform(source_type, num_buckets)
- else:
- raise ValueError(f"Cannot bucket by type: {source_type}")
-
-
-def identity(source_type: IcebergType) -> IdentityTransform:
- return IdentityTransform(source_type)
-
-
-def truncate(source_type: IcebergType, width: int) -> TruncateTransform:
- return TruncateTransform(source_type, width)
-
-
-def always_null() -> VoidTransform:
- return VoidTransform()
+ return "VoidTransform()"
diff --git a/python/pyiceberg/types.py b/python/pyiceberg/types.py
index b8068d6b78..94c7345726 100644
--- a/python/pyiceberg/types.py
+++ b/python/pyiceberg/types.py
@@ -31,6 +31,7 @@ Notes:
"""
import re
from typing import (
+ Any,
ClassVar,
Dict,
Literal,
@@ -41,10 +42,12 @@ from typing import (
from pydantic import Field, PrivateAttr
from pyiceberg.utils.iceberg_base_model import IcebergBaseModel
+from pyiceberg.utils.parsing import ParseNumberFromBrackets
from pyiceberg.utils.singleton import Singleton
DECIMAL_REGEX = re.compile(r"decimal\((\d+),\s*(\d+)\)")
-FIXED_REGEX = re.compile(r"fixed\[(\d+)\]")
+FIXED = "fixed"
+FIXED_PARSER = ParseNumberFromBrackets(FIXED)
class IcebergType(IcebergBaseModel, Singleton):
@@ -65,7 +68,7 @@ class IcebergType(IcebergBaseModel, Singleton):
yield cls.validate
@classmethod
- def validate(cls, v):
+ def validate(cls, v: Any) -> "IcebergType":
# When Pydantic is unable to determine the subtype
# In this case we'll help pydantic a bit by parsing the
# primitive type ourselves, or pointing it at the correct
@@ -123,11 +126,7 @@ class FixedType(PrimitiveType):
@staticmethod
def parse(str_repr: str) -> "FixedType":
- matches = FIXED_REGEX.search(str_repr)
- if matches:
- length = int(matches.group(1))
- return FixedType(length)
- raise ValueError(f"Could not parse {str_repr} into a FixedType")
+ return FixedType(length=FIXED_PARSER.match(str_repr))
def __init__(self, length: int):
super().__init__(__root__=f"fixed[{length}]")
diff --git a/python/pyiceberg/utils/parsing.py
b/python/pyiceberg/utils/parsing.py
new file mode 100644
index 0000000000..0566ed6c28
--- /dev/null
+++ b/python/pyiceberg/utils/parsing.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import re
+from re import Pattern
+
+
+class ParseNumberFromBrackets:
+ """Extracts the size from a string in the form of prefix[22]"""
+
+ regex: Pattern
+ prefix: str
+
+ def __init__(self, prefix: str):
+ self.prefix = prefix
+ self.regex = re.compile(rf"{prefix}\[(\d+)\]")
+
+ def match(self, str_repr: str) -> int:
+ matches = self.regex.search(str_repr)
+ if matches:
+ return int(matches.group(1))
+ raise ValueError(f"Could not match {str_repr}, expected format
{self.prefix}[22]")
diff --git a/python/tests/table/test_metadata.py
b/python/tests/table/test_metadata.py
index e7dac41919..fa64b5b199 100644
--- a/python/tests/table/test_metadata.py
+++ b/python/tests/table/test_metadata.py
@@ -147,7 +147,7 @@ def test_v2_metadata_parsing():
assert table_metadata.current_snapshot_id == 3055729675574597004
assert table_metadata.snapshots[0]["snapshot-id"] == 3051729675574597004
assert table_metadata.snapshot_log[0]["timestamp-ms"] == 1515100955770
- assert table_metadata.sort_orders[0]["order-id"] == 3
+ assert table_metadata.sort_orders[0].order_id == 3
assert table_metadata.default_sort_order_id == 3
@@ -207,7 +207,7 @@ def test_serialize_v1():
table_metadata = TableMetadataV1(**EXAMPLE_TABLE_METADATA_V1).json()
assert (
table_metadata
- == """{"location": "s3://bucket/test/location", "table-uuid":
"d20125c8-7284-442c-9aea-15fee620737c", "last-updated-ms": 1602638573874,
"last-column-id": 3, "schemas": [{"fields": [{"id": 1, "name": "x", "type":
"long", "required": true}, {"id": 2, "name": "y", "type": "long", "required":
true, "doc": "comment"}, {"id": 3, "name": "z", "type": "long", "required":
true}], "schema-id": 0, "identifier-field-ids": []}], "current-schema-id": 0,
"partition-specs": [{"spec-id": 0, "fiel [...]
+ == """{"location": "s3://bucket/test/location", "table-uuid":
"d20125c8-7284-442c-9aea-15fee620737c", "last-updated-ms": 1602638573874,
"last-column-id": 3, "schemas": [{"fields": [{"id": 1, "name": "x", "type":
"long", "required": true}, {"id": 2, "name": "y", "type": "long", "required":
true, "doc": "comment"}, {"id": 3, "name": "z", "type": "long", "required":
true}], "schema-id": 0, "identifier-field-ids": []}], "current-schema-id": 0,
"partition-specs": [{"spec-id": 0, "fiel [...]
)
@@ -215,7 +215,7 @@ def test_serialize_v2():
table_metadata = TableMetadataV2(**EXAMPLE_TABLE_METADATA_V2).json()
assert (
table_metadata
- == """{"location": "s3://bucket/test/location", "table-uuid":
"9c12d441-03fe-4693-9a96-a0705ddf69c1", "last-updated-ms": 1602638573590,
"last-column-id": 3, "schemas": [{"fields": [{"id": 1, "name": "x", "type":
"long", "required": true}], "schema-id": 0, "identifier-field-ids": []},
{"fields": [{"id": 1, "name": "x", "type": "long", "required": true}, {"id": 2,
"name": "y", "type": "long", "required": true, "doc": "comment"}, {"id": 3,
"name": "z", "type": "long", "required": tr [...]
+ == """{"location": "s3://bucket/test/location", "table-uuid":
"9c12d441-03fe-4693-9a96-a0705ddf69c1", "last-updated-ms": 1602638573590,
"last-column-id": 3, "schemas": [{"fields": [{"id": 1, "name": "x", "type":
"long", "required": true}], "schema-id": 0, "identifier-field-ids": []},
{"fields": [{"id": 1, "name": "x", "type": "long", "required": true}, {"id": 2,
"name": "y", "type": "long", "required": true, "doc": "comment"}, {"id": 3,
"name": "z", "type": "long", "required": tr [...]
)
@@ -510,7 +510,7 @@ def test_v1_write_metadata_for_v2():
]
assert metadata_v2["default-spec-id"] == 0
assert metadata_v2["last-partition-id"] == 1000
- assert metadata_v2["sort-orders"] == [{"fields": [], "order_id": 0}]
+ assert metadata_v2["sort-orders"] == [{"order-id": 0, "fields": []}]
assert metadata_v2["default-sort-order-id"] == 0
# Deprecated fields
assert "schema" not in metadata_v2
diff --git a/python/tests/table/test_partitioning.py
b/python/tests/table/test_partitioning.py
index e6afac293f..12c74625ee 100644
--- a/python/tests/table/test_partitioning.py
+++ b/python/tests/table/test_partitioning.py
@@ -17,12 +17,11 @@
from pyiceberg.schema import Schema
from pyiceberg.table.partitioning import PartitionField, PartitionSpec
-from pyiceberg.transforms import bucket
-from pyiceberg.types import IntegerType
+from pyiceberg.transforms import BucketTransform
def test_partition_field_init():
- bucket_transform = bucket(IntegerType(), 100)
+ bucket_transform = BucketTransform(100)
partition_field = PartitionField(3, 1000, bucket_transform, "id")
assert partition_field.source_id == 3
@@ -33,12 +32,12 @@ def test_partition_field_init():
assert str(partition_field) == "1000: id: bucket[100](3)"
assert (
repr(partition_field)
- == "PartitionField(source_id=3, field_id=1000,
transform=transforms.bucket(source_type=IntegerType(), num_buckets=100),
name='id')"
+ == "PartitionField(source_id=3, field_id=1000,
transform=BucketTransform(num_buckets=100), name='id')"
)
def test_partition_spec_init(table_schema_simple: Schema):
- bucket_transform = bucket(IntegerType(), 4)
+ bucket_transform: BucketTransform = BucketTransform(4)
id_field1 = PartitionField(3, 1001, bucket_transform, "id")
partition_spec1 = PartitionSpec(table_schema_simple, 0, (id_field1,), 1001)
@@ -57,7 +56,7 @@ def test_partition_spec_init(table_schema_simple: Schema):
def test_partition_compatible_with(table_schema_simple: Schema):
- bucket_transform = bucket(IntegerType(), 4)
+ bucket_transform: BucketTransform = BucketTransform(4)
field1 = PartitionField(3, 100, bucket_transform, "id")
field2 = PartitionField(3, 102, bucket_transform, "id")
lhs = PartitionSpec(table_schema_simple, 0, (field1,), 1001)
diff --git a/python/tests/table/test_sorting.py
b/python/tests/table/test_sorting.py
new file mode 100644
index 0000000000..1cdf4d0c06
--- /dev/null
+++ b/python/tests/table/test_sorting.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from pyiceberg.table.metadata import TableMetadata
+from pyiceberg.table.sorting import (
+ UNSORTED_SORT_ORDER,
+ NullOrder,
+ SortDirection,
+ SortField,
+ SortOrder,
+)
+from pyiceberg.transforms import BucketTransform, IdentityTransform,
VoidTransform
+from tests.table.test_metadata import EXAMPLE_TABLE_METADATA_V2
+
+
+def test_serialize_sort_order_unsorted():
+ assert UNSORTED_SORT_ORDER.json() == '{"order-id": 0, "fields": []}'
+
+
+def test_serialize_sort_order():
+ sort_order = SortOrder(
+ 22,
+ SortField(source_id=19, transform=IdentityTransform(),
null_order=NullOrder.NULLS_FIRST),
+ SortField(source_id=25, transform=BucketTransform(4),
direction=SortDirection.DESC),
+ SortField(source_id=22, transform=VoidTransform(),
direction=SortDirection.ASC),
+ )
+ expected = '{"order-id": 22, "fields": [{"source-id": 19, "transform":
"identity", "direction": "asc", "null-order": "nulls-first"}, {"source-id": 25,
"transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"},
{"source-id": 22, "transform": "void", "direction": "asc", "null-order":
"nulls-first"}]}'
+ assert sort_order.json() == expected
+
+
+def test_deserialize_sort_order():
+ expected = SortOrder(
+ 22,
+ SortField(source_id=19, transform=IdentityTransform(),
null_order=NullOrder.NULLS_FIRST),
+ SortField(source_id=25, transform=BucketTransform(4),
direction=SortDirection.DESC),
+ SortField(source_id=22, transform=VoidTransform(),
direction=SortDirection.ASC),
+ )
+ payload = '{"order-id": 22, "fields": [{"source-id": 19, "transform":
"identity", "direction": "asc", "null-order": "nulls-first"}, {"source-id": 25,
"transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"},
{"source-id": 22, "transform": "void", "direction": "asc", "null-order":
"nulls-first"}]}'
+
+ assert SortOrder.parse_raw(payload) == expected
+
+
+def test_sorting_schema():
+ table_metadata = TableMetadata.parse_obj(EXAMPLE_TABLE_METADATA_V2)
+
+ assert table_metadata.sort_orders == [
+ SortOrder(
+ 3,
+ SortField(2, IdentityTransform(), SortDirection.ASC,
null_order=NullOrder.NULLS_FIRST),
+ SortField(
+ 3,
+ BucketTransform(4),
+ direction=SortDirection.DESC,
+ null_order=NullOrder.NULLS_LAST,
+ ),
+ )
+ ]
diff --git a/python/tests/test_transforms.py b/python/tests/test_transforms.py
index ca2f441bce..bdebae94e4 100644
--- a/python/tests/test_transforms.py
+++ b/python/tests/test_transforms.py
@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=W0123
-
+# pylint: disable=eval-used,protected-access
from decimal import Decimal
from uuid import UUID
@@ -24,12 +23,9 @@ import pytest
from pyiceberg import transforms
from pyiceberg.transforms import (
- BucketBytesTransform,
- BucketDecimalTransform,
- BucketNumberTransform,
- BucketStringTransform,
- BucketUUIDTransform,
+ BucketTransform,
IdentityTransform,
+ Transform,
TruncateTransform,
UnknownTransform,
VoidTransform,
@@ -56,6 +52,7 @@ from pyiceberg.utils.datetime import (
timestamp_to_micros,
timestamptz_to_micros,
)
+from pyiceberg.utils.iceberg_base_model import IcebergBaseModel
@pytest.mark.parametrize(
@@ -83,30 +80,30 @@ from pyiceberg.utils.datetime import (
],
)
def test_bucket_hash_values(test_input, test_type, expected):
- assert transforms.bucket(test_type, 8).hash(test_input) == expected
+ assert BucketTransform(num_buckets=8).transform(test_type,
bucket=False)(test_input) == expected
@pytest.mark.parametrize(
- "bucket,value,expected",
+ "transform,value,expected",
[
- (transforms.bucket(IntegerType(), 100), 34, 79),
- (transforms.bucket(LongType(), 100), 34, 79),
- (transforms.bucket(DateType(), 100), 17486, 26),
- (transforms.bucket(TimeType(), 100), 81068000000, 59),
- (transforms.bucket(TimestampType(), 100), 1510871468000000, 7),
- (transforms.bucket(DecimalType(9, 2), 100), Decimal("14.20"), 59),
- (transforms.bucket(StringType(), 100), "iceberg", 89),
+ (BucketTransform(100).transform(IntegerType()), 34, 79),
+ (BucketTransform(100).transform(LongType()), 34, 79),
+ (BucketTransform(100).transform(DateType()), 17486, 26),
+ (BucketTransform(100).transform(TimeType()), 81068000000, 59),
+ (BucketTransform(100).transform(TimestampType()), 1510871468000000, 7),
+ (BucketTransform(100).transform(DecimalType(9, 2)), Decimal("14.20"),
59),
+ (BucketTransform(100).transform(StringType()), "iceberg", 89),
(
- transforms.bucket(UUIDType(), 100),
+ BucketTransform(100).transform(UUIDType()),
UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"),
40,
),
- (transforms.bucket(FixedType(3), 128), b"foo", 32),
- (transforms.bucket(BinaryType(), 128), b"\x00\x01\x02\x03", 57),
+ (BucketTransform(128).transform(FixedType(3)), b"foo", 32),
+ (BucketTransform(128).transform(BinaryType()), b"\x00\x01\x02\x03",
57),
],
)
-def test_buckets(bucket, value, expected):
- assert bucket.apply(value) == expected
+def test_buckets(transform, value, expected):
+ assert transform(value) == expected
@pytest.mark.parametrize(
@@ -126,20 +123,20 @@ def test_buckets(bucket, value, expected):
],
)
def test_bucket_method(type_var):
- bucket_transform = transforms.bucket(type_var, 8)
+ bucket_transform = BucketTransform(8)
assert str(bucket_transform) == str(eval(repr(bucket_transform)))
assert bucket_transform.can_transform(type_var)
assert bucket_transform.result_type(type_var) == IntegerType()
assert bucket_transform.num_buckets == 8
assert bucket_transform.apply(None) is None
- assert bucket_transform.to_human_string("test") == "test"
+ assert bucket_transform.to_human_string(type_var, "test") == "test"
def test_string_with_surrogate_pair():
string_with_surrogate_pair = "string with a surrogate pair: 💰"
as_bytes = bytes(string_with_surrogate_pair, "UTF-8")
- bucket_transform = transforms.bucket(StringType(), 100)
- assert bucket_transform.hash(string_with_surrogate_pair) ==
mmh3.hash(as_bytes)
+ bucket_transform = BucketTransform(100).transform(StringType(),
bucket=False)
+ assert bucket_transform(string_with_surrogate_pair) == mmh3.hash(as_bytes)
@pytest.mark.parametrize(
@@ -157,8 +154,8 @@ def test_string_with_surrogate_pair():
],
)
def test_identity_human_string(type_var, value, expected):
- identity = transforms.identity(type_var)
- assert identity.to_human_string(value) == expected
+ identity = IdentityTransform()
+ assert identity.to_human_string(type_var, value) == expected
@pytest.mark.parametrize(
@@ -181,11 +178,11 @@ def test_identity_human_string(type_var, value, expected):
],
)
def test_identity_method(type_var):
- identity_transform = transforms.identity(type_var)
+ identity_transform = IdentityTransform()
assert str(identity_transform) == str(eval(repr(identity_transform)))
assert identity_transform.can_transform(type_var)
assert identity_transform.result_type(type_var) == type_var
- assert identity_transform.apply("test") == "test"
+ assert identity_transform.transform(type_var)("test") == "test"
@pytest.mark.parametrize("type_var", [IntegerType(), LongType()])
@@ -194,8 +191,8 @@ def test_identity_method(type_var):
[(1, 0), (5, 0), (9, 0), (10, 10), (11, 10), (-1, -10), (-10, -10), (-12,
-20)],
)
def test_truncate_integer(type_var, input_var, expected):
- trunc = transforms.truncate(type_var, 10)
- assert trunc.apply(input_var) == expected
+ trunc = TruncateTransform(10)
+ assert trunc.transform(type_var)(input_var) == expected
@pytest.mark.parametrize(
@@ -209,14 +206,14 @@ def test_truncate_integer(type_var, input_var, expected):
],
)
def test_truncate_decimal(input_var, expected):
- trunc = transforms.truncate(DecimalType(9, 2), 10)
- assert trunc.apply(input_var) == expected
+ trunc = TruncateTransform(10)
+ assert trunc.transform(DecimalType(9, 2))(input_var) == expected
@pytest.mark.parametrize("input_var,expected", [("abcdefg", "abcde"), ("abc",
"abc")])
def test_truncate_string(input_var, expected):
- trunc = transforms.truncate(StringType(), 5)
- assert trunc.apply(input_var) == expected
+ trunc = TruncateTransform(5)
+ assert trunc.transform(StringType())(input_var) == expected
@pytest.mark.parametrize(
@@ -232,159 +229,92 @@ def test_truncate_string(input_var, expected):
],
)
def test_truncate_method(type_var, value, expected_human_str, expected):
- truncate_transform = transforms.truncate(type_var, 1)
+ truncate_transform = TruncateTransform(1)
assert str(truncate_transform) == str(eval(repr(truncate_transform)))
assert truncate_transform.can_transform(type_var)
assert truncate_transform.result_type(type_var) == type_var
- assert truncate_transform.to_human_string(value) == expected_human_str
- assert truncate_transform.apply(value) == expected
- assert truncate_transform.to_human_string(None) == "null"
+ assert truncate_transform.to_human_string(type_var, value) ==
expected_human_str
+ assert truncate_transform.transform(type_var)(value) == expected
+ assert truncate_transform.to_human_string(type_var, None) == "null"
assert truncate_transform.width == 1
- assert truncate_transform.apply(None) is None
+ assert truncate_transform.transform(type_var)(None) is None
assert truncate_transform.preserves_order
assert truncate_transform.satisfies_order_of(truncate_transform)
def test_unknown_transform():
- unknown_transform = transforms.UnknownTransform(FixedType(8), "unknown")
+ unknown_transform = transforms.UnknownTransform("unknown")
assert str(unknown_transform) == str(eval(repr(unknown_transform)))
with pytest.raises(AttributeError):
- unknown_transform.apply("test")
- assert unknown_transform.can_transform(FixedType(8))
+ unknown_transform.transform(StringType())("test")
assert not unknown_transform.can_transform(FixedType(5))
assert isinstance(unknown_transform.result_type(BooleanType()), StringType)
def test_void_transform():
- void_transform = transforms.always_null()
- assert void_transform is transforms.always_null()
+ void_transform = VoidTransform()
+ assert void_transform is VoidTransform()
assert void_transform == eval(repr(void_transform))
- assert void_transform.apply("test") is None
+ assert void_transform.transform(StringType())("test") is None
assert void_transform.can_transform(BooleanType())
assert isinstance(void_transform.result_type(BooleanType()), BooleanType)
assert not void_transform.preserves_order
- assert void_transform.satisfies_order_of(transforms.always_null())
- assert not void_transform.satisfies_order_of(transforms.bucket(DateType(),
100))
- assert void_transform.to_human_string("test") == "null"
+ assert void_transform.satisfies_order_of(VoidTransform())
+ assert not void_transform.satisfies_order_of(BucketTransform(100))
+ assert void_transform.to_human_string(StringType(), "test") == "null"
assert void_transform.dedup_name == "void"
-def test_bucket_number_transform_json():
- assert BucketNumberTransform(source_type=IntegerType(),
num_buckets=22).json() == '"bucket[22]"'
-
-
-def test_bucket_number_transform_str():
- assert str(BucketNumberTransform(source_type=IntegerType(),
num_buckets=22)) == "bucket[22]"
-
-
-def test_bucket_number_transform_repr():
- assert (
- repr(BucketNumberTransform(source_type=IntegerType(), num_buckets=22))
- == "transforms.bucket(source_type=IntegerType(), num_buckets=22)"
- )
-
-
-def test_bucket_decimal_transform_json():
- assert BucketDecimalTransform(source_type=DecimalType(19, 25),
num_buckets=22).json() == '"bucket[22]"'
-
-
-def test_bucket_decimal_transform_str():
- assert str(BucketDecimalTransform(source_type=DecimalType(19, 25),
num_buckets=22)) == "bucket[22]"
-
-
-def test_bucket_decimal_transform_repr():
- assert (
- repr(BucketDecimalTransform(source_type=DecimalType(19, 25),
num_buckets=22))
- == "transforms.bucket(source_type=DecimalType(precision=19, scale=25),
num_buckets=22)"
- )
-
-
-def test_bucket_string_transform_json():
- assert BucketStringTransform(StringType(), num_buckets=22).json() ==
'"bucket[22]"'
-
-
-def test_bucket_string_transform_str():
- assert str(BucketStringTransform(StringType(), num_buckets=22)) ==
"bucket[22]"
-
-
-def test_bucket_string_transform_repr():
- assert (
- repr(BucketStringTransform(StringType(), num_buckets=22)) ==
"transforms.bucket(source_type=StringType(), num_buckets=22)"
- )
+class TestType(IcebergBaseModel):
+ __root__: Transform
-def test_bucket_bytes_transform_json():
- assert BucketBytesTransform(BinaryType(), num_buckets=22).json() ==
'"bucket[22]"'
+def test_bucket_transform_serialize():
+ assert BucketTransform(num_buckets=22).json() == '"bucket[22]"'
-def test_bucket_bytes_transform_str():
- assert str(BucketBytesTransform(BinaryType(), num_buckets=22)) ==
"bucket[22]"
+def test_bucket_transform_deserialize():
+ transform = TestType.parse_raw('"bucket[22]"').__root__
+ assert transform == BucketTransform(num_buckets=22)
-def test_bucket_bytes_transform_repr():
- assert (
- repr(BucketBytesTransform(BinaryType(), num_buckets=22)) ==
"transforms.bucket(source_type=BinaryType(), num_buckets=22)"
- )
+def test_bucket_transform_str():
+ assert str(BucketTransform(num_buckets=22)) == "bucket[22]"
-def test_bucket_uuid_transform_json():
- assert BucketUUIDTransform(UUIDType(), num_buckets=22).json() ==
'"bucket[22]"'
+def test_bucket_transform_repr():
+ assert repr(BucketTransform(num_buckets=22)) ==
"BucketTransform(num_buckets=22)"
-def test_bucket_uuid_transform_str():
- assert str(BucketUUIDTransform(UUIDType(), num_buckets=22)) == "bucket[22]"
+def test_truncate_transform_serialize():
+ assert UnknownTransform("unknown").json() == '"unknown"'
-def test_bucket_uuid_transform_repr():
- assert repr(BucketUUIDTransform(UUIDType(), num_buckets=22)) ==
"transforms.bucket(source_type=UUIDType(), num_buckets=22)"
-
-
-def test_identity_transform_json():
- assert IdentityTransform(StringType()).json() == '"identity"'
-
-
-def test_identity_transform_str():
- assert str(IdentityTransform(StringType())) == "identity"
-
-
-def test_identity_transform_repr():
- assert repr(IdentityTransform(StringType())) ==
"transforms.identity(source_type=StringType())"
-
-
-def test_truncate_transform_json():
- assert TruncateTransform(StringType(), 22).json() == '"truncate[22]"'
-
-
-def test_truncate_transform_str():
- assert str(TruncateTransform(StringType(), 22)) == "truncate[22]"
-
-
-def test_truncate_transform_repr():
- assert repr(TruncateTransform(StringType(), 22)) ==
"transforms.truncate(source_type=StringType(), width=22)"
-
-
-def test_unknown_transform_json():
- assert UnknownTransform(StringType(), "unknown").json() == '"unknown"'
+def test_unknown_transform_deserialize():
+ transform = TestType.parse_raw('"unknown"').__root__
+ assert transform == UnknownTransform("unknown")
def test_unknown_transform_str():
- assert str(UnknownTransform(StringType(), "unknown")) == "unknown"
+ assert str(UnknownTransform("unknown")) == "unknown"
def test_unknown_transform_repr():
- assert (
- repr(UnknownTransform(StringType(), "unknown"))
- == "transforms.UnknownTransform(source_type=StringType(),
transform='unknown')"
- )
+ assert repr(UnknownTransform("unknown")) ==
"UnknownTransform(transform='unknown')"
-def test_void_transform_json():
+def test_void_transform_serialize():
assert VoidTransform().json() == '"void"'
+def test_void_transform_deserialize():
+ transform = TestType.parse_raw('"void"').__root__
+ assert transform == VoidTransform()
+
+
def test_void_transform_str():
assert str(VoidTransform()) == "void"
def test_void_transform_repr():
- assert repr(VoidTransform()) == "transforms.always_null()"
+ assert repr(VoidTransform()) == "VoidTransform()"