This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 7d0d0e9075 [spark] support rescale procedure in spark (#6612)
7d0d0e9075 is described below

commit 7d0d0e907550bc05c930f5bc602c5d987f60e6ab
Author: XiaoHongbo <[email protected]>
AuthorDate: Wed Nov 19 17:05:52 2025 +0800

    [spark] support rescale procedure in spark (#6612)
---
 .../org/apache/paimon/spark/SparkProcedures.java   |   2 +
 .../paimon/spark/procedure/CompactProcedure.java   |  22 +-
 .../paimon/spark/procedure/RescaleProcedure.java   | 206 +++++++++++++
 .../paimon/spark/utils/SparkProcedureUtils.java    |  19 ++
 .../spark/procedure/CompactProcedureTestBase.scala |   3 +-
 .../spark/procedure/RescaleProcedureTest.scala     | 333 +++++++++++++++++++++
 6 files changed, 563 insertions(+), 22 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
index ece6fe66c7..b5db4192b1 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
+++ 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
@@ -45,6 +45,7 @@ import 
org.apache.paimon.spark.procedure.RemoveUnexistingFilesProcedure;
 import org.apache.paimon.spark.procedure.RenameTagProcedure;
 import org.apache.paimon.spark.procedure.RepairProcedure;
 import org.apache.paimon.spark.procedure.ReplaceTagProcedure;
+import org.apache.paimon.spark.procedure.RescaleProcedure;
 import org.apache.paimon.spark.procedure.ResetConsumerProcedure;
 import org.apache.paimon.spark.procedure.RewriteFileIndexProcedure;
 import org.apache.paimon.spark.procedure.RollbackProcedure;
@@ -92,6 +93,7 @@ public class SparkProcedures {
         procedureBuilders.put("create_branch", CreateBranchProcedure::builder);
         procedureBuilders.put("delete_branch", DeleteBranchProcedure::builder);
         procedureBuilders.put("compact", CompactProcedure::builder);
+        procedureBuilders.put("rescale", RescaleProcedure::builder);
         procedureBuilders.put("migrate_database", 
MigrateDatabaseProcedure::builder);
         procedureBuilders.put("migrate_table", MigrateTableProcedure::builder);
         procedureBuilders.put("remove_orphan_files", 
RemoveOrphanFilesProcedure::builder);
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java
 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java
index 56dcebe954..8e98727332 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java
+++ 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java
@@ -20,7 +20,6 @@ package org.apache.paimon.spark.procedure;
 
 import org.apache.paimon.CoreOptions;
 import org.apache.paimon.CoreOptions.OrderType;
-import org.apache.paimon.annotation.VisibleForTesting;
 import org.apache.paimon.append.AppendCompactCoordinator;
 import org.apache.paimon.append.AppendCompactTask;
 import org.apache.paimon.append.cluster.IncrementalClusterManager;
@@ -54,7 +53,6 @@ import org.apache.paimon.table.source.DataSplit;
 import org.apache.paimon.table.source.EndOfScanException;
 import org.apache.paimon.table.source.snapshot.SnapshotReader;
 import org.apache.paimon.utils.Pair;
-import org.apache.paimon.utils.ParameterUtils;
 import org.apache.paimon.utils.ProcedureUtils;
 import org.apache.paimon.utils.SerializationUtils;
 import org.apache.paimon.utils.StringUtils;
@@ -90,7 +88,6 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -183,7 +180,7 @@ public class CompactProcedure extends BaseProcedure {
         checkArgument(
                 partitions == null || where == null,
                 "partitions and where cannot be used together.");
-        String finalWhere = partitions != null ? toWhere(partitions) : where;
+        String finalWhere = partitions != null ? 
SparkProcedureUtils.toWhere(partitions) : where;
         return modifyPaimonTable(
                 tableIdent,
                 t -> {
@@ -644,23 +641,6 @@ public class CompactProcedure extends BaseProcedure {
                                         list -> list.toArray(new 
DataSplit[0]))));
     }
 
-    @VisibleForTesting
-    static String toWhere(String partitions) {
-        List<Map<String, String>> maps = 
ParameterUtils.getPartitions(partitions.split(";"));
-
-        return maps.stream()
-                .map(
-                        a ->
-                                a.entrySet().stream()
-                                        .map(entry -> entry.getKey() + "=" + 
entry.getValue())
-                                        .reduce((s0, s1) -> s0 + " AND " + s1))
-                .filter(Optional::isPresent)
-                .map(Optional::get)
-                .map(a -> "(" + a + ")")
-                .reduce((a, b) -> a + " OR " + b)
-                .orElse(null);
-    }
-
     public static ProcedureBuilder builder() {
         return new BaseProcedure.Builder<CompactProcedure>() {
             @Override
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/RescaleProcedure.java
 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/RescaleProcedure.java
new file mode 100644
index 0000000000..8f8a8874be
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/RescaleProcedure.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.procedure;
+
+import org.apache.paimon.CoreOptions;
+import org.apache.paimon.Snapshot;
+import org.apache.paimon.partition.PartitionPredicate;
+import org.apache.paimon.spark.commands.PaimonSparkWriter;
+import org.apache.paimon.spark.util.ScanPlanHelper$;
+import org.apache.paimon.spark.utils.SparkProcedureUtils;
+import org.apache.paimon.table.BucketMode;
+import org.apache.paimon.table.FileStoreTable;
+import org.apache.paimon.table.source.DataSplit;
+import org.apache.paimon.table.source.snapshot.SnapshotReader;
+import org.apache.paimon.utils.StringUtils;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.PaimonUtils;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.catalog.Identifier;
+import org.apache.spark.sql.connector.catalog.TableCatalog;
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static org.apache.paimon.utils.Preconditions.checkArgument;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+import static org.apache.spark.sql.types.DataTypes.StringType;
+
+/**
+ * Rescale procedure. Usage:
+ *
+ * <pre><code>
+ *  CALL sys.rescale(table => 'databaseName.tableName', [bucket_num => 16], 
[partitions => 'dt=20250217,hh=08;dt=20250218,hh=08'], [where => 'dt>20250217'])
+ * </code></pre>
+ */
+public class RescaleProcedure extends BaseProcedure {
+
+    private static final Logger LOG = 
LoggerFactory.getLogger(RescaleProcedure.class);
+
+    private static final ProcedureParameter[] PARAMETERS =
+            new ProcedureParameter[] {
+                ProcedureParameter.required("table", StringType),
+                ProcedureParameter.optional("bucket_num", IntegerType),
+                ProcedureParameter.optional("partitions", StringType),
+                ProcedureParameter.optional("where", StringType),
+            };
+
+    private static final StructType OUTPUT_TYPE =
+            new StructType(
+                    new StructField[] {
+                        new StructField("result", DataTypes.BooleanType, true, 
Metadata.empty())
+                    });
+
+    protected RescaleProcedure(TableCatalog tableCatalog) {
+        super(tableCatalog);
+    }
+
+    @Override
+    public ProcedureParameter[] parameters() {
+        return PARAMETERS;
+    }
+
+    @Override
+    public StructType outputType() {
+        return OUTPUT_TYPE;
+    }
+
+    @Override
+    public InternalRow[] call(InternalRow args) {
+        Identifier tableIdent = toIdentifier(args.getString(0), 
PARAMETERS[0].name());
+        Integer bucketNum = args.isNullAt(1) ? null : args.getInt(1);
+        String partitions = blank(args, 2) ? null : args.getString(2);
+        String where = blank(args, 3) ? null : args.getString(3);
+
+        checkArgument(
+                partitions == null || where == null,
+                "partitions and where cannot be used together.");
+        String finalWhere = partitions != null ? 
SparkProcedureUtils.toWhere(partitions) : where;
+
+        return modifyPaimonTable(
+                tableIdent,
+                table -> {
+                    checkArgument(table instanceof FileStoreTable);
+                    FileStoreTable fileStoreTable = (FileStoreTable) table;
+
+                    Optional<Snapshot> optionalSnapshot = 
fileStoreTable.latestSnapshot();
+                    if (!optionalSnapshot.isPresent()) {
+                        throw new IllegalArgumentException(
+                                "Table "
+                                        + table.fullName()
+                                        + " has no snapshot, no need to 
rescale.");
+                    }
+                    Snapshot snapshot = optionalSnapshot.get();
+
+                    // If someone commits while the rescale job is running, 
this commit will be
+                    // lost.
+                    // So we use strict mode to make sure nothing is lost.
+                    Map<String, String> dynamicOptions = new HashMap<>();
+                    dynamicOptions.put(
+                            
CoreOptions.COMMIT_STRICT_MODE_LAST_SAFE_SNAPSHOT.key(),
+                            String.valueOf(snapshot.id()));
+                    fileStoreTable = fileStoreTable.copy(dynamicOptions);
+
+                    DataSourceV2Relation relation = createRelation(tableIdent);
+                    PartitionPredicate partitionPredicate =
+                            SparkProcedureUtils.convertToPartitionPredicate(
+                                    finalWhere,
+                                    
fileStoreTable.schema().logicalPartitionType(),
+                                    spark(),
+                                    relation);
+
+                    if (bucketNum == null) {
+                        checkArgument(
+                                fileStoreTable.coreOptions().bucket() != 
BucketMode.POSTPONE_BUCKET,
+                                "When rescaling postpone bucket tables, you 
must provide the resulting bucket number.");
+                    }
+
+                    execute(fileStoreTable, bucketNum, partitionPredicate, 
tableIdent);
+
+                    InternalRow internalRow = newInternalRow(true);
+                    return new InternalRow[] {internalRow};
+                });
+    }
+
+    private void execute(
+            FileStoreTable table,
+            @Nullable Integer bucketNum,
+            PartitionPredicate partitionPredicate,
+            Identifier tableIdent) {
+        DataSourceV2Relation relation = createRelation(tableIdent);
+
+        SnapshotReader snapshotReader = table.newSnapshotReader();
+        if (partitionPredicate != null) {
+            snapshotReader = 
snapshotReader.withPartitionFilter(partitionPredicate);
+        }
+        List<DataSplit> dataSplits = snapshotReader.read().dataSplits();
+
+        if (dataSplits.isEmpty()) {
+            LOG.info("No data splits found for the specified partition. No 
need to rescale.");
+            return;
+        }
+
+        Dataset<Row> datasetForRead =
+                PaimonUtils.createDataset(
+                        spark(),
+                        ScanPlanHelper$.MODULE$.createNewScanPlan(
+                                dataSplits.toArray(new DataSplit[0]), 
relation));
+
+        Map<String, String> bucketOptions = new HashMap<>(table.options());
+        if (bucketNum != null) {
+            bucketOptions.put(CoreOptions.BUCKET.key(), 
String.valueOf(bucketNum));
+        }
+        FileStoreTable rescaledTable = 
table.copy(table.schema().copy(bucketOptions));
+
+        PaimonSparkWriter writer = PaimonSparkWriter.apply(rescaledTable);
+        writer.writeBuilder().withOverwrite();
+        writer.commit(writer.write(datasetForRead));
+    }
+
+    private boolean blank(InternalRow args, int index) {
+        return args.isNullAt(index) || 
StringUtils.isNullOrWhitespaceOnly(args.getString(index));
+    }
+
+    @Override
+    public String description() {
+        return "This procedure rescales partitions of a table by changing the 
bucket number.";
+    }
+
+    public static ProcedureBuilder builder() {
+        return new BaseProcedure.Builder<RescaleProcedure>() {
+            @Override
+            public RescaleProcedure doBuild() {
+                return new RescaleProcedure(tableCatalog());
+            }
+        };
+    }
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/utils/SparkProcedureUtils.java
 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/utils/SparkProcedureUtils.java
index eafb7c9b0a..e4bc86b7a5 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/utils/SparkProcedureUtils.java
+++ 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/utils/SparkProcedureUtils.java
@@ -22,6 +22,7 @@ import org.apache.paimon.partition.PartitionPredicate;
 import org.apache.paimon.predicate.Predicate;
 import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionUtils;
 import org.apache.paimon.types.RowType;
+import org.apache.paimon.utils.ParameterUtils;
 import org.apache.paimon.utils.StringUtils;
 
 import org.apache.spark.sql.SparkSession;
@@ -34,6 +35,8 @@ import org.slf4j.LoggerFactory;
 import javax.annotation.Nullable;
 
 import java.util.List;
+import java.util.Map;
+import java.util.Optional;
 
 import static org.apache.paimon.utils.Preconditions.checkArgument;
 
@@ -88,4 +91,20 @@ public class SparkProcedureUtils {
         }
         return readParallelism;
     }
+
+    public static String toWhere(String partitions) {
+        List<Map<String, String>> maps = 
ParameterUtils.getPartitions(partitions.split(";"));
+
+        return maps.stream()
+                .map(
+                        a ->
+                                a.entrySet().stream()
+                                        .map(entry -> entry.getKey() + "=" + 
entry.getValue())
+                                        .reduce((s0, s1) -> s0 + " AND " + s1))
+                .filter(Optional::isPresent)
+                .map(Optional::get)
+                .map(a -> "(" + a + ")")
+                .reduce((a, b) -> a + " OR " + b)
+                .orElse(null);
+    }
 }
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala
index 74f80befed..81e475e983 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala
@@ -21,6 +21,7 @@ package org.apache.paimon.spark.procedure
 import org.apache.paimon.Snapshot.CommitKind
 import org.apache.paimon.fs.Path
 import org.apache.paimon.spark.PaimonSparkTestBase
+import org.apache.paimon.spark.utils.SparkProcedureUtils
 import org.apache.paimon.table.FileStoreTable
 import org.apache.paimon.table.source.DataSplit
 
@@ -591,7 +592,7 @@ abstract class CompactProcedureTestBase extends 
PaimonSparkTestBase with StreamT
   test("Paimon test: toWhere method in CompactProcedure") {
     val conditions = "f0=0,f1=0,f2=0;f0=1,f1=1,f2=1;f0=1,f1=2,f2=2;f3=3"
 
-    val where = CompactProcedure.toWhere(conditions)
+    val where = SparkProcedureUtils.toWhere(conditions)
     val whereExpected =
       "(f0=0 AND f1=0 AND f2=0) OR (f0=1 AND f1=1 AND f2=1) OR (f0=1 AND f1=2 
AND f2=2) OR (f3=3)"
 
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/RescaleProcedureTest.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/RescaleProcedureTest.scala
new file mode 100644
index 0000000000..b2d9af14ad
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/RescaleProcedureTest.scala
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.procedure
+
+import org.apache.paimon.Snapshot.CommitKind
+import org.apache.paimon.partition.PartitionPredicate
+import org.apache.paimon.spark.PaimonSparkTestBase
+import org.apache.paimon.table.FileStoreTable
+
+import org.apache.spark.sql.Row
+import org.assertj.core.api.Assertions
+
+import java.util.{Arrays, Collections, Map => JMap}
+
+import scala.collection.JavaConverters._
+
+/** Tests for the rescale procedure. See [[RescaleProcedure]]. */
+class RescaleProcedureTest extends PaimonSparkTestBase {
+
+  test("Paimon Procedure: rescale basic functionality") {
+    withTable("T") {
+      spark.sql(s"""
+                   |CREATE TABLE T (id INT, value STRING)
+                   |TBLPROPERTIES ('primary-key'='id', 'bucket'='2')
+                   |""".stripMargin)
+
+      val table = loadTable("T")
+      spark.sql(s"INSERT INTO T VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'), 
(5, 'e')")
+
+      val initialData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+      val initialSnapshotId = lastSnapshotId(table)
+
+      // Rescale with explicit bucket_num
+      spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '4')")
+      checkAnswer(spark.sql("CALL sys.rescale(table => 'T', bucket_num => 
4)"), Row(true) :: Nil)
+
+      val reloadedTable = loadTable("T")
+      
Assertions.assertThat(lastSnapshotCommand(reloadedTable)).isEqualTo(CommitKind.OVERWRITE)
+      
Assertions.assertThat(lastSnapshotId(reloadedTable)).isGreaterThan(initialSnapshotId)
+      Assertions.assertThat(getBucketCount(reloadedTable)).isEqualTo(4)
+
+      val afterData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+      
Assertions.assertThat(afterData).containsExactlyElementsOf(Arrays.asList(initialData:
 _*))
+
+      // Rescale without bucket_num (use current bucket)
+      spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '3')")
+      checkAnswer(spark.sql("CALL sys.rescale(table => 'T')"), Row(true) :: 
Nil)
+      Assertions.assertThat(getBucketCount(loadTable("T"))).isEqualTo(3)
+
+      // Rescale with explicit bucket_num
+      spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '4')")
+      checkAnswer(spark.sql("CALL sys.rescale(table => 'T', bucket_num => 
4)"), Row(true) :: Nil)
+      Assertions.assertThat(getBucketCount(loadTable("T"))).isEqualTo(4)
+
+      // Decrease bucket count (4 -> 2)
+      spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '2')")
+      checkAnswer(spark.sql("CALL sys.rescale(table => 'T', bucket_num => 
2)"), Row(true) :: Nil)
+      val reloadedTableAfterDecrease = loadTable("T")
+      
Assertions.assertThat(getBucketCount(reloadedTableAfterDecrease)).isEqualTo(2)
+      
reloadedTableAfterDecrease.newSnapshotReader.read.dataSplits.asScala.toList.foreach(
+        split => Assertions.assertThat(split.bucket()).isLessThan(2))
+
+      // Verify data integrity after bucket decrease
+      val afterDecreaseData = spark.sql("SELECT * FROM T ORDER BY 
id").collect()
+      Assertions
+        .assertThat(afterDecreaseData)
+        .containsExactlyElementsOf(Arrays.asList(initialData: _*))
+    }
+  }
+
+  test("Paimon Procedure: rescale partitioned tables") {
+    withTable("T") {
+      spark.sql(s"""
+                   |CREATE TABLE T (id INT, value STRING, pt STRING, dt 
STRING, hh INT)
+                   |TBLPROPERTIES ('primary-key'='id, pt, dt, hh', 
'bucket'='2')
+                   |PARTITIONED BY (pt, dt, hh)
+                   |""".stripMargin)
+
+      val table = loadTable("T")
+      spark.sql(
+        s"INSERT INTO T VALUES (1, 'a', 'p1', '2024-01-01', 0), (2, 'b', 'p1', 
'2024-01-01', 0)")
+      spark.sql(
+        s"INSERT INTO T VALUES (3, 'c', 'p2', '2024-01-01', 1), (4, 'd', 'p2', 
'2024-01-02', 0)")
+
+      val initialData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+      val initialSnapshotId = lastSnapshotId(table)
+
+      // Rescale single partition field
+      spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '4')")
+      checkAnswer(
+        spark.sql("CALL sys.rescale(table => 'T', bucket_num => 4, partitions 
=> 'pt=\"p1\"')"),
+        Row(true) :: Nil)
+
+      val reloadedTable = loadTable("T")
+      
Assertions.assertThat(lastSnapshotCommand(reloadedTable)).isEqualTo(CommitKind.OVERWRITE)
+      
Assertions.assertThat(lastSnapshotId(reloadedTable)).isGreaterThan(initialSnapshotId)
+
+      val p1Predicate = PartitionPredicate.fromMap(
+        reloadedTable.schema().logicalPartitionType(),
+        Collections.singletonMap("pt", "p1"),
+        reloadedTable.coreOptions().partitionDefaultName())
+      val p1Splits = reloadedTable.newSnapshotReader
+        .withPartitionFilter(p1Predicate)
+        .read
+        .dataSplits
+        .asScala
+        .toList
+      p1Splits.foreach(split => 
Assertions.assertThat(split.bucket()).isLessThan(4))
+
+      // Rescale multiple partition fields
+      val snapshotBeforeTest2 = lastSnapshotId(reloadedTable)
+      checkAnswer(
+        spark.sql(
+          "CALL sys.rescale(table => 'T', bucket_num => 4, partitions => 
'dt=\"2024-01-01\",hh=0')"),
+        Row(true) :: Nil)
+
+      val reloadedTable2 = loadTable("T")
+      
Assertions.assertThat(lastSnapshotCommand(reloadedTable2)).isEqualTo(CommitKind.OVERWRITE)
+      
Assertions.assertThat(lastSnapshotId(reloadedTable2)).isGreaterThan(snapshotBeforeTest2)
+
+      // Rescale empty partition (should not create new snapshot)
+      val snapshotBeforeEmpty = lastSnapshotId(reloadedTable2)
+      checkAnswer(
+        spark.sql("CALL sys.rescale(table => 'T', bucket_num => 4, partitions 
=> 'pt=\"p3\"')"),
+        Row(true) :: Nil)
+      
Assertions.assertThat(lastSnapshotId(loadTable("T"))).isEqualTo(snapshotBeforeEmpty)
+
+      val afterData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+      
Assertions.assertThat(afterData).containsExactlyElementsOf(Arrays.asList(initialData:
 _*))
+    }
+  }
+
+  test("Paimon Procedure: rescale with where clause") {
+    withTable("T") {
+      spark.sql(s"""
+                   |CREATE TABLE T (id INT, value STRING, dt STRING, hh INT)
+                   |TBLPROPERTIES ('primary-key'='id, dt, hh', 'bucket'='2')
+                   |PARTITIONED BY (dt, hh)
+                   |""".stripMargin)
+
+      val table = loadTable("T")
+      spark.sql(s"INSERT INTO T VALUES (1, 'a', '2024-01-01', 0), (2, 'b', 
'2024-01-01', 0)")
+      spark.sql(s"INSERT INTO T VALUES (3, 'c', '2024-01-01', 1), (4, 'd', 
'2024-01-01', 1)")
+      spark.sql(s"INSERT INTO T VALUES (5, 'e', '2024-01-02', 0), (6, 'f', 
'2024-01-02', 1)")
+
+      val initialData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+      val initialSnapshotId = lastSnapshotId(table)
+
+      // Test 1: Rescale with where clause using single partition column
+      spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '4')")
+      checkAnswer(
+        spark.sql(
+          "CALL sys.rescale(table => 'T', bucket_num => 4, where => 'dt = 
\"2024-01-01\"')"),
+        Row(true) :: Nil)
+
+      val reloadedTable = loadTable("T")
+      
Assertions.assertThat(lastSnapshotCommand(reloadedTable)).isEqualTo(CommitKind.OVERWRITE)
+      
Assertions.assertThat(lastSnapshotId(reloadedTable)).isGreaterThan(initialSnapshotId)
+
+      // Test 2: Rescale with where clause using multiple partition columns
+      val snapshotBeforeTest2 = lastSnapshotId(reloadedTable)
+      checkAnswer(
+        spark.sql(
+          "CALL sys.rescale(table => 'T', bucket_num => 4, where => 'dt = 
\"2024-01-01\" AND hh >= 1')"),
+        Row(true) :: Nil)
+
+      val reloadedTable2 = loadTable("T")
+      
Assertions.assertThat(lastSnapshotCommand(reloadedTable2)).isEqualTo(CommitKind.OVERWRITE)
+      
Assertions.assertThat(lastSnapshotId(reloadedTable2)).isGreaterThan(snapshotBeforeTest2)
+
+      // Verify data integrity
+      val afterData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+      
Assertions.assertThat(afterData).containsExactlyElementsOf(Arrays.asList(initialData:
 _*))
+    }
+  }
+
+  test("Paimon Procedure: rescale with ALTER TABLE and write validation") {
+    withTable("T") {
+      spark.sql(s"""
+                   |CREATE TABLE T (f0 INT)
+                   |TBLPROPERTIES ('bucket'='2', 'bucket-key'='f0')
+                   |""".stripMargin)
+
+      val table = loadTable("T")
+
+      spark.sql(s"INSERT INTO T VALUES (1), (2), (3), (4), (5)")
+
+      val snapshot = lastSnapshotId(table)
+      Assertions.assertThat(snapshot).isGreaterThanOrEqualTo(0)
+
+      val initialBuckets = getBucketCount(table)
+      Assertions.assertThat(initialBuckets).isEqualTo(2)
+
+      val initialData = spark.sql("SELECT * FROM T ORDER BY f0").collect()
+      Assertions.assertThat(initialData.length).isEqualTo(5)
+
+      spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '4')")
+
+      val reloadedTable = loadTable("T")
+      val newBuckets = getBucketCount(reloadedTable)
+      Assertions.assertThat(newBuckets).isEqualTo(4)
+
+      val afterAlterData = spark.sql("SELECT * FROM T ORDER BY f0").collect()
+      val initialDataList = Arrays.asList(initialData: _*)
+      
Assertions.assertThat(afterAlterData).containsExactlyElementsOf(initialDataList)
+
+      val writeError = intercept[org.apache.spark.SparkException] {
+        spark.sql("INSERT INTO T VALUES (6)")
+      }
+      val errorMessage = Option(writeError.getMessage).getOrElse("")
+      val causeMessage =
+        Option(writeError.getCause).flatMap(c => 
Option(c.getMessage)).getOrElse("")
+      val expectedMessage =
+        "Try to write table with a new bucket num 4, but the previous bucket 
num is 2"
+      val fullMessage = Seq(errorMessage, 
causeMessage).filter(_.nonEmpty).mkString(" ")
+      Assertions
+        .assertThat(fullMessage)
+        .contains(expectedMessage)
+
+      checkAnswer(spark.sql("CALL sys.rescale(table => 'T', bucket_num => 
4)"), Row(true) :: Nil)
+
+      val finalTable = loadTable("T")
+      val finalSnapshot = lastSnapshotId(finalTable)
+      Assertions.assertThat(finalSnapshot).isGreaterThan(snapshot)
+      
Assertions.assertThat(lastSnapshotCommand(finalTable)).isEqualTo(CommitKind.OVERWRITE)
+
+      val finalBuckets = getBucketCount(finalTable)
+      Assertions.assertThat(finalBuckets).isEqualTo(4)
+
+      val afterRescaleData = spark.sql("SELECT * FROM T ORDER BY f0").collect()
+      
Assertions.assertThat(afterRescaleData).containsExactlyElementsOf(initialDataList)
+
+      spark.sql("INSERT INTO T VALUES (6)")
+      val finalData = spark.sql("SELECT * FROM T ORDER BY f0").collect()
+      Assertions.assertThat(finalData.length).isEqualTo(6)
+    }
+  }
+
+  test("Paimon Procedure: rescale error cases") {
+    // Table with no snapshot
+    withTable("T1") {
+      spark.sql(s"""
+                   |CREATE TABLE T1 (id INT, value STRING)
+                   |TBLPROPERTIES ('primary-key'='id', 'bucket'='2')
+                   |""".stripMargin)
+      assert(intercept[IllegalArgumentException] {
+        spark.sql("CALL sys.rescale(table => 'T1', bucket_num => 4)")
+      }.getMessage.contains("has no snapshot"))
+    }
+
+    // Postpone bucket table requires bucket_num
+    withTable("T2") {
+      spark.sql(s"""
+                   |CREATE TABLE T2 (id INT, value STRING)
+                   |TBLPROPERTIES ('primary-key'='id', 'bucket'='-2')
+                   |""".stripMargin)
+      spark.sql(s"INSERT INTO T2 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+      assert(
+        intercept[IllegalArgumentException] {
+          spark.sql("CALL sys.rescale(table => 'T2')")
+        }.getMessage.contains(
+          "When rescaling postpone bucket tables, you must provide the 
resulting bucket number"))
+      checkAnswer(spark.sql("CALL sys.rescale(table => 'T2', bucket_num => 
4)"), Row(true) :: Nil)
+    }
+
+    // partitions and where cannot be used together
+    withTable("T3") {
+      spark.sql(s"""
+                   |CREATE TABLE T3 (id INT, value STRING, pt STRING)
+                   |TBLPROPERTIES ('primary-key'='id', 'bucket'='2')
+                   |PARTITIONED BY (pt)
+                   |""".stripMargin)
+      spark.sql(s"INSERT INTO T3 VALUES (1, 'a', 'p1'), (2, 'b', 'p2')")
+      assert(intercept[IllegalArgumentException] {
+        spark.sql(
+          "CALL sys.rescale(table => 'T3', bucket_num => 4, partitions => 
'pt=\"p1\"', where => 'pt = \"p1\"')")
+      }.getMessage.contains("partitions and where cannot be used together"))
+    }
+
+    // where clause with non-partition column should fail
+    withTable("T4") {
+      spark.sql(s"""
+                   |CREATE TABLE T4 (id INT, value STRING, pt STRING)
+                   |TBLPROPERTIES ('primary-key'='id', 'bucket'='2')
+                   |PARTITIONED BY (pt)
+                   |""".stripMargin)
+      spark.sql(s"INSERT INTO T4 VALUES (1, 'a', 'p1'), (2, 'b', 'p2')")
+      assert(intercept[IllegalArgumentException] {
+        spark.sql("CALL sys.rescale(table => 'T4', bucket_num => 4, where => 
'id = 1')")
+      }.getMessage.contains("Only partition predicate is supported"))
+    }
+  }
+
+  // ----------------------- Helper Methods -----------------------
+
+  def getBucketCount(table: FileStoreTable): Int = {
+    val bucketOption = table.coreOptions().bucket()
+    if (bucketOption == -1) {
+      val dataSplits = table.newSnapshotReader.read.dataSplits.asScala.toList
+      if (dataSplits.isEmpty) {
+        -1
+      } else {
+        dataSplits.map(_.bucket()).max + 1
+      }
+    } else {
+      bucketOption
+    }
+  }
+
+  def lastSnapshotCommand(table: FileStoreTable): CommitKind = {
+    table.snapshotManager().latestSnapshot().commitKind()
+  }
+
+  def lastSnapshotId(table: FileStoreTable): Long = {
+    table.snapshotManager().latestSnapshotId()
+  }
+}

Reply via email to