Repository: spark
Updated Branches:
  refs/heads/master df9adb5ec -> 223f63390


[SPARK-15415][SQL] Fix BroadcastHint when autoBroadcastJoinThreshold is 0 or -1

## What changes were proposed in this pull request?

This PR makes BroadcastHint more deterministic by using a special 
isBroadcastable property
instead of setting the sizeInBytes to 1.

See https://issues.apache.org/jira/browse/SPARK-15415

## How was this patch tested?

Added testcases to test if the broadcast hash join is included in the plan when 
the BroadcastHint is supplied and also tests for propagation of the joins.

Author: Jurriaan Pruis <em...@jurriaanpruis.nl>

Closes #13244 from jurriaan/broadcast-hint.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/223f6339
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/223f6339
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/223f6339

Branch: refs/heads/master
Commit: 223f6339088434eb3590c2f42091a38f05f1e5db
Parents: df9adb5
Author: Jurriaan Pruis <em...@jurriaanpruis.nl>
Authored: Sat May 21 23:01:14 2016 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Sat May 21 23:01:14 2016 -0700

----------------------------------------------------------------------
 .../catalyst/plans/logical/LogicalPlan.scala    |   3 +-
 .../sql/catalyst/plans/logical/Statistics.scala |   2 +-
 .../plans/logical/basicLogicalOperators.scala   |  29 ++++--
 .../spark/sql/execution/SparkStrategies.scala   |   3 +-
 .../execution/joins/BroadcastJoinSuite.scala    | 103 ++++++++++++++++---
 5 files changed, 114 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/223f6339/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 45ac126..4984f23 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -313,7 +313,8 @@ abstract class UnaryNode extends LogicalPlan {
       // (product of children).
       sizeInBytes = 1
     }
-    Statistics(sizeInBytes = sizeInBytes)
+
+    child.statistics.copy(sizeInBytes = sizeInBytes)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/223f6339/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index 9ac4c3a..63f86ad 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -32,4 +32,4 @@ package org.apache.spark.sql.catalyst.plans.logical
  * @param sizeInBytes Physical size in bytes. For leaf operators this defaults 
to 1, otherwise it
  *                    defaults to the product of children's `sizeInBytes`.
  */
-private[sql] case class Statistics(sizeInBytes: BigInt)
+private[sql] case class Statistics(sizeInBytes: BigInt, isBroadcastable: 
Boolean = false)

http://git-wip-us.apache.org/repos/asf/spark/blob/223f6339/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 732b0d7..bed48b6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -163,7 +163,9 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) 
extends SetOperation
     val leftSize = left.statistics.sizeInBytes
     val rightSize = right.statistics.sizeInBytes
     val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
-    Statistics(sizeInBytes = sizeInBytes)
+    val isBroadcastable = left.statistics.isBroadcastable || 
right.statistics.isBroadcastable
+
+    Statistics(sizeInBytes = sizeInBytes, isBroadcastable = isBroadcastable)
   }
 }
 
@@ -183,7 +185,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) 
extends SetOperation(le
       duplicateResolved
 
   override def statistics: Statistics = {
-    Statistics(sizeInBytes = left.statistics.sizeInBytes)
+    left.statistics.copy()
   }
 }
 
@@ -330,6 +332,16 @@ case class Join(
     case UsingJoin(_, _) => false
     case _ => resolvedExceptNatural
   }
+
+  override def statistics: Statistics = joinType match {
+    case LeftAnti | LeftSemi =>
+      // LeftSemi and LeftAnti won't ever be bigger than left
+      left.statistics.copy()
+    case _ =>
+      // make sure we don't propagate isBroadcastable in other joins, because
+      // they could explode the size.
+      super.statistics.copy(isBroadcastable = false)
+  }
 }
 
 /**
@@ -338,9 +350,8 @@ case class Join(
 case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
 
-  // We manually set statistics of BroadcastHint to smallest value to make sure
-  // the plan wrapped by BroadcastHint will be considered to broadcast later.
-  override def statistics: Statistics = Statistics(sizeInBytes = 1)
+  // set isBroadcastable to true so the child will be broadcasted
+  override def statistics: Statistics = super.statistics.copy(isBroadcastable 
= true)
 }
 
 case class InsertIntoTable(
@@ -465,7 +476,7 @@ case class Aggregate(
 
   override def statistics: Statistics = {
     if (groupingExpressions.isEmpty) {
-      Statistics(sizeInBytes = 1)
+      super.statistics.copy(sizeInBytes = 1)
     } else {
       super.statistics
     }
@@ -638,7 +649,7 @@ case class GlobalLimit(limitExpr: Expression, child: 
LogicalPlan) extends UnaryN
   override lazy val statistics: Statistics = {
     val limit = limitExpr.eval().asInstanceOf[Int]
     val sizeInBytes = (limit: Long) * output.map(a => 
a.dataType.defaultSize).sum
-    Statistics(sizeInBytes = sizeInBytes)
+    child.statistics.copy(sizeInBytes = sizeInBytes)
   }
 }
 
@@ -653,7 +664,7 @@ case class LocalLimit(limitExpr: Expression, child: 
LogicalPlan) extends UnaryNo
   override lazy val statistics: Statistics = {
     val limit = limitExpr.eval().asInstanceOf[Int]
     val sizeInBytes = (limit: Long) * output.map(a => 
a.dataType.defaultSize).sum
-    Statistics(sizeInBytes = sizeInBytes)
+    child.statistics.copy(sizeInBytes = sizeInBytes)
   }
 }
 
@@ -690,7 +701,7 @@ case class Sample(
     if (sizeInBytes == 0) {
       sizeInBytes = 1
     }
-    Statistics(sizeInBytes = sizeInBytes)
+    child.statistics.copy(sizeInBytes = sizeInBytes)
   }
 
   override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/223f6339/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3343039..664e7f5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -92,7 +92,8 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
      * Matches a plan whose output should be small enough to be used in 
broadcast join.
      */
     private def canBroadcast(plan: LogicalPlan): Boolean = {
-      plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold
+      plan.statistics.isBroadcastable ||
+        plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/spark/blob/223f6339/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 730ec43..e681b88 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,9 +22,12 @@ import scala.reflect.ClassTag
 import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
-import org.apache.spark.sql.{QueryTest, SparkSession}
+import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
+import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SQLTestUtils
 
 /**
  * Test various broadcast join operators.
@@ -33,7 +36,9 @@ import org.apache.spark.sql.functions._
  * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] 
is not triggered
  * without serializing the hashed relation, which does not happen in local 
mode.
  */
-class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
+class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
+  import testImplicits._
+
   protected var spark: SparkSession = null
 
   /**
@@ -56,30 +61,100 @@ class BroadcastJoinSuite extends QueryTest with 
BeforeAndAfterAll {
   /**
    * Test whether the specified broadcast join updates the peak execution 
memory accumulator.
    */
-  private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): 
Unit = {
+  private def testBroadcastJoinPeak[T: ClassTag](name: String, joinType: 
String): Unit = {
     AccumulatorSuite.verifyPeakExecutionMemorySet(spark.sparkContext, name) {
-      val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", 
"value")
-      val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", 
"value")
-      // Comparison at the end is for broadcast left semi join
-      val joinExpression = df1("key") === df2("key") && df1("value") > 
df2("value")
-      val df3 = df1.join(broadcast(df2), joinExpression, joinType)
-      val plan =
-        
EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan)
-      assert(plan.collect { case p: T => p }.size === 1)
+      val plan = testBroadcastJoin[T](joinType)
       plan.executeCollect()
     }
   }
 
+  private def testBroadcastJoin[T: ClassTag](joinType: String,
+                                             forceBroadcast: Boolean = false): 
SparkPlan = {
+    val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", 
"value")
+    var df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", 
"value")
+
+    // Comparison at the end is for broadcast left semi join
+    val joinExpression = df1("key") === df2("key") && df1("value") > 
df2("value")
+    val df3 = if (forceBroadcast) {
+      df1.join(broadcast(df2), joinExpression, joinType)
+    } else {
+      df1.join(df2, joinExpression, joinType)
+    }
+    val plan =
+      
EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan)
+    assert(plan.collect { case p: T => p }.size === 1)
+
+    return plan
+  }
+
   test("unsafe broadcast hash join updates peak execution memory") {
-    testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast hash join", 
"inner")
+    testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast hash join", 
"inner")
   }
 
   test("unsafe broadcast hash outer join updates peak execution memory") {
-    testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast hash outer 
join", "left_outer")
+    testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast hash outer 
join", "left_outer")
   }
 
   test("unsafe broadcast left semi join updates peak execution memory") {
-    testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast left semi 
join", "leftsemi")
+    testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast left semi 
join", "leftsemi")
+  }
+
+  test("broadcast hint isn't bothered by authBroadcastJoinThreshold set to low 
values") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+      testBroadcastJoin[BroadcastHashJoinExec]("inner", true)
+    }
+  }
+
+  test("broadcast hint isn't bothered by a disabled 
authBroadcastJoinThreshold") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      testBroadcastJoin[BroadcastHashJoinExec]("inner", true)
+    }
+  }
+
+  test("broadcast hint isn't propagated after a join") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", 
"value")
+      val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", 
"value")
+      val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key"))
+
+      val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", 
"value")
+      val df5 = df4.join(df3, Seq("key"), "inner")
+
+      val plan =
+        
EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
+
+      assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
+      assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1)
+    }
   }
 
+  private def assertBroadcastJoin(df : Dataset[Row]) : Unit = {
+    val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", 
"value")
+    val joined = df1.join(df, Seq("key"), "inner")
+
+    val plan =
+      
EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
+
+    assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
+  }
+
+  test("broadcast hint is propagated correctly") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, 
"2"))).toDF("key", "value")
+      val broadcasted = broadcast(df2)
+      val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", 
"value")
+
+      val cases = Seq(broadcasted.limit(2),
+                      broadcasted.filter("value < 10"),
+                      broadcasted.sample(true, 0.5),
+                      broadcasted.distinct(),
+                      broadcasted.groupBy("value").agg(min($"key").as("key")),
+                      // except and intersect are semi/anti-joins which won't 
return more data then
+                      // their left argument, so the broadcast hint should be 
propagated here
+                      broadcasted.except(df3),
+                      broadcasted.intersect(df3))
+
+      cases.foreach(assertBroadcastJoin)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to