Repository: spark
Updated Branches:
  refs/heads/master a5ef2d060 -> d88abb7e2


[SPARK-9990] [SQL] Create local hash join operator

This PR includes the following changes:
- Add SQLConf to LocalNode
- Add HashJoinNode
- Add ConvertToUnsafeNode and ConvertToSafeNode.scala to test unsafe hash join.

Author: zsxwing <zsxw...@gmail.com>

Closes #8535 from zsxwing/SPARK-9990.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d88abb7e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d88abb7e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d88abb7e

Branch: refs/heads/master
Commit: d88abb7e212fb55f9b0398a0f76a753c86b85cf1
Parents: a5ef2d0
Author: zsxwing <zsxw...@gmail.com>
Authored: Thu Sep 10 12:06:49 2015 -0700
Committer: Andrew Or <and...@databricks.com>
Committed: Thu Sep 10 12:06:49 2015 -0700

----------------------------------------------------------------------
 .../sql/execution/joins/HashedRelation.scala    |   4 +-
 .../sql/execution/local/ConvertToSafeNode.scala |  40 ++++++
 .../execution/local/ConvertToUnsafeNode.scala   |  40 ++++++
 .../spark/sql/execution/local/FilterNode.scala  |   4 +-
 .../sql/execution/local/HashJoinNode.scala      | 137 +++++++++++++++++++
 .../spark/sql/execution/local/LimitNode.scala   |   3 +-
 .../spark/sql/execution/local/LocalNode.scala   |  83 ++++++++++-
 .../spark/sql/execution/local/ProjectNode.scala |   4 +-
 .../spark/sql/execution/local/SeqScanNode.scala |   4 +-
 .../spark/sql/execution/local/UnionNode.scala   |   3 +-
 .../sql/execution/local/FilterNodeSuite.scala   |   4 +-
 .../sql/execution/local/HashJoinNodeSuite.scala | 130 ++++++++++++++++++
 .../sql/execution/local/LimitNodeSuite.scala    |   4 +-
 .../sql/execution/local/LocalNodeTest.scala     |   9 +-
 .../sql/execution/local/ProjectNodeSuite.scala  |   4 +-
 .../sql/execution/local/UnionNodeSuite.scala    |   6 +-
 16 files changed, 455 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 6c0196c..0cff21c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -38,7 +38,7 @@ import org.apache.spark.{SparkConf, SparkEnv}
  * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] 
to create a concrete
  * object.
  */
-private[joins] sealed trait HashedRelation {
+private[execution] sealed trait HashedRelation {
   def get(key: InternalRow): Seq[InternalRow]
 
   // This is a helper method to implement Externalizable, and is used by
@@ -111,7 +111,7 @@ final class UniqueKeyHashedRelation(private var hashTable: 
JavaHashMap[InternalR
 // TODO(rxin): a version of [[HashedRelation]] backed by arrays for 
consecutive integer keys.
 
 
-private[joins] object HashedRelation {
+private[execution] object HashedRelation {
 
   def apply(
       input: Iterator[InternalRow],

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala
new file mode 100644
index 0000000..b31c5a8
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala
@@ -0,0 +1,40 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
FromUnsafeProjection, Projection}
+
+case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends 
UnaryLocalNode(conf) {
+
+  override def output: Seq[Attribute] = child.output
+
+  private[this] var convertToSafe: Projection = _
+
+  override def open(): Unit = {
+    child.open()
+    convertToSafe = FromUnsafeProjection(child.schema)
+  }
+
+  override def next(): Boolean = child.next()
+
+  override def fetch(): InternalRow = convertToSafe(child.fetch())
+
+  override def close(): Unit = child.close()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala
new file mode 100644
index 0000000..de2f4e6
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala
@@ -0,0 +1,40 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, 
UnsafeProjection}
+
+case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends 
UnaryLocalNode(conf) {
+
+  override def output: Seq[Attribute] = child.output
+
+  private[this] var convertToUnsafe: Projection = _
+
+  override def open(): Unit = {
+    child.open()
+    convertToUnsafe = UnsafeProjection.create(child.schema)
+  }
+
+  override def next(): Boolean = child.next()
+
+  override def fetch(): InternalRow = convertToUnsafe(child.fetch())
+
+  override def close(): Unit = child.close()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala
index 81dd37c..dd1113b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala
@@ -17,12 +17,14 @@
 
 package org.apache.spark.sql.execution.local
 
+import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
 
 
-case class FilterNode(condition: Expression, child: LocalNode) extends 
UnaryLocalNode {
+case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode)
+  extends UnaryLocalNode(conf) {
 
   private[this] var predicate: (InternalRow) => Boolean = _
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
new file mode 100644
index 0000000..a3e68d6
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
@@ -0,0 +1,137 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.joins._
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+/**
+ * Much of this code is similar to 
[[org.apache.spark.sql.execution.joins.HashJoin]].
+ */
+case class HashJoinNode(
+    conf: SQLConf,
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    buildSide: BuildSide,
+    left: LocalNode,
+    right: LocalNode) extends BinaryLocalNode(conf) {
+
+  private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = 
buildSide match {
+    case BuildLeft => (left, leftKeys, right, rightKeys)
+    case BuildRight => (right, rightKeys, left, leftKeys)
+  }
+
+  private[this] var currentStreamedRow: InternalRow = _
+  private[this] var currentHashMatches: Seq[InternalRow] = _
+  private[this] var currentMatchPosition: Int = -1
+
+  private[this] var joinRow: JoinedRow = _
+  private[this] var resultProjection: (InternalRow) => InternalRow = _
+
+  private[this] var hashed: HashedRelation = _
+  private[this] var joinKeys: Projection = _
+
+  override def output: Seq[Attribute] = left.output ++ right.output
+
+  private[this] def isUnsafeMode: Boolean = {
+    (codegenEnabled && unsafeEnabled
+      && UnsafeProjection.canSupport(buildKeys)
+      && UnsafeProjection.canSupport(schema))
+  }
+
+  private[this] def buildSideKeyGenerator: Projection = {
+    if (isUnsafeMode) {
+      UnsafeProjection.create(buildKeys, buildNode.output)
+    } else {
+      newMutableProjection(buildKeys, buildNode.output)()
+    }
+  }
+
+  private[this] def streamSideKeyGenerator: Projection = {
+    if (isUnsafeMode) {
+      UnsafeProjection.create(streamedKeys, streamedNode.output)
+    } else {
+      newMutableProjection(streamedKeys, streamedNode.output)()
+    }
+  }
+
+  override def open(): Unit = {
+    buildNode.open()
+    hashed = HashedRelation.apply(
+      new LocalNodeIterator(buildNode), SQLMetrics.nullLongMetric, 
buildSideKeyGenerator)
+    streamedNode.open()
+    joinRow = new JoinedRow
+    resultProjection = {
+      if (isUnsafeMode) {
+        UnsafeProjection.create(schema)
+      } else {
+        identity[InternalRow]
+      }
+    }
+    joinKeys = streamSideKeyGenerator
+  }
+
+  override def next(): Boolean = {
+    currentMatchPosition += 1
+    if (currentHashMatches == null || currentMatchPosition >= 
currentHashMatches.size) {
+      fetchNextMatch()
+    } else {
+      true
+    }
+  }
+
+  /**
+   * Populate `currentHashMatches` with build-side rows matching the next 
streamed row.
+   * @return whether matches are found such that subsequent calls to `fetch` 
are valid.
+   */
+  private def fetchNextMatch(): Boolean = {
+    currentHashMatches = null
+    currentMatchPosition = -1
+
+    while (currentHashMatches == null && streamedNode.next()) {
+      currentStreamedRow = streamedNode.fetch()
+      val key = joinKeys(currentStreamedRow)
+      if (!key.anyNull) {
+        currentHashMatches = hashed.get(key)
+      }
+    }
+
+    if (currentHashMatches == null) {
+      false
+    } else {
+      currentMatchPosition = 0
+      true
+    }
+  }
+
+  override def fetch(): InternalRow = {
+    val ret = buildSide match {
+      case BuildRight => joinRow(currentStreamedRow, 
currentHashMatches(currentMatchPosition))
+      case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), 
currentStreamedRow)
+    }
+    resultProjection(ret)
+  }
+
+  override def close(): Unit = {
+    left.close()
+    right.close()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala
index fffc52a..401b10a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala
@@ -17,11 +17,12 @@
 
 package org.apache.spark.sql.execution.local
 
+import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 
 
-case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode {
+case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends 
UnaryLocalNode(conf) {
 
   private[this] var count = 0
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
index 1c4469a..c4f8ae3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
@@ -17,9 +17,13 @@
 
 package org.apache.spark.sql.execution.local
 
-import org.apache.spark.sql.Row
+import scala.util.control.NonFatal
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.{SQLConf, Row}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
 import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.types.StructType
 
@@ -29,7 +33,15 @@ import org.apache.spark.sql.types.StructType
  * Before consuming the iterator, open function must be called.
  * After consuming the iterator, close function must be called.
  */
-abstract class LocalNode extends TreeNode[LocalNode] {
+abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with 
Logging {
+
+  protected val codegenEnabled: Boolean = conf.codegenEnabled
+
+  protected val unsafeEnabled: Boolean = conf.unsafeEnabled
+
+  lazy val schema: StructType = StructType.fromAttributes(output)
+
+  private[this] lazy val isTesting: Boolean = 
sys.props.contains("spark.testing")
 
   def output: Seq[Attribute]
 
@@ -73,17 +85,78 @@ abstract class LocalNode extends TreeNode[LocalNode] {
     }
     result
   }
+
+  protected def newMutableProjection(
+      expressions: Seq[Expression],
+      inputSchema: Seq[Attribute]): () => MutableProjection = {
+    log.debug(
+      s"Creating MutableProj: $expressions, inputSchema: $inputSchema, 
codegen:$codegenEnabled")
+    if (codegenEnabled) {
+      try {
+        GenerateMutableProjection.generate(expressions, inputSchema)
+      } catch {
+        case NonFatal(e) =>
+          if (isTesting) {
+            throw e
+          } else {
+            log.error("Failed to generate mutable projection, fallback to 
interpreted", e)
+            () => new InterpretedMutableProjection(expressions, inputSchema)
+          }
+      }
+    } else {
+      () => new InterpretedMutableProjection(expressions, inputSchema)
+    }
+  }
+
 }
 
 
-abstract class LeafLocalNode extends LocalNode {
+abstract class LeafLocalNode(conf: SQLConf) extends LocalNode(conf) {
   override def children: Seq[LocalNode] = Seq.empty
 }
 
 
-abstract class UnaryLocalNode extends LocalNode {
+abstract class UnaryLocalNode(conf: SQLConf) extends LocalNode(conf) {
 
   def child: LocalNode
 
   override def children: Seq[LocalNode] = Seq(child)
 }
+
+abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) {
+
+  def left: LocalNode
+
+  def right: LocalNode
+
+  override def children: Seq[LocalNode] = Seq(left, right)
+}
+
+/**
+ * An thin wrapper around a [[LocalNode]] that provides an `Iterator` 
interface.
+ */
+private[local] class LocalNodeIterator(localNode: LocalNode) extends 
Iterator[InternalRow] {
+  private var nextRow: InternalRow = _
+
+  override def hasNext: Boolean = {
+    if (nextRow == null) {
+      val res = localNode.next()
+      if (res) {
+        nextRow = localNode.fetch()
+      }
+      res
+    } else {
+      true
+    }
+  }
+
+  override def next(): InternalRow = {
+    if (hasNext) {
+      val res = nextRow
+      nextRow = null
+      res
+    } else {
+      throw new NoSuchElementException
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala
index 9b8a4fe..11529d6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala
@@ -17,11 +17,13 @@
 
 package org.apache.spark.sql.execution.local
 
+import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, 
NamedExpression}
 
 
-case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) 
extends UnaryLocalNode {
+case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], 
child: LocalNode)
+  extends UnaryLocalNode(conf) {
 
   private[this] var project: UnsafeProjection = _
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala
index 242cb66..b8467f6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.sql.execution.local
 
+import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 
 /**
  * An operator that scans some local data collection in the form of Scala Seq.
  */
-case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends 
LeafLocalNode {
+case class SeqScanNode(conf: SQLConf, output: Seq[Attribute], data: 
Seq[InternalRow])
+  extends LeafLocalNode(conf) {
 
   private[this] var iterator: Iterator[InternalRow] = _
   private[this] var currentRow: InternalRow = _

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala
index ba4aa76..0f2b830 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.sql.execution.local
 
+import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 
-case class UnionNode(children: Seq[LocalNode]) extends LocalNode {
+case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends 
LocalNode(conf) {
 
   override def output: Seq[Attribute] = children.head.output
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
index 07209f3..a12670e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
@@ -25,7 +25,7 @@ class FilterNodeSuite extends LocalNodeTest with 
SharedSQLContext {
     val condition = (testData.col("key") % 2) === 0
     checkAnswer(
       testData,
-      node => FilterNode(condition.expr, node),
+      node => FilterNode(conf, condition.expr, node),
       testData.filter(condition).collect()
     )
   }
@@ -34,7 +34,7 @@ class FilterNodeSuite extends LocalNodeTest with 
SharedSQLContext {
     val condition = (emptyTestData.col("key") % 2) === 0
     checkAnswer(
       emptyTestData,
-      node => FilterNode(condition.expr, node),
+      node => FilterNode(conf, condition.expr, node),
       emptyTestData.filter(condition).collect()
     )
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
new file mode 100644
index 0000000..43b6f06
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
@@ -0,0 +1,130 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.execution.joins
+
+class HashJoinNodeSuite extends LocalNodeTest {
+
+  import testImplicits._
+
+  private def wrapForUnsafe(
+      f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => 
LocalNode = {
+    if (conf.unsafeEnabled) {
+      (left: LocalNode, right: LocalNode) => {
+        val _left = ConvertToUnsafeNode(conf, left)
+        val _right = ConvertToUnsafeNode(conf, right)
+        val r = f(_left, _right)
+        ConvertToSafeNode(conf, r)
+      }
+    } else {
+      f
+    }
+  }
+
+  def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = {
+    test(s"$suiteName: inner join with one match per row") {
+      withSQLConf(confPairs: _*) {
+        checkAnswer2(
+          upperCaseData,
+          lowerCaseData,
+          wrapForUnsafe(
+            (node1, node2) => HashJoinNode(
+              conf,
+              Seq(upperCaseData.col("N").expr),
+              Seq(lowerCaseData.col("n").expr),
+              joins.BuildLeft,
+              node1,
+              node2)
+          ),
+          upperCaseData.join(lowerCaseData, $"n" === $"N").collect()
+        )
+      }
+    }
+
+    test(s"$suiteName: inner join with multiple matches") {
+      withSQLConf(confPairs: _*) {
+        val x = testData2.where($"a" === 1).as("x")
+        val y = testData2.where($"a" === 1).as("y")
+        checkAnswer2(
+          x,
+          y,
+          wrapForUnsafe(
+            (node1, node2) => HashJoinNode(
+              conf,
+              Seq(x.col("a").expr),
+              Seq(y.col("a").expr),
+              joins.BuildLeft,
+              node1,
+              node2)
+          ),
+          x.join(y).where($"x.a" === $"y.a").collect()
+        )
+      }
+    }
+
+    test(s"$suiteName: inner join, no matches") {
+      withSQLConf(confPairs: _*) {
+        val x = testData2.where($"a" === 1).as("x")
+        val y = testData2.where($"a" === 2).as("y")
+        checkAnswer2(
+          x,
+          y,
+          wrapForUnsafe(
+            (node1, node2) => HashJoinNode(
+              conf,
+              Seq(x.col("a").expr),
+              Seq(y.col("a").expr),
+              joins.BuildLeft,
+              node1,
+              node2)
+          ),
+          Nil
+        )
+      }
+    }
+
+    test(s"$suiteName: big inner join, 4 matches per row") {
+      withSQLConf(confPairs: _*) {
+        val bigData = 
testData.unionAll(testData).unionAll(testData).unionAll(testData)
+        val bigDataX = bigData.as("x")
+        val bigDataY = bigData.as("y")
+
+        checkAnswer2(
+          bigDataX,
+          bigDataY,
+          wrapForUnsafe(
+            (node1, node2) =>
+              HashJoinNode(
+                conf,
+                Seq(bigDataX.col("key").expr),
+                Seq(bigDataY.col("key").expr),
+                joins.BuildLeft,
+                node1,
+                node2)
+          ),
+          bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect())
+      }
+    }
+  }
+
+  joinSuite(
+    "general", SQLConf.CODEGEN_ENABLED.key -> "false", 
SQLConf.UNSAFE_ENABLED.key -> "false")
+  joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", 
SQLConf.UNSAFE_ENABLED.key -> "true")
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
index 523c02f..3b18390 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
@@ -24,7 +24,7 @@ class LimitNodeSuite extends LocalNodeTest with 
SharedSQLContext {
   test("basic") {
     checkAnswer(
       testData,
-      node => LimitNode(10, node),
+      node => LimitNode(conf, 10, node),
       testData.limit(10).collect()
     )
   }
@@ -32,7 +32,7 @@ class LimitNodeSuite extends LocalNodeTest with 
SharedSQLContext {
   test("empty") {
     checkAnswer(
       emptyTestData,
-      node => LimitNode(10, node),
+      node => LimitNode(conf, 10, node),
       emptyTestData.limit(10).collect()
     )
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
index 95f0608..b95d4ea 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
@@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.local
 import scala.util.control.NonFatal
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.{DataFrame, Row, SQLConf}
+import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
 
-class LocalNodeTest extends SparkFunSuite {
+class LocalNodeTest extends SparkFunSuite with SharedSQLContext {
+
+  def conf: SQLConf = sqlContext.conf
 
   /**
    * Runs the LocalNode and makes sure the answer matches the expected result.
@@ -92,6 +94,7 @@ class LocalNodeTest extends SparkFunSuite {
 
   protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = {
     new SeqScanNode(
+      conf,
       df.queryExecution.sparkPlan.output,
       df.queryExecution.toRdd.map(_.copy()).collect())
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
index ffcf092..38e0a23 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
@@ -26,7 +26,7 @@ class ProjectNodeSuite extends LocalNodeTest with 
SharedSQLContext {
     val columns = Seq(output(1), output(0))
     checkAnswer(
       testData,
-      node => ProjectNode(columns, node),
+      node => ProjectNode(conf, columns, node),
       testData.select("value", "key").collect()
     )
   }
@@ -36,7 +36,7 @@ class ProjectNodeSuite extends LocalNodeTest with 
SharedSQLContext {
     val columns = Seq(output(1), output(0))
     checkAnswer(
       emptyTestData,
-      node => ProjectNode(columns, node),
+      node => ProjectNode(conf, columns, node),
       emptyTestData.select("value", "key").collect()
     )
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d88abb7e/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
index 3467028..eedd732 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
@@ -25,7 +25,7 @@ class UnionNodeSuite extends LocalNodeTest with 
SharedSQLContext {
     checkAnswer2(
       testData,
       testData,
-      (node1, node2) => UnionNode(Seq(node1, node2)),
+      (node1, node2) => UnionNode(conf, Seq(node1, node2)),
       testData.unionAll(testData).collect()
     )
   }
@@ -34,7 +34,7 @@ class UnionNodeSuite extends LocalNodeTest with 
SharedSQLContext {
     checkAnswer2(
       emptyTestData,
       emptyTestData,
-      (node1, node2) => UnionNode(Seq(node1, node2)),
+      (node1, node2) => UnionNode(conf, Seq(node1, node2)),
       emptyTestData.unionAll(emptyTestData).collect()
     )
   }
@@ -44,7 +44,7 @@ class UnionNodeSuite extends LocalNodeTest with 
SharedSQLContext {
       emptyTestData, emptyTestData, testData, emptyTestData)
     doCheckAnswer(
       dfs,
-      nodes => UnionNode(nodes),
+      nodes => UnionNode(conf, nodes),
       dfs.reduce(_.unionAll(_)).collect()
     )
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to