This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 6708a6ea Update table metadata throughout transaction (#471)
6708a6ea is described below
commit 6708a6eaa76a2b4aab58601f10210f778b3f03a4
Author: Fokko Driesprong <[email protected]>
AuthorDate: Thu Feb 29 13:10:42 2024 +0100
Update table metadata throughout transaction (#471)
* Update table metadata throughout transaction
This PR add support for updating the table metadata throughout
the transaction.
This way, if a schema is first evolved, and then a snapshot is
created based on the latest schema, it will be able to find the
schema.
* Fix integration tests
* Thanks Honah!
* Include the partition evolution
* Cleanup
---
pyiceberg/io/pyarrow.py | 21 +-
pyiceberg/table/__init__.py | 537 ++++++++++++++++------------------
pyiceberg/table/metadata.py | 43 +++
tests/catalog/test_sql.py | 52 ++++
tests/integration/test_rest_schema.py | 23 +-
tests/integration/test_writes.py | 2 +-
tests/table/test_init.py | 15 +-
tests/test_schema.py | 42 +--
8 files changed, 396 insertions(+), 339 deletions(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index d41f7a07..be944ffb 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -125,6 +125,7 @@ from pyiceberg.schema import (
visit_with_partner,
)
from pyiceberg.table import PropertyUtil, TableProperties, WriteTask
+from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping
from pyiceberg.transforms import TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
@@ -1720,7 +1721,7 @@ def fill_parquet_file_metadata(
data_file.split_offsets = split_offsets
-def write_file(table: Table, tasks: Iterator[WriteTask], file_schema:
Optional[Schema] = None) -> Iterator[DataFile]:
+def write_file(io: FileIO, table_metadata: TableMetadata, tasks:
Iterator[WriteTask]) -> Iterator[DataFile]:
task = next(tasks)
try:
@@ -1730,15 +1731,15 @@ def write_file(table: Table, tasks:
Iterator[WriteTask], file_schema: Optional[S
except StopIteration:
pass
- parquet_writer_kwargs = _get_parquet_writer_kwargs(table.properties)
+ parquet_writer_kwargs =
_get_parquet_writer_kwargs(table_metadata.properties)
- file_path =
f'{table.location()}/data/{task.generate_data_file_filename("parquet")}'
- file_schema = file_schema or table.schema()
- arrow_file_schema = schema_to_pyarrow(file_schema)
+ file_path =
f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
+ schema = table_metadata.schema()
+ arrow_file_schema = schema_to_pyarrow(schema)
- fo = table.io.new_output(file_path)
+ fo = io.new_output(file_path)
row_group_size = PropertyUtil.property_as_int(
- properties=table.properties,
+ properties=table_metadata.properties,
property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT,
)
@@ -1757,7 +1758,7 @@ def write_file(table: Table, tasks: Iterator[WriteTask],
file_schema: Optional[S
# sort_order_id=task.sort_order_id,
sort_order_id=None,
# Just copy these from the table for now
- spec_id=table.spec().spec_id,
+ spec_id=table_metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
)
@@ -1765,8 +1766,8 @@ def write_file(table: Table, tasks: Iterator[WriteTask],
file_schema: Optional[S
fill_parquet_file_metadata(
data_file=data_file,
parquet_metadata=writer.writer.metadata,
- stats_columns=compute_statistics_plan(file_schema, table.properties),
- parquet_column_mapping=parquet_path_to_id_mapping(file_schema),
+ stats_columns=compute_statistics_plan(schema,
table_metadata.properties),
+ parquet_column_mapping=parquet_path_to_id_mapping(schema),
)
return iter([data_file])
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index e29369a2..1a4183c9 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -31,6 +31,7 @@ from typing import (
Any,
Callable,
Dict,
+ Generic,
Iterable,
List,
Literal,
@@ -137,6 +138,7 @@ if TYPE_CHECKING:
from pyiceberg.catalog import Catalog
+
ALWAYS_TRUE = AlwaysTrue()
TABLE_ROOT_ID = -1
@@ -229,18 +231,23 @@ class PropertyUtil:
class Transaction:
_table: Table
+ table_metadata: TableMetadata
+ _autocommit: bool
_updates: Tuple[TableUpdate, ...]
_requirements: Tuple[TableRequirement, ...]
- def __init__(
- self,
- table: Table,
- actions: Optional[Tuple[TableUpdate, ...]] = None,
- requirements: Optional[Tuple[TableRequirement, ...]] = None,
- ):
+ def __init__(self, table: Table, autocommit: bool = False):
+ """Open a transaction to stage and commit changes to a table.
+
+ Args:
+ table: The table that will be altered.
+ autocommit: Option to automatically commit the changes when they
are staged.
+ """
+ self.table_metadata = table.metadata
self._table = table
- self._updates = actions or ()
- self._requirements = requirements or ()
+ self._autocommit = autocommit
+ self._updates = ()
+ self._requirements = ()
def __enter__(self) -> Transaction:
"""Start a transaction to update the table."""
@@ -248,49 +255,23 @@ class Transaction:
def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
"""Close and commit the transaction."""
- fresh_table = self.commit_transaction()
- # Update the new data in place
- self._table.metadata = fresh_table.metadata
- self._table.metadata_location = fresh_table.metadata_location
+ self.commit_transaction()
- def _append_updates(self, *new_updates: TableUpdate) -> Transaction:
- """Append updates to the set of staged updates.
+ def _apply(self, updates: Tuple[TableUpdate, ...], requirements:
Tuple[TableRequirement, ...] = ()) -> Transaction:
+ """Check if the requirements are met, and applies the updates to the
metadata."""
+ for requirement in requirements:
+ requirement.validate(self.table_metadata)
- Args:
- *new_updates: Any new updates.
+ self._updates += updates
+ self._requirements += requirements
- Raises:
- ValueError: When the type of update is not unique.
+ self.table_metadata = update_table_metadata(self.table_metadata,
updates)
- Returns:
- Transaction object with the new updates appended.
- """
- for new_update in new_updates:
- # explicitly get type of new_update as new_update is an
instantiated class
- type_new_update = type(new_update)
- if any(isinstance(update, type_new_update) for update in
self._updates):
- raise ValueError(f"Updates in a single commit need to be
unique, duplicate: {type_new_update}")
- self._updates = self._updates + new_updates
- return self
+ if self._autocommit:
+ self.commit_transaction()
+ self._updates = ()
+ self._requirements = ()
- def _append_requirements(self, *new_requirements: TableRequirement) ->
Transaction:
- """Append requirements to the set of staged requirements.
-
- Args:
- *new_requirements: Any new requirements.
-
- Raises:
- ValueError: When the type of requirement is not unique.
-
- Returns:
- Transaction object with the new requirements appended.
- """
- for new_requirement in new_requirements:
- # explicitly get type of new_update as requirement is an
instantiated class
- type_new_requirement = type(new_requirement)
- if any(isinstance(requirement, type_new_requirement) for
requirement in self._requirements):
- raise ValueError(f"Requirements in a single commit need to be
unique, duplicate: {type_new_requirement}")
- self._requirements = self._requirements + new_requirements
return self
def upgrade_table_version(self, format_version: Literal[1, 2]) ->
Transaction:
@@ -307,10 +288,11 @@ class Transaction:
if format_version < self._table.metadata.format_version:
raise ValueError(f"Cannot downgrade
v{self._table.metadata.format_version} table to v{format_version}")
+
if format_version > self._table.metadata.format_version:
- return
self._append_updates(UpgradeFormatVersionUpdate(format_version=format_version))
- else:
- return self
+ return
self._apply((UpgradeFormatVersionUpdate(format_version=format_version),))
+
+ return self
def set_properties(self, **updates: str) -> Transaction:
"""Set properties.
@@ -323,56 +305,19 @@ class Transaction:
Returns:
The alter table builder.
"""
- return self._append_updates(SetPropertiesUpdate(updates=updates))
-
- def add_snapshot(self, snapshot: Snapshot) -> Transaction:
- """Add a new snapshot to the table.
-
- Returns:
- The transaction with the add-snapshot staged.
- """
- self._append_updates(AddSnapshotUpdate(snapshot=snapshot))
-
self._append_requirements(AssertTableUUID(uuid=self._table.metadata.table_uuid))
-
- return self
-
- def set_ref_snapshot(
- self,
- snapshot_id: int,
- parent_snapshot_id: Optional[int],
- ref_name: str,
- type: str,
- max_age_ref_ms: Optional[int] = None,
- max_snapshot_age_ms: Optional[int] = None,
- min_snapshots_to_keep: Optional[int] = None,
- ) -> Transaction:
- """Update a ref to a snapshot.
+ return self._apply((SetPropertiesUpdate(updates=updates),))
- Returns:
- The transaction with the set-snapshot-ref staged
- """
- self._append_updates(
- SetSnapshotRefUpdate(
- snapshot_id=snapshot_id,
- parent_snapshot_id=parent_snapshot_id,
- ref_name=ref_name,
- type=type,
- max_age_ref_ms=max_age_ref_ms,
- max_snapshot_age_ms=max_snapshot_age_ms,
- min_snapshots_to_keep=min_snapshots_to_keep,
- )
- )
-
-
self._append_requirements(AssertRefSnapshotId(snapshot_id=parent_snapshot_id,
ref="main"))
- return self
-
- def update_schema(self) -> UpdateSchema:
+ def update_schema(self, allow_incompatible_changes: bool = False,
case_sensitive: bool = True) -> UpdateSchema:
"""Create a new UpdateSchema to alter the columns of this table.
+ Args:
+ allow_incompatible_changes: If changes are allowed that might
break downstream consumers.
+ case_sensitive: If field names are case-sensitive.
+
Returns:
A new UpdateSchema.
"""
- return UpdateSchema(self._table, self)
+ return UpdateSchema(self,
allow_incompatible_changes=allow_incompatible_changes,
case_sensitive=case_sensitive)
def update_snapshot(self) -> UpdateSnapshot:
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
@@ -380,7 +325,7 @@ class Transaction:
Returns:
A new UpdateSnapshot
"""
- return UpdateSnapshot(self._table, self)
+ return UpdateSnapshot(self, io=self._table.io)
def update_spec(self) -> UpdateSpec:
"""Create a new UpdateSpec to update the partitioning of the table.
@@ -388,7 +333,7 @@ class Transaction:
Returns:
A new UpdateSpec.
"""
- return UpdateSpec(self._table, self)
+ return UpdateSpec(self)
def remove_properties(self, *removals: str) -> Transaction:
"""Remove properties.
@@ -399,7 +344,7 @@ class Transaction:
Returns:
The alter table builder.
"""
- return self._append_updates(RemovePropertiesUpdate(removals=removals))
+ return self._apply((RemovePropertiesUpdate(removals=removals),))
def update_location(self, location: str) -> Transaction:
"""Set the new table location.
@@ -412,19 +357,12 @@ class Transaction:
"""
raise NotImplementedError("Not yet implemented")
- def schema(self) -> Schema:
- try:
- return next(update for update in self._updates if
isinstance(update, AddSchemaUpdate)).schema_
- except StopIteration:
- return self._table.schema()
-
def commit_transaction(self) -> Table:
"""Commit the changes to the catalog.
Returns:
The table with the updates applied.
"""
- # Strip the catalog name
if len(self._updates) > 0:
self._table._do_commit( # pylint: disable=W0212
updates=self._updates,
@@ -913,7 +851,7 @@ class AssertLastAssignedPartitionId(TableRequirement):
"""The table's last assigned partition id must match the requirement's
`last-assigned-partition-id`."""
type: Literal["assert-last-assigned-partition-id"] =
Field(default="assert-last-assigned-partition-id")
- last_assigned_partition_id: int = Field(...,
alias="last-assigned-partition-id")
+ last_assigned_partition_id: Optional[int] = Field(...,
alias="last-assigned-partition-id")
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
@@ -954,6 +892,9 @@ class AssertDefaultSortOrderId(TableRequirement):
)
+UpdatesAndRequirements = Tuple[Tuple[TableUpdate, ...],
Tuple[TableRequirement, ...]]
+
+
class Namespace(IcebergRootModel[List[str]]):
"""Reference to one or more levels of a namespace."""
@@ -998,6 +939,11 @@ class Table:
self.catalog = catalog
def transaction(self) -> Transaction:
+ """Create a new transaction object to first stage the changes, and
then commit them to the catalog.
+
+ Returns:
+ The transaction object
+ """
return Transaction(self)
def refresh(self) -> Table:
@@ -1080,17 +1026,6 @@ class Table:
def last_sequence_number(self) -> int:
return self.metadata.last_sequence_number
- def next_sequence_number(self) -> int:
- return self.last_sequence_number + 1 if self.metadata.format_version >
1 else INITIAL_SEQUENCE_NUMBER
-
- def new_snapshot_id(self) -> int:
- """Generate a new snapshot-id that's not in use."""
- snapshot_id = _generate_snapshot_id()
- while self.snapshot_by_id(snapshot_id) is not None:
- snapshot_id = _generate_snapshot_id()
-
- return snapshot_id
-
def current_snapshot(self) -> Optional[Snapshot]:
"""Get the current snapshot for this table, or None if there is no
current snapshot."""
if self.metadata.current_snapshot_id is not None:
@@ -1114,18 +1049,19 @@ class Table:
def update_schema(self, allow_incompatible_changes: bool = False,
case_sensitive: bool = True) -> UpdateSchema:
"""Create a new UpdateSchema to alter the columns of this table.
- Returns:
- A new UpdateSchema.
- """
- return UpdateSchema(self,
allow_incompatible_changes=allow_incompatible_changes,
case_sensitive=case_sensitive)
-
- def update_snapshot(self) -> UpdateSnapshot:
- """Create a new UpdateSnapshot to produce a new snapshot for the table.
+ Args:
+ allow_incompatible_changes: If changes are allowed that might
break downstream consumers.
+ case_sensitive: If field names are case-sensitive.
Returns:
- A new UpdateSnapshot
+ A new UpdateSchema.
"""
- return UpdateSnapshot(self)
+ return UpdateSchema(
+ transaction=Transaction(self, autocommit=True),
+ allow_incompatible_changes=allow_incompatible_changes,
+ case_sensitive=case_sensitive,
+ name_mapping=self.name_mapping(),
+ )
def name_mapping(self) -> Optional[NameMapping]:
"""Return the table's field-id NameMapping."""
@@ -1154,12 +1090,15 @@ class Table:
_check_schema(self.schema(), other_schema=df.schema)
- with self.update_snapshot().fast_append() as update_snapshot:
- # skip writing data files if the dataframe is empty
- if df.shape[0] > 0:
- data_files = _dataframe_to_data_files(self,
write_uuid=update_snapshot.commit_uuid, df=df)
- for data_file in data_files:
- update_snapshot.append_data_file(data_file)
+ with self.transaction() as txn:
+ with txn.update_snapshot().fast_append() as update_snapshot:
+ # skip writing data files if the dataframe is empty
+ if df.shape[0] > 0:
+ data_files = _dataframe_to_data_files(
+ table_metadata=self.metadata,
write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
+ )
+ for data_file in data_files:
+ update_snapshot.append_data_file(data_file)
def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression =
ALWAYS_TRUE) -> None:
"""
@@ -1186,15 +1125,18 @@ class Table:
_check_schema(self.schema(), other_schema=df.schema)
- with self.update_snapshot().overwrite() as update_snapshot:
- # skip writing data files if the dataframe is empty
- if df.shape[0] > 0:
- data_files = _dataframe_to_data_files(self,
write_uuid=update_snapshot.commit_uuid, df=df)
- for data_file in data_files:
- update_snapshot.append_data_file(data_file)
+ with self.transaction() as txn:
+ with txn.update_snapshot().overwrite() as update_snapshot:
+ # skip writing data files if the dataframe is empty
+ if df.shape[0] > 0:
+ data_files = _dataframe_to_data_files(
+ table_metadata=self.metadata,
write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
+ )
+ for data_file in data_files:
+ update_snapshot.append_data_file(data_file)
def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
- return UpdateSpec(self, case_sensitive=case_sensitive)
+ return UpdateSpec(Transaction(self, autocommit=True),
case_sensitive=case_sensitive)
def refs(self) -> Dict[str, SnapshotRef]:
"""Return the snapshot references in the table."""
@@ -1613,8 +1555,31 @@ class Move:
other_field_id: Optional[int] = None
-class UpdateSchema:
- _table: Optional[Table]
+U = TypeVar('U')
+
+
+class UpdateTableMetadata(ABC, Generic[U]):
+ _transaction: Transaction
+
+ def __init__(self, transaction: Transaction) -> None:
+ self._transaction = transaction
+
+ @abstractmethod
+ def _commit(self) -> UpdatesAndRequirements: ...
+
+ def commit(self) -> None:
+ self._transaction._apply(*self._commit())
+
+ def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
+ """Close and commit the change."""
+ self.commit()
+
+ def __enter__(self) -> U:
+ """Update the table."""
+ return self # type: ignore
+
+
+class UpdateSchema(UpdateTableMetadata["UpdateSchema"]):
_schema: Schema
_last_column_id: itertools.count[int]
_identifier_field_names: Set[str]
@@ -1629,27 +1594,25 @@ class UpdateSchema:
_id_to_parent: Dict[int, str] = {}
_allow_incompatible_changes: bool
_case_sensitive: bool
- _transaction: Optional[Transaction]
def __init__(
self,
- table: Optional[Table],
- transaction: Optional[Transaction] = None,
+ transaction: Transaction,
allow_incompatible_changes: bool = False,
case_sensitive: bool = True,
schema: Optional[Schema] = None,
+ name_mapping: Optional[NameMapping] = None,
) -> None:
- self._table = table
+ super().__init__(transaction)
if isinstance(schema, Schema):
self._schema = schema
self._last_column_id = itertools.count(1 + schema.highest_field_id)
- elif table is not None:
- self._schema = table.schema()
- self._last_column_id = itertools.count(1 +
table.metadata.last_column_id)
else:
- raise ValueError("Either provide a table or a schema")
+ self._schema = self._transaction.table_metadata.schema()
+ self._last_column_id = itertools.count(1 +
self._transaction.table_metadata.last_column_id)
+ self._name_mapping = name_mapping
self._identifier_field_names = self._schema.identifier_field_names()
self._adds = {}
@@ -1673,14 +1636,6 @@ class UpdateSchema:
self._case_sensitive = case_sensitive
self._transaction = transaction
- def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
- """Close and commit the change."""
- return self.commit()
-
- def __enter__(self) -> UpdateSchema:
- """Update the table."""
- return self
-
def case_sensitive(self, case_sensitive: bool) -> UpdateSchema:
"""Determine if the case of schema needs to be considered when
comparing column names.
@@ -2069,38 +2024,36 @@ class UpdateSchema:
return self
- def commit(self) -> None:
+ def _commit(self) -> UpdatesAndRequirements:
"""Apply the pending changes and commit."""
- if self._table is None:
- raise ValueError("Requires a table to commit to")
-
new_schema = self._apply()
- existing_schema_id = next((schema.schema_id for schema in
self._table.metadata.schemas if schema == new_schema), None)
+ existing_schema_id = next(
+ (schema.schema_id for schema in
self._transaction.table_metadata.schemas if schema == new_schema), None
+ )
+
+ requirements: Tuple[TableRequirement, ...] = ()
+ updates: Tuple[TableUpdate, ...] = ()
# Check if it is different current schema ID
- if existing_schema_id != self._table.schema().schema_id:
- requirements =
(AssertCurrentSchemaId(current_schema_id=self._schema.schema_id),)
+ if existing_schema_id != self._schema.schema_id:
+ requirements +=
(AssertCurrentSchemaId(current_schema_id=self._schema.schema_id),)
if existing_schema_id is None:
- last_column_id = max(self._table.metadata.last_column_id,
new_schema.highest_field_id)
- updates = (
+ last_column_id =
max(self._transaction.table_metadata.last_column_id,
new_schema.highest_field_id)
+ updates += (
AddSchemaUpdate(schema=new_schema,
last_column_id=last_column_id),
SetCurrentSchemaUpdate(schema_id=-1),
)
else:
- updates =
(SetCurrentSchemaUpdate(schema_id=existing_schema_id),) # type: ignore
+ updates +=
(SetCurrentSchemaUpdate(schema_id=existing_schema_id),)
- if name_mapping := self._table.name_mapping():
+ if name_mapping := self._name_mapping:
updated_name_mapping = update_mapping(name_mapping,
self._updates, self._adds)
- updates += ( # type: ignore
+ updates += (
SetPropertiesUpdate(updates={TableProperties.DEFAULT_NAME_MAPPING:
updated_name_mapping.model_dump_json()}),
)
- if self._transaction is not None:
- self._transaction._append_updates(*updates) # pylint:
disable=W0212
- self._transaction._append_requirements(*requirements) #
pylint: disable=W0212
- else:
- self._table._do_commit(updates=updates,
requirements=requirements) # pylint: disable=W0212
+ return updates, requirements
def _apply(self) -> Schema:
"""Apply the pending changes to the original schema and returns the
result.
@@ -2126,7 +2079,13 @@ class UpdateSchema:
field_ids.add(field.field_id)
- next_schema_id = 1 + (max(self._table.schemas().keys()) if self._table
is not None else self._schema.schema_id)
+ if txn := self._transaction:
+ next_schema_id = 1 + (
+ max(schema.schema_id for schema in txn.table_metadata.schemas)
if txn.table_metadata is not None else 0
+ )
+ else:
+ next_schema_id = 0
+
return Schema(*struct.fields, schema_id=next_schema_id,
identifier_field_ids=field_ids)
def assign_new_column_id(self) -> int:
@@ -2456,20 +2415,6 @@ def _add_and_move_fields(
return None if len(adds) == 0 else tuple(*fields, *adds)
-def _generate_snapshot_id() -> int:
- """Generate a new Snapshot ID from a UUID.
-
- Returns: An 64 bit long
- """
- rnd_uuid = uuid.uuid4()
- snapshot_id = int.from_bytes(
- bytes(lhs ^ rhs for lhs, rhs in zip(rnd_uuid.bytes[0:8],
rnd_uuid.bytes[8:16])), byteorder='little', signed=True
- )
- snapshot_id = snapshot_id if snapshot_id >= 0 else snapshot_id * -1
-
- return snapshot_id
-
-
@dataclass(frozen=True)
class WriteTask:
write_uuid: uuid.UUID
@@ -2496,7 +2441,7 @@ def _generate_manifest_list_path(location: str,
snapshot_id: int, attempt: int,
def _dataframe_to_data_files(
- table: Table, df: pa.Table, write_uuid: Optional[uuid.UUID] = None,
file_schema: Optional[Schema] = None
+ table_metadata: TableMetadata, df: pa.Table, io: FileIO, write_uuid:
Optional[uuid.UUID] = None
) -> Iterable[DataFile]:
"""Convert a PyArrow table into a DataFile.
@@ -2505,7 +2450,7 @@ def _dataframe_to_data_files(
"""
from pyiceberg.io.pyarrow import write_file
- if len(table.spec().fields) > 0:
+ if len([spec for spec in table_metadata.partition_specs if spec.spec_id !=
0]) > 0:
raise ValueError("Cannot write to partitioned tables")
counter = itertools.count(0)
@@ -2513,41 +2458,33 @@ def _dataframe_to_data_files(
# This is an iter, so we don't have to materialize everything every time
# This will be more relevant when we start doing partitioned writes
- yield from write_file(table, iter([WriteTask(write_uuid, next(counter),
df)]), file_schema=file_schema)
+ yield from write_file(io=io, table_metadata=table_metadata,
tasks=iter([WriteTask(write_uuid, next(counter), df)]))
-class _MergingSnapshotProducer:
+class
_MergingSnapshotProducer(UpdateTableMetadata["_MergingSnapshotProducer"]):
commit_uuid: uuid.UUID
_operation: Operation
- _table: Table
_snapshot_id: int
_parent_snapshot_id: Optional[int]
_added_data_files: List[DataFile]
- _transaction: Optional[Transaction]
def __init__(
self,
operation: Operation,
- table: Table,
+ transaction: Transaction,
+ io: FileIO,
commit_uuid: Optional[uuid.UUID] = None,
- transaction: Optional[Transaction] = None,
) -> None:
+ super().__init__(transaction)
self.commit_uuid = commit_uuid or uuid.uuid4()
+ self._io = io
self._operation = operation
- self._table = table
- self._snapshot_id = table.new_snapshot_id()
+ self._snapshot_id = self._transaction.table_metadata.new_snapshot_id()
# Since we only support the main branch for now
- self._parent_snapshot_id = snapshot.snapshot_id if (snapshot :=
self._table.current_snapshot()) else None
+ self._parent_snapshot_id = (
+ snapshot.snapshot_id if (snapshot :=
self._transaction.table_metadata.current_snapshot()) else None
+ )
self._added_data_files = []
- self._transaction = transaction
-
- def __enter__(self) -> _MergingSnapshotProducer:
- """Start a transaction to update the table."""
- return self
-
- def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
- """Close and commit the transaction."""
- self.commit()
def append_data_file(self, data_file: DataFile) ->
_MergingSnapshotProducer:
self._added_data_files.append(data_file)
@@ -2562,12 +2499,14 @@ class _MergingSnapshotProducer:
def _manifests(self) -> List[ManifestFile]:
def _write_added_manifest() -> List[ManifestFile]:
if self._added_data_files:
- output_file_location =
_new_manifest_path(location=self._table.location(), num=0,
commit_uuid=self.commit_uuid)
+ output_file_location = _new_manifest_path(
+ location=self._transaction.table_metadata.location, num=0,
commit_uuid=self.commit_uuid
+ )
with write_manifest(
- format_version=self._table.format_version,
- spec=self._table.spec(),
- schema=self._table.schema(),
-
output_file=self._table.io.new_output(output_file_location),
+
format_version=self._transaction.table_metadata.format_version,
+ spec=self._transaction.table_metadata.spec(),
+ schema=self._transaction.table_metadata.schema(),
+ output_file=self._io.new_output(output_file_location),
snapshot_id=self._snapshot_id,
) as writer:
for data_file in self._added_data_files:
@@ -2588,13 +2527,15 @@ class _MergingSnapshotProducer:
# Check if we need to mark the files as deleted
deleted_entries = self._deleted_entries()
if len(deleted_entries) > 0:
- output_file_location =
_new_manifest_path(location=self._table.location(), num=1,
commit_uuid=self.commit_uuid)
+ output_file_location = _new_manifest_path(
+ location=self._transaction.table_metadata.location, num=1,
commit_uuid=self.commit_uuid
+ )
with write_manifest(
- format_version=self._table.format_version,
- spec=self._table.spec(),
- schema=self._table.schema(),
-
output_file=self._table.io.new_output(output_file_location),
+
format_version=self._transaction.table_metadata.format_version,
+ spec=self._transaction.table_metadata.spec(),
+ schema=self._transaction.table_metadata.schema(),
+ output_file=self._io.new_output(output_file_location),
snapshot_id=self._snapshot_id,
) as writer:
for delete_entry in deleted_entries:
@@ -2617,7 +2558,11 @@ class _MergingSnapshotProducer:
for data_file in self._added_data_files:
ssc.add_file(data_file=data_file)
- previous_snapshot =
self._table.snapshot_by_id(self._parent_snapshot_id) if
self._parent_snapshot_id is not None else None
+ previous_snapshot = (
+
self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id)
+ if self._parent_snapshot_id is not None
+ else None
+ )
return update_snapshot_summaries(
summary=Summary(operation=self._operation, **ssc.build()),
@@ -2625,18 +2570,21 @@ class _MergingSnapshotProducer:
truncate_full_table=self._operation == Operation.OVERWRITE,
)
- def commit(self) -> Snapshot:
+ def _commit(self) -> UpdatesAndRequirements:
new_manifests = self._manifests()
- next_sequence_number = self._table.next_sequence_number()
+ next_sequence_number =
self._transaction.table_metadata.next_sequence_number()
summary = self._summary()
manifest_list_file_path = _generate_manifest_list_path(
- location=self._table.location(), snapshot_id=self._snapshot_id,
attempt=0, commit_uuid=self.commit_uuid
+ location=self._transaction.table_metadata.location,
+ snapshot_id=self._snapshot_id,
+ attempt=0,
+ commit_uuid=self.commit_uuid,
)
with write_manifest_list(
- format_version=self._table.metadata.format_version,
- output_file=self._table.io.new_output(manifest_list_file_path),
+ format_version=self._transaction.table_metadata.format_version,
+ output_file=self._io.new_output(manifest_list_file_path),
snapshot_id=self._snapshot_id,
parent_snapshot_id=self._parent_snapshot_id,
sequence_number=next_sequence_number,
@@ -2649,22 +2597,21 @@ class _MergingSnapshotProducer:
manifest_list=manifest_list_file_path,
sequence_number=next_sequence_number,
summary=summary,
- schema_id=self._table.schema().schema_id,
+ schema_id=self._transaction.table_metadata.current_schema_id,
)
- if self._transaction is not None:
- self._transaction.add_snapshot(snapshot=snapshot)
- self._transaction.set_ref_snapshot(
- snapshot_id=self._snapshot_id,
parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch"
- )
- else:
- with self._table.transaction() as tx:
- tx.add_snapshot(snapshot=snapshot)
- tx.set_ref_snapshot(
+ return (
+ (
+ AddSnapshotUpdate(snapshot=snapshot),
+ SetSnapshotRefUpdate(
snapshot_id=self._snapshot_id,
parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch"
- )
-
- return snapshot
+ ),
+ ),
+ (
+
AssertTableUUID(uuid=self._transaction.table_metadata.table_uuid),
+ AssertRefSnapshotId(snapshot_id=self._parent_snapshot_id,
ref="main"),
+ ),
+ )
class FastAppendFiles(_MergingSnapshotProducer):
@@ -2677,12 +2624,12 @@ class FastAppendFiles(_MergingSnapshotProducer):
existing_manifests = []
if self._parent_snapshot_id is not None:
- previous_snapshot =
self._table.snapshot_by_id(self._parent_snapshot_id)
+ previous_snapshot =
self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id)
if previous_snapshot is None:
raise ValueError(f"Snapshot could not be found:
{self._parent_snapshot_id}")
- for manifest in previous_snapshot.manifests(io=self._table.io):
+ for manifest in previous_snapshot.manifests(io=self._io):
if manifest.has_added_files() or manifest.has_existing_files()
or manifest.added_snapshot_id == self._snapshot_id:
existing_manifests.append(manifest)
@@ -2713,7 +2660,7 @@ class OverwriteFiles(_MergingSnapshotProducer):
which entries are affected.
"""
if self._parent_snapshot_id is not None:
- previous_snapshot =
self._table.snapshot_by_id(self._parent_snapshot_id)
+ previous_snapshot =
self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id)
if previous_snapshot is None:
# This should never happen since you cannot overwrite an empty
table
raise ValueError(f"Could not find the previous snapshot:
{self._parent_snapshot_id}")
@@ -2729,37 +2676,39 @@ class OverwriteFiles(_MergingSnapshotProducer):
file_sequence_number=entry.file_sequence_number,
data_file=entry.data_file,
)
- for entry in manifest.fetch_manifest_entry(self._table.io,
discard_deleted=True)
+ for entry in manifest.fetch_manifest_entry(self._io,
discard_deleted=True)
if entry.data_file.content == DataFileContent.DATA
]
- list_of_entries = executor.map(_get_entries,
previous_snapshot.manifests(self._table.io))
+ list_of_entries = executor.map(_get_entries,
previous_snapshot.manifests(self._io))
return list(chain(*list_of_entries))
else:
return []
class UpdateSnapshot:
- _table: Table
- _transaction: Optional[Transaction]
+ _transaction: Transaction
+ _io: FileIO
- def __init__(self, table: Table, transaction: Optional[Transaction] =
None) -> None:
- self._table = table
+ def __init__(self, transaction: Transaction, io: FileIO) -> None:
self._transaction = transaction
+ self._io = io
def fast_append(self) -> FastAppendFiles:
- return FastAppendFiles(table=self._table, operation=Operation.APPEND,
transaction=self._transaction)
+ return FastAppendFiles(operation=Operation.APPEND,
transaction=self._transaction, io=self._io)
def overwrite(self) -> OverwriteFiles:
return OverwriteFiles(
- table=self._table,
- operation=Operation.OVERWRITE if self._table.current_snapshot() is
not None else Operation.APPEND,
+ operation=Operation.OVERWRITE
+ if self._transaction.table_metadata.current_snapshot() is not None
+ else Operation.APPEND,
transaction=self._transaction,
+ io=self._io,
)
-class UpdateSpec:
- _table: Table
+class UpdateSpec(UpdateTableMetadata["UpdateSpec"]):
+ _transaction: Transaction
_name_to_field: Dict[str, PartitionField] = {}
_name_to_added_field: Dict[str, PartitionField] = {}
_transform_to_field: Dict[Tuple[int, str], PartitionField] = {}
@@ -2770,17 +2719,18 @@ class UpdateSpec:
_adds: List[PartitionField]
_deletes: Set[int]
_last_assigned_partition_id: int
- _transaction: Optional[Transaction]
- def __init__(self, table: Table, transaction: Optional[Transaction] =
None, case_sensitive: bool = True) -> None:
- self._table = table
- self._name_to_field = {field.name: field for field in
table.spec().fields}
+ def __init__(self, transaction: Transaction, case_sensitive: bool = True)
-> None:
+ super().__init__(transaction)
+ self._name_to_field = {field.name: field for field in
transaction.table_metadata.spec().fields}
self._name_to_added_field = {}
- self._transform_to_field = {(field.source_id, repr(field.transform)):
field for field in table.spec().fields}
+ self._transform_to_field = {
+ (field.source_id, repr(field.transform)): field for field in
transaction.table_metadata.spec().fields
+ }
self._transform_to_added_field = {}
self._adds = []
self._deletes = set()
- self._last_assigned_partition_id = table.last_partition_id()
+ self._last_assigned_partition_id =
transaction.table_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1
self._renames = {}
self._transaction = transaction
self._case_sensitive = case_sensitive
@@ -2793,7 +2743,7 @@ class UpdateSpec:
partition_field_name: Optional[str] = None,
) -> UpdateSpec:
ref = Reference(source_column_name)
- bound_ref = ref.bind(self._table.schema(), self._case_sensitive)
+ bound_ref = ref.bind(self._transaction.table_metadata.schema(),
self._case_sensitive)
# verify transform can actually bind it
output_type = bound_ref.field.field_type
if not transform.can_transform(output_type):
@@ -2864,31 +2814,24 @@ class UpdateSpec:
self._renames[name] = new_name
return self
- def commit(self) -> None:
+ def _commit(self) -> UpdatesAndRequirements:
new_spec = self._apply()
- if self._table.metadata.default_spec_id != new_spec.spec_id:
- if new_spec.spec_id not in self._table.specs():
- updates = [AddPartitionSpecUpdate(spec=new_spec),
SetDefaultSpecUpdate(spec_id=-1)]
- else:
- updates = [SetDefaultSpecUpdate(spec_id=new_spec.spec_id)]
-
- required_last_assigned_partitioned_id =
self._table.last_partition_id()
- requirements =
[AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id)]
+ updates: Tuple[TableUpdate, ...] = ()
+ requirements: Tuple[TableRequirement, ...] = ()
- if self._transaction is not None:
- self._transaction._append_updates(*updates) # pylint:
disable=W0212
- self._transaction._append_requirements(*requirements) #
pylint: disable=W0212
+ if self._transaction.table_metadata.default_spec_id !=
new_spec.spec_id:
+ if new_spec.spec_id not in
self._transaction.table_metadata.specs():
+ updates = (
+ AddPartitionSpecUpdate(spec=new_spec),
+ SetDefaultSpecUpdate(spec_id=-1),
+ )
else:
-
requirements.append(AssertDefaultSpecId(default_spec_id=self._table.spec().spec_id))
- self._table._do_commit(updates=tuple(updates),
requirements=tuple(requirements)) # pylint: disable=W0212
+ updates = (SetDefaultSpecUpdate(spec_id=new_spec.spec_id),)
- def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
- """Close and commit the change."""
- return self.commit()
+ required_last_assigned_partitioned_id =
self._transaction.table_metadata.last_partition_id
+ requirements =
(AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id),)
- def __enter__(self) -> UpdateSpec:
- """Update the table."""
- return self
+ return updates, requirements
def _apply(self) -> PartitionSpec:
def _check_and_add_partition_name(schema: Schema, name: str,
source_id: int, partition_names: Set[str]) -> None:
@@ -2915,27 +2858,47 @@ class UpdateSpec:
partition_fields = []
partition_names: Set[str] = set()
- for field in self._table.spec().fields:
+ for field in self._transaction.table_metadata.spec().fields:
if field.field_id not in self._deletes:
renamed = self._renames.get(field.name)
if renamed:
new_field = _add_new_field(
- self._table.schema(), field.source_id, field.field_id,
renamed, field.transform, partition_names
+ self._transaction.table_metadata.schema(),
+ field.source_id,
+ field.field_id,
+ renamed,
+ field.transform,
+ partition_names,
)
else:
new_field = _add_new_field(
- self._table.schema(), field.source_id, field.field_id,
field.name, field.transform, partition_names
+ self._transaction.table_metadata.schema(),
+ field.source_id,
+ field.field_id,
+ field.name,
+ field.transform,
+ partition_names,
)
partition_fields.append(new_field)
- elif self._table.format_version == 1:
+ elif self._transaction.table_metadata.format_version == 1:
renamed = self._renames.get(field.name)
if renamed:
new_field = _add_new_field(
- self._table.schema(), field.source_id, field.field_id,
renamed, VoidTransform(), partition_names
+ self._transaction.table_metadata.schema(),
+ field.source_id,
+ field.field_id,
+ renamed,
+ VoidTransform(),
+ partition_names,
)
else:
new_field = _add_new_field(
- self._table.schema(), field.source_id, field.field_id,
field.name, VoidTransform(), partition_names
+ self._transaction.table_metadata.schema(),
+ field.source_id,
+ field.field_id,
+ field.name,
+ VoidTransform(),
+ partition_names,
)
partition_fields.append(new_field)
@@ -2952,7 +2915,7 @@ class UpdateSpec:
# Reuse spec id or create a new one.
new_spec = PartitionSpec(*partition_fields)
new_spec_id = INITIAL_PARTITION_SPEC_ID
- for spec in self._table.specs().values():
+ for spec in self._transaction.table_metadata.specs().values():
if new_spec.compatible_with(spec):
new_spec_id = spec.spec_id
break
@@ -2961,10 +2924,10 @@ class UpdateSpec:
return PartitionSpec(*partition_fields, spec_id=new_spec_id)
def _partition_field(self, transform_key: Tuple[int, Transform[Any, Any]],
name: Optional[str]) -> PartitionField:
- if self._table.metadata.format_version == 2:
+ if self._transaction.table_metadata.format_version == 2:
source_id, transform = transform_key
historical_fields = []
- for spec in self._table.specs().values():
+ for spec in self._transaction.table_metadata.specs().values():
for field in spec.fields:
historical_fields.append((field.source_id, field.field_id,
repr(field.transform), field.name))
@@ -2976,7 +2939,7 @@ class UpdateSpec:
new_field_id = self._new_field_id()
if name is None:
tmp_field = PartitionField(transform_key[0], new_field_id,
transform_key[1], 'unassigned_field_name')
- name = _visit_partition_field(self._table.schema(), tmp_field,
_PartitionNameGenerator())
+ name =
_visit_partition_field(self._transaction.table_metadata.schema(), tmp_field,
_PartitionNameGenerator())
return PartitionField(transform_key[0], new_field_id,
transform_key[1], name)
def _new_field_id(self) -> int:
diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py
index 2fd8b13a..c7169151 100644
--- a/pyiceberg/table/metadata.py
+++ b/pyiceberg/table/metadata.py
@@ -226,11 +226,54 @@ class TableMetadataCommonFields(IcebergBaseModel):
"""Get the schema by schema_id."""
return next((schema for schema in self.schemas if schema.schema_id ==
schema_id), None)
+ def schema(self) -> Schema:
+ """Return the schema for this table."""
+ return next(schema for schema in self.schemas if schema.schema_id ==
self.current_schema_id)
+
+ def spec(self) -> PartitionSpec:
+ """Return the partition spec of this table."""
+ return next(spec for spec in self.partition_specs if spec.spec_id ==
self.default_spec_id)
+
+ def specs(self) -> Dict[int, PartitionSpec]:
+ """Return a dict the partition specs this table."""
+ return {spec.spec_id: spec for spec in self.partition_specs}
+
+ def new_snapshot_id(self) -> int:
+ """Generate a new snapshot-id that's not in use."""
+ snapshot_id = _generate_snapshot_id()
+ while self.snapshot_by_id(snapshot_id) is not None:
+ snapshot_id = _generate_snapshot_id()
+
+ return snapshot_id
+
+ def current_snapshot(self) -> Optional[Snapshot]:
+ """Get the current snapshot for this table, or None if there is no
current snapshot."""
+ if self.current_snapshot_id is not None:
+ return self.snapshot_by_id(self.current_snapshot_id)
+ return None
+
+ def next_sequence_number(self) -> int:
+ return self.last_sequence_number + 1 if self.format_version > 1 else
INITIAL_SEQUENCE_NUMBER
+
def sort_order_by_id(self, sort_order_id: int) -> Optional[SortOrder]:
"""Get the sort order by sort_order_id."""
return next((sort_order for sort_order in self.sort_orders if
sort_order.order_id == sort_order_id), None)
+def _generate_snapshot_id() -> int:
+ """Generate a new Snapshot ID from a UUID.
+
+ Returns: An 64 bit long
+ """
+ rnd_uuid = uuid.uuid4()
+ snapshot_id = int.from_bytes(
+ bytes(lhs ^ rhs for lhs, rhs in zip(rnd_uuid.bytes[0:8],
rnd_uuid.bytes[8:16])), byteorder='little', signed=True
+ )
+ snapshot_id = snapshot_id if snapshot_id >= 0 else snapshot_id * -1
+
+ return snapshot_id
+
+
class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel):
"""Represents version 1 of the Table Metadata.
diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py
index 2b1167f1..9f4d4af4 100644
--- a/tests/catalog/test_sql.py
+++ b/tests/catalog/test_sql.py
@@ -39,6 +39,7 @@ from pyiceberg.io import FSSPEC_FILE_IO, PY_IO_IMPL
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC
from pyiceberg.schema import Schema
+from pyiceberg.table import _dataframe_to_data_files
from pyiceberg.table.snapshots import Operation
from pyiceberg.table.sorting import (
NullOrder,
@@ -863,3 +864,54 @@ def test_concurrent_commit_table(catalog: SqlCatalog,
table_schema_simple: Schem
# This one should fail since it already has been updated
with table_b.update_schema() as update:
update.add_column(path="c", field_type=IntegerType())
+
+
[email protected](
+ 'catalog',
+ [
+ lazy_fixture('catalog_memory'),
+ lazy_fixture('catalog_sqlite'),
+ lazy_fixture('catalog_sqlite_without_rowcount'),
+ ],
+)
[email protected]("format_version", [1, 2])
+def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None:
+ identifier =
f"default.arrow_write_data_and_evolve_schema_v{format_version}"
+
+ try:
+ catalog.create_namespace("default")
+ except NamespaceAlreadyExistsError:
+ pass
+
+ try:
+ catalog.drop_table(identifier=identifier)
+ except NoSuchTableError:
+ pass
+
+ pa_table = pa.Table.from_pydict(
+ {
+ 'foo': ['a', None, 'z'],
+ },
+ schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
+ )
+
+ tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema,
properties={"format-version": str(format_version)})
+
+ pa_table_with_column = pa.Table.from_pydict(
+ {
+ 'foo': ['a', None, 'z'],
+ 'bar': [19, None, 25],
+ },
+ schema=pa.schema([
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("bar", pa.int32(), nullable=True),
+ ]),
+ )
+
+ with tbl.transaction() as txn:
+ with txn.update_schema() as schema_txn:
+ schema_txn.union_by_name(pa_table_with_column.schema)
+
+ with txn.update_snapshot().fast_append() as snapshot_update:
+ for data_file in
_dataframe_to_data_files(table_metadata=txn.table_metadata,
df=pa_table_with_column, io=tbl.io):
+ snapshot_update.append_data_file(data_file)
diff --git a/tests/integration/test_rest_schema.py
b/tests/integration/test_rest_schema.py
index aae07cab..17fb3380 100644
--- a/tests/integration/test_rest_schema.py
+++ b/tests/integration/test_rest_schema.py
@@ -84,7 +84,7 @@ def _create_table_with_schema(catalog: Catalog, schema:
Schema) -> Table:
@pytest.mark.integration
def test_add_already_exists(catalog: Catalog, table_schema_nested: Schema) ->
None:
table = _create_table_with_schema(catalog, table_schema_nested)
- update = UpdateSchema(table)
+ update = table.update_schema()
with pytest.raises(ValueError) as exc_info:
update.add_column("foo", IntegerType())
@@ -98,7 +98,7 @@ def test_add_already_exists(catalog: Catalog,
table_schema_nested: Schema) -> No
@pytest.mark.integration
def test_add_to_non_struct_type(catalog: Catalog, table_schema_simple: Schema)
-> None:
table = _create_table_with_schema(catalog, table_schema_simple)
- update = UpdateSchema(table)
+ update = table.update_schema()
with pytest.raises(ValueError) as exc_info:
update.add_column(path=("foo", "lat"), field_type=IntegerType())
assert "Cannot add column 'lat' to non-struct type: foo" in
str(exc_info.value)
@@ -1066,13 +1066,13 @@ def test_add_nested_list_of_structs(catalog: Catalog)
-> None:
def test_add_required_column(catalog: Catalog) -> None:
schema_ = Schema(NestedField(field_id=1, name="a",
field_type=BooleanType(), required=False))
table = _create_table_with_schema(catalog, schema_)
- update = UpdateSchema(table)
+ update = table.update_schema()
with pytest.raises(ValueError) as exc_info:
update.add_column(path="data", field_type=IntegerType(), required=True)
assert "Incompatible change: cannot add required column: data" in
str(exc_info.value)
new_schema = (
- UpdateSchema(table, allow_incompatible_changes=True) # pylint:
disable=W0212
+ UpdateSchema(transaction=table.transaction(),
allow_incompatible_changes=True)
.add_column(path="data", field_type=IntegerType(), required=True)
._apply()
)
@@ -1088,12 +1088,13 @@ def test_add_required_column_case_insensitive(catalog:
Catalog) -> None:
table = _create_table_with_schema(catalog, schema_)
with pytest.raises(ValueError) as exc_info:
- with UpdateSchema(table, allow_incompatible_changes=True) as update:
- update.case_sensitive(False).add_column(path="ID",
field_type=IntegerType(), required=True)
+ with table.transaction() as txn:
+ with txn.update_schema(allow_incompatible_changes=True) as update:
+ update.case_sensitive(False).add_column(path="ID",
field_type=IntegerType(), required=True)
assert "already exists: ID" in str(exc_info.value)
new_schema = (
- UpdateSchema(table, allow_incompatible_changes=True) # pylint:
disable=W0212
+ UpdateSchema(transaction=table.transaction(),
allow_incompatible_changes=True)
.add_column(path="ID", field_type=IntegerType(), required=True)
._apply()
)
@@ -1264,7 +1265,7 @@ def test_mixed_changes(catalog: Catalog) -> None:
@pytest.mark.integration
def test_ambiguous_column(catalog: Catalog, table_schema_nested: Schema) ->
None:
table = _create_table_with_schema(catalog, table_schema_nested)
- update = UpdateSchema(table)
+ update = UpdateSchema(transaction=table.transaction())
with pytest.raises(ValueError) as exc_info:
update.add_column(path="location.latitude", field_type=IntegerType())
@@ -2507,16 +2508,14 @@ def
test_two_add_schemas_in_a_single_transaction(catalog: Catalog) -> None:
),
)
- with pytest.raises(ValueError) as exc_info:
+ with pytest.raises(CommitFailedException) as exc_info:
with tbl.transaction() as tr:
with tr.update_schema() as update:
update.add_column("bar", field_type=StringType())
with tr.update_schema() as update:
update.add_column("baz", field_type=StringType())
- assert "Updates in a single commit need to be unique, duplicate: <class
'pyiceberg.table.AddSchemaUpdate'>" in str(
- exc_info.value
- )
+ assert "CommitFailedException: Requirement failed: current schema changed:
expected id 1 != 0" in str(exc_info.value)
@pytest.mark.integration
diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py
index a16ba48a..388e566b 100644
--- a/tests/integration/test_writes.py
+++ b/tests/integration/test_writes.py
@@ -652,5 +652,5 @@ def test_write_and_evolve(session_catalog: Catalog,
format_version: int) -> None
schema_txn.union_by_name(pa_table_with_column.schema)
with txn.update_snapshot().fast_append() as snapshot_update:
- for data_file in _dataframe_to_data_files(table=tbl,
df=pa_table_with_column, file_schema=txn.schema()):
+ for data_file in
_dataframe_to_data_files(table_metadata=txn.table_metadata,
df=pa_table_with_column, io=tbl.io):
snapshot_update.append_data_file(data_file)
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index 39aa72f8..e6407b60 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -62,12 +62,11 @@ from pyiceberg.table import (
UpdateSchema,
_apply_table_update,
_check_schema,
- _generate_snapshot_id,
_match_deletes_to_data_file,
_TableMetadataUpdateContext,
update_table_metadata,
)
-from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER,
TableMetadataUtil, TableMetadataV2
+from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER,
TableMetadataUtil, TableMetadataV2, _generate_snapshot_id
from pyiceberg.table.snapshots import (
Operation,
Snapshot,
@@ -435,7 +434,7 @@ def test_serialize_set_properties_updates() -> None:
def test_add_column(table_v2: Table) -> None:
- update = UpdateSchema(table_v2)
+ update = UpdateSchema(transaction=table_v2.transaction())
update.add_column(path="b", field_type=IntegerType())
apply_schema: Schema = update._apply() # pylint: disable=W0212
assert len(apply_schema.fields) == 4
@@ -469,7 +468,7 @@ def test_add_primitive_type_column(table_v2: Table) -> None:
for name, type_ in primitive_type.items():
field_name = f"new_column_{name}"
- update = UpdateSchema(table_v2)
+ update = UpdateSchema(transaction=table_v2.transaction())
update.add_column(path=field_name, field_type=type_,
doc=f"new_column_{name}")
new_schema = update._apply() # pylint: disable=W0212
@@ -481,7 +480,7 @@ def test_add_primitive_type_column(table_v2: Table) -> None:
def test_add_nested_type_column(table_v2: Table) -> None:
# add struct type column
field_name = "new_column_struct"
- update = UpdateSchema(table_v2)
+ update = UpdateSchema(transaction=table_v2.transaction())
struct_ = StructType(
NestedField(1, "lat", DoubleType()),
NestedField(2, "long", DoubleType()),
@@ -499,7 +498,7 @@ def test_add_nested_type_column(table_v2: Table) -> None:
def test_add_nested_map_type_column(table_v2: Table) -> None:
# add map type column
field_name = "new_column_map"
- update = UpdateSchema(table_v2)
+ update = UpdateSchema(transaction=table_v2.transaction())
map_ = MapType(1, StringType(), 2, IntegerType(), False)
update.add_column(path=field_name, field_type=map_)
new_schema = update._apply() # pylint: disable=W0212
@@ -511,7 +510,7 @@ def test_add_nested_map_type_column(table_v2: Table) ->
None:
def test_add_nested_list_type_column(table_v2: Table) -> None:
# add list type column
field_name = "new_column_list"
- update = UpdateSchema(table_v2)
+ update = UpdateSchema(transaction=table_v2.transaction())
list_ = ListType(
element_id=101,
element_type=StructType(
@@ -806,7 +805,7 @@ def test_metadata_isolation_from_illegal_updates(table_v1:
Table) -> None:
def test_generate_snapshot_id(table_v2: Table) -> None:
assert isinstance(_generate_snapshot_id(), int)
- assert isinstance(table_v2.new_snapshot_id(), int)
+ assert isinstance(table_v2.metadata.new_snapshot_id(), int)
def test_assert_create(table_v2: Table) -> None:
diff --git a/tests/test_schema.py b/tests/test_schema.py
index cfee6e7f..6394b72b 100644
--- a/tests/test_schema.py
+++ b/tests/test_schema.py
@@ -928,7 +928,7 @@ def primitive_fields() -> List[NestedField]:
def test_add_top_level_primitives(primitive_fields: NestedField) -> None:
for primitive_field in primitive_fields:
new_schema = Schema(primitive_field)
- applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=Schema()).union_by_name(new_schema)._apply() # type: ignore
assert applied == new_schema
@@ -942,7 +942,7 @@ def test_add_top_level_list_of_primitives(primitive_fields:
NestedField) -> None
required=False,
)
)
- applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=Schema()).union_by_name(new_schema)._apply() # type: ignore
assert applied.as_struct() == new_schema.as_struct()
@@ -958,7 +958,7 @@ def test_add_top_level_map_of_primitives(primitive_fields:
NestedField) -> None:
required=False,
)
)
- applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=Schema()).union_by_name(new_schema)._apply() # type: ignore
assert applied.as_struct() == new_schema.as_struct()
@@ -972,7 +972,7 @@ def test_add_top_struct_of_primitives(primitive_fields:
NestedField) -> None:
required=False,
)
)
- applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=Schema()).union_by_name(new_schema)._apply() # type: ignore
assert applied.as_struct() == new_schema.as_struct()
@@ -987,7 +987,7 @@ def test_add_nested_primitive(primitive_fields:
NestedField) -> None:
required=False,
)
)
- applied = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(None, None,
schema=current_schema).union_by_name(new_schema)._apply() # type: ignore
assert applied.as_struct() == new_schema.as_struct()
@@ -1007,7 +1007,7 @@ def test_add_nested_primitives(primitive_fields:
NestedField) -> None:
field_id=1, name="aStruct",
field_type=StructType(*_primitive_fields(TEST_PRIMITIVE_TYPES, 2)),
required=False
)
)
- applied = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=current_schema).union_by_name(new_schema)._apply() # type: ignore
assert applied.as_struct() == new_schema.as_struct()
@@ -1048,7 +1048,7 @@ def test_add_nested_lists(primitive_fields: NestedField)
-> None:
required=False,
)
)
- applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=Schema()).union_by_name(new_schema)._apply() # type: ignore
assert applied.as_struct() == new_schema.as_struct()
@@ -1098,7 +1098,7 @@ def test_add_nested_struct(primitive_fields: NestedField)
-> None:
required=False,
)
)
- applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=Schema()).union_by_name(new_schema)._apply() # type: ignore
assert applied.as_struct() == new_schema.as_struct()
@@ -1141,7 +1141,7 @@ def test_add_nested_maps(primitive_fields: NestedField)
-> None:
required=False,
)
)
- applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=Schema()).union_by_name(new_schema)._apply() # type: ignore
assert applied.as_struct() == new_schema.as_struct()
@@ -1164,7 +1164,7 @@ def test_detect_invalid_top_level_list() -> None:
)
with pytest.raises(ValidationError, match="Cannot change column type:
aList.element: string -> double"):
- _ = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+ _ = UpdateSchema(transaction=None,
schema=current_schema).union_by_name(new_schema)._apply() # type: ignore
def test_detect_invalid_top_level_maps() -> None:
@@ -1186,14 +1186,14 @@ def test_detect_invalid_top_level_maps() -> None:
)
with pytest.raises(ValidationError, match="Cannot change column type:
aMap.key: string -> uuid"):
- _ = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+ _ = UpdateSchema(transaction=None,
schema=current_schema).union_by_name(new_schema)._apply() # type: ignore
def test_promote_float_to_double() -> None:
current_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=FloatType(), required=False))
new_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=DoubleType(), required=False))
- applied = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=current_schema).union_by_name(new_schema)._apply() # type: ignore
assert applied.as_struct() == new_schema.as_struct()
assert len(applied.fields) == 1
@@ -1205,7 +1205,7 @@ def test_detect_invalid_promotion_double_to_float() ->
None:
new_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=FloatType(), required=False))
with pytest.raises(ValidationError, match="Cannot change column type:
aCol: double -> float"):
- _ = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+ _ = UpdateSchema(transaction=None,
schema=current_schema).union_by_name(new_schema)._apply() # type: ignore
# decimal(P,S) Fixed-point decimal; precision P, scale S -> Scale is fixed [1],
@@ -1214,7 +1214,7 @@ def
test_type_promote_decimal_to_fixed_scale_with_wider_precision() -> None:
current_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=DecimalType(precision=20, scale=1), required=False))
new_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=DecimalType(precision=22, scale=1), required=False))
- applied = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=current_schema).union_by_name(new_schema)._apply() # type: ignore
assert applied.as_struct() == new_schema.as_struct()
assert len(applied.fields) == 1
@@ -1282,7 +1282,7 @@ def test_add_nested_structs(primitive_fields:
NestedField) -> None:
required=False,
)
)
- applied = UpdateSchema(None,
schema=schema).union_by_name(new_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=schema).union_by_name(new_schema)._apply() # type: ignore
expected = Schema(
NestedField(
@@ -1322,7 +1322,7 @@ def test_replace_list_with_primitive() -> None:
new_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=StringType()))
with pytest.raises(ValidationError, match="Cannot change column type:
list<string> is not a primitive"):
- _ = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+ _ = UpdateSchema(transaction=None,
schema=current_schema).union_by_name(new_schema)._apply() # type: ignore
def test_mirrored_schemas() -> None:
@@ -1345,7 +1345,7 @@ def test_mirrored_schemas() -> None:
NestedField(9, "string6", StringType(), required=False),
)
- applied = UpdateSchema(None,
schema=current_schema).union_by_name(mirrored_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=current_schema).union_by_name(mirrored_schema)._apply() # type: ignore
assert applied.as_struct() == current_schema.as_struct()
@@ -1397,7 +1397,7 @@ def test_add_new_top_level_struct() -> None:
),
)
- applied = UpdateSchema(None,
schema=current_schema).union_by_name(observed_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=current_schema).union_by_name(observed_schema)._apply() # type: ignore
assert applied.as_struct() == observed_schema.as_struct()
@@ -1476,7 +1476,7 @@ def test_append_nested_struct() -> None:
)
)
- applied = UpdateSchema(None,
schema=current_schema).union_by_name(observed_schema)._apply()
+ applied = UpdateSchema(transaction=None,
schema=current_schema).union_by_name(observed_schema)._apply() # type: ignore
assert applied.as_struct() == observed_schema.as_struct()
@@ -1541,7 +1541,7 @@ def test_append_nested_lists() -> None:
required=False,
)
)
- union = UpdateSchema(None,
schema=current_schema).union_by_name(observed_schema)._apply()
+ union = UpdateSchema(transaction=None,
schema=current_schema).union_by_name(observed_schema)._apply() # type: ignore
expected = Schema(
NestedField(
@@ -1591,7 +1591,7 @@ def test_union_with_pa_schema(primitive_fields:
NestedField) -> None:
pa.field("baz", pa.bool_(), nullable=True),
])
- new_schema = UpdateSchema(None,
schema=base_schema).union_by_name(pa_schema)._apply()
+ new_schema = UpdateSchema(transaction=None,
schema=base_schema).union_by_name(pa_schema)._apply() # type: ignore
expected_schema = Schema(
NestedField(field_id=1, name="foo", field_type=StringType(),
required=True),