This is an automated email from the ASF dual-hosted git repository.
kevinjqliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new b0880c85 feat: Add Set Current Snapshot to ManageSnapshots API (#2871)
b0880c85 is described below
commit b0880c855b8dfd5c03019afcc5dd67a26432ce23
Author: geruh <[email protected]>
AuthorDate: Mon Jan 12 09:01:32 2026 -0800
feat: Add Set Current Snapshot to ManageSnapshots API (#2871)
# Rationale for this change
This PR adds the ability to change the set the current snapshot of a
table. A bulk of this work was done in #758 but instead we have broken
it out to focus on the set snapshot logic first. Additionally I added a
few more tests, following the existing expire snapshots behavior.
## Are these changes tested?
Yes, added tests
## Are there any user-facing changes?
New API :)
```
table.manage_snapshots().set_current_snapshot(snapshot_id=123456789).commit()
table.manage_snapshots().set_current_snapshot(ref_name="my-tag").commit()
# chaining
table.manage_snapshots() \
.create_tag(snapshot_id=older_id, tag_name="my-tag") \
.set_current_snapshot(ref_name="my-tag") \
.commit()
```
---------
Co-authored-by: Chinmay Bhat
<[email protected]>
---
pyiceberg/table/__init__.py | 26 +++-
pyiceberg/table/update/snapshot.py | 44 +++++++
tests/integration/test_snapshot_operations.py | 88 +++++++++++++
tests/table/test_manage_snapshots.py | 179 ++++++++++++++++++++++++++
4 files changed, 335 insertions(+), 2 deletions(-)
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 88a7bd00..ae5eb400 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -275,8 +275,20 @@ class Transaction:
if exctype is None and excinst is None and exctb is None:
self.commit_transaction()
- def _apply(self, updates: tuple[TableUpdate, ...], requirements:
tuple[TableRequirement, ...] = ()) -> Transaction:
- """Check if the requirements are met, and applies the updates to the
metadata."""
+ def _stage(
+ self,
+ updates: tuple[TableUpdate, ...],
+ requirements: tuple[TableRequirement, ...] = (),
+ ) -> Transaction:
+ """Stage updates to the transaction state without committing to the
catalog.
+
+ Args:
+ updates: The updates to stage.
+ requirements: The requirements that must be met.
+
+ Returns:
+ This transaction for method chaining.
+ """
for requirement in requirements:
requirement.validate(self.table_metadata)
@@ -289,6 +301,16 @@ class Transaction:
if type(new_requirement) not in existing_requirements:
self._requirements = self._requirements + (new_requirement,)
+ return self
+
+ def _apply(
+ self,
+ updates: tuple[TableUpdate, ...],
+ requirements: tuple[TableRequirement, ...] = (),
+ ) -> Transaction:
+ """Check if the requirements are met, and applies the updates to the
metadata."""
+ self._stage(updates, requirements)
+
if self._autocommit:
self.commit_transaction()
diff --git a/pyiceberg/table/update/snapshot.py
b/pyiceberg/table/update/snapshot.py
index 84298e08..bc05aab9 100644
--- a/pyiceberg/table/update/snapshot.py
+++ b/pyiceberg/table/update/snapshot.py
@@ -843,6 +843,13 @@ class
ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
"""Apply the pending changes and commit."""
return self._updates, self._requirements
+ def _commit_if_ref_updates_exist(self) -> None:
+ """Stage any pending ref updates to the transaction state."""
+ if self._updates:
+ self._transaction._stage(*self._commit())
+ self._updates = ()
+ self._requirements = ()
+
def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
"""Remove a snapshot ref.
@@ -941,6 +948,43 @@ class
ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
"""
return self._remove_ref_snapshot(ref_name=branch_name)
+ def set_current_snapshot(self, snapshot_id: int | None = None, ref_name:
str | None = None) -> ManageSnapshots:
+ """Set the current snapshot to a specific snapshot ID or ref.
+
+ Args:
+ snapshot_id: The ID of the snapshot to set as current.
+ ref_name: The snapshot reference (branch or tag) to set as current.
+
+ Returns:
+ This for method chaining.
+
+ Raises:
+ ValueError: If neither or both arguments are provided, or if the
snapshot/ref does not exist.
+ """
+ self._commit_if_ref_updates_exist()
+
+ if (snapshot_id is None) == (ref_name is None):
+ raise ValueError("Either snapshot_id or ref_name must be provided,
not both")
+
+ target_snapshot_id: int
+ if snapshot_id is not None:
+ target_snapshot_id = snapshot_id
+ else:
+ if ref_name not in self._transaction.table_metadata.refs:
+ raise ValueError(f"Cannot find matching snapshot ID for ref:
{ref_name}")
+ target_snapshot_id =
self._transaction.table_metadata.refs[ref_name].snapshot_id
+
+ if self._transaction.table_metadata.snapshot_by_id(target_snapshot_id)
is None:
+ raise ValueError(f"Cannot set current snapshot to unknown snapshot
id: {target_snapshot_id}")
+
+ update, requirement = self._transaction._set_ref_snapshot(
+ snapshot_id=target_snapshot_id,
+ ref_name=MAIN_BRANCH,
+ type=SnapshotRefType.BRANCH,
+ )
+ self._transaction._stage(update, requirement)
+ return self
+
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
"""Expire snapshots by ID.
diff --git a/tests/integration/test_snapshot_operations.py
b/tests/integration/test_snapshot_operations.py
index 1b7f2d3a..2f0447ec 100644
--- a/tests/integration/test_snapshot_operations.py
+++ b/tests/integration/test_snapshot_operations.py
@@ -72,3 +72,91 @@ def test_remove_branch(catalog: Catalog) -> None:
# now, remove the branch
tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit()
assert tbl.metadata.refs.get(branch_name, None) is None
+
+
[email protected]
[email protected]("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
+def test_set_current_snapshot(catalog: Catalog) -> None:
+ identifier = "default.test_table_snapshot_operations"
+ tbl = catalog.load_table(identifier)
+ assert len(tbl.history()) > 2
+
+ # first get the current snapshot and an older one
+ current_snapshot_id = tbl.history()[-1].snapshot_id
+ older_snapshot_id = tbl.history()[-2].snapshot_id
+
+ # set the current snapshot to the older one
+
tbl.manage_snapshots().set_current_snapshot(snapshot_id=older_snapshot_id).commit()
+
+ tbl = catalog.load_table(identifier)
+ updated_snapshot = tbl.current_snapshot()
+ assert updated_snapshot and updated_snapshot.snapshot_id ==
older_snapshot_id
+
+ # restore table
+
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
+ tbl = catalog.load_table(identifier)
+ restored_snapshot = tbl.current_snapshot()
+ assert restored_snapshot and restored_snapshot.snapshot_id ==
current_snapshot_id
+
+
[email protected]
[email protected]("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
+def test_set_current_snapshot_by_ref(catalog: Catalog) -> None:
+ identifier = "default.test_table_snapshot_operations"
+ tbl = catalog.load_table(identifier)
+ assert len(tbl.history()) > 2
+
+ # first get the current snapshot and an older one
+ current_snapshot_id = tbl.history()[-1].snapshot_id
+ older_snapshot_id = tbl.history()[-2].snapshot_id
+ assert older_snapshot_id != current_snapshot_id
+
+ # create a tag pointing to the older snapshot
+ tag_name = "my-tag"
+ tbl.manage_snapshots().create_tag(snapshot_id=older_snapshot_id,
tag_name=tag_name).commit()
+
+ # set current snapshot using the tag name
+ tbl = catalog.load_table(identifier)
+ tbl.manage_snapshots().set_current_snapshot(ref_name=tag_name).commit()
+
+ tbl = catalog.load_table(identifier)
+ updated_snapshot = tbl.current_snapshot()
+ assert updated_snapshot and updated_snapshot.snapshot_id ==
older_snapshot_id
+
+ # restore table
+
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
+ tbl = catalog.load_table(identifier)
+ tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
+ assert tbl.metadata.refs.get(tag_name, None) is None
+
+
[email protected]
[email protected]("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
+def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) ->
None:
+ identifier = "default.test_table_snapshot_operations"
+ tbl = catalog.load_table(identifier)
+ assert len(tbl.history()) > 2
+
+ current_snapshot_id = tbl.history()[-1].snapshot_id
+ older_snapshot_id = tbl.history()[-2].snapshot_id
+ assert older_snapshot_id != current_snapshot_id
+
+ # create a tag and use it to set current snapshot
+ tag_name = "my-tag"
+ (
+ tbl.manage_snapshots()
+ .create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name)
+ .set_current_snapshot(ref_name=tag_name)
+ .commit()
+ )
+
+ tbl = catalog.load_table(identifier)
+ updated_snapshot = tbl.current_snapshot()
+ assert updated_snapshot
+ assert updated_snapshot.snapshot_id == older_snapshot_id
+
+ # restore table
+
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
+ tbl = catalog.load_table(identifier)
+ tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
+ assert tbl.metadata.refs.get(tag_name, None) is None
diff --git a/tests/table/test_manage_snapshots.py
b/tests/table/test_manage_snapshots.py
new file mode 100644
index 00000000..93301a01
--- /dev/null
+++ b/tests/table/test_manage_snapshots.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest.mock import MagicMock
+from uuid import uuid4
+
+import pytest
+
+from pyiceberg.table import CommitTableResponse, Table
+from pyiceberg.table.update import SetSnapshotRefUpdate, TableUpdate
+
+
+def _mock_commit_response(table: Table) -> CommitTableResponse:
+ return CommitTableResponse(
+ metadata=table.metadata,
+ metadata_location="s3://bucket/tbl",
+ uuid=uuid4(),
+ )
+
+
+def _get_updates(mock_catalog: MagicMock) -> tuple[TableUpdate, ...]:
+ args, _ = mock_catalog.commit_table.call_args
+ return args[2]
+
+
+def test_set_current_snapshot_basic(table_v2: Table) -> None:
+ snapshot_one = 3051729675574597004
+
+ table_v2.catalog = MagicMock()
+ table_v2.catalog.commit_table.return_value =
_mock_commit_response(table_v2)
+
+
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).commit()
+
+ table_v2.catalog.commit_table.assert_called_once()
+
+ updates = _get_updates(table_v2.catalog)
+ set_ref_updates = [u for u in updates if isinstance(u,
SetSnapshotRefUpdate)]
+
+ assert len(set_ref_updates) == 1
+ update = set_ref_updates[0]
+ assert update.snapshot_id == snapshot_one
+ assert update.ref_name == "main"
+ assert update.type == "branch"
+
+
+def test_set_current_snapshot_unknown_id(table_v2: Table) -> None:
+ invalid_snapshot_id = 1234567890000
+ table_v2.catalog = MagicMock()
+
+ with pytest.raises(ValueError, match="Cannot set current snapshot to
unknown snapshot id"):
+
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=invalid_snapshot_id).commit()
+
+ table_v2.catalog.commit_table.assert_not_called()
+
+
+def test_set_current_snapshot_to_current(table_v2: Table) -> None:
+ current_snapshot = table_v2.current_snapshot()
+ assert current_snapshot is not None
+
+ table_v2.catalog = MagicMock()
+ table_v2.catalog.commit_table.return_value =
_mock_commit_response(table_v2)
+
+
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot.snapshot_id).commit()
+
+ table_v2.catalog.commit_table.assert_called_once()
+
+
+def test_set_current_snapshot_chained_with_tag(table_v2: Table) -> None:
+ snapshot_one = 3051729675574597004
+ table_v2.catalog = MagicMock()
+ table_v2.catalog.commit_table.return_value =
_mock_commit_response(table_v2)
+
+
(table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).create_tag(snapshot_one,
"my-tag").commit())
+
+ table_v2.catalog.commit_table.assert_called_once()
+
+ updates = _get_updates(table_v2.catalog)
+ set_ref_updates = [u for u in updates if isinstance(u,
SetSnapshotRefUpdate)]
+
+ assert len(set_ref_updates) == 2
+ assert {u.ref_name for u in set_ref_updates} == {"main", "my-tag"}
+
+
+def
test_set_current_snapshot_with_extensive_snapshots(table_v2_with_extensive_snapshots:
Table) -> None:
+ snapshots = table_v2_with_extensive_snapshots.metadata.snapshots
+ assert len(snapshots) > 100
+
+ target_snapshot = snapshots[50].snapshot_id
+
+ table_v2_with_extensive_snapshots.catalog = MagicMock()
+ table_v2_with_extensive_snapshots.catalog.commit_table.return_value =
_mock_commit_response(table_v2_with_extensive_snapshots)
+
+
table_v2_with_extensive_snapshots.manage_snapshots().set_current_snapshot(snapshot_id=target_snapshot).commit()
+
+ table_v2_with_extensive_snapshots.catalog.commit_table.assert_called_once()
+
+ updates = _get_updates(table_v2_with_extensive_snapshots.catalog)
+ set_ref_updates = [u for u in updates if isinstance(u,
SetSnapshotRefUpdate)]
+
+ assert len(set_ref_updates) == 1
+ assert set_ref_updates[0].snapshot_id == target_snapshot
+
+
+def test_set_current_snapshot_by_ref_name(table_v2: Table) -> None:
+ current_snapshot = table_v2.current_snapshot()
+ assert current_snapshot is not None
+
+ table_v2.catalog = MagicMock()
+ table_v2.catalog.commit_table.return_value =
_mock_commit_response(table_v2)
+
+ table_v2.manage_snapshots().set_current_snapshot(ref_name="main").commit()
+
+ updates = _get_updates(table_v2.catalog)
+ set_ref_updates = [u for u in updates if isinstance(u,
SetSnapshotRefUpdate)]
+
+ assert len(set_ref_updates) == 1
+ assert set_ref_updates[0].snapshot_id == current_snapshot.snapshot_id
+ assert set_ref_updates[0].ref_name == "main"
+
+
+def test_set_current_snapshot_unknown_ref(table_v2: Table) -> None:
+ table_v2.catalog = MagicMock()
+
+ with pytest.raises(ValueError, match="Cannot find matching snapshot ID for
ref: nonexistent"):
+
table_v2.manage_snapshots().set_current_snapshot(ref_name="nonexistent").commit()
+
+ table_v2.catalog.commit_table.assert_not_called()
+
+
+def test_set_current_snapshot_requires_one_argument(table_v2: Table) -> None:
+ table_v2.catalog = MagicMock()
+
+ with pytest.raises(ValueError, match="Either snapshot_id or ref_name must
be provided, not both"):
+ table_v2.manage_snapshots().set_current_snapshot().commit()
+
+ with pytest.raises(ValueError, match="Either snapshot_id or ref_name must
be provided, not both"):
+ table_v2.manage_snapshots().set_current_snapshot(snapshot_id=123,
ref_name="main").commit()
+
+ table_v2.catalog.commit_table.assert_not_called()
+
+
+def test_set_current_snapshot_chained_with_create_tag(table_v2: Table) -> None:
+ snapshot_one = 3051729675574597004
+ table_v2.catalog = MagicMock()
+ table_v2.catalog.commit_table.return_value =
_mock_commit_response(table_v2)
+
+ # create a tag and immediately use it to set current snapshot
+ (
+ table_v2.manage_snapshots()
+ .create_tag(snapshot_id=snapshot_one, tag_name="new-tag")
+ .set_current_snapshot(ref_name="new-tag")
+ .commit()
+ )
+
+ table_v2.catalog.commit_table.assert_called_once()
+
+ updates = _get_updates(table_v2.catalog)
+ set_ref_updates = [u for u in updates if isinstance(u,
SetSnapshotRefUpdate)]
+
+ # should have the tag and the main branch update
+ assert len(set_ref_updates) == 2
+ assert {u.ref_name for u in set_ref_updates} == {"new-tag", "main"}
+
+ # The main branch should point to the same snapshot as the tag
+ main_update = next(u for u in set_ref_updates if u.ref_name == "main")
+ assert main_update.snapshot_id == snapshot_one