Repository: spark
Updated Branches:
  refs/heads/master 874350905 -> 682eb4f2e


[SPARK-22042][SQL] ReorderJoinPredicates can break when child's partitioning is 
not decided

## What changes were proposed in this pull request?

See jira description for the bug : 
https://issues.apache.org/jira/browse/SPARK-22042

Fix done in this PR is:  In `EnsureRequirements`, apply `ReorderJoinPredicates` 
over the input tree before doing its core logic. Since the tree is transformed 
bottom-up, we can assure that the children are resolved before doing 
`ReorderJoinPredicates`.

Theoretically this will guarantee to cover all such cases while keeping the 
code simple. My small grudge is for cosmetic reasons. This PR will look weird 
given that we don't call rules from other rules (not to my knowledge). I could 
have moved all the logic for `ReorderJoinPredicates` into `EnsureRequirements` 
but that will make it a but crowded. I am happy to discuss if there are better 
options.

## How was this patch tested?

Added a new test case

Author: Tejas Patil <tej...@fb.com>

Closes #19257 from tejasapatil/SPARK-22042_ReorderJoinPredicates.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/682eb4f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/682eb4f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/682eb4f2

Branch: refs/heads/master
Commit: 682eb4f2ea152ce1043fbe689ea95318926b91b0
Parents: 8743509
Author: Tejas Patil <tej...@fb.com>
Authored: Tue Dec 12 23:30:06 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Tue Dec 12 23:30:06 2017 -0800

----------------------------------------------------------------------
 .../spark/sql/execution/QueryExecution.scala    |  2 -
 .../execution/exchange/EnsureRequirements.scala | 76 +++++++++++++++-
 .../execution/joins/ReorderJoinPredicates.scala | 94 --------------------
 .../spark/sql/sources/BucketedReadSuite.scala   | 31 +++++++
 4 files changed, 106 insertions(+), 97 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/682eb4f2/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index f404621..946475a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.command.{DescribeTableCommand, 
ExecutedCommandExec, ShowTablesCommand}
 import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ReuseExchange}
-import org.apache.spark.sql.execution.joins.ReorderJoinPredicates
 import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, 
TimestampType, _}
 import org.apache.spark.util.Utils
 
@@ -104,7 +103,6 @@ class QueryExecution(val sparkSession: SparkSession, val 
logical: LogicalPlan) {
   protected def preparations: Seq[Rule[SparkPlan]] = Seq(
     python.ExtractPythonUDFs,
     PlanSubqueries(sparkSession),
-    new ReorderJoinPredicates,
     EnsureRequirements(sparkSession.sessionState.conf),
     CollapseCodegenStages(sparkSession.sessionState.conf),
     ReuseExchange(sparkSession.sessionState.conf),

http://git-wip-us.apache.org/repos/asf/spark/blob/682eb4f2/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 4e2ca37..82f0b9f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -17,10 +17,14 @@
 
 package org.apache.spark.sql.execution.exchange
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
ShuffledHashJoinExec,
+  SortMergeJoinExec}
 import org.apache.spark.sql.internal.SQLConf
 
 /**
@@ -248,6 +252,75 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
     operator.withNewChildren(children)
   }
 
+  /**
+   * When the physical operators are created for JOIN, the ordering of join 
keys is based on order
+   * in which the join keys appear in the user query. That might not match 
with the output
+   * partitioning of the join node's children (thus leading to extra sort / 
shuffle being
+   * introduced). This rule will change the ordering of the join keys to match 
with the
+   * partitioning of the join nodes' children.
+   */
+  def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
+    def reorderJoinKeys(
+        leftKeys: Seq[Expression],
+        rightKeys: Seq[Expression],
+        leftPartitioning: Partitioning,
+        rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = 
{
+
+      def reorder(expectedOrderOfKeys: Seq[Expression],
+                  currentOrderOfKeys: Seq[Expression]): (Seq[Expression], 
Seq[Expression]) = {
+        val leftKeysBuffer = ArrayBuffer[Expression]()
+        val rightKeysBuffer = ArrayBuffer[Expression]()
+
+        expectedOrderOfKeys.foreach(expression => {
+          val index = currentOrderOfKeys.indexWhere(e => 
e.semanticEquals(expression))
+          leftKeysBuffer.append(leftKeys(index))
+          rightKeysBuffer.append(rightKeys(index))
+        })
+        (leftKeysBuffer, rightKeysBuffer)
+      }
+
+      if (leftKeys.forall(_.deterministic) && 
rightKeys.forall(_.deterministic)) {
+        leftPartitioning match {
+          case HashPartitioning(leftExpressions, _)
+            if leftExpressions.length == leftKeys.length &&
+              leftKeys.forall(x => 
leftExpressions.exists(_.semanticEquals(x))) =>
+            reorder(leftExpressions, leftKeys)
+
+          case _ => rightPartitioning match {
+            case HashPartitioning(rightExpressions, _)
+              if rightExpressions.length == rightKeys.length &&
+                rightKeys.forall(x => 
rightExpressions.exists(_.semanticEquals(x))) =>
+              reorder(rightExpressions, rightKeys)
+
+            case _ => (leftKeys, rightKeys)
+          }
+        }
+      } else {
+        (leftKeys, rightKeys)
+      }
+    }
+
+    plan.transformUp {
+      case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, 
condition, left,
+        right) =>
+        val (reorderedLeftKeys, reorderedRightKeys) =
+          reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, 
right.outputPartitioning)
+        BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, 
buildSide, condition,
+          left, right)
+
+      case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, 
condition, left, right) =>
+        val (reorderedLeftKeys, reorderedRightKeys) =
+          reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, 
right.outputPartitioning)
+        ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, 
buildSide, condition,
+          left, right)
+
+      case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, 
right) =>
+        val (reorderedLeftKeys, reorderedRightKeys) =
+          reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, 
right.outputPartitioning)
+        SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, 
condition, left, right)
+    }
+  }
+
   def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
     case operator @ ShuffleExchangeExec(partitioning, child, _) =>
       child.children match {
@@ -255,6 +328,7 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
           if (childPartitioning.guarantees(partitioning)) child else operator
         case _ => operator
       }
-    case operator: SparkPlan => ensureDistributionAndOrdering(operator)
+    case operator: SparkPlan =>
+      ensureDistributionAndOrdering(reorderJoinPredicates(operator))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/682eb4f2/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
deleted file mode 100644
index 534d8c5..0000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.SparkPlan
-
-/**
- * When the physical operators are created for JOIN, the ordering of join keys 
is based on order
- * in which the join keys appear in the user query. That might not match with 
the output
- * partitioning of the join node's children (thus leading to extra sort / 
shuffle being
- * introduced). This rule will change the ordering of the join keys to match 
with the
- * partitioning of the join nodes' children.
- */
-class ReorderJoinPredicates extends Rule[SparkPlan] {
-  private def reorderJoinKeys(
-      leftKeys: Seq[Expression],
-      rightKeys: Seq[Expression],
-      leftPartitioning: Partitioning,
-      rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
-
-    def reorder(
-        expectedOrderOfKeys: Seq[Expression],
-        currentOrderOfKeys: Seq[Expression]): (Seq[Expression], 
Seq[Expression]) = {
-      val leftKeysBuffer = ArrayBuffer[Expression]()
-      val rightKeysBuffer = ArrayBuffer[Expression]()
-
-      expectedOrderOfKeys.foreach(expression => {
-        val index = currentOrderOfKeys.indexWhere(e => 
e.semanticEquals(expression))
-        leftKeysBuffer.append(leftKeys(index))
-        rightKeysBuffer.append(rightKeys(index))
-      })
-      (leftKeysBuffer, rightKeysBuffer)
-    }
-
-    if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) 
{
-      leftPartitioning match {
-        case HashPartitioning(leftExpressions, _)
-          if leftExpressions.length == leftKeys.length &&
-            leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) 
=>
-          reorder(leftExpressions, leftKeys)
-
-        case _ => rightPartitioning match {
-          case HashPartitioning(rightExpressions, _)
-            if rightExpressions.length == rightKeys.length &&
-              rightKeys.forall(x => 
rightExpressions.exists(_.semanticEquals(x))) =>
-            reorder(rightExpressions, rightKeys)
-
-          case _ => (leftKeys, rightKeys)
-        }
-      }
-    } else {
-      (leftKeys, rightKeys)
-    }
-  }
-
-  def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
-    case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, 
condition, left, right) =>
-      val (reorderedLeftKeys, reorderedRightKeys) =
-        reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, 
right.outputPartitioning)
-      BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, 
buildSide, condition,
-        left, right)
-
-    case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, 
condition, left, right) =>
-      val (reorderedLeftKeys, reorderedRightKeys) =
-        reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, 
right.outputPartitioning)
-      ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, 
buildSide, condition,
-        left, right)
-
-    case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, 
right) =>
-      val (reorderedLeftKeys, reorderedRightKeys) =
-        reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, 
right.outputPartitioning)
-      SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, 
condition, left, right)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/682eb4f2/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index ab18905..9025859 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -602,6 +602,37 @@ abstract class BucketedReadSuite extends QueryTest with 
SQLTestUtils {
     )
   }
 
+  test("SPARK-22042 ReorderJoinPredicates can break when child's partitioning 
is not decided") {
+    withTable("bucketed_table", "table1", "table2") {
+      df.write.format("parquet").saveAsTable("table1")
+      df.write.format("parquet").saveAsTable("table2")
+      df.write.format("parquet").bucketBy(8, "j", 
"k").saveAsTable("bucketed_table")
+
+      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+        checkAnswer(
+          sql("""
+                |SELECT ab.i, ab.j, ab.k, c.i, c.j, c.k
+                |FROM (
+                |  SELECT a.i, a.j, a.k
+                |  FROM bucketed_table a
+                |  JOIN table1 b
+                |  ON a.i = b.i
+                |) ab
+                |JOIN table2 c
+                |ON ab.i = c.i
+                |""".stripMargin),
+          sql("""
+                |SELECT a.i, a.j, a.k, c.i, c.j, c.k
+                |FROM bucketed_table a
+                |JOIN table1 b
+                |ON a.i = b.i
+                |JOIN table2 c
+                |ON a.i = c.i
+                |""".stripMargin))
+      }
+    }
+  }
+
   test("error if there exists any malformed bucket files") {
     withTable("bucketed_table") {
       df1.write.format("parquet").bucketBy(8, 
"i").saveAsTable("bucketed_table")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to