Repository: spark Updated Branches: refs/heads/master 224375c55 -> cc30ef800
[SPARK-22916][SQL] shouldn't bias towards build right if user does not specify ## What changes were proposed in this pull request? When there are no broadcast hints, the current spark strategies will prefer to building the right side, without considering the sizes of the two tables. This patch added the logic to consider the sizes of the two tables for the build side. To make the logic clear, the build side is determined by two steps: 1. If there are broadcast hints, the build side is determined by `broadcastSideByHints`; 2. If there are no broadcast hints, the build side is determined by `broadcastSideBySizes`; 3. If the broadcast is disabled by the config, it falls back to the next cases. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu <feng...@databricks.com> Closes #20099 from liufengdb/fix-spark-strategies. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cc30ef80 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cc30ef80 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cc30ef80 Branch: refs/heads/master Commit: cc30ef8009b82c71a4b8e9caba82ed141761ab85 Parents: 224375c Author: Feng Liu <feng...@databricks.com> Authored: Fri Dec 29 18:48:47 2017 +0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Fri Dec 29 18:48:47 2017 +0800 ---------------------------------------------------------------------- .../spark/sql/execution/SparkStrategies.scala | 75 ++++++++++++-------- .../execution/joins/BroadcastJoinSuite.scala | 75 +++++++++++++++----- .../sql/execution/metric/SQLMetricsSuite.scala | 15 ++-- 3 files changed, 116 insertions(+), 49 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/cc30ef80/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6b3f301..0ed7c2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -158,45 +158,65 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def smallerSide = if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft - val buildRight = canBuildRight && right.stats.hints.broadcast - val buildLeft = canBuildLeft && left.stats.hints.broadcast - - if (buildRight && buildLeft) { + if (canBuildRight && canBuildLeft) { // Broadcast smaller side base on its estimated physical size // if both sides have broadcast hint smallerSide - } else if (buildRight) { + } else if (canBuildRight) { BuildRight - } else if (buildLeft) { + } else if (canBuildLeft) { BuildLeft - } else if (canBuildRight && canBuildLeft) { + } else { // for the last default broadcast nested loop join smallerSide - } else { - throw new AnalysisException("Can not decide which side to broadcast for this join") } } + private def canBroadcastByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : Boolean = { + val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast + val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast + buildLeft || buildRight + } + + private def broadcastSideByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : BuildSide = { + val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast + val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast + broadcastSide(buildLeft, buildRight, left, right) + } + + private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : Boolean = { + val buildLeft = canBuildLeft(joinType) && canBroadcast(left) + val buildRight = canBuildRight(joinType) && canBroadcast(right) + buildLeft || buildRight + } + + private def broadcastSideBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : BuildSide = { + val buildLeft = canBuildLeft(joinType) && canBroadcast(left) + val buildRight = canBuildRight(joinType) && canBroadcast(right) + broadcastSide(buildLeft, buildRight, left, right) + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- BroadcastHashJoin -------------------------------------------------------------------- + // broadcast hints were specified case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if (canBuildRight(joinType) && right.stats.hints.broadcast) || - (canBuildLeft(joinType) && left.stats.hints.broadcast) => - val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + if canBroadcastByHints(joinType, left, right) => + val buildSide = broadcastSideByHints(joinType, left, right) Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + // broadcast hints were not specified, so need to infer it from size and configuration. case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if canBuildRight(joinType) && canBroadcast(right) => - Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) - - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if canBuildLeft(joinType) && canBroadcast(left) => + if canBroadcastBySizes(joinType, left, right) => + val buildSide = broadcastSideBySizes(joinType, left, right) Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) // --- ShuffledHashJoin --------------------------------------------------------------------- @@ -225,27 +245,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Pick BroadcastNestedLoopJoin if one side could be broadcasted case j @ logical.Join(left, right, joinType, condition) - if (canBuildRight(joinType) && right.stats.hints.broadcast) || - (canBuildLeft(joinType) && left.stats.hints.broadcast) => - val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + if canBroadcastByHints(joinType, left, right) => + val buildSide = broadcastSideByHints(joinType, left, right) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case j @ logical.Join(left, right, joinType, condition) - if canBuildRight(joinType) && canBroadcast(right) => + if canBroadcastBySizes(joinType, left, right) => + val buildSide = broadcastSideBySizes(joinType, left, right) joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil - case j @ logical.Join(left, right, joinType, condition) - if canBuildLeft(joinType) && canBroadcast(left) => - joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil // Pick CartesianProduct for InnerJoin case logical.Join(left, right, _: InnerLike, condition) => joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil case logical.Join(left, right, joinType, condition) => - val buildSide = broadcastSide(canBuildLeft = true, canBuildRight = true, left, right) + val buildSide = broadcastSide( + left.stats.hints.broadcast, right.stats.hints.broadcast, left, right) // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/cc30ef80/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 67e2cdc..6da46ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -225,17 +225,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } test("Shouldn't change broadcast join buildSide if user clearly specified") { - def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = { - val executedPlan = sql(sqlStr).queryExecution.executedPlan - executedPlan match { - case b: BroadcastNestedLoopJoinExec => - assert(b.getClass.getSimpleName === joinMethod) - assert(b.buildSide === buildSide) - case w: WholeStageCodegenExec => - assert(w.children.head.getClass.getSimpleName === joinMethod) - assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide) - } - } withTempView("t1", "t2") { spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") @@ -246,9 +235,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes assert(t1Size < t2Size) - val bh = BroadcastHashJoinExec.toString - val bl = BroadcastNestedLoopJoinExec.toString - // INNER JOIN && t1Size < t2Size => BuildLeft assertJoinBuildSide( "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) @@ -266,8 +252,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { "SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildRight) - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { // INNER JOIN && t1Size < t2Size => BuildLeft assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", bl, BuildLeft) // FULL JOIN && t1Size < t2Size => BuildLeft @@ -290,4 +275,62 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } } + + test("Shouldn't bias towards build right if user didn't specify") { + + withTempView("t1", "t2") { + spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") + spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") + .createTempView("t2") + + val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes + val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes + assert(t1Size < t2Size) + + assertJoinBuildSide("SELECT * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 JOIN t1 ON t1.key = t2.key", bh, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2 ON t1.key = t2.key", bh, BuildRight) + assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1 ON t1.key = t2.key", bh, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1 ON t1.key = t2.key", bh, BuildLeft) + + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 FULL OUTER JOIN t1", bl, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2", bl, BuildRight) + assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1", bl, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2", bl, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1", bl, BuildLeft) + } + } + } + + private val bh = BroadcastHashJoinExec.toString + private val bl = BroadcastNestedLoopJoinExec.toString + + private def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = { + val executedPlan = sql(sqlStr).queryExecution.executedPlan + executedPlan match { + case b: BroadcastNestedLoopJoinExec => + assert(b.getClass.getSimpleName === joinMethod) + assert(b.buildSide === buildSide) + case b: BroadcastNestedLoopJoinExec => + assert(b.getClass.getSimpleName === joinMethod) + assert(b.buildSide === buildSide) + case w: WholeStageCodegenExec => + assert(w.children.head.getClass.getSimpleName === joinMethod) + if (w.children.head.isInstanceOf[BroadcastNestedLoopJoinExec]) { + assert( + w.children.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide === buildSide) + } else if (w.children.head.isInstanceOf[BroadcastHashJoinExec]) { + assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide) + } else { + fail() + } + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/cc30ef80/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index fc34833..a3a3f38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -478,15 +478,22 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared spark.range(10).write.parquet(dir) spark.read.parquet(dir).createOrReplaceTempView("pqS") + // The executed plan looks like: + // Exchange RoundRobinPartitioning(2) + // +- BroadcastNestedLoopJoin BuildLeft, Cross + // :- BroadcastExchange IdentityBroadcastMode + // : +- Exchange RoundRobinPartitioning(3) + // : +- *Range (0, 30, step=1, splits=2) + // +- *FileScan parquet [id#465L] Batched: true, Format: Parquet, Location: ...(ignored) val res3 = InputOutputMetricsHelper.run( spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF() ) // The query above is executed in the following stages: - // 1. sql("select * from pqS") => (10, 0, 10) - // 2. range(30) => (30, 0, 30) - // 3. crossJoin(...) of 1. and 2. => (0, 30, 300) + // 1. range(30) => (30, 0, 30) + // 2. sql("select * from pqS") => (0, 30, 0) + // 3. crossJoin(...) of 1. and 2. => (10, 0, 300) // 4. shuffle & return results => (0, 300, 0) - assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil) + assert(res3 === (30L, 0L, 30L) :: (0L, 30L, 0L) :: (10L, 0L, 300L) :: (0L, 300L, 0L) :: Nil) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org