This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4a6b2b9  [SPARK-33832][SQL] Support optimize skewed join even if 
introduce extra shuffle
4a6b2b9 is described below

commit 4a6b2b9fc8b68d59857c5ee71e817b9b06db5ba8
Author: ulysses-you <ulyssesyo...@gmail.com>
AuthorDate: Mon Sep 13 17:21:27 2021 +0800

    [SPARK-33832][SQL] Support optimize skewed join even if introduce extra 
shuffle
    
    ### What changes were proposed in this pull request?
    
    - move the rule `OptimizeSkewedJoin` from stage optimization phase to stage 
preparation phase.
    - run the rule `EnsureRequirements` one more time after the 
`OptimizeSkewedJoin` rule in the stage preparation phase.
    - add `SkewJoinAwareCost` to support estimate skewed join cost
    - add new config to decide if force optimize skewed join
    - in `OptimizeSkewedJoin`, we generate 2 physical plans, one with skew join 
optimization and one without. Then we use the cost evaluator w.r.t. the 
force-skew-join flag and pick the plan with lower cost.
    
    ### Why are the changes needed?
    
    In general, skewed join has more impact on performance  than once more 
shuffle. It makes sense to force optimize skewed join even if introduce extra 
shuffle.
    
    A common case:
    ```
    HashAggregate
      SortMergJoin
        Sort
          Exchange
        Sort
          Exchange
    ```
    and after this PR, the plan looks like:
    ```
    HashAggregate
      Exchange
        SortMergJoin (isSkew=true)
          Sort
            Exchange
          Sort
            Exchange
    ```
    
    Note that, the new introduced shuffle also can be optimized by AQE.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, a new config.
    
    ### How was this patch tested?
    
    * Add new test
    * pass exists test `SPARK-30524: Do not optimize skew join if introduce 
additional shuffle`
    * pass exists test `SPARK-33551: Do not use custom shuffle reader for 
repartition`
    
    Closes #32816 from ulysses-you/support-extra-shuffle.
    
    Authored-by: ulysses-you <ulyssesyo...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  7 ++
 .../execution/adaptive/AdaptiveSparkPlanExec.scala | 31 +++----
 .../execution/adaptive/OptimizeSkewedJoin.scala    | 25 ++++--
 .../sql/execution/adaptive/simpleCosting.scala     | 48 +++++++++--
 .../execution/exchange/EnsureRequirements.scala    | 96 ++++++++++++++--------
 .../adaptive/AdaptiveQueryExecSuite.scala          | 68 +++++++++++++++
 6 files changed, 217 insertions(+), 58 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 8ba2b9f..9f71ecb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -666,6 +666,13 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN =
+    buildConf("spark.sql.adaptive.forceOptimizeSkewedJoin")
+      .doc("When true, force enable OptimizeSkewedJoin even if it introduces 
extra shuffle.")
+      .version("3.3.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS =
     buildConf("spark.sql.adaptive.customCostEvaluatorClass")
       .doc("The custom cost evaluator class to be used for adaptive execution. 
If not being set," +
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index bf810f3..13c9528 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -97,27 +97,36 @@ case class AdaptiveSparkPlanExec(
     AQEUtils.getRequiredDistribution(inputPlan)
   }
 
+  @transient private val costEvaluator =
+    conf.getConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS) match {
+      case Some(className) => CostEvaluator.instantiate(className, 
session.sparkContext.getConf)
+      case _ => 
SimpleCostEvaluator(conf.getConf(SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN))
+    }
+
   // A list of physical plan rules to be applied before creation of query 
stages. The physical
   // plan should reach a final status of query stages (i.e., no more addition 
or removal of
   // Exchange nodes) after running these rules.
-  @transient private val queryStagePreparationRules: Seq[Rule[SparkPlan]] = 
Seq(
-    RemoveRedundantProjects,
+  @transient private val queryStagePreparationRules: Seq[Rule[SparkPlan]] = {
     // For cases like `df.repartition(a, b).select(c)`, there is no 
distribution requirement for
     // the final plan, but we do need to respect the user-specified 
repartition. Here we ask
     // `EnsureRequirements` to not optimize out the user-specified 
repartition-by-col to work
     // around this case.
-    EnsureRequirements(optimizeOutRepartition = 
requiredDistribution.isDefined),
-    RemoveRedundantSorts,
-    DisableUnnecessaryBucketedScan
-  ) ++ context.session.sessionState.queryStagePrepRules
+    val ensureRequirements =
+      EnsureRequirements(requiredDistribution.isDefined, requiredDistribution)
+    Seq(
+      RemoveRedundantProjects,
+      ensureRequirements,
+      RemoveRedundantSorts,
+      DisableUnnecessaryBucketedScan,
+      OptimizeSkewedJoin(ensureRequirements, costEvaluator)
+    ) ++ context.session.sessionState.queryStagePrepRules
+  }
 
   // A list of physical optimizer rules to be applied to a new stage before 
its execution. These
   // optimizations should be stage-independent.
   @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
     PlanAdaptiveDynamicPruningFilters(this),
     ReuseAdaptiveSubquery(context.subqueryCache),
-    // Skew join does not handle `AQEShuffleRead` so needs to be applied first.
-    OptimizeSkewedJoin,
     OptimizeSkewInRebalancePartitions,
     CoalesceShufflePartitions(context.session),
     // `OptimizeShuffleWithLocalRead` needs to make use of 
'AQEShuffleReadExec.partitionSpecs'
@@ -169,12 +178,6 @@ case class AdaptiveSparkPlanExec(
     optimized
   }
 
-  @transient private val costEvaluator =
-    conf.getConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS) match {
-      case Some(className) => CostEvaluator.instantiate(className, 
session.sparkContext.getConf)
-      case _ => SimpleCostEvaluator
-    }
-
   @transient val initialPlan = context.session.withActive {
     applyPhysicalRules(
       inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE 
Preparations")))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
index 88abe68..2fe5b18 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala
@@ -22,8 +22,9 @@ import scala.collection.mutable
 import org.apache.commons.io.FileUtils
 
 import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, 
ShuffleOrigin}
+import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, 
EnsureRequirements}
 import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, 
SortMergeJoinExec}
 import org.apache.spark.sql.internal.SQLConf
 
@@ -48,9 +49,10 @@ import org.apache.spark.sql.internal.SQLConf
  * (L3, R3-1), (L3, R3-2),
  * (L4-1, R4-1), (L4-2, R4-1), (L4-1, R4-2), (L4-2, R4-2)
  */
-object OptimizeSkewedJoin extends AQEShuffleReadRule {
-
-  override val supportedShuffleOrigins: Seq[ShuffleOrigin] = 
Seq(ENSURE_REQUIREMENTS)
+case class OptimizeSkewedJoin(
+    ensureRequirements: EnsureRequirements,
+    costEvaluator: CostEvaluator)
+  extends Rule[SparkPlan] {
 
   /**
    * A partition is considered as a skewed partition if its size is larger 
than the median
@@ -250,7 +252,17 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
       // SHJ
       //   Shuffle
       //   Shuffle
-      optimizeSkewJoin(plan)
+      val optimized = ensureRequirements.apply(optimizeSkewJoin(plan))
+      val originCost = costEvaluator.evaluateCost(plan)
+      val optimizedCost = costEvaluator.evaluateCost(optimized)
+      // two cases we will pick new plan:
+      //   1. optimize the skew join without extra shuffle
+      //   2. optimize the skew join with extra shuffle but the costEvaluator 
think it's better
+      if (optimizedCost <= originCost) {
+        optimized
+      } else {
+        plan
+      }
     } else {
       plan
     }
@@ -258,7 +270,8 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
 
   object ShuffleStage {
     def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match {
-      case s: ShuffleQueryStageExec if s.mapStats.isDefined && 
isSupported(s.shuffle) =>
+      case s: ShuffleQueryStageExec if s.isMaterialized && 
s.mapStats.isDefined &&
+        s.shuffle.shuffleOrigin == ENSURE_REQUIREMENTS =>
         Some(s)
       case _ => None
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
index 7f02683..864563b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
+import org.apache.spark.sql.execution.joins.ShuffledJoin
 
 /**
  * A simple implementation of [[Cost]], which takes a number of [[Long]] as 
the cost value.
@@ -35,15 +36,52 @@ case class SimpleCost(value: Long) extends Cost {
 }
 
 /**
- * A simple implementation of [[CostEvaluator]], which counts the number of
- * [[ShuffleExchangeLike]] nodes in the plan.
+ * A skew join aware implementation of [[Cost]], which consider shuffle number 
and skew join number.
+ *
+ * We always pick the cost which has more skew join even if it introduces one 
or more extra shuffle.
+ * Otherwise, if two costs have the same number of skew join or no skew join, 
we will pick the one
+ * with small number of shuffle.
  */
-object SimpleCostEvaluator extends CostEvaluator {
+case class SkewJoinAwareCost(
+    numShuffles: Int,
+    numSkewJoins: Int) extends Cost {
+  override def compare(that: Cost): Int = that match {
+    case other: SkewJoinAwareCost =>
+      // If more skew joins are optimized or less shuffle nodes, it means the 
cost is lower
+      if (numSkewJoins > other.numSkewJoins) {
+        -1
+      } else if (numSkewJoins < other.numSkewJoins) {
+        1
+      } else if (numShuffles < other.numShuffles) {
+        -1
+      } else if (numShuffles > other.numShuffles) {
+        1
+      } else {
+        0
+      }
+
+    case _ =>
+      throw 
QueryExecutionErrors.cannotCompareCostWithTargetCostError(that.toString)
+  }
+}
 
+/**
+ * A skew join aware implementation of [[CostEvaluator]], which counts the 
number of
+ * [[ShuffleExchangeLike]] nodes and skew join nodes in the plan.
+ */
+case class SimpleCostEvaluator(forceOptimizeSkewedJoin: Boolean) extends 
CostEvaluator {
   override def evaluateCost(plan: SparkPlan): Cost = {
-    val cost = plan.collect {
+    val numShuffles = plan.collect {
       case s: ShuffleExchangeLike => s
     }.size
-    SimpleCost(cost)
+
+    if (forceOptimizeSkewedJoin) {
+      val numSkewJoins = plan.collect {
+        case j: ShuffledJoin if j.isSkewJoin => j
+      }.size
+      SkewJoinAwareCost(numShuffles, numSkewJoins)
+    } else {
+      SimpleCost(numShuffles)
+    }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 23716f1..86b2344 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -38,18 +38,23 @@ import 
org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoin
  *                               but can be false in AQE when AQE optimization 
may change the plan
  *                               output partitioning and need to retain the 
user-specified
  *                               repartition shuffles in the plan.
+ * @param requiredDistribution The root required distribution we should 
ensure. This value is used
+ *                             in AQE in case we change final stage output 
partitioning.
  */
-case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends 
Rule[SparkPlan] {
-
-  private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
-    val requiredChildDistributions: Seq[Distribution] = 
operator.requiredChildDistribution
-    val requiredChildOrderings: Seq[Seq[SortOrder]] = 
operator.requiredChildOrdering
-    var children: Seq[SparkPlan] = operator.children
-    assert(requiredChildDistributions.length == children.length)
-    assert(requiredChildOrderings.length == children.length)
+case class EnsureRequirements(
+    optimizeOutRepartition: Boolean = true,
+    requiredDistribution: Option[Distribution] = None)
+  extends Rule[SparkPlan] {
 
+  private def ensureDistributionAndOrdering(
+      originalChildren: Seq[SparkPlan],
+      requiredChildDistributions: Seq[Distribution],
+      requiredChildOrderings: Seq[Seq[SortOrder]],
+      shuffleOrigin: ShuffleOrigin): Seq[SparkPlan] = {
+    assert(requiredChildDistributions.length == originalChildren.length)
+    assert(requiredChildOrderings.length == originalChildren.length)
     // Ensure that the operator's children satisfy their output distribution 
requirements.
-    children = children.zip(requiredChildDistributions).map {
+    var newChildren = originalChildren.zip(requiredChildDistributions).map {
       case (child, distribution) if 
child.outputPartitioning.satisfies(distribution) =>
         child
       case (child, BroadcastDistribution(mode)) =>
@@ -57,7 +62,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean 
= true) extends Ru
       case (child, distribution) =>
         val numPartitions = distribution.requiredNumPartitions
           .getOrElse(conf.numShufflePartitions)
-        ShuffleExchangeExec(distribution.createPartitioning(numPartitions), 
child)
+        ShuffleExchangeExec(distribution.createPartitioning(numPartitions), 
child, shuffleOrigin)
     }
 
     // Get the indexes of children which have specified distribution 
requirements and need to have
@@ -69,7 +74,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean 
= true) extends Ru
     }.map(_._2)
 
     val childrenNumPartitions =
-      childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet
+      
childrenIndexes.map(newChildren(_).outputPartitioning.numPartitions).toSet
 
     if (childrenNumPartitions.size > 1) {
       // Get the number of partitions which is explicitly required by the 
distributions.
@@ -78,7 +83,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean 
= true) extends Ru
           index => requiredChildDistributions(index).requiredNumPartitions
         }.toSet
         assert(numPartitionsSet.size <= 1,
-          s"$operator have incompatible requirements of the number of 
partitions for its children")
+          s"$requiredChildDistributions have incompatible requirements of the 
number of partitions")
         numPartitionsSet.headOption
       }
 
@@ -87,7 +92,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean 
= true) extends Ru
       // 1. We should avoid shuffling these children.
       // 2. We should have a reasonable parallelism.
       val nonShuffleChildrenNumPartitions =
-        
childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec])
+        
childrenIndexes.map(newChildren).filterNot(_.isInstanceOf[ShuffleExchangeExec])
           .map(_.outputPartitioning.numPartitions)
       val expectedChildrenNumPartitions = if 
(nonShuffleChildrenNumPartitions.nonEmpty) {
         if (nonShuffleChildrenNumPartitions.length == childrenIndexes.length) {
@@ -106,7 +111,7 @@ case class EnsureRequirements(optimizeOutRepartition: 
Boolean = true) extends Ru
 
       val targetNumPartitions = 
requiredNumPartitions.getOrElse(expectedChildrenNumPartitions)
 
-      children = children.zip(requiredChildDistributions).zipWithIndex.map {
+      newChildren = 
newChildren.zip(requiredChildDistributions).zipWithIndex.map {
         case ((child, distribution), index) if childrenIndexes.contains(index) 
=>
           if (child.outputPartitioning.numPartitions == targetNumPartitions) {
             child
@@ -124,7 +129,7 @@ case class EnsureRequirements(optimizeOutRepartition: 
Boolean = true) extends Ru
     }
 
     // Now that we've performed any necessary shuffles, add sorts to guarantee 
output orderings:
-    children = children.zip(requiredChildOrderings).map { case (child, 
requiredOrdering) =>
+    newChildren = newChildren.zip(requiredChildOrderings).map { case (child, 
requiredOrdering) =>
       // If child.outputOrdering already satisfies the requiredOrdering, we do 
not need to sort.
       if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) 
{
         child
@@ -133,7 +138,7 @@ case class EnsureRequirements(optimizeOutRepartition: 
Boolean = true) extends Ru
       }
     }
 
-    operator.withNewChildren(children)
+    newChildren
   }
 
   private def reorder(
@@ -254,25 +259,50 @@ case class EnsureRequirements(optimizeOutRepartition: 
Boolean = true) extends Ru
     }
   }
 
-  def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
-    case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, 
shuffleOrigin)
-        if optimizeOutRepartition &&
-          (shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == 
REPARTITION_BY_NUM) =>
-      def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
-        partitioning match {
-          case lower: HashPartitioning if upper.semanticEquals(lower) => true
-          case lower: PartitioningCollection =>
-            lower.partitionings.exists(hasSemanticEqualPartitioning)
-          case _ => false
+  def apply(plan: SparkPlan): SparkPlan = {
+    val newPlan = plan.transformUp {
+      case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, 
shuffleOrigin)
+          if optimizeOutRepartition &&
+            (shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == 
REPARTITION_BY_NUM) =>
+        def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean 
= {
+          partitioning match {
+            case lower: HashPartitioning if upper.semanticEquals(lower) => true
+            case lower: PartitioningCollection =>
+              lower.partitionings.exists(hasSemanticEqualPartitioning)
+            case _ => false
+          }
         }
-      }
-      if (hasSemanticEqualPartitioning(child.outputPartitioning)) {
-        child
+        if (hasSemanticEqualPartitioning(child.outputPartitioning)) {
+          child
+        } else {
+          operator
+        }
+
+      case operator: SparkPlan =>
+        val reordered = reorderJoinPredicates(operator)
+        val newChildren = ensureDistributionAndOrdering(
+          reordered.children,
+          reordered.requiredChildDistribution,
+          reordered.requiredChildOrdering,
+          ENSURE_REQUIREMENTS)
+        reordered.withNewChildren(newChildren)
+    }
+
+    if (requiredDistribution.isDefined) {
+      val shuffleOrigin = if 
(requiredDistribution.get.requiredNumPartitions.isDefined) {
+        REPARTITION_BY_NUM
       } else {
-        operator
+        REPARTITION_BY_COL
       }
-
-    case operator: SparkPlan =>
-      ensureDistributionAndOrdering(reorderJoinPredicates(operator))
+      val finalPlan = ensureDistributionAndOrdering(
+        newPlan :: Nil,
+        requiredDistribution.get :: Nil,
+        Seq(Nil),
+        shuffleOrigin)
+      assert(finalPlan.size == 1)
+      finalPlan.head
+    } else {
+      newPlan
+    }
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 4471fda..548ba87 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -1908,6 +1908,74 @@ class AdaptiveQueryExecSuite
     }
   }
 
+  test("SPARK-33832: Support optimize skew join even if introduce extra 
shuffle") {
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+      SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> 
"false",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
+      SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100",
+      SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+      SQLConf.SHUFFLE_PARTITIONS.key -> "10",
+      SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN.key -> "true") {
+      withTempView("skewData1", "skewData2") {
+        spark
+          .range(0, 1000, 1, 10)
+          .selectExpr("id % 3 as key1", "id as value1")
+          .createOrReplaceTempView("skewData1")
+        spark
+          .range(0, 1000, 1, 10)
+          .selectExpr("id % 1 as key2", "id as value2")
+          .createOrReplaceTempView("skewData2")
+
+        // check if optimized skewed join does not satisfy the required 
distribution
+        Seq(true, false).foreach { hasRequiredDistribution =>
+          Seq(true, false).foreach { hasPartitionNumber =>
+            val repartition = if (hasRequiredDistribution) {
+              s"/*+ repartition(${ if (hasPartitionNumber) "10," else ""}key1) 
*/"
+            } else {
+              ""
+            }
+
+            // check required distribution and extra shuffle
+            val (_, adaptive1) =
+              runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM 
skewData1 " +
+                s"JOIN skewData2 ON key1 = key2 GROUP BY key1")
+            val shuffles1 = collect(adaptive1) {
+              case s: ShuffleExchangeExec => s
+            }
+            assert(shuffles1.size == 3)
+            // shuffles1.head is the top-level shuffle under the Aggregate 
operator
+            assert(shuffles1.head.shuffleOrigin == ENSURE_REQUIREMENTS)
+            val smj1 = findTopLevelSortMergeJoin(adaptive1)
+            assert(smj1.size == 1 && smj1.head.isSkewJoin)
+
+            // only check required distribution
+            val (_, adaptive2) =
+              runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM 
skewData1 " +
+                s"JOIN skewData2 ON key1 = key2")
+            val shuffles2 = collect(adaptive2) {
+              case s: ShuffleExchangeExec => s
+            }
+            if (hasRequiredDistribution) {
+              assert(shuffles2.size == 3)
+              val finalShuffle = shuffles2.head
+              if (hasPartitionNumber) {
+                assert(finalShuffle.shuffleOrigin == REPARTITION_BY_NUM)
+              } else {
+                assert(finalShuffle.shuffleOrigin == REPARTITION_BY_COL)
+              }
+            } else {
+              assert(shuffles2.size == 2)
+            }
+            val smj2 = findTopLevelSortMergeJoin(adaptive2)
+            assert(smj2.size == 1 && smj2.head.isSkewJoin)
+          }
+        }
+      }
+    }
+  }
+
   test("SPARK-35968: AQE coalescing should not produce too small partitions by 
default") {
     withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
       val (_, adaptive) =

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to