This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new ed0b0cf79f [spark] Fix the bucket join may produce wrong result after
bucket rescaled (#5669)
ed0b0cf79f is described below
commit ed0b0cf79fe2303c9fd8bf6db8932112975e0fcd
Author: WenjunMin <[email protected]>
AuthorDate: Thu May 29 10:01:11 2025 +0800
[spark] Fix the bucket join may produce wrong result after bucket rescaled
(#5669)
---
.../scala/org/apache/paimon/spark/PaimonScan.scala | 34 +++++++++++++---
.../org/apache/paimon/spark/PaimonBaseScan.scala | 21 ++++++----
.../scala/org/apache/paimon/spark/PaimonScan.scala | 34 +++++++++++++---
.../paimon/spark/sql/BucketedTableQueryTest.scala | 47 +++++++++++++++++++++-
4 files changed, 117 insertions(+), 19 deletions(-)
diff --git
a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index 4c62d58a81..ec589442e8 100644
---
a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++
b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -60,11 +60,16 @@ case class PaimonScan(
// so we only support one bucket key case.
assert(bucketSpec.getNumBuckets > 0)
assert(bucketSpec.getBucketKeys.size() == 1)
- val bucketKey = bucketSpec.getBucketKeys.get(0)
- if (requiredSchema.exists(f => conf.resolver(f.name, bucketKey))) {
- Some(Expressions.bucket(bucketSpec.getNumBuckets, bucketKey))
- } else {
- None
+ extractBucketNumber() match {
+ case Some(num) =>
+ val bucketKey = bucketSpec.getBucketKeys.get(0)
+ if (requiredSchema.exists(f => conf.resolver(f.name,
bucketKey))) {
+ Some(Expressions.bucket(num, bucketKey))
+ } else {
+ None
+ }
+
+ case _ => None
}
}
@@ -72,6 +77,24 @@ case class PaimonScan(
}
}
+ /**
+ * Extract the bucket number from the splits only if all splits have the
same totalBuckets number.
+ */
+ private def extractBucketNumber(): Option[Int] = {
+ val splits = getOriginSplits
+ if (splits.exists(!_.isInstanceOf[DataSplit])) {
+ None
+ } else {
+ val deduplicated =
+ splits.map(s =>
Option(s.asInstanceOf[DataSplit].totalBuckets())).toSeq.distinct
+
+ deduplicated match {
+ case Seq(Some(num)) => Some(num)
+ case _ => None
+ }
+ }
+ }
+
private def shouldDoBucketedScan: Boolean = {
!bucketedScanDisabled && conf.v2BucketingEnabled &&
extractBucketTransform.isDefined
}
@@ -120,6 +143,7 @@ case class PaimonScan(
readBuilder.withFilter(partitionFilter.head)
// set inputPartitions null to trigger to get the new splits.
inputPartitions = null
+ inputSplits = null
}
}
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
index 74741f5364..b0447c8830 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
@@ -51,6 +51,8 @@ abstract class PaimonBaseScan(
protected var inputPartitions: Seq[PaimonInputPartition] = _
+ protected var inputSplits: Array[Split] = _
+
override val coreOptions: CoreOptions = CoreOptions.fromMap(table.options())
lazy val statistics: Optional[stats.Statistics] = table.statistics()
@@ -65,14 +67,17 @@ abstract class PaimonBaseScan(
@VisibleForTesting
def getOriginSplits: Array[Split] = {
- readBuilder
- .newScan()
- .asInstanceOf[InnerTableScan]
- .withMetricRegistry(paimonMetricsRegistry)
- .plan()
- .splits()
- .asScala
- .toArray
+ if (inputSplits == null) {
+ inputSplits = readBuilder
+ .newScan()
+ .asInstanceOf[InnerTableScan]
+ .withMetricRegistry(paimonMetricsRegistry)
+ .plan()
+ .splits()
+ .asScala
+ .toArray
+ }
+ inputSplits
}
final def lazyInputPartitions: Seq[PaimonInputPartition] = {
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index 20c1cfffad..616c660255 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -62,11 +62,16 @@ case class PaimonScan(
// so we only support one bucket key case.
assert(bucketSpec.getNumBuckets > 0)
assert(bucketSpec.getBucketKeys.size() == 1)
- val bucketKey = bucketSpec.getBucketKeys.get(0)
- if (requiredSchema.exists(f => conf.resolver(f.name, bucketKey))) {
- Some(Expressions.bucket(bucketSpec.getNumBuckets, bucketKey))
- } else {
- None
+ extractBucketNumber() match {
+ case Some(num) =>
+ val bucketKey = bucketSpec.getBucketKeys.get(0)
+ if (requiredSchema.exists(f => conf.resolver(f.name,
bucketKey))) {
+ Some(Expressions.bucket(num, bucketKey))
+ } else {
+ None
+ }
+
+ case _ => None
}
}
@@ -74,6 +79,24 @@ case class PaimonScan(
}
}
+ /**
+ * Extract the bucket number from the splits only if all splits have the
same totalBuckets number.
+ */
+ private def extractBucketNumber(): Option[Int] = {
+ val splits = getOriginSplits
+ if (splits.exists(!_.isInstanceOf[DataSplit])) {
+ None
+ } else {
+ val deduplicated =
+ splits.map(s =>
Option(s.asInstanceOf[DataSplit].totalBuckets())).toSeq.distinct
+
+ deduplicated match {
+ case Seq(Some(num)) => Some(num)
+ case _ => None
+ }
+ }
+ }
+
private def shouldDoBucketedScan: Boolean = {
!bucketedScanDisabled && conf.v2BucketingEnabled &&
extractBucketTransform.isDefined
}
@@ -169,6 +192,7 @@ case class PaimonScan(
readBuilder.withFilter(partitionFilter.toList.asJava)
// set inputPartitions null to trigger to get the new splits.
inputPartitions = null
+ inputSplits = null
}
}
}
diff --git
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BucketedTableQueryTest.scala
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BucketedTableQueryTest.scala
index 35931924c4..3f87f8ec6f 100644
---
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BucketedTableQueryTest.scala
+++
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BucketedTableQueryTest.scala
@@ -36,7 +36,9 @@ class BucketedTableQueryTest extends PaimonSparkTestBase with
AdaptiveSparkPlanH
}
withSparkSQLConf(
"spark.sql.sources.v2.bucketing.enabled" -> "true",
- "spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+ "spark.sql.requireAllClusterKeysForCoPartition" -> "false",
+ "spark.sql.autoBroadcastJoinThreshold" -> "-1"
+ ) {
val df = spark.sql(query)
checkAnswer(df, expectedResult.toSeq)
assert(collect(df.queryExecution.executedPlan) {
@@ -50,6 +52,49 @@ class BucketedTableQueryTest extends PaimonSparkTestBase
with AdaptiveSparkPlanH
}
}
+ test("Query on a rescaled bucket table") {
+ assume(gteqSpark3_3)
+
+ withTable("t1", "t2") {
+
+ spark.sql(
+ "CREATE TABLE t1 (id INT, c STRING, dt STRING) partitioned by (dt)
TBLPROPERTIES ('bucket'='2', 'bucket-key' = 'id')")
+ spark.sql(
+ "CREATE TABLE t2 (id INT, c STRING, dt STRING) partitioned by (dt)
TBLPROPERTIES ('bucket'='3', 'bucket-key' = 'id')")
+ spark.sql("INSERT INTO t1 VALUES (1, 'x1', '20250101'), (3, 'x2',
'20250101')")
+ spark.sql("INSERT INTO t2 VALUES (1, 'x1', '20250101'), (4, 'x2',
'20250101')")
+ checkAnswerAndShuffleSorts(
+ "SELECT * FROM t1 JOIN t2 on t1.id = t2.id and t1.dt = '20250101' and
t2.dt = '20250101'",
+ 2,
+ 2)
+ spark.sql("ALTER TABLE t1 SET TBLPROPERTIES ('bucket' = '3')")
+ checkAnswerAndShuffleSorts(
+ "SELECT * FROM t1 JOIN t2 on t1.id = t2.id and t1.dt = t2.dt ",
+ 2,
+ 2)
+ }
+
+ withTable("t1", "t2") {
+
+ spark.sql(
+ "CREATE TABLE t1 (id INT, c STRING, dt STRING) partitioned by (dt)
TBLPROPERTIES ('bucket'='2', 'bucket-key' = 'id')")
+ spark.sql(
+ "CREATE TABLE t2 (id INT, c STRING, dt STRING) partitioned by (dt)
TBLPROPERTIES ('bucket'='2', 'bucket-key' = 'id')")
+ // TODO if the input partition is not aligned by bucket value, the
bucket join will not be applied.
+ spark.sql("INSERT INTO t1 VALUES (1, 'x1', '20250101'), (2, 'x2',
'20250101')")
+ spark.sql("INSERT INTO t2 VALUES (1, 'x1', '20250101'), (5, 'x2',
'20250101')")
+ checkAnswerAndShuffleSorts(
+ "SELECT * FROM t1 JOIN t2 on t1.id = t2.id and t1.dt = '20250101' and
t2.dt = '20250101'",
+ 0,
+ 2)
+ spark.sql("ALTER TABLE t1 SET TBLPROPERTIES ('bucket' = '3')")
+ checkAnswerAndShuffleSorts(
+ "SELECT * FROM t1 JOIN t2 on t1.id = t2.id and t1.dt = t2.dt ",
+ 0,
+ 2)
+ }
+ }
+
test("Query on a bucketed table - join - positive case") {
assume(gteqSpark3_3)