This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 7d0d0e9075 [spark] support rescale procedure in spark (#6612)
7d0d0e9075 is described below
commit 7d0d0e907550bc05c930f5bc602c5d987f60e6ab
Author: XiaoHongbo <[email protected]>
AuthorDate: Wed Nov 19 17:05:52 2025 +0800
[spark] support rescale procedure in spark (#6612)
---
.../org/apache/paimon/spark/SparkProcedures.java | 2 +
.../paimon/spark/procedure/CompactProcedure.java | 22 +-
.../paimon/spark/procedure/RescaleProcedure.java | 206 +++++++++++++
.../paimon/spark/utils/SparkProcedureUtils.java | 19 ++
.../spark/procedure/CompactProcedureTestBase.scala | 3 +-
.../spark/procedure/RescaleProcedureTest.scala | 333 +++++++++++++++++++++
6 files changed, 563 insertions(+), 22 deletions(-)
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
index ece6fe66c7..b5db4192b1 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
@@ -45,6 +45,7 @@ import
org.apache.paimon.spark.procedure.RemoveUnexistingFilesProcedure;
import org.apache.paimon.spark.procedure.RenameTagProcedure;
import org.apache.paimon.spark.procedure.RepairProcedure;
import org.apache.paimon.spark.procedure.ReplaceTagProcedure;
+import org.apache.paimon.spark.procedure.RescaleProcedure;
import org.apache.paimon.spark.procedure.ResetConsumerProcedure;
import org.apache.paimon.spark.procedure.RewriteFileIndexProcedure;
import org.apache.paimon.spark.procedure.RollbackProcedure;
@@ -92,6 +93,7 @@ public class SparkProcedures {
procedureBuilders.put("create_branch", CreateBranchProcedure::builder);
procedureBuilders.put("delete_branch", DeleteBranchProcedure::builder);
procedureBuilders.put("compact", CompactProcedure::builder);
+ procedureBuilders.put("rescale", RescaleProcedure::builder);
procedureBuilders.put("migrate_database",
MigrateDatabaseProcedure::builder);
procedureBuilders.put("migrate_table", MigrateTableProcedure::builder);
procedureBuilders.put("remove_orphan_files",
RemoveOrphanFilesProcedure::builder);
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java
index 56dcebe954..8e98727332 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java
@@ -20,7 +20,6 @@ package org.apache.paimon.spark.procedure;
import org.apache.paimon.CoreOptions;
import org.apache.paimon.CoreOptions.OrderType;
-import org.apache.paimon.annotation.VisibleForTesting;
import org.apache.paimon.append.AppendCompactCoordinator;
import org.apache.paimon.append.AppendCompactTask;
import org.apache.paimon.append.cluster.IncrementalClusterManager;
@@ -54,7 +53,6 @@ import org.apache.paimon.table.source.DataSplit;
import org.apache.paimon.table.source.EndOfScanException;
import org.apache.paimon.table.source.snapshot.SnapshotReader;
import org.apache.paimon.utils.Pair;
-import org.apache.paimon.utils.ParameterUtils;
import org.apache.paimon.utils.ProcedureUtils;
import org.apache.paimon.utils.SerializationUtils;
import org.apache.paimon.utils.StringUtils;
@@ -90,7 +88,6 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@@ -183,7 +180,7 @@ public class CompactProcedure extends BaseProcedure {
checkArgument(
partitions == null || where == null,
"partitions and where cannot be used together.");
- String finalWhere = partitions != null ? toWhere(partitions) : where;
+ String finalWhere = partitions != null ?
SparkProcedureUtils.toWhere(partitions) : where;
return modifyPaimonTable(
tableIdent,
t -> {
@@ -644,23 +641,6 @@ public class CompactProcedure extends BaseProcedure {
list -> list.toArray(new
DataSplit[0]))));
}
- @VisibleForTesting
- static String toWhere(String partitions) {
- List<Map<String, String>> maps =
ParameterUtils.getPartitions(partitions.split(";"));
-
- return maps.stream()
- .map(
- a ->
- a.entrySet().stream()
- .map(entry -> entry.getKey() + "=" +
entry.getValue())
- .reduce((s0, s1) -> s0 + " AND " + s1))
- .filter(Optional::isPresent)
- .map(Optional::get)
- .map(a -> "(" + a + ")")
- .reduce((a, b) -> a + " OR " + b)
- .orElse(null);
- }
-
public static ProcedureBuilder builder() {
return new BaseProcedure.Builder<CompactProcedure>() {
@Override
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/RescaleProcedure.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/RescaleProcedure.java
new file mode 100644
index 0000000000..8f8a8874be
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/RescaleProcedure.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.procedure;
+
+import org.apache.paimon.CoreOptions;
+import org.apache.paimon.Snapshot;
+import org.apache.paimon.partition.PartitionPredicate;
+import org.apache.paimon.spark.commands.PaimonSparkWriter;
+import org.apache.paimon.spark.util.ScanPlanHelper$;
+import org.apache.paimon.spark.utils.SparkProcedureUtils;
+import org.apache.paimon.table.BucketMode;
+import org.apache.paimon.table.FileStoreTable;
+import org.apache.paimon.table.source.DataSplit;
+import org.apache.paimon.table.source.snapshot.SnapshotReader;
+import org.apache.paimon.utils.StringUtils;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.PaimonUtils;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.catalog.Identifier;
+import org.apache.spark.sql.connector.catalog.TableCatalog;
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static org.apache.paimon.utils.Preconditions.checkArgument;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+import static org.apache.spark.sql.types.DataTypes.StringType;
+
+/**
+ * Rescale procedure. Usage:
+ *
+ * <pre><code>
+ * CALL sys.rescale(table => 'databaseName.tableName', [bucket_num => 16],
[partitions => 'dt=20250217,hh=08;dt=20250218,hh=08'], [where => 'dt>20250217'])
+ * </code></pre>
+ */
+public class RescaleProcedure extends BaseProcedure {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(RescaleProcedure.class);
+
+ private static final ProcedureParameter[] PARAMETERS =
+ new ProcedureParameter[] {
+ ProcedureParameter.required("table", StringType),
+ ProcedureParameter.optional("bucket_num", IntegerType),
+ ProcedureParameter.optional("partitions", StringType),
+ ProcedureParameter.optional("where", StringType),
+ };
+
+ private static final StructType OUTPUT_TYPE =
+ new StructType(
+ new StructField[] {
+ new StructField("result", DataTypes.BooleanType, true,
Metadata.empty())
+ });
+
+ protected RescaleProcedure(TableCatalog tableCatalog) {
+ super(tableCatalog);
+ }
+
+ @Override
+ public ProcedureParameter[] parameters() {
+ return PARAMETERS;
+ }
+
+ @Override
+ public StructType outputType() {
+ return OUTPUT_TYPE;
+ }
+
+ @Override
+ public InternalRow[] call(InternalRow args) {
+ Identifier tableIdent = toIdentifier(args.getString(0),
PARAMETERS[0].name());
+ Integer bucketNum = args.isNullAt(1) ? null : args.getInt(1);
+ String partitions = blank(args, 2) ? null : args.getString(2);
+ String where = blank(args, 3) ? null : args.getString(3);
+
+ checkArgument(
+ partitions == null || where == null,
+ "partitions and where cannot be used together.");
+ String finalWhere = partitions != null ?
SparkProcedureUtils.toWhere(partitions) : where;
+
+ return modifyPaimonTable(
+ tableIdent,
+ table -> {
+ checkArgument(table instanceof FileStoreTable);
+ FileStoreTable fileStoreTable = (FileStoreTable) table;
+
+ Optional<Snapshot> optionalSnapshot =
fileStoreTable.latestSnapshot();
+ if (!optionalSnapshot.isPresent()) {
+ throw new IllegalArgumentException(
+ "Table "
+ + table.fullName()
+ + " has no snapshot, no need to
rescale.");
+ }
+ Snapshot snapshot = optionalSnapshot.get();
+
+ // If someone commits while the rescale job is running,
this commit will be
+ // lost.
+ // So we use strict mode to make sure nothing is lost.
+ Map<String, String> dynamicOptions = new HashMap<>();
+ dynamicOptions.put(
+
CoreOptions.COMMIT_STRICT_MODE_LAST_SAFE_SNAPSHOT.key(),
+ String.valueOf(snapshot.id()));
+ fileStoreTable = fileStoreTable.copy(dynamicOptions);
+
+ DataSourceV2Relation relation = createRelation(tableIdent);
+ PartitionPredicate partitionPredicate =
+ SparkProcedureUtils.convertToPartitionPredicate(
+ finalWhere,
+
fileStoreTable.schema().logicalPartitionType(),
+ spark(),
+ relation);
+
+ if (bucketNum == null) {
+ checkArgument(
+ fileStoreTable.coreOptions().bucket() !=
BucketMode.POSTPONE_BUCKET,
+ "When rescaling postpone bucket tables, you
must provide the resulting bucket number.");
+ }
+
+ execute(fileStoreTable, bucketNum, partitionPredicate,
tableIdent);
+
+ InternalRow internalRow = newInternalRow(true);
+ return new InternalRow[] {internalRow};
+ });
+ }
+
+ private void execute(
+ FileStoreTable table,
+ @Nullable Integer bucketNum,
+ PartitionPredicate partitionPredicate,
+ Identifier tableIdent) {
+ DataSourceV2Relation relation = createRelation(tableIdent);
+
+ SnapshotReader snapshotReader = table.newSnapshotReader();
+ if (partitionPredicate != null) {
+ snapshotReader =
snapshotReader.withPartitionFilter(partitionPredicate);
+ }
+ List<DataSplit> dataSplits = snapshotReader.read().dataSplits();
+
+ if (dataSplits.isEmpty()) {
+ LOG.info("No data splits found for the specified partition. No
need to rescale.");
+ return;
+ }
+
+ Dataset<Row> datasetForRead =
+ PaimonUtils.createDataset(
+ spark(),
+ ScanPlanHelper$.MODULE$.createNewScanPlan(
+ dataSplits.toArray(new DataSplit[0]),
relation));
+
+ Map<String, String> bucketOptions = new HashMap<>(table.options());
+ if (bucketNum != null) {
+ bucketOptions.put(CoreOptions.BUCKET.key(),
String.valueOf(bucketNum));
+ }
+ FileStoreTable rescaledTable =
table.copy(table.schema().copy(bucketOptions));
+
+ PaimonSparkWriter writer = PaimonSparkWriter.apply(rescaledTable);
+ writer.writeBuilder().withOverwrite();
+ writer.commit(writer.write(datasetForRead));
+ }
+
+ private boolean blank(InternalRow args, int index) {
+ return args.isNullAt(index) ||
StringUtils.isNullOrWhitespaceOnly(args.getString(index));
+ }
+
+ @Override
+ public String description() {
+ return "This procedure rescales partitions of a table by changing the
bucket number.";
+ }
+
+ public static ProcedureBuilder builder() {
+ return new BaseProcedure.Builder<RescaleProcedure>() {
+ @Override
+ public RescaleProcedure doBuild() {
+ return new RescaleProcedure(tableCatalog());
+ }
+ };
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/utils/SparkProcedureUtils.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/utils/SparkProcedureUtils.java
index eafb7c9b0a..e4bc86b7a5 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/utils/SparkProcedureUtils.java
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/utils/SparkProcedureUtils.java
@@ -22,6 +22,7 @@ import org.apache.paimon.partition.PartitionPredicate;
import org.apache.paimon.predicate.Predicate;
import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionUtils;
import org.apache.paimon.types.RowType;
+import org.apache.paimon.utils.ParameterUtils;
import org.apache.paimon.utils.StringUtils;
import org.apache.spark.sql.SparkSession;
@@ -34,6 +35,8 @@ import org.slf4j.LoggerFactory;
import javax.annotation.Nullable;
import java.util.List;
+import java.util.Map;
+import java.util.Optional;
import static org.apache.paimon.utils.Preconditions.checkArgument;
@@ -88,4 +91,20 @@ public class SparkProcedureUtils {
}
return readParallelism;
}
+
+ public static String toWhere(String partitions) {
+ List<Map<String, String>> maps =
ParameterUtils.getPartitions(partitions.split(";"));
+
+ return maps.stream()
+ .map(
+ a ->
+ a.entrySet().stream()
+ .map(entry -> entry.getKey() + "=" +
entry.getValue())
+ .reduce((s0, s1) -> s0 + " AND " + s1))
+ .filter(Optional::isPresent)
+ .map(Optional::get)
+ .map(a -> "(" + a + ")")
+ .reduce((a, b) -> a + " OR " + b)
+ .orElse(null);
+ }
}
diff --git
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala
index 74f80befed..81e475e983 100644
---
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala
+++
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala
@@ -21,6 +21,7 @@ package org.apache.paimon.spark.procedure
import org.apache.paimon.Snapshot.CommitKind
import org.apache.paimon.fs.Path
import org.apache.paimon.spark.PaimonSparkTestBase
+import org.apache.paimon.spark.utils.SparkProcedureUtils
import org.apache.paimon.table.FileStoreTable
import org.apache.paimon.table.source.DataSplit
@@ -591,7 +592,7 @@ abstract class CompactProcedureTestBase extends
PaimonSparkTestBase with StreamT
test("Paimon test: toWhere method in CompactProcedure") {
val conditions = "f0=0,f1=0,f2=0;f0=1,f1=1,f2=1;f0=1,f1=2,f2=2;f3=3"
- val where = CompactProcedure.toWhere(conditions)
+ val where = SparkProcedureUtils.toWhere(conditions)
val whereExpected =
"(f0=0 AND f1=0 AND f2=0) OR (f0=1 AND f1=1 AND f2=1) OR (f0=1 AND f1=2
AND f2=2) OR (f3=3)"
diff --git
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/RescaleProcedureTest.scala
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/RescaleProcedureTest.scala
new file mode 100644
index 0000000000..b2d9af14ad
--- /dev/null
+++
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/procedure/RescaleProcedureTest.scala
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.procedure
+
+import org.apache.paimon.Snapshot.CommitKind
+import org.apache.paimon.partition.PartitionPredicate
+import org.apache.paimon.spark.PaimonSparkTestBase
+import org.apache.paimon.table.FileStoreTable
+
+import org.apache.spark.sql.Row
+import org.assertj.core.api.Assertions
+
+import java.util.{Arrays, Collections, Map => JMap}
+
+import scala.collection.JavaConverters._
+
+/** Tests for the rescale procedure. See [[RescaleProcedure]]. */
+class RescaleProcedureTest extends PaimonSparkTestBase {
+
+ test("Paimon Procedure: rescale basic functionality") {
+ withTable("T") {
+ spark.sql(s"""
+ |CREATE TABLE T (id INT, value STRING)
+ |TBLPROPERTIES ('primary-key'='id', 'bucket'='2')
+ |""".stripMargin)
+
+ val table = loadTable("T")
+ spark.sql(s"INSERT INTO T VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'),
(5, 'e')")
+
+ val initialData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+ val initialSnapshotId = lastSnapshotId(table)
+
+ // Rescale with explicit bucket_num
+ spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '4')")
+ checkAnswer(spark.sql("CALL sys.rescale(table => 'T', bucket_num =>
4)"), Row(true) :: Nil)
+
+ val reloadedTable = loadTable("T")
+
Assertions.assertThat(lastSnapshotCommand(reloadedTable)).isEqualTo(CommitKind.OVERWRITE)
+
Assertions.assertThat(lastSnapshotId(reloadedTable)).isGreaterThan(initialSnapshotId)
+ Assertions.assertThat(getBucketCount(reloadedTable)).isEqualTo(4)
+
+ val afterData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+
Assertions.assertThat(afterData).containsExactlyElementsOf(Arrays.asList(initialData:
_*))
+
+ // Rescale without bucket_num (use current bucket)
+ spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '3')")
+ checkAnswer(spark.sql("CALL sys.rescale(table => 'T')"), Row(true) ::
Nil)
+ Assertions.assertThat(getBucketCount(loadTable("T"))).isEqualTo(3)
+
+ // Rescale with explicit bucket_num
+ spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '4')")
+ checkAnswer(spark.sql("CALL sys.rescale(table => 'T', bucket_num =>
4)"), Row(true) :: Nil)
+ Assertions.assertThat(getBucketCount(loadTable("T"))).isEqualTo(4)
+
+ // Decrease bucket count (4 -> 2)
+ spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '2')")
+ checkAnswer(spark.sql("CALL sys.rescale(table => 'T', bucket_num =>
2)"), Row(true) :: Nil)
+ val reloadedTableAfterDecrease = loadTable("T")
+
Assertions.assertThat(getBucketCount(reloadedTableAfterDecrease)).isEqualTo(2)
+
reloadedTableAfterDecrease.newSnapshotReader.read.dataSplits.asScala.toList.foreach(
+ split => Assertions.assertThat(split.bucket()).isLessThan(2))
+
+ // Verify data integrity after bucket decrease
+ val afterDecreaseData = spark.sql("SELECT * FROM T ORDER BY
id").collect()
+ Assertions
+ .assertThat(afterDecreaseData)
+ .containsExactlyElementsOf(Arrays.asList(initialData: _*))
+ }
+ }
+
+ test("Paimon Procedure: rescale partitioned tables") {
+ withTable("T") {
+ spark.sql(s"""
+ |CREATE TABLE T (id INT, value STRING, pt STRING, dt
STRING, hh INT)
+ |TBLPROPERTIES ('primary-key'='id, pt, dt, hh',
'bucket'='2')
+ |PARTITIONED BY (pt, dt, hh)
+ |""".stripMargin)
+
+ val table = loadTable("T")
+ spark.sql(
+ s"INSERT INTO T VALUES (1, 'a', 'p1', '2024-01-01', 0), (2, 'b', 'p1',
'2024-01-01', 0)")
+ spark.sql(
+ s"INSERT INTO T VALUES (3, 'c', 'p2', '2024-01-01', 1), (4, 'd', 'p2',
'2024-01-02', 0)")
+
+ val initialData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+ val initialSnapshotId = lastSnapshotId(table)
+
+ // Rescale single partition field
+ spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '4')")
+ checkAnswer(
+ spark.sql("CALL sys.rescale(table => 'T', bucket_num => 4, partitions
=> 'pt=\"p1\"')"),
+ Row(true) :: Nil)
+
+ val reloadedTable = loadTable("T")
+
Assertions.assertThat(lastSnapshotCommand(reloadedTable)).isEqualTo(CommitKind.OVERWRITE)
+
Assertions.assertThat(lastSnapshotId(reloadedTable)).isGreaterThan(initialSnapshotId)
+
+ val p1Predicate = PartitionPredicate.fromMap(
+ reloadedTable.schema().logicalPartitionType(),
+ Collections.singletonMap("pt", "p1"),
+ reloadedTable.coreOptions().partitionDefaultName())
+ val p1Splits = reloadedTable.newSnapshotReader
+ .withPartitionFilter(p1Predicate)
+ .read
+ .dataSplits
+ .asScala
+ .toList
+ p1Splits.foreach(split =>
Assertions.assertThat(split.bucket()).isLessThan(4))
+
+ // Rescale multiple partition fields
+ val snapshotBeforeTest2 = lastSnapshotId(reloadedTable)
+ checkAnswer(
+ spark.sql(
+ "CALL sys.rescale(table => 'T', bucket_num => 4, partitions =>
'dt=\"2024-01-01\",hh=0')"),
+ Row(true) :: Nil)
+
+ val reloadedTable2 = loadTable("T")
+
Assertions.assertThat(lastSnapshotCommand(reloadedTable2)).isEqualTo(CommitKind.OVERWRITE)
+
Assertions.assertThat(lastSnapshotId(reloadedTable2)).isGreaterThan(snapshotBeforeTest2)
+
+ // Rescale empty partition (should not create new snapshot)
+ val snapshotBeforeEmpty = lastSnapshotId(reloadedTable2)
+ checkAnswer(
+ spark.sql("CALL sys.rescale(table => 'T', bucket_num => 4, partitions
=> 'pt=\"p3\"')"),
+ Row(true) :: Nil)
+
Assertions.assertThat(lastSnapshotId(loadTable("T"))).isEqualTo(snapshotBeforeEmpty)
+
+ val afterData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+
Assertions.assertThat(afterData).containsExactlyElementsOf(Arrays.asList(initialData:
_*))
+ }
+ }
+
+ test("Paimon Procedure: rescale with where clause") {
+ withTable("T") {
+ spark.sql(s"""
+ |CREATE TABLE T (id INT, value STRING, dt STRING, hh INT)
+ |TBLPROPERTIES ('primary-key'='id, dt, hh', 'bucket'='2')
+ |PARTITIONED BY (dt, hh)
+ |""".stripMargin)
+
+ val table = loadTable("T")
+ spark.sql(s"INSERT INTO T VALUES (1, 'a', '2024-01-01', 0), (2, 'b',
'2024-01-01', 0)")
+ spark.sql(s"INSERT INTO T VALUES (3, 'c', '2024-01-01', 1), (4, 'd',
'2024-01-01', 1)")
+ spark.sql(s"INSERT INTO T VALUES (5, 'e', '2024-01-02', 0), (6, 'f',
'2024-01-02', 1)")
+
+ val initialData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+ val initialSnapshotId = lastSnapshotId(table)
+
+ // Test 1: Rescale with where clause using single partition column
+ spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '4')")
+ checkAnswer(
+ spark.sql(
+ "CALL sys.rescale(table => 'T', bucket_num => 4, where => 'dt =
\"2024-01-01\"')"),
+ Row(true) :: Nil)
+
+ val reloadedTable = loadTable("T")
+
Assertions.assertThat(lastSnapshotCommand(reloadedTable)).isEqualTo(CommitKind.OVERWRITE)
+
Assertions.assertThat(lastSnapshotId(reloadedTable)).isGreaterThan(initialSnapshotId)
+
+ // Test 2: Rescale with where clause using multiple partition columns
+ val snapshotBeforeTest2 = lastSnapshotId(reloadedTable)
+ checkAnswer(
+ spark.sql(
+ "CALL sys.rescale(table => 'T', bucket_num => 4, where => 'dt =
\"2024-01-01\" AND hh >= 1')"),
+ Row(true) :: Nil)
+
+ val reloadedTable2 = loadTable("T")
+
Assertions.assertThat(lastSnapshotCommand(reloadedTable2)).isEqualTo(CommitKind.OVERWRITE)
+
Assertions.assertThat(lastSnapshotId(reloadedTable2)).isGreaterThan(snapshotBeforeTest2)
+
+ // Verify data integrity
+ val afterData = spark.sql("SELECT * FROM T ORDER BY id").collect()
+
Assertions.assertThat(afterData).containsExactlyElementsOf(Arrays.asList(initialData:
_*))
+ }
+ }
+
+ test("Paimon Procedure: rescale with ALTER TABLE and write validation") {
+ withTable("T") {
+ spark.sql(s"""
+ |CREATE TABLE T (f0 INT)
+ |TBLPROPERTIES ('bucket'='2', 'bucket-key'='f0')
+ |""".stripMargin)
+
+ val table = loadTable("T")
+
+ spark.sql(s"INSERT INTO T VALUES (1), (2), (3), (4), (5)")
+
+ val snapshot = lastSnapshotId(table)
+ Assertions.assertThat(snapshot).isGreaterThanOrEqualTo(0)
+
+ val initialBuckets = getBucketCount(table)
+ Assertions.assertThat(initialBuckets).isEqualTo(2)
+
+ val initialData = spark.sql("SELECT * FROM T ORDER BY f0").collect()
+ Assertions.assertThat(initialData.length).isEqualTo(5)
+
+ spark.sql("ALTER TABLE T SET TBLPROPERTIES ('bucket' = '4')")
+
+ val reloadedTable = loadTable("T")
+ val newBuckets = getBucketCount(reloadedTable)
+ Assertions.assertThat(newBuckets).isEqualTo(4)
+
+ val afterAlterData = spark.sql("SELECT * FROM T ORDER BY f0").collect()
+ val initialDataList = Arrays.asList(initialData: _*)
+
Assertions.assertThat(afterAlterData).containsExactlyElementsOf(initialDataList)
+
+ val writeError = intercept[org.apache.spark.SparkException] {
+ spark.sql("INSERT INTO T VALUES (6)")
+ }
+ val errorMessage = Option(writeError.getMessage).getOrElse("")
+ val causeMessage =
+ Option(writeError.getCause).flatMap(c =>
Option(c.getMessage)).getOrElse("")
+ val expectedMessage =
+ "Try to write table with a new bucket num 4, but the previous bucket
num is 2"
+ val fullMessage = Seq(errorMessage,
causeMessage).filter(_.nonEmpty).mkString(" ")
+ Assertions
+ .assertThat(fullMessage)
+ .contains(expectedMessage)
+
+ checkAnswer(spark.sql("CALL sys.rescale(table => 'T', bucket_num =>
4)"), Row(true) :: Nil)
+
+ val finalTable = loadTable("T")
+ val finalSnapshot = lastSnapshotId(finalTable)
+ Assertions.assertThat(finalSnapshot).isGreaterThan(snapshot)
+
Assertions.assertThat(lastSnapshotCommand(finalTable)).isEqualTo(CommitKind.OVERWRITE)
+
+ val finalBuckets = getBucketCount(finalTable)
+ Assertions.assertThat(finalBuckets).isEqualTo(4)
+
+ val afterRescaleData = spark.sql("SELECT * FROM T ORDER BY f0").collect()
+
Assertions.assertThat(afterRescaleData).containsExactlyElementsOf(initialDataList)
+
+ spark.sql("INSERT INTO T VALUES (6)")
+ val finalData = spark.sql("SELECT * FROM T ORDER BY f0").collect()
+ Assertions.assertThat(finalData.length).isEqualTo(6)
+ }
+ }
+
+ test("Paimon Procedure: rescale error cases") {
+ // Table with no snapshot
+ withTable("T1") {
+ spark.sql(s"""
+ |CREATE TABLE T1 (id INT, value STRING)
+ |TBLPROPERTIES ('primary-key'='id', 'bucket'='2')
+ |""".stripMargin)
+ assert(intercept[IllegalArgumentException] {
+ spark.sql("CALL sys.rescale(table => 'T1', bucket_num => 4)")
+ }.getMessage.contains("has no snapshot"))
+ }
+
+ // Postpone bucket table requires bucket_num
+ withTable("T2") {
+ spark.sql(s"""
+ |CREATE TABLE T2 (id INT, value STRING)
+ |TBLPROPERTIES ('primary-key'='id', 'bucket'='-2')
+ |""".stripMargin)
+ spark.sql(s"INSERT INTO T2 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
+ assert(
+ intercept[IllegalArgumentException] {
+ spark.sql("CALL sys.rescale(table => 'T2')")
+ }.getMessage.contains(
+ "When rescaling postpone bucket tables, you must provide the
resulting bucket number"))
+ checkAnswer(spark.sql("CALL sys.rescale(table => 'T2', bucket_num =>
4)"), Row(true) :: Nil)
+ }
+
+ // partitions and where cannot be used together
+ withTable("T3") {
+ spark.sql(s"""
+ |CREATE TABLE T3 (id INT, value STRING, pt STRING)
+ |TBLPROPERTIES ('primary-key'='id', 'bucket'='2')
+ |PARTITIONED BY (pt)
+ |""".stripMargin)
+ spark.sql(s"INSERT INTO T3 VALUES (1, 'a', 'p1'), (2, 'b', 'p2')")
+ assert(intercept[IllegalArgumentException] {
+ spark.sql(
+ "CALL sys.rescale(table => 'T3', bucket_num => 4, partitions =>
'pt=\"p1\"', where => 'pt = \"p1\"')")
+ }.getMessage.contains("partitions and where cannot be used together"))
+ }
+
+ // where clause with non-partition column should fail
+ withTable("T4") {
+ spark.sql(s"""
+ |CREATE TABLE T4 (id INT, value STRING, pt STRING)
+ |TBLPROPERTIES ('primary-key'='id', 'bucket'='2')
+ |PARTITIONED BY (pt)
+ |""".stripMargin)
+ spark.sql(s"INSERT INTO T4 VALUES (1, 'a', 'p1'), (2, 'b', 'p2')")
+ assert(intercept[IllegalArgumentException] {
+ spark.sql("CALL sys.rescale(table => 'T4', bucket_num => 4, where =>
'id = 1')")
+ }.getMessage.contains("Only partition predicate is supported"))
+ }
+ }
+
+ // ----------------------- Helper Methods -----------------------
+
+ def getBucketCount(table: FileStoreTable): Int = {
+ val bucketOption = table.coreOptions().bucket()
+ if (bucketOption == -1) {
+ val dataSplits = table.newSnapshotReader.read.dataSplits.asScala.toList
+ if (dataSplits.isEmpty) {
+ -1
+ } else {
+ dataSplits.map(_.bucket()).max + 1
+ }
+ } else {
+ bucketOption
+ }
+ }
+
+ def lastSnapshotCommand(table: FileStoreTable): CommitKind = {
+ table.snapshotManager().latestSnapshot().commitKind()
+ }
+
+ def lastSnapshotId(table: FileStoreTable): Long = {
+ table.snapshotManager().latestSnapshotId()
+ }
+}