This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 72f547d  [SPARK-27485] EnsureRequirements.reorder should handle 
duplicate expressions gracefully
72f547d is described below

commit 72f547d4a960ba0ba9cace53a0a5553eca1b4dd6
Author: herman <her...@databricks.com>
AuthorDate: Tue Jul 16 17:09:52 2019 +0800

    [SPARK-27485] EnsureRequirements.reorder should handle duplicate 
expressions gracefully
    
    ## What changes were proposed in this pull request?
    When reordering joins EnsureRequirements only checks if all the join keys 
are present in the partitioning expression seq. This is problematic when the 
joins keys and and partitioning expressions both contain duplicates but not the 
same number of duplicates for each expression, e.g. `Seq(a, a, b)` vs `Seq(a, 
b, b)`. This fails with an index lookup failure in the `reorder` function.
    
    This PR fixes this removing the equality checking logic from the 
`reorderJoinKeys` function, and by doing the multiset equality in the `reorder` 
function while building the reordered key sequences.
    
    ## How was this patch tested?
    Added a unit test to the `PlannerSuite` and added an integration test to 
`JoinSuite`
    
    Closes #25167 from hvanhovell/SPARK-27485.
    
    Authored-by: herman <her...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../execution/exchange/EnsureRequirements.scala    | 72 ++++++++++++----------
 .../scala/org/apache/spark/sql/JoinSuite.scala     | 20 ++++++
 .../apache/spark/sql/execution/PlannerSuite.scala  | 26 ++++++++
 3 files changed, 86 insertions(+), 32 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index d2d5011..bdb9a31 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
ShuffledHashJoinExec,
-  SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, 
SortMergeJoinExec}
 import org.apache.spark.sql.internal.SQLConf
 
 /**
@@ -221,25 +220,41 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
   }
 
   private def reorder(
-      leftKeys: Seq[Expression],
-      rightKeys: Seq[Expression],
+      leftKeys: IndexedSeq[Expression],
+      rightKeys: IndexedSeq[Expression],
       expectedOrderOfKeys: Seq[Expression],
       currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) 
= {
-    val leftKeysBuffer = ArrayBuffer[Expression]()
-    val rightKeysBuffer = ArrayBuffer[Expression]()
-    val pickedIndexes = mutable.Set[Int]()
-    val keysAndIndexes = currentOrderOfKeys.zipWithIndex
+    if (expectedOrderOfKeys.size != currentOrderOfKeys.size) {
+      return (leftKeys, rightKeys)
+    }
+
+    // Build a lookup between an expression and the positions its holds in the 
current key seq.
+    val keyToIndexMap = mutable.Map.empty[Expression, mutable.BitSet]
+    currentOrderOfKeys.zipWithIndex.foreach {
+      case (key, index) =>
+        keyToIndexMap.getOrElseUpdate(key.canonicalized, 
mutable.BitSet.empty).add(index)
+    }
+
+    // Reorder the keys.
+    val leftKeysBuffer = new ArrayBuffer[Expression](leftKeys.size)
+    val rightKeysBuffer = new ArrayBuffer[Expression](rightKeys.size)
+    val iterator = expectedOrderOfKeys.iterator
+    while (iterator.hasNext) {
+      // Lookup the current index of this key.
+      keyToIndexMap.get(iterator.next().canonicalized) match {
+        case Some(indices) if indices.nonEmpty =>
+          // Take the first available index from the map.
+          val index = indices.firstKey
+          indices.remove(index)
 
-    expectedOrderOfKeys.foreach(expression => {
-      val index = keysAndIndexes.find { case (e, idx) =>
-        // As we may have the same key used many times, we need to filter out 
its occurrence we
-        // have already used.
-        e.semanticEquals(expression) && !pickedIndexes.contains(idx)
-      }.map(_._2).get
-      pickedIndexes += index
-      leftKeysBuffer.append(leftKeys(index))
-      rightKeysBuffer.append(rightKeys(index))
-    })
+          // Add the keys for that index to the reordered keys.
+          leftKeysBuffer += leftKeys(index)
+          rightKeysBuffer += rightKeys(index)
+        case _ =>
+          // The expression cannot be found, or we have exhausted all indices 
for that expression.
+          return (leftKeys, rightKeys)
+      }
+    }
     (leftKeysBuffer, rightKeysBuffer)
   }
 
@@ -249,20 +264,13 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
       leftPartitioning: Partitioning,
       rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
     if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) 
{
-      leftPartitioning match {
-        case HashPartitioning(leftExpressions, _)
-          if leftExpressions.length == leftKeys.length &&
-            leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) 
=>
-          reorder(leftKeys, rightKeys, leftExpressions, leftKeys)
-
-        case _ => rightPartitioning match {
-          case HashPartitioning(rightExpressions, _)
-            if rightExpressions.length == rightKeys.length &&
-              rightKeys.forall(x => 
rightExpressions.exists(_.semanticEquals(x))) =>
-            reorder(leftKeys, rightKeys, rightExpressions, rightKeys)
-
-          case _ => (leftKeys, rightKeys)
-        }
+      (leftPartitioning, rightPartitioning) match {
+        case (HashPartitioning(leftExpressions, _), _) =>
+          reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, 
leftExpressions, leftKeys)
+        case (_, HashPartitioning(rightExpressions, _)) =>
+          reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, 
rightExpressions, rightKeys)
+        case _ =>
+          (leftKeys, rightKeys)
       }
     } else {
       (leftKeys, rightKeys)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 52fa22c..a44deaf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -894,6 +894,26 @@ class JoinSuite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("SPARK-27485: EnsureRequirements should not fail join with duplicate 
keys") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      val tbl_a = spark.range(40)
+        .select($"id" as "x", $"id" % 10 as "y")
+        .repartition(2, $"x", $"y", $"x")
+        .as("tbl_a")
+
+      val tbl_b = spark.range(20)
+        .select($"id" as "x", $"id" % 2 as "y1", $"id" % 20 as "y2")
+        .as("tbl_b")
+
+      val res = tbl_a
+        .join(tbl_b,
+          $"tbl_a.x" === $"tbl_b.x" && $"tbl_a.y" === $"tbl_b.y1" && 
$"tbl_a.y" === $"tbl_b.y2")
+        .select($"tbl_a.x")
+      checkAnswer(res, Row(0L) :: Row(1L) :: Nil)
+    }
+  }
+
   test("SPARK-26352: join reordering should not change the order of columns") {
     withTable("tab1", "tab2", "tab3") {
       spark.sql("select 1 as x, 100 as y").write.saveAsTable("tab1")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index d9fb172..3c3af80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -697,6 +697,32 @@ class PlannerSuite extends SharedSQLContext {
     }
   }
 
+  test("SPARK-27485: EnsureRequirements.reorder should handle duplicate 
expressions") {
+    val plan1 = DummySparkPlan(
+      outputPartitioning = HashPartitioning(exprA :: exprB :: exprA :: Nil, 5))
+    val plan2 = DummySparkPlan()
+    val smjExec = SortMergeJoinExec(
+      leftKeys = exprA :: exprB :: exprB :: Nil,
+      rightKeys = exprA :: exprC :: exprC :: Nil,
+      joinType = Inner,
+      condition = None,
+      left = plan1,
+      right = plan2)
+    val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
+    outputPlan match {
+      case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+             SortExec(_, _,
+               
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _), _),
+             SortExec(_, _,
+               
ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _), _)) 
=>
+        assert(leftKeys === smjExec.leftKeys)
+        assert(rightKeys === smjExec.rightKeys)
+        assert(leftKeys === leftPartitioningExpressions)
+        assert(rightKeys === rightPartitioningExpressions)
+      case _ => fail(outputPlan.toString)
+    }
+  }
+
   test("SPARK-24500: create union with stream of children") {
     val df = Union(Stream(
       Range(1, 1, 1, 1),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to