This is an automated email from the ASF dual-hosted git repository. sunchao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5ec62a760bcd [SPARK-48065][SQL] SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict 5ec62a760bcd is described below commit 5ec62a760bcd718d6a600979df03bcabc2192d6b Author: Szehon Ho <szehon.apa...@gmail.com> AuthorDate: Thu May 2 16:14:36 2024 -0700 [SPARK-48065][SQL] SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict ### What changes were proposed in this pull request? If spark.sql.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled is true, change KeyGroupedPartitioning.satisfies0(distribution) check from all clustering keys (here, join keys) being in partition keys, to the two sets overlapping. ### Why are the changes needed? If spark.sql.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled is true, then SPJ no longer triggers if there are more join keys than partition keys. But SPJ is supported in this case if flag is false. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests in KeyGroupedPartitioningSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes #46325 from szehon-ho/fix_spj_less_join_key. Authored-by: Szehon Ho <szehon.apa...@gmail.com> Signed-off-by: Chao Sun <c...@openai.com> --- .../sql/catalyst/plans/physical/partitioning.scala | 5 +- .../connector/KeyGroupedPartitioningSuite.scala | 60 ++++++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2364130f79e4..43aba478c37b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -385,8 +385,9 @@ case class KeyGroupedPartitioning( val attributes = expressions.flatMap(_.collectLeaves()) if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - // check that all join keys (required clustering keys) contained in partitioning - requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) && + // check that join keys (required clustering keys) + // overlap with partition keys (KeyGroupedPartitioning attributes) + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && expressions.forall(_.collectLeaves().size == 1) } else { attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index ec275fe101fd..10a32441b6cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1227,6 +1227,66 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-48065: SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id")) + createTable(table1, columns, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-01' as timestamp)), " + + "(2, 'cc', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp))") + + createTable(table2, columns, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(2, 'ww', cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.id AS id, t1.data AS t1data, t2.data AS t2data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.id = t2.id AND t1.data = t2.data ORDER BY t1.id, t1data, t2data + |""".stripMargin) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + if (partiallyClustered) { + assert(scans == Seq(8, 8)) + } else { + assert(scans == Seq(4, 4)) + } + checkAnswer(df, Seq( + Row(3, "dd", "dd"), + Row(3, "dd", "dd"), + Row(3, "dd", "dd"), + Row(3, "dd", "dd") + )) + } + } + } + } + test("SPARK-44647: test join key is subset of cluster key " + "with push values and partially-clustered") { val table1 = "tab1e1" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org