This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new f789c21 fix: CometExec's outputPartitioning might not be same as
Spark expects after AQE interferes (#299)
f789c21 is described below
commit f789c21736819c0ffa5ba56aaa2c5ec4bcb7127a
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Apr 23 09:04:14 2024 -0700
fix: CometExec's outputPartitioning might not be same as Spark expects
after AQE interferes (#299)
* fix: CometExec's outputPartitioning might not be same as Spark expects
after AQE interferes
* Add compatibility with Spark 3.2 and 3.3
* Remove unused import
---
.../shims/ShimCometBroadcastHashJoinExec.scala | 39 +++++
.../org/apache/spark/sql/comet/operators.scala | 157 ++++++++++++++++++++-
.../comet/plans/AliasAwareOutputExpression.scala | 150 ++++++++++++++++++++
.../PartitioningPreservingUnaryExecNode.scala | 76 ++++++++++
.../org/apache/comet/exec/CometExecSuite.scala | 25 +++-
5 files changed, 440 insertions(+), 7 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala
b/spark/src/main/scala/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala
new file mode 100644
index 0000000..eef0ee9
--- /dev/null
+++
b/spark/src/main/scala/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+
+trait ShimCometBroadcastHashJoinExec {
+
+ /**
+ * Returns the expressions that are used for hash partitioning including
`HashPartitioning` and
+ * `CoalescedHashPartitioning`. They shares same trait
`HashPartitioningLike` since Spark 3.4,
+ * but Spark 3.2/3.3 doesn't have `HashPartitioningLike` and
`CoalescedHashPartitioning`.
+ *
+ * TODO: remove after dropping Spark 3.2 and 3.3 support.
+ */
+ def getHashPartitioningLikeExpressions(partitioning: Partitioning):
Seq[Expression] = {
+ partitioning.getClass.getDeclaredMethods
+ .filter(_.getName == "expressions")
+ .flatMap(_.invoke(partitioning).asInstanceOf[Seq[Expression]])
+ }
+}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 1065367..571ec22 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -22,6 +22,7 @@ package org.apache.spark.sql.comet
import java.io.{ByteArrayOutputStream, DataInputStream}
import java.nio.channels.Channels
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{SparkEnv, TaskContext}
@@ -30,13 +31,15 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet,
Expression, NamedExpression, SortOrder}
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
AggregateMode}
-import org.apache.spark.sql.catalyst.optimizer.BuildSide
-import org.apache.spark.sql.catalyst.plans.JoinType
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator,
CometShuffleExchangeExec}
+import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec,
ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan,
UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
BroadcastQueryStageExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
@@ -47,6 +50,7 @@ import com.google.common.base.Objects
import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException}
import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.shims.ShimCometBroadcastHashJoinExec
/**
* A Comet physical operator
@@ -69,6 +73,10 @@ abstract class CometExec extends CometPlan {
override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering
+ // `CometExec` reuses the outputPartitioning of the original SparkPlan.
+ // Note that if the outputPartitioning of the original SparkPlan depends on
its children,
+ // we should override this method in the specific CometExec, because Spark
AQE may change the
+ // outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec.
override def outputPartitioning: Partitioning =
originalPlan.outputPartitioning
/**
@@ -377,7 +385,8 @@ case class CometProjectExec(
override val output: Seq[Attribute],
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
- extends CometUnaryExec {
+ extends CometUnaryExec
+ with PartitioningPreservingUnaryExecNode {
override def producedAttributes: AttributeSet = outputSet
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -396,6 +405,8 @@ case class CometProjectExec(
}
override def hashCode(): Int = Objects.hashCode(projectList, output, child)
+
+ override protected def outputExpressions: Seq[NamedExpression] = projectList
}
case class CometFilterExec(
@@ -405,6 +416,9 @@ case class CometFilterExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -439,6 +453,9 @@ case class CometSortExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -471,6 +488,9 @@ case class CometLocalLimitExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -498,6 +518,9 @@ case class CometGlobalLimitExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -586,7 +609,8 @@ case class CometHashAggregateExec(
mode: Option[AggregateMode],
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
- extends CometUnaryExec {
+ extends CometUnaryExec
+ with PartitioningPreservingUnaryExecNode {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -618,6 +642,9 @@ case class CometHashAggregateExec(
override def hashCode(): Int =
Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode,
child)
+
+ override protected def outputExpressions: Seq[NamedExpression] =
+ originalPlan.asInstanceOf[HashAggregateExec].resultExpressions
}
case class CometHashJoinExec(
@@ -632,6 +659,18 @@ case class CometHashJoinExec(
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {
+
+ override def outputPartitioning: Partitioning = joinType match {
+ case _: InnerLike =>
+ PartitioningCollection(Seq(left.outputPartitioning,
right.outputPartitioning))
+ case LeftOuter => left.outputPartitioning
+ case RightOuter => right.outputPartitioning
+ case FullOuter =>
UnknownPartitioning(left.outputPartitioning.numPartitions)
+ case LeftExistence(_) => left.outputPartitioning
+ case x =>
+ throw new IllegalArgumentException(s"ShuffledJoin should not take $x as
the JoinType")
+ }
+
override def withNewChildrenInternal(newLeft: SparkPlan, newRight:
SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)
@@ -668,7 +707,101 @@ case class CometBroadcastHashJoinExec(
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
- extends CometBinaryExec {
+ extends CometBinaryExec
+ with ShimCometBroadcastHashJoinExec {
+
+ // The following logic of `outputPartitioning` is copied from Spark
`BroadcastHashJoinExec`.
+ protected lazy val streamedPlan: SparkPlan = buildSide match {
+ case BuildLeft => right
+ case BuildRight => left
+ }
+
+ override lazy val outputPartitioning: Partitioning = {
+ joinType match {
+ case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit
> 0 =>
+ streamedPlan.outputPartitioning match {
+ case h: HashPartitioning => expandOutputPartitioning(h)
+ case h: Expression if
h.getClass.getName.contains("CoalescedHashPartitioning") =>
+ expandOutputPartitioning(h)
+ case c: PartitioningCollection => expandOutputPartitioning(c)
+ case other => other
+ }
+ case _ => streamedPlan.outputPartitioning
+ }
+ }
+
+ protected lazy val (buildKeys, streamedKeys) = {
+ require(
+ leftKeys.length == rightKeys.length &&
+ leftKeys
+ .map(_.dataType)
+ .zip(rightKeys.map(_.dataType))
+ .forall(types => types._1.sameType(types._2)),
+ "Join keys from two sides should have same length and types")
+ buildSide match {
+ case BuildLeft => (leftKeys, rightKeys)
+ case BuildRight => (rightKeys, leftKeys)
+ }
+ }
+
+ // An one-to-many mapping from a streamed key to build keys.
+ private lazy val streamedKeyToBuildKeyMapping = {
+ val mapping = mutable.Map.empty[Expression, Seq[Expression]]
+ streamedKeys.zip(buildKeys).foreach { case (streamedKey, buildKey) =>
+ val key = streamedKey.canonicalized
+ mapping.get(key) match {
+ case Some(v) => mapping.put(key, v :+ buildKey)
+ case None => mapping.put(key, Seq(buildKey))
+ }
+ }
+ mapping.toMap
+ }
+
+ // Expands the given partitioning collection recursively.
+ private def expandOutputPartitioning(
+ partitioning: PartitioningCollection): PartitioningCollection = {
+ PartitioningCollection(partitioning.partitionings.flatMap {
+ case h: HashPartitioning => expandOutputPartitioning(h).partitionings
+ case h: Expression if
h.getClass.getName.contains("CoalescedHashPartitioning") =>
+ expandOutputPartitioning(h).partitionings
+ case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
+ case other => Seq(other)
+ })
+ }
+
+ // Expands the given hash partitioning by substituting streamed keys with
build keys.
+ // For example, if the expressions for the given partitioning are Seq("a",
"b", "c")
+ // where the streamed keys are Seq("b", "c") and the build keys are Seq("x",
"y"),
+ // the expanded partitioning will have the following expressions:
+ // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x",
"y").
+ // The expanded expressions are returned as PartitioningCollection.
+ private def expandOutputPartitioning(
+ partitioning: Partitioning with Expression): PartitioningCollection = {
+ val maxNumCombinations =
conf.broadcastHashJoinOutputPartitioningExpandLimit
+ var currentNumCombinations = 0
+
+ def generateExprCombinations(
+ current: Seq[Expression],
+ accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
+ if (currentNumCombinations >= maxNumCombinations) {
+ Nil
+ } else if (current.isEmpty) {
+ currentNumCombinations += 1
+ Seq(accumulated)
+ } else {
+ val buildKeysOpt =
streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
+ generateExprCombinations(current.tail, accumulated :+ current.head) ++
+ buildKeysOpt
+ .map(_.flatMap(b => generateExprCombinations(current.tail,
accumulated :+ b)))
+ .getOrElse(Nil)
+ }
+ }
+
+ PartitioningCollection(
+
generateExprCombinations(getHashPartitioningLikeExpressions(partitioning), Nil)
+ .map(exprs =>
partitioning.withNewChildren(exprs).asInstanceOf[Partitioning]))
+ }
+
override def withNewChildrenInternal(newLeft: SparkPlan, newRight:
SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)
@@ -705,6 +838,18 @@ case class CometSortMergeJoinExec(
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {
+
+ override def outputPartitioning: Partitioning = joinType match {
+ case _: InnerLike =>
+ PartitioningCollection(Seq(left.outputPartitioning,
right.outputPartitioning))
+ case LeftOuter => left.outputPartitioning
+ case RightOuter => right.outputPartitioning
+ case FullOuter =>
UnknownPartitioning(left.outputPartitioning.numPartitions)
+ case LeftExistence(_) => left.outputPartitioning
+ case x =>
+ throw new IllegalArgumentException(s"ShuffledJoin should not take $x as
the JoinType")
+ }
+
override def withNewChildrenInternal(newLeft: SparkPlan, newRight:
SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala
new file mode 100644
index 0000000..6e5b44c
--- /dev/null
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.plans
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
AttributeSet, Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin
+
+/**
+ * A trait that provides functionality to handle aliases in the
`outputExpressions`.
+ */
+trait AliasAwareOutputExpression extends SQLConfHelper {
+ // `SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT` is Spark 3.4+ only.
+ // Use a default value for now.
+ protected val aliasCandidateLimit = 100
+ protected def outputExpressions: Seq[NamedExpression]
+
+ /**
+ * This method can be used to strip expression which does not affect the
result, for example:
+ * strip the expression which is ordering agnostic for output ordering.
+ */
+ protected def strip(expr: Expression): Expression = expr
+
+ // Build an `Expression` -> `Attribute` alias map.
+ // There can be multiple alias defined for the same expressions but it
doesn't make sense to store
+ // more than `aliasCandidateLimit` attributes for an expression. In those
cases the old logic
+ // handled only the last alias so we need to make sure that we give
precedence to that.
+ // If the `outputExpressions` contain simple attributes we need to add those
too to the map.
+ @transient
+ private lazy val aliasMap = {
+ val aliases = mutable.Map[Expression, mutable.ArrayBuffer[Attribute]]()
+ outputExpressions.reverse.foreach {
+ case a @ Alias(child, _) =>
+ val buffer =
+ aliases.getOrElseUpdate(strip(child).canonicalized,
mutable.ArrayBuffer.empty)
+ if (buffer.size < aliasCandidateLimit) {
+ buffer += a.toAttribute
+ }
+ case _ =>
+ }
+ outputExpressions.foreach {
+ case a: Attribute if aliases.contains(a.canonicalized) =>
+ val buffer = aliases(a.canonicalized)
+ if (buffer.size < aliasCandidateLimit) {
+ buffer += a
+ }
+ case _ =>
+ }
+ aliases
+ }
+
+ protected def hasAlias: Boolean = aliasMap.nonEmpty
+
+ /**
+ * Return a stream of expressions in which the original expression is
projected with `aliasMap`.
+ */
+ protected def projectExpression(expr: Expression): Stream[Expression] = {
+ val outputSet = AttributeSet(outputExpressions.map(_.toAttribute))
+ multiTransformDown(expr) {
+ // Mapping with aliases
+ case e: Expression if aliasMap.contains(e.canonicalized) =>
+ aliasMap(e.canonicalized).toSeq ++ (if (e.containsChild.nonEmpty)
Seq(e) else Seq.empty)
+
+ // Prune if we encounter an attribute that we can't map and it is not in
output set.
+ // This prune will go up to the closest `multiTransformDown()` call and
returns `Stream.empty`
+ // there.
+ case a: Attribute if !outputSet.contains(a) => Seq.empty
+ }
+ }
+
+ // Copied from Spark 3.4+ to make it available in Spark 3.2+.
+ def multiTransformDown(expr: Expression)(
+ rule: PartialFunction[Expression, Seq[Expression]]): Stream[Expression]
= {
+
+ // We could return `Seq(this)` if the `rule` doesn't apply and handle both
+ // - the doesn't apply
+ // - and the rule returns a one element `Seq(originalNode)`
+ // cases together. The returned `Seq` can be a `Stream` and unfortunately
it doesn't seem like
+ // there is a way to match on a one element stream without eagerly
computing the tail's head.
+ // This contradicts with the purpose of only taking the necessary elements
from the
+ // alternatives. I.e. the "multiTransformDown is lazy" test case in
`TreeNodeSuite` would fail.
+ // Please note that this behaviour has a downside as well that we can only
mark the rule on the
+ // original node ineffective if the rule didn't match.
+ var ruleApplied = true
+ val afterRules = CurrentOrigin.withOrigin(expr.origin) {
+ rule.applyOrElse(
+ expr,
+ (_: Expression) => {
+ ruleApplied = false
+ Seq.empty
+ })
+ }
+
+ val afterRulesStream = if (afterRules.isEmpty) {
+ if (ruleApplied) {
+ // If the rule returned with empty alternatives then prune
+ Stream.empty
+ } else {
+ // If the rule was not applied then keep the original node
+ Stream(expr)
+ }
+ } else {
+ // If the rule was applied then use the returned alternatives
+ afterRules.toStream.map { afterRule =>
+ if (expr fastEquals afterRule) {
+ expr
+ } else {
+ afterRule.copyTagsFrom(expr)
+ afterRule
+ }
+ }
+ }
+
+ afterRulesStream.flatMap { afterRule =>
+ if (afterRule.containsChild.nonEmpty) {
+ generateCartesianProduct(afterRule.children.map(c => () =>
multiTransformDown(c)(rule)))
+ .map(afterRule.withNewChildren)
+ } else {
+ Stream(afterRule)
+ }
+ }
+ }
+
+ def generateCartesianProduct[T](elementSeqs: Seq[() => Seq[T]]):
Stream[Seq[T]] = {
+ elementSeqs.foldRight(Stream(Seq.empty[T]))((elements, elementTails) =>
+ for {
+ elementTail <- elementTails
+ element <- elements()
+ } yield element +: elementTail)
+ }
+}
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/plans/PartitioningPreservingUnaryExecNode.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/plans/PartitioningPreservingUnaryExecNode.scala
new file mode 100644
index 0000000..8c6f0af
--- /dev/null
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/plans/PartitioningPreservingUnaryExecNode.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.plans
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression}
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning,
PartitioningCollection, UnknownPartitioning}
+import org.apache.spark.sql.execution.UnaryExecNode
+
+/**
+ * A trait that handles aliases in the `outputExpressions` to produce
`outputPartitioning` that
+ * satisfies distribution requirements.
+ *
+ * This is copied from Spark's `PartitioningPreservingUnaryExecNode` because
it is only available
+ * in Spark 3.4+. This is a workaround to make it available in Spark 3.2+.
+ */
+trait PartitioningPreservingUnaryExecNode extends UnaryExecNode with
AliasAwareOutputExpression {
+ final override def outputPartitioning: Partitioning = {
+ val partitionings: Seq[Partitioning] = if (hasAlias) {
+ flattenPartitioning(child.outputPartitioning).flatMap {
+ case e: Expression =>
+ // We need unique partitionings but if the input partitioning is
+ // `HashPartitioning(Seq(id + id))` and we have `id -> a` and `id ->
b` aliases then after
+ // the projection we have 4 partitionings:
+ // `HashPartitioning(Seq(a + a))`, `HashPartitioning(Seq(a + b))`,
+ // `HashPartitioning(Seq(b + a))`, `HashPartitioning(Seq(b + b))`,
but
+ // `HashPartitioning(Seq(a + b))` is the same as
`HashPartitioning(Seq(b + a))`.
+ val partitioningSet = mutable.Set.empty[Expression]
+ projectExpression(e)
+ .filter(e => partitioningSet.add(e.canonicalized))
+ .take(aliasCandidateLimit)
+ .asInstanceOf[Stream[Partitioning]]
+ case o => Seq(o)
+ }
+ } else {
+ // Filter valid partitiongs (only reference output attributes of the
current plan node)
+ val outputSet = AttributeSet(outputExpressions.map(_.toAttribute))
+ flattenPartitioning(child.outputPartitioning).filter {
+ case e: Expression => e.references.subsetOf(outputSet)
+ case _ => true
+ }
+ }
+ partitionings match {
+ case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions)
+ case Seq(p) => p
+ case ps => PartitioningCollection(ps)
+ }
+ }
+
+ private def flattenPartitioning(partitioning: Partitioning):
Seq[Partitioning] = {
+ partitioning match {
+ case PartitioningCollection(childPartitionings) =>
+ childPartitionings.flatMap(flattenPartitioning)
+ case rest =>
+ rest +: Nil
+ }
+ }
+}
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index cc968a6..264ea4c 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -19,6 +19,7 @@
package org.apache.comet.exec
+import java.sql.Date
import java.time.{Duration, Period}
import scala.collection.JavaConverters._
@@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics,
CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
-import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometCollectLimitExec, CometFilterExec, CometHashAggregateExec,
CometProjectExec, CometRowToColumnarExec, CometScanExec,
CometSortMergeJoinExec, CometTakeOrderedAndProjectExec}
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometCollectLimitExec, CometFilterExec, CometHashAggregateExec,
CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec,
CometSortMergeJoinExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec,
SQLExecution, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
@@ -61,6 +62,28 @@ class CometExecSuite extends CometTestBase {
}
}
+ test("Ensure that the correct outputPartitioning of CometSort") {
+ withTable("test_data") {
+ val tableDF = spark.sparkContext
+ .parallelize(
+ (1 to 10).map { i =>
+ (if (i > 4) 5 else i, i.toString, Date.valueOf(s"${2020 +
i}-$i-$i"))
+ },
+ 3)
+ .toDF("id", "data", "day")
+ tableDF.write.saveAsTable("test_data")
+
+ val df = sql("SELECT * FROM test_data")
+ .repartition($"data")
+ .sortWithinPartitions($"id", $"data", $"day")
+ df.collect()
+ val sort = stripAQEPlan(df.queryExecution.executedPlan).collect { case
s: CometSortExec =>
+ s
+ }.head
+ assert(sort.outputPartitioning == sort.child.outputPartitioning)
+ }
+ }
+
test("Repeated shuffle exchange don't fail") {
assume(isSpark33Plus)
Seq("true", "false").foreach { aqeEnabled =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]