This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 6708a6ea Update table metadata throughout transaction (#471)
6708a6ea is described below

commit 6708a6eaa76a2b4aab58601f10210f778b3f03a4
Author: Fokko Driesprong <[email protected]>
AuthorDate: Thu Feb 29 13:10:42 2024 +0100

    Update table metadata throughout transaction (#471)
    
    * Update table metadata throughout transaction
    
    This PR add support for updating the table metadata throughout
    the transaction.
    
    This way, if a schema is first evolved, and then a snapshot is
    created based on the latest schema, it will be able to find the
    schema.
    
    * Fix integration tests
    
    * Thanks Honah!
    
    * Include the partition evolution
    
    * Cleanup
---
 pyiceberg/io/pyarrow.py               |  21 +-
 pyiceberg/table/__init__.py           | 537 ++++++++++++++++------------------
 pyiceberg/table/metadata.py           |  43 +++
 tests/catalog/test_sql.py             |  52 ++++
 tests/integration/test_rest_schema.py |  23 +-
 tests/integration/test_writes.py      |   2 +-
 tests/table/test_init.py              |  15 +-
 tests/test_schema.py                  |  42 +--
 8 files changed, 396 insertions(+), 339 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index d41f7a07..be944ffb 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -125,6 +125,7 @@ from pyiceberg.schema import (
     visit_with_partner,
 )
 from pyiceberg.table import PropertyUtil, TableProperties, WriteTask
+from pyiceberg.table.metadata import TableMetadata
 from pyiceberg.table.name_mapping import NameMapping
 from pyiceberg.transforms import TruncateTransform
 from pyiceberg.typedef import EMPTY_DICT, Properties, Record
@@ -1720,7 +1721,7 @@ def fill_parquet_file_metadata(
     data_file.split_offsets = split_offsets
 
 
-def write_file(table: Table, tasks: Iterator[WriteTask], file_schema: 
Optional[Schema] = None) -> Iterator[DataFile]:
+def write_file(io: FileIO, table_metadata: TableMetadata, tasks: 
Iterator[WriteTask]) -> Iterator[DataFile]:
     task = next(tasks)
 
     try:
@@ -1730,15 +1731,15 @@ def write_file(table: Table, tasks: 
Iterator[WriteTask], file_schema: Optional[S
     except StopIteration:
         pass
 
-    parquet_writer_kwargs = _get_parquet_writer_kwargs(table.properties)
+    parquet_writer_kwargs = 
_get_parquet_writer_kwargs(table_metadata.properties)
 
-    file_path = 
f'{table.location()}/data/{task.generate_data_file_filename("parquet")}'
-    file_schema = file_schema or table.schema()
-    arrow_file_schema = schema_to_pyarrow(file_schema)
+    file_path = 
f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
+    schema = table_metadata.schema()
+    arrow_file_schema = schema_to_pyarrow(schema)
 
-    fo = table.io.new_output(file_path)
+    fo = io.new_output(file_path)
     row_group_size = PropertyUtil.property_as_int(
-        properties=table.properties,
+        properties=table_metadata.properties,
         property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
         default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT,
     )
@@ -1757,7 +1758,7 @@ def write_file(table: Table, tasks: Iterator[WriteTask], 
file_schema: Optional[S
         # sort_order_id=task.sort_order_id,
         sort_order_id=None,
         # Just copy these from the table for now
-        spec_id=table.spec().spec_id,
+        spec_id=table_metadata.default_spec_id,
         equality_ids=None,
         key_metadata=None,
     )
@@ -1765,8 +1766,8 @@ def write_file(table: Table, tasks: Iterator[WriteTask], 
file_schema: Optional[S
     fill_parquet_file_metadata(
         data_file=data_file,
         parquet_metadata=writer.writer.metadata,
-        stats_columns=compute_statistics_plan(file_schema, table.properties),
-        parquet_column_mapping=parquet_path_to_id_mapping(file_schema),
+        stats_columns=compute_statistics_plan(schema, 
table_metadata.properties),
+        parquet_column_mapping=parquet_path_to_id_mapping(schema),
     )
     return iter([data_file])
 
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index e29369a2..1a4183c9 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -31,6 +31,7 @@ from typing import (
     Any,
     Callable,
     Dict,
+    Generic,
     Iterable,
     List,
     Literal,
@@ -137,6 +138,7 @@ if TYPE_CHECKING:
 
     from pyiceberg.catalog import Catalog
 
+
 ALWAYS_TRUE = AlwaysTrue()
 TABLE_ROOT_ID = -1
 
@@ -229,18 +231,23 @@ class PropertyUtil:
 
 class Transaction:
     _table: Table
+    table_metadata: TableMetadata
+    _autocommit: bool
     _updates: Tuple[TableUpdate, ...]
     _requirements: Tuple[TableRequirement, ...]
 
-    def __init__(
-        self,
-        table: Table,
-        actions: Optional[Tuple[TableUpdate, ...]] = None,
-        requirements: Optional[Tuple[TableRequirement, ...]] = None,
-    ):
+    def __init__(self, table: Table, autocommit: bool = False):
+        """Open a transaction to stage and commit changes to a table.
+
+        Args:
+            table: The table that will be altered.
+            autocommit: Option to automatically commit the changes when they 
are staged.
+        """
+        self.table_metadata = table.metadata
         self._table = table
-        self._updates = actions or ()
-        self._requirements = requirements or ()
+        self._autocommit = autocommit
+        self._updates = ()
+        self._requirements = ()
 
     def __enter__(self) -> Transaction:
         """Start a transaction to update the table."""
@@ -248,49 +255,23 @@ class Transaction:
 
     def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
         """Close and commit the transaction."""
-        fresh_table = self.commit_transaction()
-        # Update the new data in place
-        self._table.metadata = fresh_table.metadata
-        self._table.metadata_location = fresh_table.metadata_location
+        self.commit_transaction()
 
-    def _append_updates(self, *new_updates: TableUpdate) -> Transaction:
-        """Append updates to the set of staged updates.
+    def _apply(self, updates: Tuple[TableUpdate, ...], requirements: 
Tuple[TableRequirement, ...] = ()) -> Transaction:
+        """Check if the requirements are met, and applies the updates to the 
metadata."""
+        for requirement in requirements:
+            requirement.validate(self.table_metadata)
 
-        Args:
-            *new_updates: Any new updates.
+        self._updates += updates
+        self._requirements += requirements
 
-        Raises:
-            ValueError: When the type of update is not unique.
+        self.table_metadata = update_table_metadata(self.table_metadata, 
updates)
 
-        Returns:
-            Transaction object with the new updates appended.
-        """
-        for new_update in new_updates:
-            # explicitly get type of new_update as new_update is an 
instantiated class
-            type_new_update = type(new_update)
-            if any(isinstance(update, type_new_update) for update in 
self._updates):
-                raise ValueError(f"Updates in a single commit need to be 
unique, duplicate: {type_new_update}")
-        self._updates = self._updates + new_updates
-        return self
+        if self._autocommit:
+            self.commit_transaction()
+            self._updates = ()
+            self._requirements = ()
 
-    def _append_requirements(self, *new_requirements: TableRequirement) -> 
Transaction:
-        """Append requirements to the set of staged requirements.
-
-        Args:
-            *new_requirements: Any new requirements.
-
-        Raises:
-            ValueError: When the type of requirement is not unique.
-
-        Returns:
-            Transaction object with the new requirements appended.
-        """
-        for new_requirement in new_requirements:
-            # explicitly get type of new_update as requirement is an 
instantiated class
-            type_new_requirement = type(new_requirement)
-            if any(isinstance(requirement, type_new_requirement) for 
requirement in self._requirements):
-                raise ValueError(f"Requirements in a single commit need to be 
unique, duplicate: {type_new_requirement}")
-        self._requirements = self._requirements + new_requirements
         return self
 
     def upgrade_table_version(self, format_version: Literal[1, 2]) -> 
Transaction:
@@ -307,10 +288,11 @@ class Transaction:
 
         if format_version < self._table.metadata.format_version:
             raise ValueError(f"Cannot downgrade 
v{self._table.metadata.format_version} table to v{format_version}")
+
         if format_version > self._table.metadata.format_version:
-            return 
self._append_updates(UpgradeFormatVersionUpdate(format_version=format_version))
-        else:
-            return self
+            return 
self._apply((UpgradeFormatVersionUpdate(format_version=format_version),))
+
+        return self
 
     def set_properties(self, **updates: str) -> Transaction:
         """Set properties.
@@ -323,56 +305,19 @@ class Transaction:
         Returns:
             The alter table builder.
         """
-        return self._append_updates(SetPropertiesUpdate(updates=updates))
-
-    def add_snapshot(self, snapshot: Snapshot) -> Transaction:
-        """Add a new snapshot to the table.
-
-        Returns:
-            The transaction with the add-snapshot staged.
-        """
-        self._append_updates(AddSnapshotUpdate(snapshot=snapshot))
-        
self._append_requirements(AssertTableUUID(uuid=self._table.metadata.table_uuid))
-
-        return self
-
-    def set_ref_snapshot(
-        self,
-        snapshot_id: int,
-        parent_snapshot_id: Optional[int],
-        ref_name: str,
-        type: str,
-        max_age_ref_ms: Optional[int] = None,
-        max_snapshot_age_ms: Optional[int] = None,
-        min_snapshots_to_keep: Optional[int] = None,
-    ) -> Transaction:
-        """Update a ref to a snapshot.
+        return self._apply((SetPropertiesUpdate(updates=updates),))
 
-        Returns:
-            The transaction with the set-snapshot-ref staged
-        """
-        self._append_updates(
-            SetSnapshotRefUpdate(
-                snapshot_id=snapshot_id,
-                parent_snapshot_id=parent_snapshot_id,
-                ref_name=ref_name,
-                type=type,
-                max_age_ref_ms=max_age_ref_ms,
-                max_snapshot_age_ms=max_snapshot_age_ms,
-                min_snapshots_to_keep=min_snapshots_to_keep,
-            )
-        )
-
-        
self._append_requirements(AssertRefSnapshotId(snapshot_id=parent_snapshot_id, 
ref="main"))
-        return self
-
-    def update_schema(self) -> UpdateSchema:
+    def update_schema(self, allow_incompatible_changes: bool = False, 
case_sensitive: bool = True) -> UpdateSchema:
         """Create a new UpdateSchema to alter the columns of this table.
 
+        Args:
+            allow_incompatible_changes: If changes are allowed that might 
break downstream consumers.
+            case_sensitive: If field names are case-sensitive.
+
         Returns:
             A new UpdateSchema.
         """
-        return UpdateSchema(self._table, self)
+        return UpdateSchema(self, 
allow_incompatible_changes=allow_incompatible_changes, 
case_sensitive=case_sensitive)
 
     def update_snapshot(self) -> UpdateSnapshot:
         """Create a new UpdateSnapshot to produce a new snapshot for the table.
@@ -380,7 +325,7 @@ class Transaction:
         Returns:
             A new UpdateSnapshot
         """
-        return UpdateSnapshot(self._table, self)
+        return UpdateSnapshot(self, io=self._table.io)
 
     def update_spec(self) -> UpdateSpec:
         """Create a new UpdateSpec to update the partitioning of the table.
@@ -388,7 +333,7 @@ class Transaction:
         Returns:
             A new UpdateSpec.
         """
-        return UpdateSpec(self._table, self)
+        return UpdateSpec(self)
 
     def remove_properties(self, *removals: str) -> Transaction:
         """Remove properties.
@@ -399,7 +344,7 @@ class Transaction:
         Returns:
             The alter table builder.
         """
-        return self._append_updates(RemovePropertiesUpdate(removals=removals))
+        return self._apply((RemovePropertiesUpdate(removals=removals),))
 
     def update_location(self, location: str) -> Transaction:
         """Set the new table location.
@@ -412,19 +357,12 @@ class Transaction:
         """
         raise NotImplementedError("Not yet implemented")
 
-    def schema(self) -> Schema:
-        try:
-            return next(update for update in self._updates if 
isinstance(update, AddSchemaUpdate)).schema_
-        except StopIteration:
-            return self._table.schema()
-
     def commit_transaction(self) -> Table:
         """Commit the changes to the catalog.
 
         Returns:
             The table with the updates applied.
         """
-        # Strip the catalog name
         if len(self._updates) > 0:
             self._table._do_commit(  # pylint: disable=W0212
                 updates=self._updates,
@@ -913,7 +851,7 @@ class AssertLastAssignedPartitionId(TableRequirement):
     """The table's last assigned partition id must match the requirement's 
`last-assigned-partition-id`."""
 
     type: Literal["assert-last-assigned-partition-id"] = 
Field(default="assert-last-assigned-partition-id")
-    last_assigned_partition_id: int = Field(..., 
alias="last-assigned-partition-id")
+    last_assigned_partition_id: Optional[int] = Field(..., 
alias="last-assigned-partition-id")
 
     def validate(self, base_metadata: Optional[TableMetadata]) -> None:
         if base_metadata is None:
@@ -954,6 +892,9 @@ class AssertDefaultSortOrderId(TableRequirement):
             )
 
 
+UpdatesAndRequirements = Tuple[Tuple[TableUpdate, ...], 
Tuple[TableRequirement, ...]]
+
+
 class Namespace(IcebergRootModel[List[str]]):
     """Reference to one or more levels of a namespace."""
 
@@ -998,6 +939,11 @@ class Table:
         self.catalog = catalog
 
     def transaction(self) -> Transaction:
+        """Create a new transaction object to first stage the changes, and 
then commit them to the catalog.
+
+        Returns:
+            The transaction object
+        """
         return Transaction(self)
 
     def refresh(self) -> Table:
@@ -1080,17 +1026,6 @@ class Table:
     def last_sequence_number(self) -> int:
         return self.metadata.last_sequence_number
 
-    def next_sequence_number(self) -> int:
-        return self.last_sequence_number + 1 if self.metadata.format_version > 
1 else INITIAL_SEQUENCE_NUMBER
-
-    def new_snapshot_id(self) -> int:
-        """Generate a new snapshot-id that's not in use."""
-        snapshot_id = _generate_snapshot_id()
-        while self.snapshot_by_id(snapshot_id) is not None:
-            snapshot_id = _generate_snapshot_id()
-
-        return snapshot_id
-
     def current_snapshot(self) -> Optional[Snapshot]:
         """Get the current snapshot for this table, or None if there is no 
current snapshot."""
         if self.metadata.current_snapshot_id is not None:
@@ -1114,18 +1049,19 @@ class Table:
     def update_schema(self, allow_incompatible_changes: bool = False, 
case_sensitive: bool = True) -> UpdateSchema:
         """Create a new UpdateSchema to alter the columns of this table.
 
-        Returns:
-            A new UpdateSchema.
-        """
-        return UpdateSchema(self, 
allow_incompatible_changes=allow_incompatible_changes, 
case_sensitive=case_sensitive)
-
-    def update_snapshot(self) -> UpdateSnapshot:
-        """Create a new UpdateSnapshot to produce a new snapshot for the table.
+        Args:
+            allow_incompatible_changes: If changes are allowed that might 
break downstream consumers.
+            case_sensitive: If field names are case-sensitive.
 
         Returns:
-            A new UpdateSnapshot
+            A new UpdateSchema.
         """
-        return UpdateSnapshot(self)
+        return UpdateSchema(
+            transaction=Transaction(self, autocommit=True),
+            allow_incompatible_changes=allow_incompatible_changes,
+            case_sensitive=case_sensitive,
+            name_mapping=self.name_mapping(),
+        )
 
     def name_mapping(self) -> Optional[NameMapping]:
         """Return the table's field-id NameMapping."""
@@ -1154,12 +1090,15 @@ class Table:
 
         _check_schema(self.schema(), other_schema=df.schema)
 
-        with self.update_snapshot().fast_append() as update_snapshot:
-            # skip writing data files if the dataframe is empty
-            if df.shape[0] > 0:
-                data_files = _dataframe_to_data_files(self, 
write_uuid=update_snapshot.commit_uuid, df=df)
-                for data_file in data_files:
-                    update_snapshot.append_data_file(data_file)
+        with self.transaction() as txn:
+            with txn.update_snapshot().fast_append() as update_snapshot:
+                # skip writing data files if the dataframe is empty
+                if df.shape[0] > 0:
+                    data_files = _dataframe_to_data_files(
+                        table_metadata=self.metadata, 
write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
+                    )
+                    for data_file in data_files:
+                        update_snapshot.append_data_file(data_file)
 
     def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = 
ALWAYS_TRUE) -> None:
         """
@@ -1186,15 +1125,18 @@ class Table:
 
         _check_schema(self.schema(), other_schema=df.schema)
 
-        with self.update_snapshot().overwrite() as update_snapshot:
-            # skip writing data files if the dataframe is empty
-            if df.shape[0] > 0:
-                data_files = _dataframe_to_data_files(self, 
write_uuid=update_snapshot.commit_uuid, df=df)
-                for data_file in data_files:
-                    update_snapshot.append_data_file(data_file)
+        with self.transaction() as txn:
+            with txn.update_snapshot().overwrite() as update_snapshot:
+                # skip writing data files if the dataframe is empty
+                if df.shape[0] > 0:
+                    data_files = _dataframe_to_data_files(
+                        table_metadata=self.metadata, 
write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
+                    )
+                    for data_file in data_files:
+                        update_snapshot.append_data_file(data_file)
 
     def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
-        return UpdateSpec(self, case_sensitive=case_sensitive)
+        return UpdateSpec(Transaction(self, autocommit=True), 
case_sensitive=case_sensitive)
 
     def refs(self) -> Dict[str, SnapshotRef]:
         """Return the snapshot references in the table."""
@@ -1613,8 +1555,31 @@ class Move:
     other_field_id: Optional[int] = None
 
 
-class UpdateSchema:
-    _table: Optional[Table]
+U = TypeVar('U')
+
+
+class UpdateTableMetadata(ABC, Generic[U]):
+    _transaction: Transaction
+
+    def __init__(self, transaction: Transaction) -> None:
+        self._transaction = transaction
+
+    @abstractmethod
+    def _commit(self) -> UpdatesAndRequirements: ...
+
+    def commit(self) -> None:
+        self._transaction._apply(*self._commit())
+
+    def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
+        """Close and commit the change."""
+        self.commit()
+
+    def __enter__(self) -> U:
+        """Update the table."""
+        return self  # type: ignore
+
+
+class UpdateSchema(UpdateTableMetadata["UpdateSchema"]):
     _schema: Schema
     _last_column_id: itertools.count[int]
     _identifier_field_names: Set[str]
@@ -1629,27 +1594,25 @@ class UpdateSchema:
     _id_to_parent: Dict[int, str] = {}
     _allow_incompatible_changes: bool
     _case_sensitive: bool
-    _transaction: Optional[Transaction]
 
     def __init__(
         self,
-        table: Optional[Table],
-        transaction: Optional[Transaction] = None,
+        transaction: Transaction,
         allow_incompatible_changes: bool = False,
         case_sensitive: bool = True,
         schema: Optional[Schema] = None,
+        name_mapping: Optional[NameMapping] = None,
     ) -> None:
-        self._table = table
+        super().__init__(transaction)
 
         if isinstance(schema, Schema):
             self._schema = schema
             self._last_column_id = itertools.count(1 + schema.highest_field_id)
-        elif table is not None:
-            self._schema = table.schema()
-            self._last_column_id = itertools.count(1 + 
table.metadata.last_column_id)
         else:
-            raise ValueError("Either provide a table or a schema")
+            self._schema = self._transaction.table_metadata.schema()
+            self._last_column_id = itertools.count(1 + 
self._transaction.table_metadata.last_column_id)
 
+        self._name_mapping = name_mapping
         self._identifier_field_names = self._schema.identifier_field_names()
 
         self._adds = {}
@@ -1673,14 +1636,6 @@ class UpdateSchema:
         self._case_sensitive = case_sensitive
         self._transaction = transaction
 
-    def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
-        """Close and commit the change."""
-        return self.commit()
-
-    def __enter__(self) -> UpdateSchema:
-        """Update the table."""
-        return self
-
     def case_sensitive(self, case_sensitive: bool) -> UpdateSchema:
         """Determine if the case of schema needs to be considered when 
comparing column names.
 
@@ -2069,38 +2024,36 @@ class UpdateSchema:
 
         return self
 
-    def commit(self) -> None:
+    def _commit(self) -> UpdatesAndRequirements:
         """Apply the pending changes and commit."""
-        if self._table is None:
-            raise ValueError("Requires a table to commit to")
-
         new_schema = self._apply()
 
-        existing_schema_id = next((schema.schema_id for schema in 
self._table.metadata.schemas if schema == new_schema), None)
+        existing_schema_id = next(
+            (schema.schema_id for schema in 
self._transaction.table_metadata.schemas if schema == new_schema), None
+        )
+
+        requirements: Tuple[TableRequirement, ...] = ()
+        updates: Tuple[TableUpdate, ...] = ()
 
         # Check if it is different current schema ID
-        if existing_schema_id != self._table.schema().schema_id:
-            requirements = 
(AssertCurrentSchemaId(current_schema_id=self._schema.schema_id),)
+        if existing_schema_id != self._schema.schema_id:
+            requirements += 
(AssertCurrentSchemaId(current_schema_id=self._schema.schema_id),)
             if existing_schema_id is None:
-                last_column_id = max(self._table.metadata.last_column_id, 
new_schema.highest_field_id)
-                updates = (
+                last_column_id = 
max(self._transaction.table_metadata.last_column_id, 
new_schema.highest_field_id)
+                updates += (
                     AddSchemaUpdate(schema=new_schema, 
last_column_id=last_column_id),
                     SetCurrentSchemaUpdate(schema_id=-1),
                 )
             else:
-                updates = 
(SetCurrentSchemaUpdate(schema_id=existing_schema_id),)  # type: ignore
+                updates += 
(SetCurrentSchemaUpdate(schema_id=existing_schema_id),)
 
-            if name_mapping := self._table.name_mapping():
+            if name_mapping := self._name_mapping:
                 updated_name_mapping = update_mapping(name_mapping, 
self._updates, self._adds)
-                updates += (  # type: ignore
+                updates += (
                     
SetPropertiesUpdate(updates={TableProperties.DEFAULT_NAME_MAPPING: 
updated_name_mapping.model_dump_json()}),
                 )
 
-            if self._transaction is not None:
-                self._transaction._append_updates(*updates)  # pylint: 
disable=W0212
-                self._transaction._append_requirements(*requirements)  # 
pylint: disable=W0212
-            else:
-                self._table._do_commit(updates=updates, 
requirements=requirements)  # pylint: disable=W0212
+        return updates, requirements
 
     def _apply(self) -> Schema:
         """Apply the pending changes to the original schema and returns the 
result.
@@ -2126,7 +2079,13 @@ class UpdateSchema:
 
             field_ids.add(field.field_id)
 
-        next_schema_id = 1 + (max(self._table.schemas().keys()) if self._table 
is not None else self._schema.schema_id)
+        if txn := self._transaction:
+            next_schema_id = 1 + (
+                max(schema.schema_id for schema in txn.table_metadata.schemas) 
if txn.table_metadata is not None else 0
+            )
+        else:
+            next_schema_id = 0
+
         return Schema(*struct.fields, schema_id=next_schema_id, 
identifier_field_ids=field_ids)
 
     def assign_new_column_id(self) -> int:
@@ -2456,20 +2415,6 @@ def _add_and_move_fields(
     return None if len(adds) == 0 else tuple(*fields, *adds)
 
 
-def _generate_snapshot_id() -> int:
-    """Generate a new Snapshot ID from a UUID.
-
-    Returns: An 64 bit long
-    """
-    rnd_uuid = uuid.uuid4()
-    snapshot_id = int.from_bytes(
-        bytes(lhs ^ rhs for lhs, rhs in zip(rnd_uuid.bytes[0:8], 
rnd_uuid.bytes[8:16])), byteorder='little', signed=True
-    )
-    snapshot_id = snapshot_id if snapshot_id >= 0 else snapshot_id * -1
-
-    return snapshot_id
-
-
 @dataclass(frozen=True)
 class WriteTask:
     write_uuid: uuid.UUID
@@ -2496,7 +2441,7 @@ def _generate_manifest_list_path(location: str, 
snapshot_id: int, attempt: int,
 
 
 def _dataframe_to_data_files(
-    table: Table, df: pa.Table, write_uuid: Optional[uuid.UUID] = None, 
file_schema: Optional[Schema] = None
+    table_metadata: TableMetadata, df: pa.Table, io: FileIO, write_uuid: 
Optional[uuid.UUID] = None
 ) -> Iterable[DataFile]:
     """Convert a PyArrow table into a DataFile.
 
@@ -2505,7 +2450,7 @@ def _dataframe_to_data_files(
     """
     from pyiceberg.io.pyarrow import write_file
 
-    if len(table.spec().fields) > 0:
+    if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 
0]) > 0:
         raise ValueError("Cannot write to partitioned tables")
 
     counter = itertools.count(0)
@@ -2513,41 +2458,33 @@ def _dataframe_to_data_files(
 
     # This is an iter, so we don't have to materialize everything every time
     # This will be more relevant when we start doing partitioned writes
-    yield from write_file(table, iter([WriteTask(write_uuid, next(counter), 
df)]), file_schema=file_schema)
+    yield from write_file(io=io, table_metadata=table_metadata, 
tasks=iter([WriteTask(write_uuid, next(counter), df)]))
 
 
-class _MergingSnapshotProducer:
+class 
_MergingSnapshotProducer(UpdateTableMetadata["_MergingSnapshotProducer"]):
     commit_uuid: uuid.UUID
     _operation: Operation
-    _table: Table
     _snapshot_id: int
     _parent_snapshot_id: Optional[int]
     _added_data_files: List[DataFile]
-    _transaction: Optional[Transaction]
 
     def __init__(
         self,
         operation: Operation,
-        table: Table,
+        transaction: Transaction,
+        io: FileIO,
         commit_uuid: Optional[uuid.UUID] = None,
-        transaction: Optional[Transaction] = None,
     ) -> None:
+        super().__init__(transaction)
         self.commit_uuid = commit_uuid or uuid.uuid4()
+        self._io = io
         self._operation = operation
-        self._table = table
-        self._snapshot_id = table.new_snapshot_id()
+        self._snapshot_id = self._transaction.table_metadata.new_snapshot_id()
         # Since we only support the main branch for now
-        self._parent_snapshot_id = snapshot.snapshot_id if (snapshot := 
self._table.current_snapshot()) else None
+        self._parent_snapshot_id = (
+            snapshot.snapshot_id if (snapshot := 
self._transaction.table_metadata.current_snapshot()) else None
+        )
         self._added_data_files = []
-        self._transaction = transaction
-
-    def __enter__(self) -> _MergingSnapshotProducer:
-        """Start a transaction to update the table."""
-        return self
-
-    def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
-        """Close and commit the transaction."""
-        self.commit()
 
     def append_data_file(self, data_file: DataFile) -> 
_MergingSnapshotProducer:
         self._added_data_files.append(data_file)
@@ -2562,12 +2499,14 @@ class _MergingSnapshotProducer:
     def _manifests(self) -> List[ManifestFile]:
         def _write_added_manifest() -> List[ManifestFile]:
             if self._added_data_files:
-                output_file_location = 
_new_manifest_path(location=self._table.location(), num=0, 
commit_uuid=self.commit_uuid)
+                output_file_location = _new_manifest_path(
+                    location=self._transaction.table_metadata.location, num=0, 
commit_uuid=self.commit_uuid
+                )
                 with write_manifest(
-                    format_version=self._table.format_version,
-                    spec=self._table.spec(),
-                    schema=self._table.schema(),
-                    
output_file=self._table.io.new_output(output_file_location),
+                    
format_version=self._transaction.table_metadata.format_version,
+                    spec=self._transaction.table_metadata.spec(),
+                    schema=self._transaction.table_metadata.schema(),
+                    output_file=self._io.new_output(output_file_location),
                     snapshot_id=self._snapshot_id,
                 ) as writer:
                     for data_file in self._added_data_files:
@@ -2588,13 +2527,15 @@ class _MergingSnapshotProducer:
             # Check if we need to mark the files as deleted
             deleted_entries = self._deleted_entries()
             if len(deleted_entries) > 0:
-                output_file_location = 
_new_manifest_path(location=self._table.location(), num=1, 
commit_uuid=self.commit_uuid)
+                output_file_location = _new_manifest_path(
+                    location=self._transaction.table_metadata.location, num=1, 
commit_uuid=self.commit_uuid
+                )
 
                 with write_manifest(
-                    format_version=self._table.format_version,
-                    spec=self._table.spec(),
-                    schema=self._table.schema(),
-                    
output_file=self._table.io.new_output(output_file_location),
+                    
format_version=self._transaction.table_metadata.format_version,
+                    spec=self._transaction.table_metadata.spec(),
+                    schema=self._transaction.table_metadata.schema(),
+                    output_file=self._io.new_output(output_file_location),
                     snapshot_id=self._snapshot_id,
                 ) as writer:
                     for delete_entry in deleted_entries:
@@ -2617,7 +2558,11 @@ class _MergingSnapshotProducer:
         for data_file in self._added_data_files:
             ssc.add_file(data_file=data_file)
 
-        previous_snapshot = 
self._table.snapshot_by_id(self._parent_snapshot_id) if 
self._parent_snapshot_id is not None else None
+        previous_snapshot = (
+            
self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id)
+            if self._parent_snapshot_id is not None
+            else None
+        )
 
         return update_snapshot_summaries(
             summary=Summary(operation=self._operation, **ssc.build()),
@@ -2625,18 +2570,21 @@ class _MergingSnapshotProducer:
             truncate_full_table=self._operation == Operation.OVERWRITE,
         )
 
-    def commit(self) -> Snapshot:
+    def _commit(self) -> UpdatesAndRequirements:
         new_manifests = self._manifests()
-        next_sequence_number = self._table.next_sequence_number()
+        next_sequence_number = 
self._transaction.table_metadata.next_sequence_number()
 
         summary = self._summary()
 
         manifest_list_file_path = _generate_manifest_list_path(
-            location=self._table.location(), snapshot_id=self._snapshot_id, 
attempt=0, commit_uuid=self.commit_uuid
+            location=self._transaction.table_metadata.location,
+            snapshot_id=self._snapshot_id,
+            attempt=0,
+            commit_uuid=self.commit_uuid,
         )
         with write_manifest_list(
-            format_version=self._table.metadata.format_version,
-            output_file=self._table.io.new_output(manifest_list_file_path),
+            format_version=self._transaction.table_metadata.format_version,
+            output_file=self._io.new_output(manifest_list_file_path),
             snapshot_id=self._snapshot_id,
             parent_snapshot_id=self._parent_snapshot_id,
             sequence_number=next_sequence_number,
@@ -2649,22 +2597,21 @@ class _MergingSnapshotProducer:
             manifest_list=manifest_list_file_path,
             sequence_number=next_sequence_number,
             summary=summary,
-            schema_id=self._table.schema().schema_id,
+            schema_id=self._transaction.table_metadata.current_schema_id,
         )
 
-        if self._transaction is not None:
-            self._transaction.add_snapshot(snapshot=snapshot)
-            self._transaction.set_ref_snapshot(
-                snapshot_id=self._snapshot_id, 
parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch"
-            )
-        else:
-            with self._table.transaction() as tx:
-                tx.add_snapshot(snapshot=snapshot)
-                tx.set_ref_snapshot(
+        return (
+            (
+                AddSnapshotUpdate(snapshot=snapshot),
+                SetSnapshotRefUpdate(
                     snapshot_id=self._snapshot_id, 
parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch"
-                )
-
-        return snapshot
+                ),
+            ),
+            (
+                
AssertTableUUID(uuid=self._transaction.table_metadata.table_uuid),
+                AssertRefSnapshotId(snapshot_id=self._parent_snapshot_id, 
ref="main"),
+            ),
+        )
 
 
 class FastAppendFiles(_MergingSnapshotProducer):
@@ -2677,12 +2624,12 @@ class FastAppendFiles(_MergingSnapshotProducer):
         existing_manifests = []
 
         if self._parent_snapshot_id is not None:
-            previous_snapshot = 
self._table.snapshot_by_id(self._parent_snapshot_id)
+            previous_snapshot = 
self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id)
 
             if previous_snapshot is None:
                 raise ValueError(f"Snapshot could not be found: 
{self._parent_snapshot_id}")
 
-            for manifest in previous_snapshot.manifests(io=self._table.io):
+            for manifest in previous_snapshot.manifests(io=self._io):
                 if manifest.has_added_files() or manifest.has_existing_files() 
or manifest.added_snapshot_id == self._snapshot_id:
                     existing_manifests.append(manifest)
 
@@ -2713,7 +2660,7 @@ class OverwriteFiles(_MergingSnapshotProducer):
         which entries are affected.
         """
         if self._parent_snapshot_id is not None:
-            previous_snapshot = 
self._table.snapshot_by_id(self._parent_snapshot_id)
+            previous_snapshot = 
self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id)
             if previous_snapshot is None:
                 # This should never happen since you cannot overwrite an empty 
table
                 raise ValueError(f"Could not find the previous snapshot: 
{self._parent_snapshot_id}")
@@ -2729,37 +2676,39 @@ class OverwriteFiles(_MergingSnapshotProducer):
                         file_sequence_number=entry.file_sequence_number,
                         data_file=entry.data_file,
                     )
-                    for entry in manifest.fetch_manifest_entry(self._table.io, 
discard_deleted=True)
+                    for entry in manifest.fetch_manifest_entry(self._io, 
discard_deleted=True)
                     if entry.data_file.content == DataFileContent.DATA
                 ]
 
-            list_of_entries = executor.map(_get_entries, 
previous_snapshot.manifests(self._table.io))
+            list_of_entries = executor.map(_get_entries, 
previous_snapshot.manifests(self._io))
             return list(chain(*list_of_entries))
         else:
             return []
 
 
 class UpdateSnapshot:
-    _table: Table
-    _transaction: Optional[Transaction]
+    _transaction: Transaction
+    _io: FileIO
 
-    def __init__(self, table: Table, transaction: Optional[Transaction] = 
None) -> None:
-        self._table = table
+    def __init__(self, transaction: Transaction, io: FileIO) -> None:
         self._transaction = transaction
+        self._io = io
 
     def fast_append(self) -> FastAppendFiles:
-        return FastAppendFiles(table=self._table, operation=Operation.APPEND, 
transaction=self._transaction)
+        return FastAppendFiles(operation=Operation.APPEND, 
transaction=self._transaction, io=self._io)
 
     def overwrite(self) -> OverwriteFiles:
         return OverwriteFiles(
-            table=self._table,
-            operation=Operation.OVERWRITE if self._table.current_snapshot() is 
not None else Operation.APPEND,
+            operation=Operation.OVERWRITE
+            if self._transaction.table_metadata.current_snapshot() is not None
+            else Operation.APPEND,
             transaction=self._transaction,
+            io=self._io,
         )
 
 
-class UpdateSpec:
-    _table: Table
+class UpdateSpec(UpdateTableMetadata["UpdateSpec"]):
+    _transaction: Transaction
     _name_to_field: Dict[str, PartitionField] = {}
     _name_to_added_field: Dict[str, PartitionField] = {}
     _transform_to_field: Dict[Tuple[int, str], PartitionField] = {}
@@ -2770,17 +2719,18 @@ class UpdateSpec:
     _adds: List[PartitionField]
     _deletes: Set[int]
     _last_assigned_partition_id: int
-    _transaction: Optional[Transaction]
 
-    def __init__(self, table: Table, transaction: Optional[Transaction] = 
None, case_sensitive: bool = True) -> None:
-        self._table = table
-        self._name_to_field = {field.name: field for field in 
table.spec().fields}
+    def __init__(self, transaction: Transaction, case_sensitive: bool = True) 
-> None:
+        super().__init__(transaction)
+        self._name_to_field = {field.name: field for field in 
transaction.table_metadata.spec().fields}
         self._name_to_added_field = {}
-        self._transform_to_field = {(field.source_id, repr(field.transform)): 
field for field in table.spec().fields}
+        self._transform_to_field = {
+            (field.source_id, repr(field.transform)): field for field in 
transaction.table_metadata.spec().fields
+        }
         self._transform_to_added_field = {}
         self._adds = []
         self._deletes = set()
-        self._last_assigned_partition_id = table.last_partition_id()
+        self._last_assigned_partition_id = 
transaction.table_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1
         self._renames = {}
         self._transaction = transaction
         self._case_sensitive = case_sensitive
@@ -2793,7 +2743,7 @@ class UpdateSpec:
         partition_field_name: Optional[str] = None,
     ) -> UpdateSpec:
         ref = Reference(source_column_name)
-        bound_ref = ref.bind(self._table.schema(), self._case_sensitive)
+        bound_ref = ref.bind(self._transaction.table_metadata.schema(), 
self._case_sensitive)
         # verify transform can actually bind it
         output_type = bound_ref.field.field_type
         if not transform.can_transform(output_type):
@@ -2864,31 +2814,24 @@ class UpdateSpec:
         self._renames[name] = new_name
         return self
 
-    def commit(self) -> None:
+    def _commit(self) -> UpdatesAndRequirements:
         new_spec = self._apply()
-        if self._table.metadata.default_spec_id != new_spec.spec_id:
-            if new_spec.spec_id not in self._table.specs():
-                updates = [AddPartitionSpecUpdate(spec=new_spec), 
SetDefaultSpecUpdate(spec_id=-1)]
-            else:
-                updates = [SetDefaultSpecUpdate(spec_id=new_spec.spec_id)]
-
-            required_last_assigned_partitioned_id = 
self._table.last_partition_id()
-            requirements = 
[AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id)]
+        updates: Tuple[TableUpdate, ...] = ()
+        requirements: Tuple[TableRequirement, ...] = ()
 
-            if self._transaction is not None:
-                self._transaction._append_updates(*updates)  # pylint: 
disable=W0212
-                self._transaction._append_requirements(*requirements)  # 
pylint: disable=W0212
+        if self._transaction.table_metadata.default_spec_id != 
new_spec.spec_id:
+            if new_spec.spec_id not in 
self._transaction.table_metadata.specs():
+                updates = (
+                    AddPartitionSpecUpdate(spec=new_spec),
+                    SetDefaultSpecUpdate(spec_id=-1),
+                )
             else:
-                
requirements.append(AssertDefaultSpecId(default_spec_id=self._table.spec().spec_id))
-                self._table._do_commit(updates=tuple(updates), 
requirements=tuple(requirements))  # pylint: disable=W0212
+                updates = (SetDefaultSpecUpdate(spec_id=new_spec.spec_id),)
 
-    def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
-        """Close and commit the change."""
-        return self.commit()
+            required_last_assigned_partitioned_id = 
self._transaction.table_metadata.last_partition_id
+            requirements = 
(AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id),)
 
-    def __enter__(self) -> UpdateSpec:
-        """Update the table."""
-        return self
+        return updates, requirements
 
     def _apply(self) -> PartitionSpec:
         def _check_and_add_partition_name(schema: Schema, name: str, 
source_id: int, partition_names: Set[str]) -> None:
@@ -2915,27 +2858,47 @@ class UpdateSpec:
 
         partition_fields = []
         partition_names: Set[str] = set()
-        for field in self._table.spec().fields:
+        for field in self._transaction.table_metadata.spec().fields:
             if field.field_id not in self._deletes:
                 renamed = self._renames.get(field.name)
                 if renamed:
                     new_field = _add_new_field(
-                        self._table.schema(), field.source_id, field.field_id, 
renamed, field.transform, partition_names
+                        self._transaction.table_metadata.schema(),
+                        field.source_id,
+                        field.field_id,
+                        renamed,
+                        field.transform,
+                        partition_names,
                     )
                 else:
                     new_field = _add_new_field(
-                        self._table.schema(), field.source_id, field.field_id, 
field.name, field.transform, partition_names
+                        self._transaction.table_metadata.schema(),
+                        field.source_id,
+                        field.field_id,
+                        field.name,
+                        field.transform,
+                        partition_names,
                     )
                 partition_fields.append(new_field)
-            elif self._table.format_version == 1:
+            elif self._transaction.table_metadata.format_version == 1:
                 renamed = self._renames.get(field.name)
                 if renamed:
                     new_field = _add_new_field(
-                        self._table.schema(), field.source_id, field.field_id, 
renamed, VoidTransform(), partition_names
+                        self._transaction.table_metadata.schema(),
+                        field.source_id,
+                        field.field_id,
+                        renamed,
+                        VoidTransform(),
+                        partition_names,
                     )
                 else:
                     new_field = _add_new_field(
-                        self._table.schema(), field.source_id, field.field_id, 
field.name, VoidTransform(), partition_names
+                        self._transaction.table_metadata.schema(),
+                        field.source_id,
+                        field.field_id,
+                        field.name,
+                        VoidTransform(),
+                        partition_names,
                     )
 
                 partition_fields.append(new_field)
@@ -2952,7 +2915,7 @@ class UpdateSpec:
         # Reuse spec id or create a new one.
         new_spec = PartitionSpec(*partition_fields)
         new_spec_id = INITIAL_PARTITION_SPEC_ID
-        for spec in self._table.specs().values():
+        for spec in self._transaction.table_metadata.specs().values():
             if new_spec.compatible_with(spec):
                 new_spec_id = spec.spec_id
                 break
@@ -2961,10 +2924,10 @@ class UpdateSpec:
         return PartitionSpec(*partition_fields, spec_id=new_spec_id)
 
     def _partition_field(self, transform_key: Tuple[int, Transform[Any, Any]], 
name: Optional[str]) -> PartitionField:
-        if self._table.metadata.format_version == 2:
+        if self._transaction.table_metadata.format_version == 2:
             source_id, transform = transform_key
             historical_fields = []
-            for spec in self._table.specs().values():
+            for spec in self._transaction.table_metadata.specs().values():
                 for field in spec.fields:
                     historical_fields.append((field.source_id, field.field_id, 
repr(field.transform), field.name))
 
@@ -2976,7 +2939,7 @@ class UpdateSpec:
         new_field_id = self._new_field_id()
         if name is None:
             tmp_field = PartitionField(transform_key[0], new_field_id, 
transform_key[1], 'unassigned_field_name')
-            name = _visit_partition_field(self._table.schema(), tmp_field, 
_PartitionNameGenerator())
+            name = 
_visit_partition_field(self._transaction.table_metadata.schema(), tmp_field, 
_PartitionNameGenerator())
         return PartitionField(transform_key[0], new_field_id, 
transform_key[1], name)
 
     def _new_field_id(self) -> int:
diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py
index 2fd8b13a..c7169151 100644
--- a/pyiceberg/table/metadata.py
+++ b/pyiceberg/table/metadata.py
@@ -226,11 +226,54 @@ class TableMetadataCommonFields(IcebergBaseModel):
         """Get the schema by schema_id."""
         return next((schema for schema in self.schemas if schema.schema_id == 
schema_id), None)
 
+    def schema(self) -> Schema:
+        """Return the schema for this table."""
+        return next(schema for schema in self.schemas if schema.schema_id == 
self.current_schema_id)
+
+    def spec(self) -> PartitionSpec:
+        """Return the partition spec of this table."""
+        return next(spec for spec in self.partition_specs if spec.spec_id == 
self.default_spec_id)
+
+    def specs(self) -> Dict[int, PartitionSpec]:
+        """Return a dict the partition specs this table."""
+        return {spec.spec_id: spec for spec in self.partition_specs}
+
+    def new_snapshot_id(self) -> int:
+        """Generate a new snapshot-id that's not in use."""
+        snapshot_id = _generate_snapshot_id()
+        while self.snapshot_by_id(snapshot_id) is not None:
+            snapshot_id = _generate_snapshot_id()
+
+        return snapshot_id
+
+    def current_snapshot(self) -> Optional[Snapshot]:
+        """Get the current snapshot for this table, or None if there is no 
current snapshot."""
+        if self.current_snapshot_id is not None:
+            return self.snapshot_by_id(self.current_snapshot_id)
+        return None
+
+    def next_sequence_number(self) -> int:
+        return self.last_sequence_number + 1 if self.format_version > 1 else 
INITIAL_SEQUENCE_NUMBER
+
     def sort_order_by_id(self, sort_order_id: int) -> Optional[SortOrder]:
         """Get the sort order by sort_order_id."""
         return next((sort_order for sort_order in self.sort_orders if 
sort_order.order_id == sort_order_id), None)
 
 
+def _generate_snapshot_id() -> int:
+    """Generate a new Snapshot ID from a UUID.
+
+    Returns: An 64 bit long
+    """
+    rnd_uuid = uuid.uuid4()
+    snapshot_id = int.from_bytes(
+        bytes(lhs ^ rhs for lhs, rhs in zip(rnd_uuid.bytes[0:8], 
rnd_uuid.bytes[8:16])), byteorder='little', signed=True
+    )
+    snapshot_id = snapshot_id if snapshot_id >= 0 else snapshot_id * -1
+
+    return snapshot_id
+
+
 class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel):
     """Represents version 1 of the Table Metadata.
 
diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py
index 2b1167f1..9f4d4af4 100644
--- a/tests/catalog/test_sql.py
+++ b/tests/catalog/test_sql.py
@@ -39,6 +39,7 @@ from pyiceberg.io import FSSPEC_FILE_IO, PY_IO_IMPL
 from pyiceberg.io.pyarrow import schema_to_pyarrow
 from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC
 from pyiceberg.schema import Schema
+from pyiceberg.table import _dataframe_to_data_files
 from pyiceberg.table.snapshots import Operation
 from pyiceberg.table.sorting import (
     NullOrder,
@@ -863,3 +864,54 @@ def test_concurrent_commit_table(catalog: SqlCatalog, 
table_schema_simple: Schem
         # This one should fail since it already has been updated
         with table_b.update_schema() as update:
             update.add_column(path="c", field_type=IntegerType())
+
+
[email protected](
+    'catalog',
+    [
+        lazy_fixture('catalog_memory'),
+        lazy_fixture('catalog_sqlite'),
+        lazy_fixture('catalog_sqlite_without_rowcount'),
+    ],
+)
[email protected]("format_version", [1, 2])
+def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None:
+    identifier = 
f"default.arrow_write_data_and_evolve_schema_v{format_version}"
+
+    try:
+        catalog.create_namespace("default")
+    except NamespaceAlreadyExistsError:
+        pass
+
+    try:
+        catalog.drop_table(identifier=identifier)
+    except NoSuchTableError:
+        pass
+
+    pa_table = pa.Table.from_pydict(
+        {
+            'foo': ['a', None, 'z'],
+        },
+        schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
+    )
+
+    tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema, 
properties={"format-version": str(format_version)})
+
+    pa_table_with_column = pa.Table.from_pydict(
+        {
+            'foo': ['a', None, 'z'],
+            'bar': [19, None, 25],
+        },
+        schema=pa.schema([
+            pa.field("foo", pa.string(), nullable=True),
+            pa.field("bar", pa.int32(), nullable=True),
+        ]),
+    )
+
+    with tbl.transaction() as txn:
+        with txn.update_schema() as schema_txn:
+            schema_txn.union_by_name(pa_table_with_column.schema)
+
+        with txn.update_snapshot().fast_append() as snapshot_update:
+            for data_file in 
_dataframe_to_data_files(table_metadata=txn.table_metadata, 
df=pa_table_with_column, io=tbl.io):
+                snapshot_update.append_data_file(data_file)
diff --git a/tests/integration/test_rest_schema.py 
b/tests/integration/test_rest_schema.py
index aae07cab..17fb3380 100644
--- a/tests/integration/test_rest_schema.py
+++ b/tests/integration/test_rest_schema.py
@@ -84,7 +84,7 @@ def _create_table_with_schema(catalog: Catalog, schema: 
Schema) -> Table:
 @pytest.mark.integration
 def test_add_already_exists(catalog: Catalog, table_schema_nested: Schema) -> 
None:
     table = _create_table_with_schema(catalog, table_schema_nested)
-    update = UpdateSchema(table)
+    update = table.update_schema()
 
     with pytest.raises(ValueError) as exc_info:
         update.add_column("foo", IntegerType())
@@ -98,7 +98,7 @@ def test_add_already_exists(catalog: Catalog, 
table_schema_nested: Schema) -> No
 @pytest.mark.integration
 def test_add_to_non_struct_type(catalog: Catalog, table_schema_simple: Schema) 
-> None:
     table = _create_table_with_schema(catalog, table_schema_simple)
-    update = UpdateSchema(table)
+    update = table.update_schema()
     with pytest.raises(ValueError) as exc_info:
         update.add_column(path=("foo", "lat"), field_type=IntegerType())
     assert "Cannot add column 'lat' to non-struct type: foo" in 
str(exc_info.value)
@@ -1066,13 +1066,13 @@ def test_add_nested_list_of_structs(catalog: Catalog) 
-> None:
 def test_add_required_column(catalog: Catalog) -> None:
     schema_ = Schema(NestedField(field_id=1, name="a", 
field_type=BooleanType(), required=False))
     table = _create_table_with_schema(catalog, schema_)
-    update = UpdateSchema(table)
+    update = table.update_schema()
     with pytest.raises(ValueError) as exc_info:
         update.add_column(path="data", field_type=IntegerType(), required=True)
     assert "Incompatible change: cannot add required column: data" in 
str(exc_info.value)
 
     new_schema = (
-        UpdateSchema(table, allow_incompatible_changes=True)  # pylint: 
disable=W0212
+        UpdateSchema(transaction=table.transaction(), 
allow_incompatible_changes=True)
         .add_column(path="data", field_type=IntegerType(), required=True)
         ._apply()
     )
@@ -1088,12 +1088,13 @@ def test_add_required_column_case_insensitive(catalog: 
Catalog) -> None:
     table = _create_table_with_schema(catalog, schema_)
 
     with pytest.raises(ValueError) as exc_info:
-        with UpdateSchema(table, allow_incompatible_changes=True) as update:
-            update.case_sensitive(False).add_column(path="ID", 
field_type=IntegerType(), required=True)
+        with table.transaction() as txn:
+            with txn.update_schema(allow_incompatible_changes=True) as update:
+                update.case_sensitive(False).add_column(path="ID", 
field_type=IntegerType(), required=True)
     assert "already exists: ID" in str(exc_info.value)
 
     new_schema = (
-        UpdateSchema(table, allow_incompatible_changes=True)  # pylint: 
disable=W0212
+        UpdateSchema(transaction=table.transaction(), 
allow_incompatible_changes=True)
         .add_column(path="ID", field_type=IntegerType(), required=True)
         ._apply()
     )
@@ -1264,7 +1265,7 @@ def test_mixed_changes(catalog: Catalog) -> None:
 @pytest.mark.integration
 def test_ambiguous_column(catalog: Catalog, table_schema_nested: Schema) -> 
None:
     table = _create_table_with_schema(catalog, table_schema_nested)
-    update = UpdateSchema(table)
+    update = UpdateSchema(transaction=table.transaction())
 
     with pytest.raises(ValueError) as exc_info:
         update.add_column(path="location.latitude", field_type=IntegerType())
@@ -2507,16 +2508,14 @@ def 
test_two_add_schemas_in_a_single_transaction(catalog: Catalog) -> None:
         ),
     )
 
-    with pytest.raises(ValueError) as exc_info:
+    with pytest.raises(CommitFailedException) as exc_info:
         with tbl.transaction() as tr:
             with tr.update_schema() as update:
                 update.add_column("bar", field_type=StringType())
             with tr.update_schema() as update:
                 update.add_column("baz", field_type=StringType())
 
-    assert "Updates in a single commit need to be unique, duplicate: <class 
'pyiceberg.table.AddSchemaUpdate'>" in str(
-        exc_info.value
-    )
+    assert "CommitFailedException: Requirement failed: current schema changed: 
expected id 1 != 0" in str(exc_info.value)
 
 
 @pytest.mark.integration
diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py
index a16ba48a..388e566b 100644
--- a/tests/integration/test_writes.py
+++ b/tests/integration/test_writes.py
@@ -652,5 +652,5 @@ def test_write_and_evolve(session_catalog: Catalog, 
format_version: int) -> None
             schema_txn.union_by_name(pa_table_with_column.schema)
 
         with txn.update_snapshot().fast_append() as snapshot_update:
-            for data_file in _dataframe_to_data_files(table=tbl, 
df=pa_table_with_column, file_schema=txn.schema()):
+            for data_file in 
_dataframe_to_data_files(table_metadata=txn.table_metadata, 
df=pa_table_with_column, io=tbl.io):
                 snapshot_update.append_data_file(data_file)
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index 39aa72f8..e6407b60 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -62,12 +62,11 @@ from pyiceberg.table import (
     UpdateSchema,
     _apply_table_update,
     _check_schema,
-    _generate_snapshot_id,
     _match_deletes_to_data_file,
     _TableMetadataUpdateContext,
     update_table_metadata,
 )
-from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, 
TableMetadataUtil, TableMetadataV2
+from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, 
TableMetadataUtil, TableMetadataV2, _generate_snapshot_id
 from pyiceberg.table.snapshots import (
     Operation,
     Snapshot,
@@ -435,7 +434,7 @@ def test_serialize_set_properties_updates() -> None:
 
 
 def test_add_column(table_v2: Table) -> None:
-    update = UpdateSchema(table_v2)
+    update = UpdateSchema(transaction=table_v2.transaction())
     update.add_column(path="b", field_type=IntegerType())
     apply_schema: Schema = update._apply()  # pylint: disable=W0212
     assert len(apply_schema.fields) == 4
@@ -469,7 +468,7 @@ def test_add_primitive_type_column(table_v2: Table) -> None:
 
     for name, type_ in primitive_type.items():
         field_name = f"new_column_{name}"
-        update = UpdateSchema(table_v2)
+        update = UpdateSchema(transaction=table_v2.transaction())
         update.add_column(path=field_name, field_type=type_, 
doc=f"new_column_{name}")
         new_schema = update._apply()  # pylint: disable=W0212
 
@@ -481,7 +480,7 @@ def test_add_primitive_type_column(table_v2: Table) -> None:
 def test_add_nested_type_column(table_v2: Table) -> None:
     # add struct type column
     field_name = "new_column_struct"
-    update = UpdateSchema(table_v2)
+    update = UpdateSchema(transaction=table_v2.transaction())
     struct_ = StructType(
         NestedField(1, "lat", DoubleType()),
         NestedField(2, "long", DoubleType()),
@@ -499,7 +498,7 @@ def test_add_nested_type_column(table_v2: Table) -> None:
 def test_add_nested_map_type_column(table_v2: Table) -> None:
     # add map type column
     field_name = "new_column_map"
-    update = UpdateSchema(table_v2)
+    update = UpdateSchema(transaction=table_v2.transaction())
     map_ = MapType(1, StringType(), 2, IntegerType(), False)
     update.add_column(path=field_name, field_type=map_)
     new_schema = update._apply()  # pylint: disable=W0212
@@ -511,7 +510,7 @@ def test_add_nested_map_type_column(table_v2: Table) -> 
None:
 def test_add_nested_list_type_column(table_v2: Table) -> None:
     # add list type column
     field_name = "new_column_list"
-    update = UpdateSchema(table_v2)
+    update = UpdateSchema(transaction=table_v2.transaction())
     list_ = ListType(
         element_id=101,
         element_type=StructType(
@@ -806,7 +805,7 @@ def test_metadata_isolation_from_illegal_updates(table_v1: 
Table) -> None:
 
 def test_generate_snapshot_id(table_v2: Table) -> None:
     assert isinstance(_generate_snapshot_id(), int)
-    assert isinstance(table_v2.new_snapshot_id(), int)
+    assert isinstance(table_v2.metadata.new_snapshot_id(), int)
 
 
 def test_assert_create(table_v2: Table) -> None:
diff --git a/tests/test_schema.py b/tests/test_schema.py
index cfee6e7f..6394b72b 100644
--- a/tests/test_schema.py
+++ b/tests/test_schema.py
@@ -928,7 +928,7 @@ def primitive_fields() -> List[NestedField]:
 def test_add_top_level_primitives(primitive_fields: NestedField) -> None:
     for primitive_field in primitive_fields:
         new_schema = Schema(primitive_field)
-        applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+        applied = UpdateSchema(transaction=None, 
schema=Schema()).union_by_name(new_schema)._apply()  # type: ignore
         assert applied == new_schema
 
 
@@ -942,7 +942,7 @@ def test_add_top_level_list_of_primitives(primitive_fields: 
NestedField) -> None
                 required=False,
             )
         )
-        applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+        applied = UpdateSchema(transaction=None, 
schema=Schema()).union_by_name(new_schema)._apply()  # type: ignore
         assert applied.as_struct() == new_schema.as_struct()
 
 
@@ -958,7 +958,7 @@ def test_add_top_level_map_of_primitives(primitive_fields: 
NestedField) -> None:
                 required=False,
             )
         )
-        applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+        applied = UpdateSchema(transaction=None, 
schema=Schema()).union_by_name(new_schema)._apply()  # type: ignore
         assert applied.as_struct() == new_schema.as_struct()
 
 
@@ -972,7 +972,7 @@ def test_add_top_struct_of_primitives(primitive_fields: 
NestedField) -> None:
                 required=False,
             )
         )
-        applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+        applied = UpdateSchema(transaction=None, 
schema=Schema()).union_by_name(new_schema)._apply()  # type: ignore
         assert applied.as_struct() == new_schema.as_struct()
 
 
@@ -987,7 +987,7 @@ def test_add_nested_primitive(primitive_fields: 
NestedField) -> None:
                 required=False,
             )
         )
-        applied = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+        applied = UpdateSchema(None, None, 
schema=current_schema).union_by_name(new_schema)._apply()  # type: ignore
         assert applied.as_struct() == new_schema.as_struct()
 
 
@@ -1007,7 +1007,7 @@ def test_add_nested_primitives(primitive_fields: 
NestedField) -> None:
             field_id=1, name="aStruct", 
field_type=StructType(*_primitive_fields(TEST_PRIMITIVE_TYPES, 2)), 
required=False
         )
     )
-    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+    applied = UpdateSchema(transaction=None, 
schema=current_schema).union_by_name(new_schema)._apply()  # type: ignore
     assert applied.as_struct() == new_schema.as_struct()
 
 
@@ -1048,7 +1048,7 @@ def test_add_nested_lists(primitive_fields: NestedField) 
-> None:
             required=False,
         )
     )
-    applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+    applied = UpdateSchema(transaction=None, 
schema=Schema()).union_by_name(new_schema)._apply()  # type: ignore
     assert applied.as_struct() == new_schema.as_struct()
 
 
@@ -1098,7 +1098,7 @@ def test_add_nested_struct(primitive_fields: NestedField) 
-> None:
             required=False,
         )
     )
-    applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+    applied = UpdateSchema(transaction=None, 
schema=Schema()).union_by_name(new_schema)._apply()  # type: ignore
     assert applied.as_struct() == new_schema.as_struct()
 
 
@@ -1141,7 +1141,7 @@ def test_add_nested_maps(primitive_fields: NestedField) 
-> None:
             required=False,
         )
     )
-    applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+    applied = UpdateSchema(transaction=None, 
schema=Schema()).union_by_name(new_schema)._apply()  # type: ignore
     assert applied.as_struct() == new_schema.as_struct()
 
 
@@ -1164,7 +1164,7 @@ def test_detect_invalid_top_level_list() -> None:
     )
 
     with pytest.raises(ValidationError, match="Cannot change column type: 
aList.element: string -> double"):
-        _ = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+        _ = UpdateSchema(transaction=None, 
schema=current_schema).union_by_name(new_schema)._apply()  # type: ignore
 
 
 def test_detect_invalid_top_level_maps() -> None:
@@ -1186,14 +1186,14 @@ def test_detect_invalid_top_level_maps() -> None:
     )
 
     with pytest.raises(ValidationError, match="Cannot change column type: 
aMap.key: string -> uuid"):
-        _ = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+        _ = UpdateSchema(transaction=None, 
schema=current_schema).union_by_name(new_schema)._apply()  # type: ignore
 
 
 def test_promote_float_to_double() -> None:
     current_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=FloatType(), required=False))
     new_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=DoubleType(), required=False))
 
-    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+    applied = UpdateSchema(transaction=None, 
schema=current_schema).union_by_name(new_schema)._apply()  # type: ignore
 
     assert applied.as_struct() == new_schema.as_struct()
     assert len(applied.fields) == 1
@@ -1205,7 +1205,7 @@ def test_detect_invalid_promotion_double_to_float() -> 
None:
     new_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=FloatType(), required=False))
 
     with pytest.raises(ValidationError, match="Cannot change column type: 
aCol: double -> float"):
-        _ = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+        _ = UpdateSchema(transaction=None, 
schema=current_schema).union_by_name(new_schema)._apply()  # type: ignore
 
 
 # decimal(P,S) Fixed-point decimal; precision P, scale S -> Scale is fixed [1],
@@ -1214,7 +1214,7 @@ def 
test_type_promote_decimal_to_fixed_scale_with_wider_precision() -> None:
     current_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=DecimalType(precision=20, scale=1), required=False))
     new_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=DecimalType(precision=22, scale=1), required=False))
 
-    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+    applied = UpdateSchema(transaction=None, 
schema=current_schema).union_by_name(new_schema)._apply()  # type: ignore
 
     assert applied.as_struct() == new_schema.as_struct()
     assert len(applied.fields) == 1
@@ -1282,7 +1282,7 @@ def test_add_nested_structs(primitive_fields: 
NestedField) -> None:
             required=False,
         )
     )
-    applied = UpdateSchema(None, 
schema=schema).union_by_name(new_schema)._apply()
+    applied = UpdateSchema(transaction=None, 
schema=schema).union_by_name(new_schema)._apply()  # type: ignore
 
     expected = Schema(
         NestedField(
@@ -1322,7 +1322,7 @@ def test_replace_list_with_primitive() -> None:
     new_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=StringType()))
 
     with pytest.raises(ValidationError, match="Cannot change column type: 
list<string> is not a primitive"):
-        _ = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+        _ = UpdateSchema(transaction=None, 
schema=current_schema).union_by_name(new_schema)._apply()  # type: ignore
 
 
 def test_mirrored_schemas() -> None:
@@ -1345,7 +1345,7 @@ def test_mirrored_schemas() -> None:
         NestedField(9, "string6", StringType(), required=False),
     )
 
-    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(mirrored_schema)._apply()
+    applied = UpdateSchema(transaction=None, 
schema=current_schema).union_by_name(mirrored_schema)._apply()  # type: ignore
 
     assert applied.as_struct() == current_schema.as_struct()
 
@@ -1397,7 +1397,7 @@ def test_add_new_top_level_struct() -> None:
         ),
     )
 
-    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(observed_schema)._apply()
+    applied = UpdateSchema(transaction=None, 
schema=current_schema).union_by_name(observed_schema)._apply()  # type: ignore
 
     assert applied.as_struct() == observed_schema.as_struct()
 
@@ -1476,7 +1476,7 @@ def test_append_nested_struct() -> None:
         )
     )
 
-    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(observed_schema)._apply()
+    applied = UpdateSchema(transaction=None, 
schema=current_schema).union_by_name(observed_schema)._apply()  # type: ignore
 
     assert applied.as_struct() == observed_schema.as_struct()
 
@@ -1541,7 +1541,7 @@ def test_append_nested_lists() -> None:
             required=False,
         )
     )
-    union = UpdateSchema(None, 
schema=current_schema).union_by_name(observed_schema)._apply()
+    union = UpdateSchema(transaction=None, 
schema=current_schema).union_by_name(observed_schema)._apply()  # type: ignore
 
     expected = Schema(
         NestedField(
@@ -1591,7 +1591,7 @@ def test_union_with_pa_schema(primitive_fields: 
NestedField) -> None:
         pa.field("baz", pa.bool_(), nullable=True),
     ])
 
-    new_schema = UpdateSchema(None, 
schema=base_schema).union_by_name(pa_schema)._apply()
+    new_schema = UpdateSchema(transaction=None, 
schema=base_schema).union_by_name(pa_schema)._apply()  # type: ignore
 
     expected_schema = Schema(
         NestedField(field_id=1, name="foo", field_type=StringType(), 
required=True),

Reply via email to