Repository: spark
Updated Branches:
  refs/heads/master 8f29b7caf -> 07fa1910d


[SPARK-4570][SQL]add BroadcastLeftSemiJoinHash

JIRA issue: [SPARK-4570](https://issues.apache.org/jira/browse/SPARK-4570)
We are planning to create a `BroadcastLeftSemiJoinHash` to implement the 
broadcast join for `left semijoin`
In left semijoin :
If the size of data from right side is smaller than the user-settable threshold 
`AUTO_BROADCASTJOIN_THRESHOLD`,
the planner would mark it as the `broadcast` relation and mark the other 
relation as the stream side. The broadcast table will be broadcasted to all of 
the executors involved in the join, as a `org.apache.spark.broadcast.Broadcast` 
object. It will use `joins.BroadcastLeftSemiJoinHash`.,else it will use 
`joins.LeftSemiJoinHash`.

The benchmark suggests these  made the optimized version 4x faster  when `left 
semijoin`
<pre><code>
Original:
left semi join : 9288 ms
Optimized:
left semi join : 1963 ms
</code></pre>
The micro benchmark load `data1/kv3.txt` into a normal Hive table.
Benchmark code:
<pre><code>
 def benchmark(f: => Unit) = {
    val begin = System.currentTimeMillis()
    f
    val end = System.currentTimeMillis()
    end - begin
  }
  val sc = new SparkContext(
    new SparkConf()
      .setMaster("local")
      .setAppName(getClass.getSimpleName.stripSuffix("$")))
  val hiveContext = new HiveContext(sc)
  import hiveContext._
  sql("drop table if exists left_table")
  sql("drop table if exists right_table")
  sql( """create table left_table (key int, value string)
       """.stripMargin)
  sql( s"""load data local inpath "/data1/kv3.txt" into table left_table""")
  sql( """create table right_table (key int, value string)
       """.stripMargin)
  sql(
    """
      |from left_table
      |insert overwrite table right_table
      |select left_table.key, left_table.value
    """.stripMargin)

  val leftSimeJoin = sql(
    """select a.key from left_table a
      |left semi join right_table b on a.key = b.key""".stripMargin)
  val leftSemiJoinDuration = benchmark(leftSimeJoin.count())
  println(s"left semi join : $leftSemiJoinDuration ms ")
</code></pre>

Author: wangxiaojing <u9j...@gmail.com>

Closes #3442 from wangxiaojing/SPARK-4570 and squashes the following commits:

a4a43c9 [wangxiaojing] rebase
f103983 [wangxiaojing] change style
fbe4887 [wangxiaojing] change style
ff2e618 [wangxiaojing] add testsuite
1a8da2a [wangxiaojing] add BroadcastLeftSemiJoinHash


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/07fa1910
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/07fa1910
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/07fa1910

Branch: refs/heads/master
Commit: 07fa1910d9c4092d670381c447403105f01c584e
Parents: 8f29b7c
Author: wangxiaojing <u9j...@gmail.com>
Authored: Tue Dec 30 13:54:12 2014 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Tue Dec 30 13:54:12 2014 -0800

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   |  6 ++
 .../joins/BroadcastLeftSemiJoinHash.scala       | 67 ++++++++++++++++++++
 .../scala/org/apache/spark/sql/JoinSuite.scala  | 38 +++++++++++
 .../apache/spark/sql/hive/StatisticsSuite.scala | 50 ++++++++++++++-
 4 files changed, 160 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/07fa1910/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 9151da6..ce878c1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -33,6 +33,12 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
 
   object LeftSemiJoin extends Strategy with PredicateHelper {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+      case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, 
right)
+        if sqlContext.autoBroadcastJoinThreshold > 0 &&
+          right.statistics.sizeInBytes <= 
sqlContext.autoBroadcastJoinThreshold =>
+        val semiJoin = joins.BroadcastLeftSemiJoinHash(
+          leftKeys, rightKeys, planLater(left), planLater(right))
+        condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
       // Find left semi joins where at least some predicates can be evaluated 
by matching join keys
       case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, 
right) =>
         val semiJoin = joins.LeftSemiJoinHash(

http://git-wip-us.apache.org/repos/asf/spark/blob/07fa1910/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
new file mode 100644
index 0000000..2ab064f
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.{Expression, Row}
+import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Build the right table's join keys into a HashSet, and iteratively go 
through the left
+ * table, to find the if join keys are in the Hash set.
+ */
+@DeveloperApi
+case class BroadcastLeftSemiJoinHash(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    left: SparkPlan,
+    right: SparkPlan) extends BinaryNode with HashJoin {
+
+  override val buildSide = BuildRight
+
+  override def output = left.output
+
+  override def execute() = {
+    val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator
+    val hashSet = new java.util.HashSet[Row]()
+    var currentRow: Row = null
+
+    // Create a Hash set of buildKeys
+    while (buildIter.hasNext) {
+      currentRow = buildIter.next()
+      val rowKey = buildSideKeyGenerator(currentRow)
+      if (!rowKey.anyNull) {
+        val keyExists = hashSet.contains(rowKey)
+        if (!keyExists) {
+          hashSet.add(rowKey)
+        }
+      }
+    }
+
+    val broadcastedRelation = sparkContext.broadcast(hashSet)
+
+    streamedPlan.execute().mapPartitions { streamIter =>
+      val joinKeys = streamSideKeyGenerator()
+      streamIter.filter(current => {
+        !joinKeys(current).anyNull && 
broadcastedRelation.value.contains(joinKeys.currentValue)
+      })
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/07fa1910/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 0378fd7..1a4232d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       case j: LeftSemiJoinBNL => j
       case j: CartesianProduct => j
       case j: BroadcastNestedLoopJoin => j
+      case j: BroadcastLeftSemiJoinHash => j
     }
 
     assert(operators.size === 1)
@@ -382,4 +383,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
         """.stripMargin),
       (null, 10) :: Nil)
   }
+
+  test("broadcasted left semi join operator selection") {
+    clearCache()
+    sql("CACHE TABLE testData")
+    val tmp = autoBroadcastJoinThreshold
+
+    sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000")
+    Seq(
+      ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
+        classOf[BroadcastLeftSemiJoinHash])
+    ).foreach {
+      case (query, joinClass) => assertJoin(query, joinClass)
+    }
+
+    sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
+
+    Seq(
+      ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", 
classOf[LeftSemiJoinHash])
+    ).foreach {
+      case (query, joinClass) => assertJoin(query, joinClass)
+    }
+
+    setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString)
+    sql("UNCACHE TABLE testData")
+  }
+
+  test("left semi join") {
+    val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
+    checkAnswer(rdd,
+      (1, 1) ::
+      (1, 2) ::
+      (2, 1) ::
+      (2, 2) ::
+      (3, 1) ::
+      (3, 2) :: Nil)
+
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/07fa1910/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index ff4071d..4b6a930 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll
 import scala.reflect.ClassTag
 
 import org.apache.spark.sql.{SQLConf, QueryTest}
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, 
ShuffledHashJoin}
+import org.apache.spark.sql.execution.joins._
 import org.apache.spark.sql.hive.test.TestHive
 import org.apache.spark.sql.hive.test.TestHive._
 import org.apache.spark.sql.hive.execution._
@@ -193,4 +193,52 @@ class StatisticsSuite extends QueryTest with 
BeforeAndAfterAll {
     )
   }
 
+  test("auto converts to broadcast left semi join, by size estimate of a 
relation") {
+    val leftSemiJoinQuery =
+      """SELECT * FROM src a
+        |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin
+    val answer = (86, "val_86") :: Nil
+
+    var rdd = sql(leftSemiJoinQuery)
+
+    // Assert src has a size smaller than the threshold.
+    val sizes = rdd.queryExecution.analyzed.collect {
+      case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass
+        .isAssignableFrom(r.getClass) =>
+        r.statistics.sizeInBytes
+    }
+    assert(sizes.size === 2 && sizes(1) <= autoBroadcastJoinThreshold
+      && sizes(0) <= autoBroadcastJoinThreshold,
+      s"query should contain two relations, each of which has size smaller 
than autoConvertSize")
+
+    // Using `sparkPlan` because for relevant patterns in HashJoin to be
+    // matched, other strategies need to be applied.
+    var bhj = rdd.queryExecution.sparkPlan.collect {
+      case j: BroadcastLeftSemiJoinHash => j
+    }
+    assert(bhj.size === 1,
+      s"actual query plans do not contain broadcast join: 
${rdd.queryExecution}")
+
+    checkAnswer(rdd, answer) // check correctness of output
+
+    TestHive.settings.synchronized {
+      val tmp = autoBroadcastJoinThreshold
+
+      sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
+      rdd = sql(leftSemiJoinQuery)
+      bhj = rdd.queryExecution.sparkPlan.collect {
+        case j: BroadcastLeftSemiJoinHash => j
+      }
+      assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is 
switched off")
+
+      val shj = rdd.queryExecution.sparkPlan.collect {
+        case j: LeftSemiJoinHash => j
+      }
+      assert(shj.size === 1,
+        "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned 
off")
+
+      sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp")
+    }
+
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to