This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9520087409d [SPARK-44647][SQL] Support SPJ where join keys are less 
than cluster keys
9520087409d is described below

commit 9520087409d5bd7e6a2651dacf2c295d564d5559
Author: Szehon Ho <szehon.apa...@gmail.com>
AuthorDate: Mon Sep 11 11:18:46 2023 -0700

    [SPARK-44647][SQL] Support SPJ where join keys are less than cluster keys
    
    ### What changes were proposed in this pull request?
    - Add new conf 
spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled
    - Change key compatibility checks in EnsureRequirements.  Remove checks 
where all partition keys must be in join keys to allow isKeyCompatible = true 
in this case (if this flag is enabled)
    - Change BatchScanExec/DataSourceV2Relation to group splits by join keys if 
they differ from partition keys (previously grouped only by partition values).  
Do same for all auxiliary data structure, like commonPartValues.
    - Implement partiallyClustered skew-handling.
      - Group only the replicate side (now by join key as well), replicate by 
the total size of other-side partitions that share the join key.
      - add an additional sort for partitions based on join key, as when we 
group the replicate side, partition ordering becomes out of order from the 
non-replicate side.
    
    ### Why are the changes needed?
    - Support Storage Partition Join in cases where the join condition does not 
contain all the partition keys, but just some of them
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    -Added tests in KeyGroupedPartitioningSuite
    -Found two existing problems, will address in separate PR:
    - Because of https://github.com/apache/spark/pull/37886   we have to select 
all join keys to trigger SPJ in this case, otherwise DSV2 scan does not report 
KeyGroupedPartitioning and SPJ does not get triggered.  Need to see how to 
relax this.
    - https://issues.apache.org/jira/browse/SPARK-44641 was found when testing 
this change.  This pr refactors some of those code to add group-by-join-key, 
but doesnt change the underlying logic, so issue continues to exist.  Hopefully 
this will also get fixed in another way.
    
    Closes #42306 from szehon-ho/spj_attempt_master.
    
    Authored-by: Szehon Ho <szehon.apa...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../sql/catalyst/plans/physical/partitioning.scala |  59 ++++-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  15 ++
 .../execution/datasources/v2/BatchScanExec.scala   |  56 +++--
 .../execution/exchange/EnsureRequirements.scala    |  15 +-
 .../connector/KeyGroupedPartitioningSuite.scala    | 265 ++++++++++++++++++++-
 5 files changed, 378 insertions(+), 32 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 0be4a61f275..a61bd3b7324 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -355,7 +355,14 @@ case class KeyGroupedPartitioning(
           } else {
             // We'll need to find leaf attributes from the partition 
expressions first.
             val attributes = expressions.flatMap(_.collectLeaves())
-            attributes.forall(x => 
requiredClustering.exists(_.semanticEquals(x)))
+
+            if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
+              // check that all join keys (required clustering keys) contained 
in partitioning
+              requiredClustering.forall(x => 
attributes.exists(_.semanticEquals(x))) &&
+                  expressions.forall(_.collectLeaves().size == 1)
+            } else {
+              attributes.forall(x => 
requiredClustering.exists(_.semanticEquals(x)))
+            }
           }
 
         case _ =>
@@ -364,8 +371,20 @@ case class KeyGroupedPartitioning(
     }
   }
 
-  override def createShuffleSpec(distribution: ClusteredDistribution): 
ShuffleSpec =
-    KeyGroupedShuffleSpec(this, distribution)
+  override def createShuffleSpec(distribution: ClusteredDistribution): 
ShuffleSpec = {
+    val result = KeyGroupedShuffleSpec(this, distribution)
+    if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
+      // If allowing join keys to be subset of clustering keys, we should 
create a new
+      // `KeyGroupedPartitioning` here that is grouped on the join keys 
instead, and use that as
+      // the returned shuffle spec.
+      val joinKeyPositions = 
result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
+      val projectedPartitioning = KeyGroupedPartitioning(expressions, 
joinKeyPositions,
+          partitionValues, originalPartitionValues)
+      result.copy(partitioning = projectedPartitioning, joinKeyPositions = 
Some(joinKeyPositions))
+    } else {
+      result
+    }
+  }
 
   lazy val uniquePartitionValues: Seq[InternalRow] = {
     partitionValues
@@ -378,8 +397,25 @@ case class KeyGroupedPartitioning(
 object KeyGroupedPartitioning {
   def apply(
       expressions: Seq[Expression],
-      partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
-    KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, 
partitionValues)
+      projectionPositions: Seq[Int],
+      partitionValues: Seq[InternalRow],
+      originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
+    val projectedExpressions = projectionPositions.map(expressions(_))
+    val projectedPartitionValues = partitionValues.map(project(expressions, 
projectionPositions, _))
+    val projectedOriginalPartitionValues =
+      originalPartitionValues.map(project(expressions, projectionPositions, _))
+
+    KeyGroupedPartitioning(projectedExpressions, 
projectedPartitionValues.length,
+      projectedPartitionValues, projectedOriginalPartitionValues)
+  }
+
+  def project(
+      expressions: Seq[Expression],
+      positions: Seq[Int],
+      input: InternalRow): InternalRow = {
+    val projectedValues: Array[Any] = positions.map(i => input.get(i, 
expressions(i).dataType))
+      .toArray
+    new GenericInternalRow(projectedValues)
   }
 
   def supportsExpressions(expressions: Seq[Expression]): Boolean = {
@@ -672,9 +708,18 @@ case class HashShuffleSpec(
   override def numPartitions: Int = partitioning.numPartitions
 }
 
+/**
+ * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]].
+ *
+ * @param partitioning key grouped partitioning
+ * @param distribution distribution
+ * @param joinKeyPosition position of join keys among cluster keys.
+ *                        This is set if joining on a subset of cluster keys 
is allowed.
+ */
 case class KeyGroupedShuffleSpec(
     partitioning: KeyGroupedPartitioning,
-    distribution: ClusteredDistribution) extends ShuffleSpec {
+    distribution: ClusteredDistribution,
+    joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec {
 
   /**
    * A sequence where each element is a set of positions of the partition 
expression to the cluster
@@ -709,7 +754,7 @@ case class KeyGroupedShuffleSpec(
     //    3.3 each pair of partition expressions at the same index must share 
compatible
     //        transform functions.
     //  4. the partition values from both sides are following the same order.
-    case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, 
otherDistribution) =>
+    case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, 
otherDistribution, _) =>
       distribution.clustering.length == otherDistribution.clustering.length &&
         numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
           
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 8c8b33921e3..49a4b0bf98b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1530,6 +1530,18 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS =
+    
buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled")
+      .doc("Whether to allow storage-partition join in the case where join 
keys are" +
+        "a subset of the partition keys of the source tables. At planning 
time, " +
+        "Spark will group the partitions by only those keys that are in the 
join keys." +
+        s"This is currently enabled only if 
${REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key} " +
+        "is false."
+      )
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val BUCKETING_MAX_BUCKETS = 
buildConf("spark.sql.sources.bucketing.maxBuckets")
     .doc("The maximum number of buckets allowed.")
     .version("2.4.0")
@@ -4936,6 +4948,9 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
   def v2BucketingShuffleEnabled: Boolean =
     getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED)
 
+  def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
+    getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)
+
   def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
     getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index 932ac0f5a1b..094a7b20808 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -120,7 +120,12 @@ case class BatchScanExec(
         val newPartValues = spjParams.commonPartitionValues.get.flatMap {
           case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
         }
-        k.copy(numPartitions = newPartValues.length, partitionValues = 
newPartValues)
+        val expressions = spjParams.joinKeyPositions match {
+          case Some(projectionPositions) => projectionPositions.map(i => 
k.expressions(i))
+          case _ => k.expressions
+        }
+        k.copy(expressions = expressions, numPartitions = newPartValues.length,
+          partitionValues = newPartValues)
       case p => p
     }
   }
@@ -132,14 +137,29 @@ case class BatchScanExec(
       // return an empty RDD with 1 partition if dynamic filtering removed the 
only split
       sparkContext.parallelize(Array.empty[InternalRow], 1)
     } else {
-      var finalPartitions = filteredPartitions
-
-      outputPartitioning match {
+      val finalPartitions = outputPartitioning match {
         case p: KeyGroupedPartitioning =>
-          val groupedPartitions = filteredPartitions.map(splits => {
-            assert(splits.nonEmpty && 
splits.head.isInstanceOf[HasPartitionKey])
-            (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
-          })
+          assert(spjParams.keyGroupedPartitioning.isDefined)
+          val expressions = spjParams.keyGroupedPartitioning.get
+
+          // Re-group the input partitions if we are projecting on a subset of 
join keys
+          val (groupedPartitions, partExpressions) = 
spjParams.joinKeyPositions match {
+            case Some(projectPositions) =>
+              val projectedExpressions = projectPositions.map(i => 
expressions(i))
+              val parts = filteredPartitions.flatten.groupBy(part => {
+                val row = part.asInstanceOf[HasPartitionKey].partitionKey()
+                val projectedRow = KeyGroupedPartitioning.project(
+                  expressions, projectPositions, row)
+                InternalRowComparableWrapper(projectedRow, 
projectedExpressions)
+              }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq
+              (parts, projectedExpressions)
+            case _ =>
+              val groupedParts = filteredPartitions.map(splits => {
+                assert(splits.nonEmpty && 
splits.head.isInstanceOf[HasPartitionKey])
+                (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), 
splits)
+              })
+              (groupedParts, expressions)
+          }
 
           // When partially clustered, the input partitions are not grouped by 
partition
           // values. Here we'll need to check `commonPartitionValues` and 
decide how to group
@@ -149,12 +169,12 @@ case class BatchScanExec(
             // should contain.
             val commonPartValuesMap = spjParams.commonPartitionValues
                 .get
-                .map(t => (InternalRowComparableWrapper(t._1, p.expressions), 
t._2))
+                .map(t => (InternalRowComparableWrapper(t._1, 
partExpressions), t._2))
                 .toMap
             val nestGroupedPartitions = groupedPartitions.map { case 
(partValue, splits) =>
               // `commonPartValuesMap` should contain the part value since 
it's the super set.
               val numSplits = commonPartValuesMap
-                  .get(InternalRowComparableWrapper(partValue, p.expressions))
+                  .get(InternalRowComparableWrapper(partValue, 
partExpressions))
               assert(numSplits.isDefined, s"Partition value $partValue does 
not exist in " +
                   "common partition values from Spark plan")
 
@@ -169,37 +189,37 @@ case class BatchScanExec(
                 // sides of a join will have the same number of partitions & 
splits.
                 splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
               }
-              (InternalRowComparableWrapper(partValue, p.expressions), 
newSplits)
+              (InternalRowComparableWrapper(partValue, partExpressions), 
newSplits)
             }
 
             // Now fill missing partition keys with empty partitions
             val partitionMapping = nestGroupedPartitions.toMap
-            finalPartitions = spjParams.commonPartitionValues.get.flatMap {
+            spjParams.commonPartitionValues.get.flatMap {
               case (partValue, numSplits) =>
                 // Use empty partition for those partition values that are not 
present.
                 partitionMapping.getOrElse(
-                  InternalRowComparableWrapper(partValue, p.expressions),
+                  InternalRowComparableWrapper(partValue, partExpressions),
                   Seq.fill(numSplits)(Seq.empty))
             }
           } else {
             // either `commonPartitionValues` is not defined, or it is defined 
but
             // `applyPartialClustering` is false.
             val partitionMapping = groupedPartitions.map { case (partValue, 
splits) =>
-              InternalRowComparableWrapper(partValue, p.expressions) -> splits
+              InternalRowComparableWrapper(partValue, partExpressions) -> 
splits
             }.toMap
 
             // In case `commonPartitionValues` is not defined (e.g., SPJ is 
not used), there
             // could exist duplicated partition values, as partition grouping 
is not done
             // at the beginning and postponed to this method. It is important 
to use unique
             // partition values here so that grouped partitions won't get 
duplicated.
-            finalPartitions = p.uniquePartitionValues.map { partValue =>
+            p.uniquePartitionValues.map { partValue =>
               // Use empty partition for those partition values that are not 
present
               partitionMapping.getOrElse(
-                InternalRowComparableWrapper(partValue, p.expressions), 
Seq.empty)
+                InternalRowComparableWrapper(partValue, partExpressions), 
Seq.empty)
             }
           }
 
-        case _ =>
+        case _ => filteredPartitions
       }
 
       new DataSourceRDD(
@@ -234,6 +254,7 @@ case class BatchScanExec(
 
 case class StoragePartitionJoinParams(
     keyGroupedPartitioning: Option[Seq[Expression]] = None,
+    joinKeyPositions: Option[Seq[Int]] = None,
     commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
     applyPartialClustering: Boolean = false,
     replicatePartitions: Boolean = false) {
@@ -247,6 +268,7 @@ case class StoragePartitionJoinParams(
   }
 
   override def hashCode(): Int = Objects.hashCode(
+    joinKeyPositions: Option[Seq[Int]],
     commonPartitionValues: Option[Seq[(InternalRow, Int)]],
     applyPartialClustering: java.lang.Boolean,
     replicatePartitions: java.lang.Boolean)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index f8e6fd1d016..8552c950f67 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -380,7 +380,8 @@ case class EnsureRequirements(
     val rightSpec = specs(1)
 
     var isCompatible = false
-    if (!conf.v2BucketingPushPartValuesEnabled) {
+    if (!conf.v2BucketingPushPartValuesEnabled &&
+        !conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
       isCompatible = leftSpec.isCompatibleWith(rightSpec)
     } else {
       logInfo("Pushing common partition values for storage-partitioned join")
@@ -505,10 +506,10 @@ case class EnsureRequirements(
         }
 
         // Now we need to push-down the common partition key to the scan in 
each child
-        newLeft = populatePartitionValues(
-          left, mergedPartValues, applyPartialClustering, replicateLeftSide)
-        newRight = populatePartitionValues(
-          right, mergedPartValues, applyPartialClustering, replicateRightSide)
+        newLeft = populatePartitionValues(left, mergedPartValues, 
leftSpec.joinKeyPositions,
+          applyPartialClustering, replicateLeftSide)
+        newRight = populatePartitionValues(right, mergedPartValues, 
rightSpec.joinKeyPositions,
+          applyPartialClustering, replicateRightSide)
       }
     }
 
@@ -530,19 +531,21 @@ case class EnsureRequirements(
   private def populatePartitionValues(
       plan: SparkPlan,
       values: Seq[(InternalRow, Int)],
+      joinKeyPositions: Option[Seq[Int]],
       applyPartialClustering: Boolean,
       replicatePartitions: Boolean): SparkPlan = plan match {
     case scan: BatchScanExec =>
       scan.copy(
         spjParams = scan.spjParams.copy(
           commonPartitionValues = Some(values),
+          joinKeyPositions = joinKeyPositions,
           applyPartialClustering = applyPartialClustering,
           replicatePartitions = replicatePartitions
         )
       )
     case node =>
       node.mapChildren(child => populatePartitionValues(
-        child, values, applyPartialClustering, replicatePartitions))
+        child, values, joinKeyPositions, applyPartialClustering, 
replicatePartitions))
   }
 
   /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index b22aba61aab..ffd1c8e31e9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -98,14 +98,17 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
     val catalystDistribution = physical.ClusteredDistribution(
       Seq(TransformExpression(YearsFunction, Seq(attr("ts")))))
     val partitionValues = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v)))
+    val projectedPositions = catalystDistribution.clustering.indices
 
     checkQueryPlan(df, catalystDistribution,
-      physical.KeyGroupedPartitioning(catalystDistribution.clustering, 
partitionValues))
+      physical.KeyGroupedPartitioning(catalystDistribution.clustering, 
projectedPositions,
+        partitionValues, partitionValues))
 
     // multiple group keys should work too as long as partition keys are 
subset of them
     df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts")
     checkQueryPlan(df, catalystDistribution,
-      physical.KeyGroupedPartitioning(catalystDistribution.clustering, 
partitionValues))
+      physical.KeyGroupedPartitioning(catalystDistribution.clustering, 
projectedPositions,
+        partitionValues, partitionValues))
   }
 
   test("non-clustered distribution: no partition") {
@@ -1276,4 +1279,262 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
       }
     }
   }
+
+  test("SPARK-44647: test join key is subset of cluster key " +
+      "with push values and partially-clustered") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+    val partition = Array(identity("id"), identity("data"))
+    createTable(table1, schema, partition)
+    sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+        "(1, 'aa', cast('2020-01-01' as timestamp)), " +
+        "(2, 'bb', cast('2020-01-01' as timestamp)), " +
+        "(2, 'cc', cast('2020-01-01' as timestamp)), " +
+        "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+        "(3, 'dd', cast('2020-01-01' as timestamp)), " +
+        "(3, 'ee', cast('2020-01-01' as timestamp)), " +
+        "(3, 'ee', cast('2020-01-01' as timestamp))")
+
+    createTable(table2, schema, partition)
+    sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+        "(4, 'zz', cast('2020-01-01' as timestamp)), " +
+        "(4, 'zz', cast('2020-01-01' as timestamp)), " +
+        "(3, 'yy', cast('2020-01-01' as timestamp)), " +
+        "(3, 'yy', cast('2020-01-01' as timestamp)), " +
+        "(3, 'xx', cast('2020-01-01' as timestamp)), " +
+        "(3, 'xx', cast('2020-01-01' as timestamp)), " +
+        "(2, 'ww', cast('2020-01-01' as timestamp))")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(true, false).foreach { partiallyClustered =>
+        Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys =>
+
+          withSQLConf(
+            SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+            SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
pushDownValues.toString,
+            SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key 
->
+                partiallyClustered.toString,
+            SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key 
->
+                allowJoinKeysSubsetOfPartitionKeys.toString) {
+
+            val df = sql("SELECT t1.id AS id, t1.data AS t1data, t2.data AS 
t2data " +
+                s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " +
+                "ON t1.id = t2.id ORDER BY t1.id, t1data, t2data")
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            if (allowJoinKeysSubsetOfPartitionKeys) {
+              assert(shuffles.isEmpty, "SPJ should be triggered")
+            } else {
+              assert(shuffles.nonEmpty, "SPJ should not be triggered")
+            }
+
+            val scans = collectScans(df.queryExecution.executedPlan)
+                .map(_.inputRDD.partitions.length)
+
+            (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match {
+              // SPJ and partially-clustered
+              case (true, true) => assert(scans == Seq(8, 8))
+              // SPJ and not partially-clustered
+              case (true, false) => assert(scans == Seq(4, 4))
+              // No SPJ
+              case _ => assert(scans == Seq(5, 4))
+            }
+
+            checkAnswer(df, Seq(
+              Row(2, "bb", "ww"),
+              Row(2, "cc", "ww"),
+              Row(3, "dd", "xx"),
+              Row(3, "dd", "xx"),
+              Row(3, "dd", "xx"),
+              Row(3, "dd", "xx"),
+              Row(3, "dd", "yy"),
+              Row(3, "dd", "yy"),
+              Row(3, "dd", "yy"),
+              Row(3, "dd", "yy"),
+              Row(3, "ee", "xx"),
+              Row(3, "ee", "xx"),
+              Row(3, "ee", "xx"),
+              Row(3, "ee", "xx"),
+              Row(3, "ee", "yy"),
+              Row(3, "ee", "yy"),
+              Row(3, "ee", "yy"),
+              Row(3, "ee", "yy")
+            ))
+          }
+        }
+      }
+    }
+  }
+
+  test("SPARK-44647: test join key is the second cluster key") {
+    val table1 = "tab1e1"
+    val table2 = "table2"
+    val partition = Array(identity("id"), identity("data"))
+    createTable(table1, schema, partition)
+    sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
+        "(1, 'aa', cast('2020-01-01' as timestamp)), " +
+        "(2, 'bb', cast('2020-01-02' as timestamp)), " +
+        "(3, 'cc', cast('2020-01-03' as timestamp))")
+
+    createTable(table2, schema, partition)
+    sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
+        "(4, 'aa', cast('2020-01-01' as timestamp)), " +
+        "(5, 'bb', cast('2020-01-02' as timestamp)), " +
+        "(6, 'cc', cast('2020-01-03' as timestamp))")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(true, false).foreach { partiallyClustered =>
+        Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys =>
+          withSQLConf(
+            SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+            SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
+                pushDownValues.toString,
+            SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key 
->
+                partiallyClustered.toString,
+            SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key 
->
+                allowJoinKeysSubsetOfPartitionKeys.toString) {
+
+            val df = sql("SELECT t1.id AS t1id, t2.id as t2id, t1.data AS data 
" +
+                s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " +
+                "ON t1.data = t2.data ORDER BY t1id, t1id, data")
+
+            checkAnswer(df, Seq(Row(1, 4, "aa"), Row(2, 5, "bb"), Row(3, 6, 
"cc")))
+
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            if (allowJoinKeysSubsetOfPartitionKeys) {
+              assert(shuffles.isEmpty, "SPJ should be triggered")
+            } else {
+              assert(shuffles.nonEmpty, "SPJ should not be triggered")
+            }
+
+            val scans = collectScans(df.queryExecution.executedPlan)
+                .map(_.inputRDD.partitions.length)
+            (pushDownValues, allowJoinKeysSubsetOfPartitionKeys, 
partiallyClustered) match {
+              // SPJ and partially-clustered
+              case (true, true, true) => assert(scans == Seq(3, 3))
+              // non-SPJ or SPJ/partially-clustered
+              case _ => assert(scans == Seq(3, 3))
+            }
+          }
+        }
+      }
+    }
+  }
+
+  test("SPARK-44647: test join key is the second partition key and a 
transform") {
+    val items_partitions = Array(bucket(8, "id"), days("arrive_time"))
+    createTable(items, items_schema, items_partitions)
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+        s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
+        s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+        s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " +
+        s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    val purchases_partitions = Array(bucket(8, "item_id"), days("time"))
+    createTable(purchases, purchases_schema, purchases_partitions)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+        s"(1, 42.0, cast('2020-01-01' as timestamp)), " +
+        s"(1, 44.0, cast('2020-01-15' as timestamp)), " +
+        s"(1, 45.0, cast('2020-01-15' as timestamp)), " +
+        s"(2, 11.0, cast('2020-01-01' as timestamp)), " +
+        s"(3, 19.5, cast('2020-02-01' as timestamp))")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(true, false).foreach { partiallyClustered =>
+        Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys =>
+
+          withSQLConf(
+            SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+            SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
pushDownValues.toString,
+            SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key 
->
+                partiallyClustered.toString,
+            SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key 
->
+                allowJoinKeysSubsetOfPartitionKeys.toString) {
+            val df = sql("SELECT id, name, i.price as purchase_price, " +
+                "p.item_id, p.price as sale_price " +
+                s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+                "ON i.arrive_time = p.time " +
+                "ORDER BY id, purchase_price, p.item_id, sale_price")
+
+            // Currently SPJ for case where join key not same as partition key
+            // only supported when push-part-values enabled
+            val shuffles = collectShuffles(df.queryExecution.executedPlan)
+            if (allowJoinKeysSubsetOfPartitionKeys) {
+              assert(shuffles.isEmpty, "SPJ should be triggered")
+            } else {
+              assert(shuffles.nonEmpty, "SPJ should not be triggered")
+            }
+
+            val scans = collectScans(df.queryExecution.executedPlan)
+                .map(_.inputRDD.partitions.length)
+            (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match {
+              // SPJ and partially-clustered
+              case (true, true) => assert(scans == Seq(5, 5))
+              // SPJ and not partially-clustered
+              case (true, false) => assert(scans == Seq(3, 3))
+              // No SPJ
+              case _ => assert(scans == Seq(4, 4))
+            }
+
+            checkAnswer(df,
+              Seq(
+                Row(1, "aa", 40.0, 1, 42.0),
+                Row(1, "aa", 40.0, 2, 11.0),
+                Row(1, "aa", 41.0, 1, 44.0),
+                Row(1, "aa", 41.0, 1, 45.0),
+                Row(2, "bb", 10.0, 1, 42.0),
+                Row(2, "bb", 10.0, 2, 11.0),
+                Row(2, "bb", 10.5, 1, 42.0),
+                Row(2, "bb", 10.5, 2, 11.0),
+                Row(3, "cc", 15.5, 3, 19.5)
+              )
+            )
+          }
+        }
+      }
+    }
+  }
+
+  test("SPARK-44647: shuffle one side and join keys are less than partition 
keys") {
+    val items_partitions = Array(identity("id"), identity("name"))
+    createTable(items, items_schema, items_partitions)
+
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+      "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+      "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " +
+      "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+      "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    createTable(purchases, purchases_schema, Array.empty)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+      "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+      "(1, 89.0, cast('2020-01-03' as timestamp)), " +
+      "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+      "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+      "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+    Seq(true, false).foreach { pushdownValues =>
+      Seq(true, false).foreach { partiallyClustered =>
+        withSQLConf(
+          SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
+          SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
pushdownValues.toString,
+          SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key
+            -> partiallyClustered.toString,
+          SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true") {
+          val df = sql("SELECT id, name, i.price as purchase_price, p.price as 
sale_price " +
+            s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+            "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          assert(shuffles.size == 1, "SPJ should be triggered")
+          checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
+            Row(1, "aa", 30.0, 89.0),
+            Row(1, "aa", 40.0, 42.0),
+            Row(1, "aa", 40.0, 89.0),
+            Row(3, "bb", 10.0, 19.5)))
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to