Repository: spark
Updated Branches:
  refs/heads/master e626ac5f5 -> c2af42b5f


[SPARK-9990] [SQL] Local hash join follow-ups

1. Hide `LocalNodeIterator` behind the `LocalNode#asIterator` method
2. Add tests for this

Author: Andrew Or <and...@databricks.com>

Closes #8708 from andrewor14/local-hash-join-follow-up.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c2af42b5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c2af42b5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c2af42b5

Branch: refs/heads/master
Commit: c2af42b5f32287ff595ad027a8191d4b75702d8d
Parents: e626ac5
Author: Andrew Or <and...@databricks.com>
Authored: Fri Sep 11 15:01:37 2015 -0700
Committer: Andrew Or <and...@databricks.com>
Committed: Fri Sep 11 15:01:37 2015 -0700

----------------------------------------------------------------------
 .../sql/execution/joins/HashedRelation.scala    |   7 +-
 .../sql/execution/local/HashJoinNode.scala      |   3 +-
 .../spark/sql/execution/local/LocalNode.scala   |   4 +-
 .../sql/execution/local/LocalNodeSuite.scala    | 116 +++++++++++++++++++
 4 files changed, 125 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c2af42b5/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 0cff21c..bc255b2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -25,7 +25,8 @@ import org.apache.spark.shuffle.ShuffleMemoryManager
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkSqlSerializer
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.local.LocalNode
+import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.map.BytesToBytesMap
 import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, 
MemoryAllocator, TaskMemoryManager}
@@ -113,6 +114,10 @@ final class UniqueKeyHashedRelation(private var hashTable: 
JavaHashMap[InternalR
 
 private[execution] object HashedRelation {
 
+  def apply(localNode: LocalNode, keyGenerator: Projection): HashedRelation = {
+    apply(localNode.asIterator, SQLMetrics.nullLongMetric, keyGenerator)
+  }
+
   def apply(
       input: Iterator[InternalRow],
       numInputRows: LongSQLMetric,

http://git-wip-us.apache.org/repos/asf/spark/blob/c2af42b5/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
index a3e68d6..e7b24e3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
@@ -75,8 +75,7 @@ case class HashJoinNode(
 
   override def open(): Unit = {
     buildNode.open()
-    hashed = HashedRelation.apply(
-      new LocalNodeIterator(buildNode), SQLMetrics.nullLongMetric, 
buildSideKeyGenerator)
+    hashed = HashedRelation(buildNode, buildSideKeyGenerator)
     streamedNode.open()
     joinRow = new JoinedRow
     resultProjection = {

http://git-wip-us.apache.org/repos/asf/spark/blob/c2af42b5/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
index a2c275d..e540ef8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
@@ -77,7 +77,7 @@ abstract class LocalNode(conf: SQLConf) extends 
TreeNode[LocalNode] with Logging
   /**
    * Returns the content of the iterator from the beginning to the end in the 
form of a Scala Seq.
    */
-  def collect(): Seq[Row] = {
+  final def collect(): Seq[Row] = {
     val converter = 
CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output))
     val result = new scala.collection.mutable.ArrayBuffer[Row]
     open()
@@ -140,7 +140,7 @@ abstract class BinaryLocalNode(conf: SQLConf) extends 
LocalNode(conf) {
 /**
  * An thin wrapper around a [[LocalNode]] that provides an `Iterator` 
interface.
  */
-private[local] class LocalNodeIterator(localNode: LocalNode) extends 
Iterator[InternalRow] {
+private class LocalNodeIterator(localNode: LocalNode) extends 
Iterator[InternalRow] {
   private var nextRow: InternalRow = _
 
   override def hasNext: Boolean = {

http://git-wip-us.apache.org/repos/asf/spark/blob/c2af42b5/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
new file mode 100644
index 0000000..b89fa46
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
@@ -0,0 +1,116 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.IntegerType
+
+class LocalNodeSuite extends SparkFunSuite {
+  private val data = (1 to 100).toArray
+
+  test("basic open, next, fetch, close") {
+    val node = new DummyLocalNode(data)
+    assert(!node.isOpen)
+    node.open()
+    assert(node.isOpen)
+    data.foreach { i =>
+      assert(node.next())
+      // fetch should be idempotent
+      val fetched = node.fetch()
+      assert(node.fetch() === fetched)
+      assert(node.fetch() === fetched)
+      assert(node.fetch().numFields === 1)
+      assert(node.fetch().getInt(0) === i)
+    }
+    assert(!node.next())
+    node.close()
+    assert(!node.isOpen)
+  }
+
+  test("asIterator") {
+    val node = new DummyLocalNode(data)
+    val iter = node.asIterator
+    node.open()
+    data.foreach { i =>
+      // hasNext should be idempotent
+      assert(iter.hasNext)
+      assert(iter.hasNext)
+      val item = iter.next()
+      assert(item.numFields === 1)
+      assert(item.getInt(0) === i)
+    }
+    intercept[NoSuchElementException] {
+      iter.next()
+    }
+    node.close()
+  }
+
+  test("collect") {
+    val node = new DummyLocalNode(data)
+    node.open()
+    val collected = node.collect()
+    assert(collected.size === data.size)
+    assert(collected.forall(_.size === 1))
+    assert(collected.map(_.getInt(0)) === data)
+    node.close()
+  }
+
+}
+
+/**
+ * A dummy [[LocalNode]] that just returns one row per integer in the input.
+ */
+private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends 
LocalNode(conf) {
+  private var index = Int.MinValue
+
+  def this(input: Array[Int]) {
+    this(new SQLConf, input)
+  }
+
+  def isOpen: Boolean = {
+    index != Int.MinValue
+  }
+
+  override def output: Seq[Attribute] = {
+    Seq(AttributeReference("something", IntegerType)())
+  }
+
+  override def children: Seq[LocalNode] = Seq.empty
+
+  override def open(): Unit = {
+    index = -1
+  }
+
+  override def next(): Boolean = {
+    index += 1
+    index < input.size
+  }
+
+  override def fetch(): InternalRow = {
+    assert(index >= 0 && index < input.size)
+    val values = Array(input(index).asInstanceOf[Any])
+    new GenericInternalRow(values)
+  }
+
+  override def close(): Unit = {
+    index = Int.MinValue
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to