This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9e2aafb1373 [SPARK-45036][SQL] SPJ: Simplify the logic to handle 
partially clustered distribution
9e2aafb1373 is described below

commit 9e2aafb13739f9c07f8218cd325c5532063b1a51
Author: Chao Sun <sunc...@apple.com>
AuthorDate: Mon Sep 4 14:05:14 2023 -0700

    [SPARK-45036][SQL] SPJ: Simplify the logic to handle partially clustered 
distribution
    
    ### What changes were proposed in this pull request?
    
    In SPJ, currently the logic to handle partially clustered distribution is a 
bit complicated. For instance, when the feature is eanbled (by enabling both 
`conf.v2BucketingPushPartValuesEnabled` and 
`conf.v2BucketingPartiallyClusteredDistributionEnabled`), Spark should postpone 
the combining of input splits until it is about to create an input RDD in 
`BatchScanExec`. To implement this, `groupPartitions` in 
`DataSourceV2ScanExecBase` currently takes the flag as input and has two 
differen [...]
    
    This PR introduces a new field in `KeyGroupedPartitioning`, named 
`originalPartitionValues`, that is used to store the original partition values 
from input before splits combining  has been applied. The field is used when 
partially clustered distribution is enabled. With this, `groupPartitions` 
becomes easier to understand.
    
    In addition, this also simplifies `BatchScanExec.inputRDD` by combining two 
branches where partially clustered distribution is not enabled.
    
    ### Why are the changes needed?
    
    To simplify the current logic in the SPJ w.r.t partially clustered 
distribution.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #42757 from sunchao/SPARK-45036.
    
    Authored-by: Chao Sun <sunc...@apple.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../sql/catalyst/plans/physical/partitioning.scala |  35 +++---
 .../execution/datasources/v2/BatchScanExec.scala   | 117 +++++++++------------
 .../datasources/v2/DataSourceV2ScanExecBase.scala  |  65 +++++++-----
 .../execution/exchange/EnsureRequirements.scala    |   9 +-
 .../execution/exchange/ShuffleExchangeExec.scala   |   4 +-
 .../DistributionAndOrderingSuiteBase.scala         |   6 +-
 .../connector/KeyGroupedPartitioningSuite.scala    |   2 +-
 .../exchange/EnsureRequirementsSuite.scala         |   2 +-
 8 files changed, 122 insertions(+), 118 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index ce557422a08..0be4a61f275 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -312,26 +312,37 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
  * Represents a partitioning where rows are split across partitions based on 
transforms defined
  * by `expressions`. `partitionValues`, if defined, should contain value of 
partition key(s) in
  * ascending order, after evaluated by the transforms in `expressions`, for 
each input partition.
- * In addition, its length must be the same as the number of input partitions 
(and thus is a 1-1
- * mapping). The `partitionValues` may contain duplicated partition values.
+ * In addition, its length must be the same as the number of Spark partitions 
(and thus is a 1-1
+ * mapping), and each row in `partitionValues` must be unique.
  *
- * For example, if `expressions` is `[years(ts_col)]`, then a valid value of 
`partitionValues` is
- * `[0, 1, 2]`, which represents 3 input partitions with distinct partition 
values. All rows
- * in each partition have the same value for column `ts_col` (which is of 
timestamp type), after
- * being applied by the `years` transform.
+ * The `originalPartitionValues`, on the other hand, are partition values from 
the original input
+ * splits returned by data sources. It may contain duplicated values.
  *
- * On the other hand, `[0, 0, 1]` is not a valid value for `partitionValues` 
since `0` is
- * duplicated twice.
+ * For example, if a data source reports partition transform expressions 
`[years(ts_col)]` with 4
+ * input splits whose corresponding partition values are `[0, 1, 2, 2]`, then 
the `expressions`
+ * in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, 
which
+ * represents 3 input partitions with distinct partition values. All rows in 
each partition have
+ * the same value for column `ts_col` (which is of timestamp type), after 
being applied by the
+ * `years` transform. This is generated after combining the two splits with 
partition value `2`
+ * into a single Spark partition.
+ *
+ * On the other hand, in this example `[0, 1, 2, 2]` is the value of 
`originalPartitionValues`
+ * which is calculated from the original input splits.
  *
  * @param expressions partition expressions for the partitioning.
  * @param numPartitions the number of partitions
- * @param partitionValues the values for the cluster keys of the distribution, 
must be
- *                        in ascending order.
+ * @param partitionValues the values for the final cluster keys (that is, 
after applying grouping
+ *                        on the input splits according to `expressions`) of 
the distribution,
+ *                        must be in ascending order, and must NOT contain 
duplicated values.
+ * @param originalPartitionValues the original input partition values before 
any grouping has been
+ *                                applied, must be in ascending order, and may 
contain duplicated
+ *                                values
  */
 case class KeyGroupedPartitioning(
     expressions: Seq[Expression],
     numPartitions: Int,
-    partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {
+    partitionValues: Seq[InternalRow] = Seq.empty,
+    originalPartitionValues: Seq[InternalRow] = Seq.empty) extends 
Partitioning {
 
   override def satisfies0(required: Distribution): Boolean = {
     super.satisfies0(required) || {
@@ -368,7 +379,7 @@ object KeyGroupedPartitioning {
   def apply(
       expressions: Seq[Expression],
       partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
-    KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues)
+    KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, 
partitionValues)
   }
 
   def supportsExpressions(expressions: Seq[Expression]): Boolean = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index cc674961f8e..932ac0f5a1b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -28,7 +28,6 @@ import 
org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Par
 import org.apache.spark.sql.catalyst.util.{truncatedString, 
InternalRowComparableWrapper}
 import org.apache.spark.sql.connector.catalog.Table
 import org.apache.spark.sql.connector.read._
-import org.apache.spark.sql.internal.SQLConf
 
 /**
  * Physical plan node for scanning a batch of data from a data source v2.
@@ -101,7 +100,7 @@ case class BatchScanExec(
                 "partition values that are not present in the original 
partitioning.")
           }
 
-          groupPartitions(newPartitions).get.map(_._2)
+          groupPartitions(newPartitions).get.groupedParts.map(_.parts)
 
         case _ =>
           // no validation is needed as the data source did not report any 
specific partitioning
@@ -137,81 +136,63 @@ case class BatchScanExec(
 
       outputPartitioning match {
         case p: KeyGroupedPartitioning =>
-          if (conf.v2BucketingPushPartValuesEnabled &&
-              conf.v2BucketingPartiallyClusteredDistributionEnabled) {
-            assert(filteredPartitions.forall(_.size == 1),
-              "Expect partitions to be not grouped when " +
-                  
s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
-                  "is enabled")
-
-            val groupedPartitions = 
groupPartitions(finalPartitions.map(_.head),
-              groupSplits = true).get
-
-            // This means the input partitions are not grouped by partition 
values. We'll need to
-            // check `groupByPartitionValues` and decide whether to group and 
replicate splits
-            // within a partition.
-            if (spjParams.commonPartitionValues.isDefined &&
-              spjParams.applyPartialClustering) {
-              // A mapping from the common partition values to how many splits 
the partition
-              // should contain.
-              val commonPartValuesMap = spjParams.commonPartitionValues
+          val groupedPartitions = filteredPartitions.map(splits => {
+            assert(splits.nonEmpty && 
splits.head.isInstanceOf[HasPartitionKey])
+            (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
+          })
+
+          // When partially clustered, the input partitions are not grouped by 
partition
+          // values. Here we'll need to check `commonPartitionValues` and 
decide how to group
+          // and replicate splits within a partition.
+          if (spjParams.commonPartitionValues.isDefined && 
spjParams.applyPartialClustering) {
+            // A mapping from the common partition values to how many splits 
the partition
+            // should contain.
+            val commonPartValuesMap = spjParams.commonPartitionValues
                 .get
                 .map(t => (InternalRowComparableWrapper(t._1, p.expressions), 
t._2))
                 .toMap
-              val nestGroupedPartitions = groupedPartitions.map {
-                case (partValue, splits) =>
-                  // `commonPartValuesMap` should contain the part value since 
it's the super set.
-                  val numSplits = commonPartValuesMap
-                    .get(InternalRowComparableWrapper(partValue, 
p.expressions))
-                  assert(numSplits.isDefined, s"Partition value $partValue 
does not exist in " +
-                      "common partition values from Spark plan")
-
-                  val newSplits = if (spjParams.replicatePartitions) {
-                    // We need to also replicate partitions according to the 
other side of join
-                    Seq.fill(numSplits.get)(splits)
-                  } else {
-                    // Not grouping by partition values: this could be the 
side with partially
-                    // clustered distribution. Because of dynamic filtering, 
we'll need to check if
-                    // the final number of splits of a partition is smaller 
than the original
-                    // number, and fill with empty splits if so. This is 
necessary so that both
-                    // sides of a join will have the same number of partitions 
& splits.
-                    splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
-                  }
-                  (InternalRowComparableWrapper(partValue, p.expressions), 
newSplits)
+            val nestGroupedPartitions = groupedPartitions.map { case 
(partValue, splits) =>
+              // `commonPartValuesMap` should contain the part value since 
it's the super set.
+              val numSplits = commonPartValuesMap
+                  .get(InternalRowComparableWrapper(partValue, p.expressions))
+              assert(numSplits.isDefined, s"Partition value $partValue does 
not exist in " +
+                  "common partition values from Spark plan")
+
+              val newSplits = if (spjParams.replicatePartitions) {
+                // We need to also replicate partitions according to the other 
side of join
+                Seq.fill(numSplits.get)(splits)
+              } else {
+                // Not grouping by partition values: this could be the side 
with partially
+                // clustered distribution. Because of dynamic filtering, we'll 
need to check if
+                // the final number of splits of a partition is smaller than 
the original
+                // number, and fill with empty splits if so. This is necessary 
so that both
+                // sides of a join will have the same number of partitions & 
splits.
+                splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
               }
+              (InternalRowComparableWrapper(partValue, p.expressions), 
newSplits)
+            }
 
-              // Now fill missing partition keys with empty partitions
-              val partitionMapping = nestGroupedPartitions.toMap
-              finalPartitions = spjParams.commonPartitionValues.get.flatMap {
-                case (partValue, numSplits) =>
-                  // Use empty partition for those partition values that are 
not present.
-                  partitionMapping.getOrElse(
-                    InternalRowComparableWrapper(partValue, p.expressions),
-                    Seq.fill(numSplits)(Seq.empty))
-              }
-            } else {
-              // either `commonPartitionValues` is not defined, or it is 
defined but
-              // `applyPartialClustering` is false.
-              val partitionMapping = groupedPartitions.map { case (row, parts) 
=>
-                InternalRowComparableWrapper(row, p.expressions) -> parts
-              }.toMap
-
-              // In case `commonPartitionValues` is not defined (e.g., SPJ is 
not used), there
-              // could exist duplicated partition values, as partition 
grouping is not done
-              // at the beginning and postponed to this method. It is 
important to use unique
-              // partition values here so that grouped partitions won't get 
duplicated.
-              finalPartitions = p.uniquePartitionValues.map { partValue =>
-                // Use empty partition for those partition values that are not 
present
+            // Now fill missing partition keys with empty partitions
+            val partitionMapping = nestGroupedPartitions.toMap
+            finalPartitions = spjParams.commonPartitionValues.get.flatMap {
+              case (partValue, numSplits) =>
+                // Use empty partition for those partition values that are not 
present.
                 partitionMapping.getOrElse(
-                  InternalRowComparableWrapper(partValue, p.expressions), 
Seq.empty)
-              }
+                  InternalRowComparableWrapper(partValue, p.expressions),
+                  Seq.fill(numSplits)(Seq.empty))
             }
           } else {
-            val partitionMapping = finalPartitions.map { parts =>
-              val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey()
-              InternalRowComparableWrapper(row, p.expressions) -> parts
+            // either `commonPartitionValues` is not defined, or it is defined 
but
+            // `applyPartialClustering` is false.
+            val partitionMapping = groupedPartitions.map { case (partValue, 
splits) =>
+              InternalRowComparableWrapper(partValue, p.expressions) -> splits
             }.toMap
-            finalPartitions = p.partitionValues.map { partValue =>
+
+            // In case `commonPartitionValues` is not defined (e.g., SPJ is 
not used), there
+            // could exist duplicated partition values, as partition grouping 
is not done
+            // at the beginning and postponed to this method. It is important 
to use unique
+            // partition values here so that grouped partitions won't get 
duplicated.
+            finalPartitions = p.uniquePartitionValues.map { partValue =>
               // Use empty partition for those partition values that are not 
present
               partitionMapping.getOrElse(
                 InternalRowComparableWrapper(partValue, p.expressions), 
Seq.empty)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
index f688d3514d9..94667fbd00c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
@@ -62,8 +62,9 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
     redact(result)
   }
 
-  def partitions: Seq[Seq[InputPartition]] =
-    groupedPartitions.map(_.map(_._2)).getOrElse(inputPartitions.map(Seq(_)))
+  def partitions: Seq[Seq[InputPartition]] = {
+    
groupedPartitions.map(_.groupedParts.map(_.parts)).getOrElse(inputPartitions.map(Seq(_)))
+  }
 
   /**
    * Shorthand for calling redact() without specifying redacting rules
@@ -94,8 +95,10 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
     keyGroupedPartitioning match {
       case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) =>
         groupedPartitions
-          .map { partitionValues =>
-            KeyGroupedPartitioning(exprs, partitionValues.size, 
partitionValues.map(_._1))
+          .map { keyGroupedPartsInfo =>
+            val keyGroupedParts = keyGroupedPartsInfo.groupedParts
+            KeyGroupedPartitioning(exprs, keyGroupedParts.size, 
keyGroupedParts.map(_.value),
+              keyGroupedPartsInfo.originalParts.map(_.partitionKey()))
           }
           .getOrElse(super.outputPartitioning)
       case _ =>
@@ -103,7 +106,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
     }
   }
 
-  @transient lazy val groupedPartitions: Option[Seq[(InternalRow, 
Seq[InputPartition])]] = {
+  @transient lazy val groupedPartitions: Option[KeyGroupedPartitionInfo] = {
     // Early check if we actually need to materialize the input partitions.
     keyGroupedPartitioning match {
       case Some(_) => groupPartitions(inputPartitions)
@@ -117,24 +120,21 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
    *   - all input partitions implement [[HasPartitionKey]]
    *   - `keyGroupedPartitioning` is set
    *
-   * The result, if defined, is a list of tuples where the first element is a 
partition value,
-   * and the second element is a list of input partitions that share the same 
partition value.
+   * The result, if defined, is a [[KeyGroupedPartitionInfo]] which contains a 
list of
+   * [[KeyGroupedPartition]], as well as a list of partition values from the 
original input splits,
+   * sorted according to the partition keys in ascending order.
    *
    * A non-empty result means each partition is clustered on a single key and 
therefore eligible
    * for further optimizations to eliminate shuffling in some operations such 
as join and aggregate.
    */
-  def groupPartitions(
-      inputPartitions: Seq[InputPartition],
-      groupSplits: Boolean = !conf.v2BucketingPushPartValuesEnabled ||
-          !conf.v2BucketingPartiallyClusteredDistributionEnabled):
-    Option[Seq[(InternalRow, Seq[InputPartition])]] = {
-
+  def groupPartitions(inputPartitions: Seq[InputPartition]): 
Option[KeyGroupedPartitionInfo] = {
     if (!SQLConf.get.v2BucketingEnabled) return None
+
     keyGroupedPartitioning.flatMap { expressions =>
       val results = inputPartitions.takeWhile {
         case _: HasPartitionKey => true
         case _ => false
-      }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p))
+      }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), 
p.asInstanceOf[HasPartitionKey]))
 
       if (results.length != inputPartitions.length || inputPartitions.isEmpty) 
{
         // Not all of the `InputPartitions` implements `HasPartitionKey`, 
therefore skip here.
@@ -143,32 +143,25 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
         // also sort the input partitions according to their partition key 
order. This ensures
         // a canonical order from both sides of a bucketed join, for example.
         val partitionDataTypes = expressions.map(_.dataType)
-        val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = {
+        val partitionOrdering: Ordering[(InternalRow, InputPartition)] = {
           
RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1)
         }
-
-        val partitions = if (groupSplits) {
-          // Group the splits by their partition value
-          results
+        val sortedKeyToPartitions = results.sorted(partitionOrdering)
+        val groupedPartitions = sortedKeyToPartitions
             .map(t => (InternalRowComparableWrapper(t._1, expressions), t._2))
             .groupBy(_._1)
             .toSeq
-            .map {
-              case (key, s) => (key.row, s.map(_._2))
-            }
-        } else {
-          // No splits grouping, each split will become a separate Spark 
partition
-          results.map(t => (t._1, Seq(t._2)))
-        }
+            .map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) }
 
-        Some(partitions.sorted(partitionOrdering))
+        Some(KeyGroupedPartitionInfo(groupedPartitions, 
sortedKeyToPartitions.map(_._2)))
       }
     }
   }
 
   override def outputOrdering: Seq[SortOrder] = {
     // when multiple partitions are grouped together, ordering inside 
partitions is not preserved
-    val partitioningPreservesOrdering = 
groupedPartitions.forall(_.forall(_._2.length <= 1))
+    val partitioningPreservesOrdering = groupedPartitions
+        .forall(_.groupedParts.forall(_.parts.length <= 1))
     ordering.filter(_ => 
partitioningPreservesOrdering).getOrElse(super.outputOrdering)
   }
 
@@ -217,3 +210,19 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
     }
   }
 }
+
+/**
+ * A key-grouped Spark partition, which could consist of multiple input splits
+ *
+ * @param value the partition value shared by all the input splits
+ * @param parts the input splits that are grouped into a single Spark partition
+ */
+private[v2] case class KeyGroupedPartition(value: InternalRow, parts: 
Seq[InputPartition])
+
+/**
+ * Information about key-grouped partitions, which contains a list of grouped 
partitions as well
+ * as the original input partitions before the grouping.
+ */
+private[v2] case class KeyGroupedPartitionInfo(
+    groupedParts: Seq[KeyGroupedPartition],
+    originalParts: Seq[HasPartitionKey])
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 42c880e7c62..f8e6fd1d016 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -288,12 +288,12 @@ case class EnsureRequirements(
         reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, 
rightExpressions, rightKeys)
           .orElse(reorderJoinKeysRecursively(
             leftKeys, rightKeys, leftPartitioning, None))
-      case (Some(KeyGroupedPartitioning(clustering, _, _)), _) =>
+      case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) =>
         val leafExprs = clustering.flatMap(_.collectLeaves())
         reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, 
leftKeys)
             .orElse(reorderJoinKeysRecursively(
               leftKeys, rightKeys, None, rightPartitioning))
-      case (_, Some(KeyGroupedPartitioning(clustering, _, _))) =>
+      case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) =>
         val leafExprs = clustering.flatMap(_.collectLeaves())
         reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, 
rightKeys)
             .orElse(reorderJoinKeysRecursively(
@@ -483,7 +483,10 @@ case class EnsureRequirements(
                   s"'$joinType'. Skipping partially clustered distribution.")
               replicateRightSide = false
             } else {
-              val partValues = if (replicateLeftSide) rightPartValues else 
leftPartValues
+              // In partially clustered distribution, we should use un-grouped 
partition values
+              val spec = if (replicateLeftSide) rightSpec else leftSpec
+              val partValues = spec.partitioning.originalPartitionValues
+
               val numExpectedPartitions = partValues
                 .map(InternalRowComparableWrapper(_, partitionExprs))
                 .groupBy(identity)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 750b96dc83d..509f1e6a1e4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -301,7 +301,7 @@ object ShuffleExchangeExec {
           ascending = true,
           samplePointsPerPartitionHint = 
SQLConf.get.rangeExchangeSampleSizePerPartition)
       case SinglePartition => new ConstantPartitioner
-      case k @ KeyGroupedPartitioning(expressions, n, _) =>
+      case k @ KeyGroupedPartitioning(expressions, n, _, _) =>
         val valueMap = k.uniquePartitionValues.zipWithIndex.map {
           case (partition, index) => 
(partition.toSeq(expressions.map(_.dataType)), index)
         }.toMap
@@ -332,7 +332,7 @@ object ShuffleExchangeExec {
         val projection = 
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
         row => projection(row)
       case SinglePartition => identity
-      case KeyGroupedPartitioning(expressions, _, _) =>
+      case KeyGroupedPartitioning(expressions, _, _, _) =>
         row => bindReferences(expressions, outputAttributes).map(_.eval(row))
       case _ => throw new IllegalStateException(s"Exchange not implemented for 
$newPartitioning")
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
index f4317e63276..1a0efa7c4aa 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala
@@ -51,9 +51,9 @@ abstract class DistributionAndOrderingSuiteBase
       plan: QueryPlan[T]): Partitioning = partitioning match {
     case HashPartitioning(exprs, numPartitions) =>
       HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions)
-    case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) =>
-      KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), 
numPartitions,
-        partitionValues)
+    case KeyGroupedPartitioning(clustering, numPartitions, partValues, 
originalPartValues) =>
+      KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), 
numPartitions, partValues,
+        originalPartValues)
     case PartitioningCollection(partitionings) =>
       PartitioningCollection(partitionings.map(resolvePartitioning(_, plan)))
     case RangePartitioning(ordering, numPartitions) =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 5b5e4021173..b22aba61aab 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -131,7 +131,7 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
     // Has exactly one partition.
     val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v)))
     checkQueryPlan(df, distribution,
-      physical.KeyGroupedPartitioning(distribution.clustering, 1, 
partitionValues))
+      physical.KeyGroupedPartitioning(distribution.clustering, 1, 
partitionValues, partitionValues))
   }
 
   test("non-clustered distribution: no V2 catalog") {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 3c9b92e5f66..3b0bb088a10 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -1127,7 +1127,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
       EnsureRequirements.apply(smjExec) match {
         case ShuffledHashJoinExec(_, _, _, _, _,
         DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _),
-        ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv),
+        ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _),
         DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) =>
           assert(left.expressions == a1 :: Nil)
           assert(attrs == a1 :: Nil)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to