Repository: spark
Updated Branches:
  refs/heads/master 224375c55 -> cc30ef800


[SPARK-22916][SQL] shouldn't bias towards build right if user does not specify

## What changes were proposed in this pull request?

When there are no broadcast hints, the current spark strategies will prefer to 
building the right side, without considering the sizes of the two tables. This 
patch added the logic to consider the sizes of the two tables for the build 
side. To make the logic clear, the build side is determined by two steps:

1. If there are broadcast hints, the build side is determined by 
`broadcastSideByHints`;
2. If there are no broadcast hints, the build side is determined by 
`broadcastSideBySizes`;
3. If the broadcast is disabled by the config, it falls back to the next cases.

## How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, 
manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, 
remove this)

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Author: Feng Liu <feng...@databricks.com>

Closes #20099 from liufengdb/fix-spark-strategies.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cc30ef80
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cc30ef80
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cc30ef80

Branch: refs/heads/master
Commit: cc30ef8009b82c71a4b8e9caba82ed141761ab85
Parents: 224375c
Author: Feng Liu <feng...@databricks.com>
Authored: Fri Dec 29 18:48:47 2017 +0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Fri Dec 29 18:48:47 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   | 75 ++++++++++++--------
 .../execution/joins/BroadcastJoinSuite.scala    | 75 +++++++++++++++-----
 .../sql/execution/metric/SQLMetricsSuite.scala  | 15 ++--
 3 files changed, 116 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cc30ef80/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 6b3f301..0ed7c2f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -158,45 +158,65 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       def smallerSide =
         if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else 
BuildLeft
 
-      val buildRight = canBuildRight && right.stats.hints.broadcast
-      val buildLeft = canBuildLeft && left.stats.hints.broadcast
-
-      if (buildRight && buildLeft) {
+      if (canBuildRight && canBuildLeft) {
         // Broadcast smaller side base on its estimated physical size
         // if both sides have broadcast hint
         smallerSide
-      } else if (buildRight) {
+      } else if (canBuildRight) {
         BuildRight
-      } else if (buildLeft) {
+      } else if (canBuildLeft) {
         BuildLeft
-      } else if (canBuildRight && canBuildLeft) {
+      } else {
         // for the last default broadcast nested loop join
         smallerSide
-      } else {
-        throw new AnalysisException("Can not decide which side to broadcast 
for this join")
       }
     }
 
+    private def canBroadcastByHints(joinType: JoinType, left: LogicalPlan, 
right: LogicalPlan)
+      : Boolean = {
+      val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
+      val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
+      buildLeft || buildRight
+    }
+
+    private def broadcastSideByHints(joinType: JoinType, left: LogicalPlan, 
right: LogicalPlan)
+      : BuildSide = {
+      val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
+      val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
+      broadcastSide(buildLeft, buildRight, left, right)
+    }
+
+    private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan, 
right: LogicalPlan)
+      : Boolean = {
+      val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
+      val buildRight = canBuildRight(joinType) && canBroadcast(right)
+      buildLeft || buildRight
+    }
+
+    private def broadcastSideBySizes(joinType: JoinType, left: LogicalPlan, 
right: LogicalPlan)
+      : BuildSide = {
+      val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
+      val buildRight = canBuildRight(joinType) && canBroadcast(right)
+      broadcastSide(buildLeft, buildRight, left, right)
+    }
+
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
 
       // --- BroadcastHashJoin 
--------------------------------------------------------------------
 
+      // broadcast hints were specified
       case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right)
-        if (canBuildRight(joinType) && right.stats.hints.broadcast) ||
-          (canBuildLeft(joinType) && left.stats.hints.broadcast) =>
-        val buildSide = broadcastSide(canBuildLeft(joinType), 
canBuildRight(joinType), left, right)
+        if canBroadcastByHints(joinType, left, right) =>
+        val buildSide = broadcastSideByHints(joinType, left, right)
         Seq(joins.BroadcastHashJoinExec(
           leftKeys, rightKeys, joinType, buildSide, condition, 
planLater(left), planLater(right)))
 
+      // broadcast hints were not specified, so need to infer it from size and 
configuration.
       case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right)
-        if canBuildRight(joinType) && canBroadcast(right) =>
-        Seq(joins.BroadcastHashJoinExec(
-          leftKeys, rightKeys, joinType, BuildRight, condition, 
planLater(left), planLater(right)))
-
-      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, 
right)
-        if canBuildLeft(joinType) && canBroadcast(left) =>
+        if canBroadcastBySizes(joinType, left, right) =>
+        val buildSide = broadcastSideBySizes(joinType, left, right)
         Seq(joins.BroadcastHashJoinExec(
-          leftKeys, rightKeys, joinType, BuildLeft, condition, 
planLater(left), planLater(right)))
+          leftKeys, rightKeys, joinType, buildSide, condition, 
planLater(left), planLater(right)))
 
       // --- ShuffledHashJoin 
---------------------------------------------------------------------
 
@@ -225,27 +245,24 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
 
       // Pick BroadcastNestedLoopJoin if one side could be broadcasted
       case j @ logical.Join(left, right, joinType, condition)
-        if (canBuildRight(joinType) && right.stats.hints.broadcast) ||
-          (canBuildLeft(joinType) && left.stats.hints.broadcast) =>
-        val buildSide = broadcastSide(canBuildLeft(joinType), 
canBuildRight(joinType), left, right)
+          if canBroadcastByHints(joinType, left, right) =>
+        val buildSide = broadcastSideByHints(joinType, left, right)
         joins.BroadcastNestedLoopJoinExec(
           planLater(left), planLater(right), buildSide, joinType, condition) 
:: Nil
 
       case j @ logical.Join(left, right, joinType, condition)
-          if canBuildRight(joinType) && canBroadcast(right) =>
+          if canBroadcastBySizes(joinType, left, right) =>
+        val buildSide = broadcastSideBySizes(joinType, left, right)
         joins.BroadcastNestedLoopJoinExec(
-          planLater(left), planLater(right), BuildRight, joinType, condition) 
:: Nil
-      case j @ logical.Join(left, right, joinType, condition)
-          if canBuildLeft(joinType) && canBroadcast(left) =>
-        joins.BroadcastNestedLoopJoinExec(
-          planLater(left), planLater(right), BuildLeft, joinType, condition) 
:: Nil
+          planLater(left), planLater(right), buildSide, joinType, condition) 
:: Nil
 
       // Pick CartesianProduct for InnerJoin
       case logical.Join(left, right, _: InnerLike, condition) =>
         joins.CartesianProductExec(planLater(left), planLater(right), 
condition) :: Nil
 
       case logical.Join(left, right, joinType, condition) =>
-        val buildSide = broadcastSide(canBuildLeft = true, canBuildRight = 
true, left, right)
+        val buildSide = broadcastSide(
+          left.stats.hints.broadcast, right.stats.hints.broadcast, left, right)
         // This join could be very slow or OOM
         joins.BroadcastNestedLoopJoinExec(
           planLater(left), planLater(right), buildSide, joinType, condition) 
:: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/cc30ef80/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 67e2cdc..6da46ea 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -225,17 +225,6 @@ class BroadcastJoinSuite extends QueryTest with 
SQLTestUtils {
   }
 
   test("Shouldn't change broadcast join buildSide if user clearly specified") {
-    def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: 
BuildSide): Any = {
-      val executedPlan = sql(sqlStr).queryExecution.executedPlan
-      executedPlan match {
-        case b: BroadcastNestedLoopJoinExec =>
-          assert(b.getClass.getSimpleName === joinMethod)
-          assert(b.buildSide === buildSide)
-        case w: WholeStageCodegenExec =>
-          assert(w.children.head.getClass.getSimpleName === joinMethod)
-          assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide 
=== buildSide)
-      }
-    }
 
     withTempView("t1", "t2") {
       spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", 
"value").createTempView("t1")
@@ -246,9 +235,6 @@ class BroadcastJoinSuite extends QueryTest with 
SQLTestUtils {
       val t2Size = 
spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
       assert(t1Size < t2Size)
 
-      val bh = BroadcastHashJoinExec.toString
-      val bl = BroadcastNestedLoopJoinExec.toString
-
       // INNER JOIN && t1Size < t2Size => BuildLeft
       assertJoinBuildSide(
         "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", 
bh, BuildLeft)
@@ -266,8 +252,7 @@ class BroadcastJoinSuite extends QueryTest with 
SQLTestUtils {
         "SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, 
BuildRight)
 
 
-      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
-        SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
         // INNER JOIN && t1Size < t2Size => BuildLeft
         assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", 
bl, BuildLeft)
         // FULL JOIN && t1Size < t2Size => BuildLeft
@@ -290,4 +275,62 @@ class BroadcastJoinSuite extends QueryTest with 
SQLTestUtils {
       }
     }
   }
+
+  test("Shouldn't bias towards build right if user didn't specify") {
+
+    withTempView("t1", "t2") {
+      spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", 
"value").createTempView("t1")
+      spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, 
"123"))).toDF("key", "value")
+        .createTempView("t2")
+
+      val t1Size = 
spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
+      val t2Size = 
spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
+      assert(t1Size < t2Size)
+
+      assertJoinBuildSide("SELECT * FROM t1 JOIN t2 ON t1.key = t2.key", bh, 
BuildLeft)
+      assertJoinBuildSide("SELECT * FROM t2 JOIN t1 ON t1.key = t2.key", bh, 
BuildRight)
+
+      assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2 ON t1.key = t2.key", 
bh, BuildRight)
+      assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1 ON t1.key = t2.key", 
bh, BuildRight)
+
+      assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2 ON t1.key = t2.key", 
bh, BuildLeft)
+      assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1 ON t1.key = t2.key", 
bh, BuildLeft)
+
+      withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+        assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, 
BuildLeft)
+        assertJoinBuildSide("SELECT * FROM t2 FULL OUTER JOIN t1", bl, 
BuildRight)
+
+        assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2", bl, BuildRight)
+        assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1", bl, BuildRight)
+
+        assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2", bl, BuildLeft)
+        assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1", bl, BuildLeft)
+      }
+    }
+  }
+
+  private val bh = BroadcastHashJoinExec.toString
+  private val bl = BroadcastNestedLoopJoinExec.toString
+
+  private def assertJoinBuildSide(sqlStr: String, joinMethod: String, 
buildSide: BuildSide): Any = {
+    val executedPlan = sql(sqlStr).queryExecution.executedPlan
+    executedPlan match {
+      case b: BroadcastNestedLoopJoinExec =>
+        assert(b.getClass.getSimpleName === joinMethod)
+        assert(b.buildSide === buildSide)
+      case b: BroadcastNestedLoopJoinExec =>
+        assert(b.getClass.getSimpleName === joinMethod)
+        assert(b.buildSide === buildSide)
+      case w: WholeStageCodegenExec =>
+        assert(w.children.head.getClass.getSimpleName === joinMethod)
+        if (w.children.head.isInstanceOf[BroadcastNestedLoopJoinExec]) {
+          assert(
+            
w.children.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide === 
buildSide)
+        } else if (w.children.head.isInstanceOf[BroadcastHashJoinExec]) {
+          assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide 
=== buildSide)
+        } else {
+          fail()
+        }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/cc30ef80/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index fc34833..a3a3f38 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -478,15 +478,22 @@ class SQLMetricsSuite extends SparkFunSuite with 
SQLMetricsTestUtils with Shared
       spark.range(10).write.parquet(dir)
       spark.read.parquet(dir).createOrReplaceTempView("pqS")
 
+      // The executed plan looks like:
+      // Exchange RoundRobinPartitioning(2)
+      // +- BroadcastNestedLoopJoin BuildLeft, Cross
+      //   :- BroadcastExchange IdentityBroadcastMode
+      //   :  +- Exchange RoundRobinPartitioning(3)
+      //   :     +- *Range (0, 30, step=1, splits=2)
+      //   +- *FileScan parquet [id#465L] Batched: true, Format: Parquet, 
Location: ...(ignored)
       val res3 = InputOutputMetricsHelper.run(
         spark.range(30).repartition(3).crossJoin(sql("select * from 
pqS")).repartition(2).toDF()
       )
       // The query above is executed in the following stages:
-      //   1. sql("select * from pqS")    => (10, 0, 10)
-      //   2. range(30)                   => (30, 0, 30)
-      //   3. crossJoin(...) of 1. and 2. => (0, 30, 300)
+      //   1. range(30)                   => (30, 0, 30)
+      //   2. sql("select * from pqS")    => (0, 30, 0)
+      //   3. crossJoin(...) of 1. and 2. => (10, 0, 300)
       //   4. shuffle & return results    => (0, 300, 0)
-      assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: 
(0L, 300L, 0L) :: Nil)
+      assert(res3 === (30L, 0L, 30L) :: (0L, 30L, 0L) :: (10L, 0L, 300L) :: 
(0L, 300L, 0L) :: Nil)
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to