This is an automated email from the ASF dual-hosted git repository.
laskoviymishka pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-go.git
The following commit(s) were added to refs/heads/main by this push:
new eee234ea feat(cli): add upgrade and rollback commands (#1071)
eee234ea is described below
commit eee234ea08b27ad35a6d8f76fd2e9812fb3f35fd
Author: Tanmay Rauth <[email protected]>
AuthorDate: Wed May 13 15:05:15 2026 -0700
feat(cli): add upgrade and rollback commands (#1071)
Add `iceberg upgrade TABLE_ID VERSION` with --dry-run and --yes. Add
`iceberg rollback TABLE_ID --snapshot-id ID` with --yes.
Related: #957
Depends On: #1073
---
cmd/iceberg/upgrade_rollback.go | 173 ++++++++++++++-
cmd/iceberg/upgrade_rollback_test.go | 410 +++++++++++++++++++++++++++++++++++
table/transaction.go | 23 ++
3 files changed, 595 insertions(+), 11 deletions(-)
diff --git a/cmd/iceberg/upgrade_rollback.go b/cmd/iceberg/upgrade_rollback.go
index bfc7fc59..c05abcdb 100644
--- a/cmd/iceberg/upgrade_rollback.go
+++ b/cmd/iceberg/upgrade_rollback.go
@@ -19,23 +19,174 @@ package main
import (
"context"
- "errors"
+ "encoding/json"
+ "fmt"
"os"
+ "strconv"
"github.com/apache/iceberg-go/catalog"
+ "github.com/apache/iceberg-go/table"
+ "github.com/pterm/pterm"
)
-func runUpgrade(_ context.Context, output Output, _ catalog.Catalog, _
*UpgradeCmd) {
- output.Error(errors.New("upgrade: not yet implemented"))
- os.Exit(1)
+var osExit = os.Exit
+
+func runUpgrade(ctx context.Context, output Output, cat catalog.Catalog, cmd
*UpgradeCmd) {
+ tbl := loadTable(ctx, output, cat, cmd.TableID)
+ meta := tbl.Metadata()
+ currentVersion := meta.Version()
+
+ result := UpgradeResult{
+ DryRun: cmd.DryRun,
+ Table: tableIDString(tbl),
+ PreviousVersion: currentVersion,
+ TargetVersion: cmd.FormatVersion,
+ SpecURL: specURL(cmd.FormatVersion),
+ }
+
+ if cmd.DryRun {
+ output.UpgradeResult(result)
+
+ return
+ }
+
+ if cmd.FormatVersion <= currentVersion {
+ output.Error(fmt.Errorf("target format version %d must be
greater than current version %d",
+ cmd.FormatVersion, currentVersion))
+ osExit(1)
+
+ return
+ }
+
+ prompt := fmt.Sprintf("Upgrade %s from format version %d to %d?",
+ tableIDString(tbl), currentVersion, cmd.FormatVersion)
+ if err := confirmAction(prompt, cmd.Yes); err != nil {
+ output.Error(err)
+ osExit(1)
+
+ return
+ }
+
+ tx := tbl.NewTransaction()
+ if err := tx.UpgradeFormatVersion(cmd.FormatVersion); err != nil {
+ output.Error(fmt.Errorf("upgrade failed: %w", err))
+ osExit(1)
+
+ return
+ }
+
+ if _, err := tx.Commit(ctx); err != nil {
+ output.Error(fmt.Errorf("commit failed: %w", err))
+ osExit(1)
+
+ return
+ }
+
+ output.UpgradeResult(result)
}
-func runRollback(_ context.Context, output Output, _ catalog.Catalog, _
*RollbackCmd) {
- output.Error(errors.New("rollback: not yet implemented"))
- os.Exit(1)
+func runRollback(ctx context.Context, output Output, cat catalog.Catalog, cmd
*RollbackCmd) {
+ tbl := loadTable(ctx, output, cat, cmd.TableID)
+ meta := tbl.Metadata()
+
+ snap := meta.SnapshotByID(cmd.SnapshotID)
+ if snap == nil {
+ output.Error(fmt.Errorf("snapshot %d not found in table %s",
cmd.SnapshotID, tableIDString(tbl)))
+ osExit(1)
+
+ return
+ }
+
+ if cs := meta.CurrentSnapshot(); cs != nil {
+ if !table.IsAncestorOf(cs.SnapshotID, cmd.SnapshotID,
meta.SnapshotByID) {
+ output.Error(fmt.Errorf("snapshot %d is not an ancestor
of current snapshot %d", cmd.SnapshotID, cs.SnapshotID))
+ osExit(1)
+
+ return
+ }
+ }
+
+ var previousSnapshotID *int64
+ if cs := meta.CurrentSnapshot(); cs != nil {
+ id := cs.SnapshotID
+ previousSnapshotID = &id
+ }
+
+ prompt := fmt.Sprintf("Roll back %s to snapshot %d?",
tableIDString(tbl), cmd.SnapshotID)
+ if err := confirmAction(prompt, cmd.Yes); err != nil {
+ output.Error(err)
+ osExit(1)
+
+ return
+ }
+
+ tx := tbl.NewTransaction()
+ if err := tx.RollbackToSnapshot(cmd.SnapshotID); err != nil {
+ output.Error(fmt.Errorf("rollback failed: %w", err))
+ osExit(1)
+
+ return
+ }
+
+ if _, err := tx.Commit(ctx); err != nil {
+ output.Error(fmt.Errorf("commit failed: %w", err))
+ osExit(1)
+
+ return
+ }
+
+ result := RollbackResult{
+ Table: tableIDString(tbl),
+ PreviousSnapshotID: previousSnapshotID,
+ RolledBackToSnapshotID: cmd.SnapshotID,
+ }
+
+ output.RollbackResult(result)
+}
+
+func specURL(version int) string {
+ switch version {
+ case 1:
+ return
"https://iceberg.apache.org/spec/#version-1-analytic-data-tables"
+ case 2:
+ return
"https://iceberg.apache.org/spec/#version-2-row-level-deletes"
+ case 3:
+ return
"https://iceberg.apache.org/spec/#version-3-extended-types-and-features"
+ default:
+ return "https://iceberg.apache.org/spec/"
+ }
}
-func (textOutput) UpgradeResult(_ UpgradeResult) {}
-func (jsonOutput) UpgradeResult(_ UpgradeResult) {}
-func (textOutput) RollbackResult(_ RollbackResult) {}
-func (jsonOutput) RollbackResult(_ RollbackResult) {}
+func (textOutput) UpgradeResult(result UpgradeResult) {
+ if result.DryRun {
+ pterm.Printfln("[DRY RUN] Would upgrade %s from format version
%d to %d.",
+ result.Table, result.PreviousVersion,
result.TargetVersion)
+ } else {
+ pterm.Printfln("Upgraded %s from format version %d to %d.",
+ result.Table, result.PreviousVersion,
result.TargetVersion)
+ }
+
+ pterm.Printfln("Spec: %s", result.SpecURL)
+}
+
+func (j jsonOutput) UpgradeResult(result UpgradeResult) {
+ if err := json.NewEncoder(os.Stdout).Encode(result); err != nil {
+ j.Error(err)
+ }
+}
+
+func (textOutput) RollbackResult(result RollbackResult) {
+ prev := "none"
+ if result.PreviousSnapshotID != nil {
+ prev = strconv.FormatInt(*result.PreviousSnapshotID, 10)
+ }
+
+ pterm.Printfln("Rolled back %s to snapshot %d (previous: %s).",
+ result.Table, result.RolledBackToSnapshotID, prev)
+}
+
+func (j jsonOutput) RollbackResult(result RollbackResult) {
+ if err := json.NewEncoder(os.Stdout).Encode(result); err != nil {
+ j.Error(err)
+ }
+}
diff --git a/cmd/iceberg/upgrade_rollback_test.go
b/cmd/iceberg/upgrade_rollback_test.go
new file mode 100644
index 00000000..881f2e97
--- /dev/null
+++ b/cmd/iceberg/upgrade_rollback_test.go
@@ -0,0 +1,410 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package main
+
+import (
+ "bytes"
+ "context"
+ "iter"
+ "os"
+ "testing"
+
+ "github.com/apache/iceberg-go"
+ "github.com/apache/iceberg-go/catalog"
+ "github.com/apache/iceberg-go/table"
+ "github.com/pterm/pterm"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestSpecURL(t *testing.T) {
+ assert.Contains(t, specURL(1), "version-1")
+ assert.Contains(t, specURL(2), "version-2")
+ assert.Contains(t, specURL(3), "version-3")
+ assert.Equal(t, "https://iceberg.apache.org/spec/", specURL(99))
+}
+
+func TestTextOutputUpgradeResultDryRun(t *testing.T) {
+ var buf bytes.Buffer
+ pterm.SetDefaultOutput(&buf)
+ pterm.DisableColor()
+ t.Cleanup(func() { pterm.SetDefaultOutput(os.Stdout);
pterm.EnableColor() })
+
+ result := UpgradeResult{
+ DryRun: true,
+ Table: "db.events",
+ PreviousVersion: 1,
+ TargetVersion: 2,
+ SpecURL: specURL(2),
+ }
+
+ buf.Reset()
+ textOutput{}.UpgradeResult(result)
+
+ output := buf.String()
+ assert.Contains(t, output, "[DRY RUN]")
+ assert.Contains(t, output, "format version 1 to 2")
+ assert.Contains(t, output, "version-2")
+}
+
+func TestTextOutputUpgradeResultCommitted(t *testing.T) {
+ var buf bytes.Buffer
+ pterm.SetDefaultOutput(&buf)
+ pterm.DisableColor()
+ t.Cleanup(func() { pterm.SetDefaultOutput(os.Stdout);
pterm.EnableColor() })
+
+ result := UpgradeResult{
+ DryRun: false,
+ Table: "db.events",
+ PreviousVersion: 1,
+ TargetVersion: 2,
+ SpecURL: specURL(2),
+ }
+
+ buf.Reset()
+ textOutput{}.UpgradeResult(result)
+
+ output := buf.String()
+ assert.Contains(t, output, "Upgraded db.events")
+ assert.Contains(t, output, "format version 1 to 2")
+}
+
+func TestJSONOutputUpgradeResult(t *testing.T) {
+ r, w, err := os.Pipe()
+ require.NoError(t, err)
+ oldStdout := os.Stdout
+ os.Stdout = w
+ t.Cleanup(func() { w.Close(); os.Stdout = oldStdout })
+
+ result := UpgradeResult{
+ DryRun: true,
+ Table: "db.events",
+ PreviousVersion: 1,
+ TargetVersion: 2,
+ SpecURL:
"https://iceberg.apache.org/spec/#version-2-row-level-deletes",
+ }
+
+ jsonOutput{}.UpgradeResult(result)
+
+ w.Close()
+ var buf bytes.Buffer
+ _, _ = buf.ReadFrom(r)
+
+ output := buf.String()
+ assert.Contains(t, output, `"dry_run":true`)
+ assert.Contains(t, output, `"previous_version":1`)
+ assert.Contains(t, output, `"target_version":2`)
+}
+
+func TestTextOutputRollbackResult(t *testing.T) {
+ var buf bytes.Buffer
+ pterm.SetDefaultOutput(&buf)
+ pterm.DisableColor()
+ t.Cleanup(func() { pterm.SetDefaultOutput(os.Stdout);
pterm.EnableColor() })
+
+ prevID := int64(100)
+ result := RollbackResult{
+ Table: "db.events",
+ PreviousSnapshotID: &prevID,
+ RolledBackToSnapshotID: 50,
+ }
+
+ buf.Reset()
+ textOutput{}.RollbackResult(result)
+
+ output := buf.String()
+ assert.Contains(t, output, "Rolled back db.events to snapshot 50")
+ assert.Contains(t, output, "previous: 100")
+}
+
+func TestTextOutputRollbackResultNoPrevious(t *testing.T) {
+ var buf bytes.Buffer
+ pterm.SetDefaultOutput(&buf)
+ pterm.DisableColor()
+ t.Cleanup(func() { pterm.SetDefaultOutput(os.Stdout);
pterm.EnableColor() })
+
+ result := RollbackResult{
+ Table: "db.events",
+ PreviousSnapshotID: nil,
+ RolledBackToSnapshotID: 50,
+ }
+
+ buf.Reset()
+ textOutput{}.RollbackResult(result)
+
+ output := buf.String()
+ assert.Contains(t, output, "previous: none")
+}
+
+func TestJSONOutputRollbackResult(t *testing.T) {
+ r, w, err := os.Pipe()
+ require.NoError(t, err)
+ oldStdout := os.Stdout
+ os.Stdout = w
+ t.Cleanup(func() { w.Close(); os.Stdout = oldStdout })
+
+ prevID := int64(100)
+ result := RollbackResult{
+ Table: "db.events",
+ PreviousSnapshotID: &prevID,
+ RolledBackToSnapshotID: 50,
+ }
+
+ jsonOutput{}.RollbackResult(result)
+
+ w.Close()
+ var buf bytes.Buffer
+ _, _ = buf.ReadFrom(r)
+
+ output := buf.String()
+ assert.Contains(t, output, `"table":"db.events"`)
+ assert.Contains(t, output, `"previous_snapshot_id":100`)
+ assert.Contains(t, output, `"rolled_back_to_snapshot_id":50`)
+}
+
+const upgradeRollbackTestMetadata = `{
+ "format-version": 1,
+ "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1",
+ "location": "s3://bucket/test/location",
+ "last-updated-ms": 1602638573590,
+ "last-column-id": 3,
+ "schemas": [
+ {
+ "type": "struct",
+ "schema-id": 0,
+ "fields": [
+ {"id": 1, "name": "x", "required": true, "type": "long"}
+ ]
+ }
+ ],
+ "current-schema-id": 0,
+ "partition-specs": [{"spec-id": 0, "fields": []}],
+ "default-spec-id": 0,
+ "last-partition-id": 999,
+ "sort-orders": [{"order-id": 0, "fields": []}],
+ "default-sort-order-id": 0,
+ "current-snapshot-id": 200,
+ "snapshots": [
+ {
+ "snapshot-id": 100,
+ "timestamp-ms": 1515100955770,
+ "summary": {"operation": "append"},
+ "manifest-list": "s3://a/b/1.avro"
+ },
+ {
+ "snapshot-id": 200,
+ "parent-snapshot-id": 100,
+ "timestamp-ms": 1555100955770,
+ "summary": {"operation": "append"},
+ "manifest-list": "s3://a/b/2.avro"
+ }
+ ],
+ "snapshot-log": [
+ {"snapshot-id": 100, "timestamp-ms": 1515100955770},
+ {"snapshot-id": 200, "timestamp-ms": 1555100955770}
+ ]
+}`
+
+type panicCatalog struct {
+ tbl *table.Table
+}
+
+func (p *panicCatalog) CatalogType() catalog.Type { return "panic" }
+func (p *panicCatalog) CreateTable(context.Context, table.Identifier,
*iceberg.Schema, ...catalog.CreateTableOpt) (*table.Table, error) {
+ panic("CreateTable must not be called")
+}
+
+func (p *panicCatalog) CommitTable(context.Context, table.Identifier,
[]table.Requirement, []table.Update) (table.Metadata, string, error) {
+ panic("CommitTable must not be called")
+}
+
+func (p *panicCatalog) ListTables(context.Context, table.Identifier)
iter.Seq2[table.Identifier, error] {
+ panic("ListTables must not be called")
+}
+
+func (p *panicCatalog) LoadTable(_ context.Context, _ table.Identifier)
(*table.Table, error) {
+ return p.tbl, nil
+}
+
+func (p *panicCatalog) DropTable(context.Context, table.Identifier) error {
+ panic("DropTable must not be called")
+}
+
+func (p *panicCatalog) RenameTable(context.Context, table.Identifier,
table.Identifier) (*table.Table, error) {
+ panic("RenameTable must not be called")
+}
+
+func (p *panicCatalog) CheckTableExists(context.Context, table.Identifier)
(bool, error) {
+ panic("CheckTableExists must not be called")
+}
+
+func (p *panicCatalog) ListNamespaces(context.Context, table.Identifier)
([]table.Identifier, error) {
+ panic("ListNamespaces must not be called")
+}
+
+func (p *panicCatalog) CreateNamespace(context.Context, table.Identifier,
iceberg.Properties) error {
+ panic("CreateNamespace must not be called")
+}
+
+func (p *panicCatalog) DropNamespace(context.Context, table.Identifier) error {
+ panic("DropNamespace must not be called")
+}
+
+func (p *panicCatalog) CheckNamespaceExists(context.Context, table.Identifier)
(bool, error) {
+ panic("CheckNamespaceExists must not be called")
+}
+
+func (p *panicCatalog) LoadNamespaceProperties(context.Context,
table.Identifier) (iceberg.Properties, error) {
+ panic("LoadNamespaceProperties must not be called")
+}
+
+func (p *panicCatalog) UpdateNamespaceProperties(context.Context,
table.Identifier, []string, iceberg.Properties)
(catalog.PropertiesUpdateSummary, error) {
+ panic("UpdateNamespaceProperties must not be called")
+}
+
+func TestRunUpgradeDryRunDoesNotCommit(t *testing.T) {
+ meta, err :=
table.ParseMetadataBytes([]byte(upgradeRollbackTestMetadata))
+ require.NoError(t, err)
+
+ tbl := table.New([]string{"db", "events"}, meta, "", nil, nil)
+ cat := &panicCatalog{tbl: tbl}
+
+ var buf bytes.Buffer
+ pterm.SetDefaultOutput(&buf)
+ pterm.DisableColor()
+ t.Cleanup(func() { pterm.SetDefaultOutput(os.Stdout);
pterm.EnableColor() })
+
+ runUpgrade(context.Background(), textOutput{}, cat, &UpgradeCmd{
+ TableID: "db.events",
+ FormatVersion: 2,
+ DryRun: true,
+ })
+
+ output := buf.String()
+ assert.Contains(t, output, "[DRY RUN]")
+ assert.Contains(t, output, "format version 1 to 2")
+}
+
+func TestRunUpgradeDryRunAlreadyAtVersion(t *testing.T) {
+ meta, err :=
table.ParseMetadataBytes([]byte(upgradeRollbackTestMetadata))
+ require.NoError(t, err)
+
+ tbl := table.New([]string{"db", "events"}, meta, "", nil, nil)
+ cat := &panicCatalog{tbl: tbl}
+
+ var buf bytes.Buffer
+ pterm.SetDefaultOutput(&buf)
+ pterm.DisableColor()
+ t.Cleanup(func() { pterm.SetDefaultOutput(os.Stdout);
pterm.EnableColor() })
+
+ runUpgrade(context.Background(), textOutput{}, cat, &UpgradeCmd{
+ TableID: "db.events",
+ FormatVersion: 1,
+ DryRun: true,
+ })
+
+ output := buf.String()
+ assert.Contains(t, output, "[DRY RUN]")
+ assert.Contains(t, output, "format version 1 to 1")
+}
+
+func TestConfirmActionNonTTYWithoutYes(t *testing.T) {
+ err := confirmAction("do something?", false)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "stdin is not a terminal")
+}
+
+func TestRunRollbackRejectsNonAncestor(t *testing.T) {
+ const metaWithBranch = `{
+ "format-version": 2,
+ "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1",
+ "location": "s3://bucket/test/location",
+ "last-sequence-number": 2,
+ "last-updated-ms": 1602638573590,
+ "last-column-id": 1,
+ "schemas": [{"type": "struct", "schema-id": 0, "fields": [{"id": 1,
"name": "x", "required": true, "type": "long"}]}],
+ "current-schema-id": 0,
+ "partition-specs": [{"spec-id": 0, "fields": []}],
+ "default-spec-id": 0,
+ "last-partition-id": 999,
+ "sort-orders": [{"order-id": 0, "fields": []}],
+ "default-sort-order-id": 0,
+ "current-snapshot-id": 200,
+ "snapshots": [
+ {"snapshot-id": 100, "timestamp-ms": 1515100955770, "sequence-number":
0, "summary": {"operation": "append"}, "manifest-list": "s3://a/b/1.avro"},
+ {"snapshot-id": 200, "parent-snapshot-id": 100, "timestamp-ms":
1555100955770, "sequence-number": 1, "summary": {"operation": "append"},
"manifest-list": "s3://a/b/2.avro"},
+ {"snapshot-id": 300, "timestamp-ms": 1565100955770, "sequence-number":
2, "summary": {"operation": "append"}, "manifest-list": "s3://a/b/3.avro"}
+ ],
+ "snapshot-log": [
+ {"snapshot-id": 100, "timestamp-ms": 1515100955770},
+ {"snapshot-id": 200, "timestamp-ms": 1555100955770}
+ ],
+ "refs": {"main": {"snapshot-id": 200, "type": "branch"}}
+}`
+
+ meta, err := table.ParseMetadataBytes([]byte(metaWithBranch))
+ require.NoError(t, err)
+
+ tbl := table.New([]string{"db", "events"}, meta, "", nil, nil)
+ cat := &panicCatalog{tbl: tbl}
+
+ var errOut errCapture
+ exitCode := captureExit(func() {
+ runRollback(context.Background(), &errOut, cat, &RollbackCmd{
+ TableID: "db.events",
+ SnapshotID: 300,
+ Yes: true,
+ })
+ })
+
+ assert.Equal(t, 1, exitCode)
+ assert.Contains(t, errOut.lastErr.Error(), "not an ancestor")
+}
+
+type errCapture struct {
+ textOutput
+ lastErr error
+}
+
+func (e *errCapture) Error(err error) {
+ e.lastErr = err
+}
+
+func captureExit(f func()) (exitCode int) {
+ origExit := osExit
+ defer func() { osExit = origExit }()
+
+ osExit = func(code int) {
+ exitCode = code
+ panic(exitSentinel{})
+ }
+
+ defer func() {
+ if r := recover(); r != nil {
+ if _, ok := r.(exitSentinel); !ok {
+ panic(r)
+ }
+ }
+ }()
+
+ f()
+
+ return 0
+}
+
+type exitSentinel struct{}
diff --git a/table/transaction.go b/table/transaction.go
index 2fdc9872..42d1cd61 100644
--- a/table/transaction.go
+++ b/table/transaction.go
@@ -187,6 +187,29 @@ func (t *Transaction) UpgradeFormatVersion(version int)
error {
return t.apply([]Update{NewUpgradeFormatVersionUpdate(version)}, nil)
}
+func (t *Transaction) RollbackToSnapshot(snapshotID int64) error {
+ cs := t.meta.currentSnapshot()
+ if cs == nil {
+ return errors.New("cannot rollback: table has no current
snapshot")
+ }
+
+ lookup := func(id int64) *Snapshot {
+ s, _ := t.meta.SnapshotByID(id)
+
+ return s
+ }
+
+ if !IsAncestorOf(cs.SnapshotID, snapshotID, lookup) {
+ return fmt.Errorf("snapshot %d is not an ancestor of current
snapshot %d",
+ snapshotID, cs.SnapshotID)
+ }
+
+ update := NewSetSnapshotRefUpdate(MainBranch, snapshotID, BranchRef, 0,
0, 0)
+ req := AssertRefSnapshotID(MainBranch, &cs.SnapshotID)
+
+ return t.apply([]Update{update}, []Requirement{req})
+}
+
func (t *Transaction) UpdateSpec(caseSensitive bool) *UpdateSpec {
return NewUpdateSpec(t, caseSensitive)
}