This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new f789c21  fix: CometExec's outputPartitioning might not be same as 
Spark expects after AQE interferes (#299)
f789c21 is described below

commit f789c21736819c0ffa5ba56aaa2c5ec4bcb7127a
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Apr 23 09:04:14 2024 -0700

    fix: CometExec's outputPartitioning might not be same as Spark expects 
after AQE interferes (#299)
    
    * fix: CometExec's outputPartitioning might not be same as Spark expects 
after AQE interferes
    
    * Add compatibility with Spark 3.2 and 3.3
    
    * Remove unused import
---
 .../shims/ShimCometBroadcastHashJoinExec.scala     |  39 +++++
 .../org/apache/spark/sql/comet/operators.scala     | 157 ++++++++++++++++++++-
 .../comet/plans/AliasAwareOutputExpression.scala   | 150 ++++++++++++++++++++
 .../PartitioningPreservingUnaryExecNode.scala      |  76 ++++++++++
 .../org/apache/comet/exec/CometExecSuite.scala     |  25 +++-
 5 files changed, 440 insertions(+), 7 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala
 
b/spark/src/main/scala/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala
new file mode 100644
index 0000000..eef0ee9
--- /dev/null
+++ 
b/spark/src/main/scala/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+
+trait ShimCometBroadcastHashJoinExec {
+
+  /**
+   * Returns the expressions that are used for hash partitioning including 
`HashPartitioning` and
+   * `CoalescedHashPartitioning`. They shares same trait 
`HashPartitioningLike` since Spark 3.4,
+   * but Spark 3.2/3.3 doesn't have `HashPartitioningLike` and 
`CoalescedHashPartitioning`.
+   *
+   * TODO: remove after dropping Spark 3.2 and 3.3 support.
+   */
+  def getHashPartitioningLikeExpressions(partitioning: Partitioning): 
Seq[Expression] = {
+    partitioning.getClass.getDeclaredMethods
+      .filter(_.getName == "expressions")
+      .flatMap(_.invoke(partitioning).asInstanceOf[Seq[Expression]])
+  }
+}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 1065367..571ec22 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -22,6 +22,7 @@ package org.apache.spark.sql.comet
 import java.io.{ByteArrayOutputStream, DataInputStream}
 import java.nio.channels.Channels
 
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.{SparkEnv, TaskContext}
@@ -30,13 +31,15 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, 
Expression, NamedExpression, SortOrder}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
AggregateMode}
-import org.apache.spark.sql.catalyst.optimizer.BuildSide
-import org.apache.spark.sql.catalyst.plans.JoinType
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, PartitioningCollection, UnknownPartitioning}
 import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, 
CometShuffleExchangeExec}
+import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode
 import org.apache.spark.sql.comet.util.Utils
 import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, 
ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, 
UnaryExecNode}
 import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, 
BroadcastQueryStageExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.internal.SQLConf
@@ -47,6 +50,7 @@ import com.google.common.base.Objects
 
 import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException}
 import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.shims.ShimCometBroadcastHashJoinExec
 
 /**
  * A Comet physical operator
@@ -69,6 +73,10 @@ abstract class CometExec extends CometPlan {
 
   override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering
 
+  // `CometExec` reuses the outputPartitioning of the original SparkPlan.
+  // Note that if the outputPartitioning of the original SparkPlan depends on 
its children,
+  // we should override this method in the specific CometExec, because Spark 
AQE may change the
+  // outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec.
   override def outputPartitioning: Partitioning = 
originalPlan.outputPartitioning
 
   /**
@@ -377,7 +385,8 @@ case class CometProjectExec(
     override val output: Seq[Attribute],
     child: SparkPlan,
     override val serializedPlanOpt: SerializedPlan)
-    extends CometUnaryExec {
+    extends CometUnaryExec
+    with PartitioningPreservingUnaryExecNode {
   override def producedAttributes: AttributeSet = outputSet
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     this.copy(child = newChild)
@@ -396,6 +405,8 @@ case class CometProjectExec(
   }
 
   override def hashCode(): Int = Objects.hashCode(projectList, output, child)
+
+  override protected def outputExpressions: Seq[NamedExpression] = projectList
 }
 
 case class CometFilterExec(
@@ -405,6 +416,9 @@ case class CometFilterExec(
     child: SparkPlan,
     override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec {
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     this.copy(child = newChild)
 
@@ -439,6 +453,9 @@ case class CometSortExec(
     child: SparkPlan,
     override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec {
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     this.copy(child = newChild)
 
@@ -471,6 +488,9 @@ case class CometLocalLimitExec(
     child: SparkPlan,
     override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec {
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     this.copy(child = newChild)
 
@@ -498,6 +518,9 @@ case class CometGlobalLimitExec(
     child: SparkPlan,
     override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec {
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     this.copy(child = newChild)
 
@@ -586,7 +609,8 @@ case class CometHashAggregateExec(
     mode: Option[AggregateMode],
     child: SparkPlan,
     override val serializedPlanOpt: SerializedPlan)
-    extends CometUnaryExec {
+    extends CometUnaryExec
+    with PartitioningPreservingUnaryExecNode {
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     this.copy(child = newChild)
 
@@ -618,6 +642,9 @@ case class CometHashAggregateExec(
 
   override def hashCode(): Int =
     Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, 
child)
+
+  override protected def outputExpressions: Seq[NamedExpression] =
+    originalPlan.asInstanceOf[HashAggregateExec].resultExpressions
 }
 
 case class CometHashJoinExec(
@@ -632,6 +659,18 @@ case class CometHashJoinExec(
     override val right: SparkPlan,
     override val serializedPlanOpt: SerializedPlan)
     extends CometBinaryExec {
+
+  override def outputPartitioning: Partitioning = joinType match {
+    case _: InnerLike =>
+      PartitioningCollection(Seq(left.outputPartitioning, 
right.outputPartitioning))
+    case LeftOuter => left.outputPartitioning
+    case RightOuter => right.outputPartitioning
+    case FullOuter => 
UnknownPartitioning(left.outputPartitioning.numPartitions)
+    case LeftExistence(_) => left.outputPartitioning
+    case x =>
+      throw new IllegalArgumentException(s"ShuffledJoin should not take $x as 
the JoinType")
+  }
+
   override def withNewChildrenInternal(newLeft: SparkPlan, newRight: 
SparkPlan): SparkPlan =
     this.copy(left = newLeft, right = newRight)
 
@@ -668,7 +707,101 @@ case class CometBroadcastHashJoinExec(
     override val left: SparkPlan,
     override val right: SparkPlan,
     override val serializedPlanOpt: SerializedPlan)
-    extends CometBinaryExec {
+    extends CometBinaryExec
+    with ShimCometBroadcastHashJoinExec {
+
+  // The following logic of `outputPartitioning` is copied from Spark 
`BroadcastHashJoinExec`.
+  protected lazy val streamedPlan: SparkPlan = buildSide match {
+    case BuildLeft => right
+    case BuildRight => left
+  }
+
+  override lazy val outputPartitioning: Partitioning = {
+    joinType match {
+      case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit 
> 0 =>
+        streamedPlan.outputPartitioning match {
+          case h: HashPartitioning => expandOutputPartitioning(h)
+          case h: Expression if 
h.getClass.getName.contains("CoalescedHashPartitioning") =>
+            expandOutputPartitioning(h)
+          case c: PartitioningCollection => expandOutputPartitioning(c)
+          case other => other
+        }
+      case _ => streamedPlan.outputPartitioning
+    }
+  }
+
+  protected lazy val (buildKeys, streamedKeys) = {
+    require(
+      leftKeys.length == rightKeys.length &&
+        leftKeys
+          .map(_.dataType)
+          .zip(rightKeys.map(_.dataType))
+          .forall(types => types._1.sameType(types._2)),
+      "Join keys from two sides should have same length and types")
+    buildSide match {
+      case BuildLeft => (leftKeys, rightKeys)
+      case BuildRight => (rightKeys, leftKeys)
+    }
+  }
+
+  // An one-to-many mapping from a streamed key to build keys.
+  private lazy val streamedKeyToBuildKeyMapping = {
+    val mapping = mutable.Map.empty[Expression, Seq[Expression]]
+    streamedKeys.zip(buildKeys).foreach { case (streamedKey, buildKey) =>
+      val key = streamedKey.canonicalized
+      mapping.get(key) match {
+        case Some(v) => mapping.put(key, v :+ buildKey)
+        case None => mapping.put(key, Seq(buildKey))
+      }
+    }
+    mapping.toMap
+  }
+
+  // Expands the given partitioning collection recursively.
+  private def expandOutputPartitioning(
+      partitioning: PartitioningCollection): PartitioningCollection = {
+    PartitioningCollection(partitioning.partitionings.flatMap {
+      case h: HashPartitioning => expandOutputPartitioning(h).partitionings
+      case h: Expression if 
h.getClass.getName.contains("CoalescedHashPartitioning") =>
+        expandOutputPartitioning(h).partitionings
+      case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
+      case other => Seq(other)
+    })
+  }
+
+  // Expands the given hash partitioning by substituting streamed keys with 
build keys.
+  // For example, if the expressions for the given partitioning are Seq("a", 
"b", "c")
+  // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", 
"y"),
+  // the expanded partitioning will have the following expressions:
+  // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", 
"y").
+  // The expanded expressions are returned as PartitioningCollection.
+  private def expandOutputPartitioning(
+      partitioning: Partitioning with Expression): PartitioningCollection = {
+    val maxNumCombinations = 
conf.broadcastHashJoinOutputPartitioningExpandLimit
+    var currentNumCombinations = 0
+
+    def generateExprCombinations(
+        current: Seq[Expression],
+        accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
+      if (currentNumCombinations >= maxNumCombinations) {
+        Nil
+      } else if (current.isEmpty) {
+        currentNumCombinations += 1
+        Seq(accumulated)
+      } else {
+        val buildKeysOpt = 
streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
+        generateExprCombinations(current.tail, accumulated :+ current.head) ++
+          buildKeysOpt
+            .map(_.flatMap(b => generateExprCombinations(current.tail, 
accumulated :+ b)))
+            .getOrElse(Nil)
+      }
+    }
+
+    PartitioningCollection(
+      
generateExprCombinations(getHashPartitioningLikeExpressions(partitioning), Nil)
+        .map(exprs => 
partitioning.withNewChildren(exprs).asInstanceOf[Partitioning]))
+  }
+
   override def withNewChildrenInternal(newLeft: SparkPlan, newRight: 
SparkPlan): SparkPlan =
     this.copy(left = newLeft, right = newRight)
 
@@ -705,6 +838,18 @@ case class CometSortMergeJoinExec(
     override val right: SparkPlan,
     override val serializedPlanOpt: SerializedPlan)
     extends CometBinaryExec {
+
+  override def outputPartitioning: Partitioning = joinType match {
+    case _: InnerLike =>
+      PartitioningCollection(Seq(left.outputPartitioning, 
right.outputPartitioning))
+    case LeftOuter => left.outputPartitioning
+    case RightOuter => right.outputPartitioning
+    case FullOuter => 
UnknownPartitioning(left.outputPartitioning.numPartitions)
+    case LeftExistence(_) => left.outputPartitioning
+    case x =>
+      throw new IllegalArgumentException(s"ShuffledJoin should not take $x as 
the JoinType")
+  }
+
   override def withNewChildrenInternal(newLeft: SparkPlan, newRight: 
SparkPlan): SparkPlan =
     this.copy(left = newLeft, right = newRight)
 
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala
new file mode 100644
index 0000000..6e5b44c
--- /dev/null
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.plans
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeSet, Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin
+
+/**
+ * A trait that provides functionality to handle aliases in the 
`outputExpressions`.
+ */
+trait AliasAwareOutputExpression extends SQLConfHelper {
+  // `SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT` is Spark 3.4+ only.
+  // Use a default value for now.
+  protected val aliasCandidateLimit = 100
+  protected def outputExpressions: Seq[NamedExpression]
+
+  /**
+   * This method can be used to strip expression which does not affect the 
result, for example:
+   * strip the expression which is ordering agnostic for output ordering.
+   */
+  protected def strip(expr: Expression): Expression = expr
+
+  // Build an `Expression` -> `Attribute` alias map.
+  // There can be multiple alias defined for the same expressions but it 
doesn't make sense to store
+  // more than `aliasCandidateLimit` attributes for an expression. In those 
cases the old logic
+  // handled only the last alias so we need to make sure that we give 
precedence to that.
+  // If the `outputExpressions` contain simple attributes we need to add those 
too to the map.
+  @transient
+  private lazy val aliasMap = {
+    val aliases = mutable.Map[Expression, mutable.ArrayBuffer[Attribute]]()
+    outputExpressions.reverse.foreach {
+      case a @ Alias(child, _) =>
+        val buffer =
+          aliases.getOrElseUpdate(strip(child).canonicalized, 
mutable.ArrayBuffer.empty)
+        if (buffer.size < aliasCandidateLimit) {
+          buffer += a.toAttribute
+        }
+      case _ =>
+    }
+    outputExpressions.foreach {
+      case a: Attribute if aliases.contains(a.canonicalized) =>
+        val buffer = aliases(a.canonicalized)
+        if (buffer.size < aliasCandidateLimit) {
+          buffer += a
+        }
+      case _ =>
+    }
+    aliases
+  }
+
+  protected def hasAlias: Boolean = aliasMap.nonEmpty
+
+  /**
+   * Return a stream of expressions in which the original expression is 
projected with `aliasMap`.
+   */
+  protected def projectExpression(expr: Expression): Stream[Expression] = {
+    val outputSet = AttributeSet(outputExpressions.map(_.toAttribute))
+    multiTransformDown(expr) {
+      // Mapping with aliases
+      case e: Expression if aliasMap.contains(e.canonicalized) =>
+        aliasMap(e.canonicalized).toSeq ++ (if (e.containsChild.nonEmpty) 
Seq(e) else Seq.empty)
+
+      // Prune if we encounter an attribute that we can't map and it is not in 
output set.
+      // This prune will go up to the closest `multiTransformDown()` call and 
returns `Stream.empty`
+      // there.
+      case a: Attribute if !outputSet.contains(a) => Seq.empty
+    }
+  }
+
+  // Copied from Spark 3.4+ to make it available in Spark 3.2+.
+  def multiTransformDown(expr: Expression)(
+      rule: PartialFunction[Expression, Seq[Expression]]): Stream[Expression] 
= {
+
+    // We could return `Seq(this)` if the `rule` doesn't apply and handle both
+    // - the doesn't apply
+    // - and the rule returns a one element `Seq(originalNode)`
+    // cases together. The returned `Seq` can be a `Stream` and unfortunately 
it doesn't seem like
+    // there is a way to match on a one element stream without eagerly 
computing the tail's head.
+    // This contradicts with the purpose of only taking the necessary elements 
from the
+    // alternatives. I.e. the "multiTransformDown is lazy" test case in 
`TreeNodeSuite` would fail.
+    // Please note that this behaviour has a downside as well that we can only 
mark the rule on the
+    // original node ineffective if the rule didn't match.
+    var ruleApplied = true
+    val afterRules = CurrentOrigin.withOrigin(expr.origin) {
+      rule.applyOrElse(
+        expr,
+        (_: Expression) => {
+          ruleApplied = false
+          Seq.empty
+        })
+    }
+
+    val afterRulesStream = if (afterRules.isEmpty) {
+      if (ruleApplied) {
+        // If the rule returned with empty alternatives then prune
+        Stream.empty
+      } else {
+        // If the rule was not applied then keep the original node
+        Stream(expr)
+      }
+    } else {
+      // If the rule was applied then use the returned alternatives
+      afterRules.toStream.map { afterRule =>
+        if (expr fastEquals afterRule) {
+          expr
+        } else {
+          afterRule.copyTagsFrom(expr)
+          afterRule
+        }
+      }
+    }
+
+    afterRulesStream.flatMap { afterRule =>
+      if (afterRule.containsChild.nonEmpty) {
+        generateCartesianProduct(afterRule.children.map(c => () => 
multiTransformDown(c)(rule)))
+          .map(afterRule.withNewChildren)
+      } else {
+        Stream(afterRule)
+      }
+    }
+  }
+
+  def generateCartesianProduct[T](elementSeqs: Seq[() => Seq[T]]): 
Stream[Seq[T]] = {
+    elementSeqs.foldRight(Stream(Seq.empty[T]))((elements, elementTails) =>
+      for {
+        elementTail <- elementTails
+        element <- elements()
+      } yield element +: elementTail)
+  }
+}
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/plans/PartitioningPreservingUnaryExecNode.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/plans/PartitioningPreservingUnaryExecNode.scala
new file mode 100644
index 0000000..8c6f0af
--- /dev/null
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/plans/PartitioningPreservingUnaryExecNode.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.plans
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression}
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
PartitioningCollection, UnknownPartitioning}
+import org.apache.spark.sql.execution.UnaryExecNode
+
+/**
+ * A trait that handles aliases in the `outputExpressions` to produce 
`outputPartitioning` that
+ * satisfies distribution requirements.
+ *
+ * This is copied from Spark's `PartitioningPreservingUnaryExecNode` because 
it is only available
+ * in Spark 3.4+. This is a workaround to make it available in Spark 3.2+.
+ */
+trait PartitioningPreservingUnaryExecNode extends UnaryExecNode with 
AliasAwareOutputExpression {
+  final override def outputPartitioning: Partitioning = {
+    val partitionings: Seq[Partitioning] = if (hasAlias) {
+      flattenPartitioning(child.outputPartitioning).flatMap {
+        case e: Expression =>
+          // We need unique partitionings but if the input partitioning is
+          // `HashPartitioning(Seq(id + id))` and we have `id -> a` and `id -> 
b` aliases then after
+          // the projection we have 4 partitionings:
+          // `HashPartitioning(Seq(a + a))`, `HashPartitioning(Seq(a + b))`,
+          // `HashPartitioning(Seq(b + a))`, `HashPartitioning(Seq(b + b))`, 
but
+          // `HashPartitioning(Seq(a + b))` is the same as 
`HashPartitioning(Seq(b + a))`.
+          val partitioningSet = mutable.Set.empty[Expression]
+          projectExpression(e)
+            .filter(e => partitioningSet.add(e.canonicalized))
+            .take(aliasCandidateLimit)
+            .asInstanceOf[Stream[Partitioning]]
+        case o => Seq(o)
+      }
+    } else {
+      // Filter valid partitiongs (only reference output attributes of the 
current plan node)
+      val outputSet = AttributeSet(outputExpressions.map(_.toAttribute))
+      flattenPartitioning(child.outputPartitioning).filter {
+        case e: Expression => e.references.subsetOf(outputSet)
+        case _ => true
+      }
+    }
+    partitionings match {
+      case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions)
+      case Seq(p) => p
+      case ps => PartitioningCollection(ps)
+    }
+  }
+
+  private def flattenPartitioning(partitioning: Partitioning): 
Seq[Partitioning] = {
+    partitioning match {
+      case PartitioningCollection(childPartitionings) =>
+        childPartitionings.flatMap(flattenPartitioning)
+      case rest =>
+        rest +: Nil
+    }
+  }
+}
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index cc968a6..264ea4c 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -19,6 +19,7 @@
 
 package org.apache.comet.exec
 
+import java.sql.Date
 import java.time.{Duration, Period}
 
 import scala.collection.JavaConverters._
@@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, 
CatalogTable}
 import org.apache.spark.sql.catalyst.expressions.Hex
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
-import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, 
CometProjectExec, CometRowToColumnarExec, CometScanExec, 
CometSortMergeJoinExec, CometTakeOrderedAndProjectExec}
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, 
CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec, 
CometSortMergeJoinExec, CometTakeOrderedAndProjectExec}
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometShuffleExchangeExec}
 import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, 
SQLExecution, UnionExec}
 import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
@@ -61,6 +62,28 @@ class CometExecSuite extends CometTestBase {
     }
   }
 
+  test("Ensure that the correct outputPartitioning of CometSort") {
+    withTable("test_data") {
+      val tableDF = spark.sparkContext
+        .parallelize(
+          (1 to 10).map { i =>
+            (if (i > 4) 5 else i, i.toString, Date.valueOf(s"${2020 + 
i}-$i-$i"))
+          },
+          3)
+        .toDF("id", "data", "day")
+      tableDF.write.saveAsTable("test_data")
+
+      val df = sql("SELECT * FROM test_data")
+        .repartition($"data")
+        .sortWithinPartitions($"id", $"data", $"day")
+      df.collect()
+      val sort = stripAQEPlan(df.queryExecution.executedPlan).collect { case 
s: CometSortExec =>
+        s
+      }.head
+      assert(sort.outputPartitioning == sort.child.outputPartitioning)
+    }
+  }
+
   test("Repeated shuffle exchange don't fail") {
     assume(isSpark33Plus)
     Seq("true", "false").foreach { aqeEnabled =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to