Repository: spark
Updated Branches:
  refs/heads/master 1eede3b25 -> e626ac5f5


[SPARK-9992] [SPARK-9994] [SPARK-9998] [SQL] Implement the local TopK, sample 
and intersect operators

This PR is in conflict with #8535. I will update this one when #8535 gets 
merged.

Author: zsxwing <zsxw...@gmail.com>

Closes #8573 from zsxwing/more-local-operators.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e626ac5f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e626ac5f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e626ac5f

Branch: refs/heads/master
Commit: e626ac5f5c27dcc74113070f2fec03682bcd12bd
Parents: 1eede3b
Author: zsxwing <zsxw...@gmail.com>
Authored: Fri Sep 11 15:00:13 2015 -0700
Committer: Andrew Or <and...@databricks.com>
Committed: Fri Sep 11 15:00:13 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/basicOperators.scala    |  2 +-
 .../sql/execution/local/IntersectNode.scala     | 63 +++++++++++++++
 .../spark/sql/execution/local/LocalNode.scala   |  5 ++
 .../spark/sql/execution/local/SampleNode.scala  | 82 ++++++++++++++++++++
 .../local/TakeOrderedAndProjectNode.scala       | 73 +++++++++++++++++
 .../execution/local/IntersectNodeSuite.scala    | 35 +++++++++
 .../sql/execution/local/SampleNodeSuite.scala   | 40 ++++++++++
 .../local/TakeOrderedAndProjectNodeSuite.scala  | 54 +++++++++++++
 8 files changed, 353 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e626ac5f/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 3f68b05..bf6d44c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -138,7 +138,7 @@ case class Filter(condition: Expression, child: SparkPlan) 
extends UnaryNode {
  *                   will be ub - lb.
  * @param withReplacement Whether to sample with replacement.
  * @param seed the random seed
- * @param child the QueryPlan
+ * @param child the SparkPlan
  */
 @DeveloperApi
 case class Sample(

http://git-wip-us.apache.org/repos/asf/spark/blob/e626ac5f/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala
new file mode 100644
index 0000000..740d485
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala
@@ -0,0 +1,63 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+
+case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode)
+  extends BinaryLocalNode(conf) {
+
+  override def output: Seq[Attribute] = left.output
+
+  private[this] var leftRows: mutable.HashSet[InternalRow] = _
+
+  private[this] var currentRow: InternalRow = _
+
+  override def open(): Unit = {
+    left.open()
+    leftRows = mutable.HashSet[InternalRow]()
+    while (left.next()) {
+      leftRows += left.fetch().copy()
+    }
+    left.close()
+    right.open()
+  }
+
+  override def next(): Boolean = {
+    currentRow = null
+    while (currentRow == null && right.next()) {
+      currentRow = right.fetch()
+      if (!leftRows.contains(currentRow)) {
+        currentRow = null
+      }
+    }
+    currentRow != null
+  }
+
+  override def fetch(): InternalRow = currentRow
+
+  override def close(): Unit = {
+    left.close()
+    right.close()
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e626ac5f/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
index c4f8ae3..a2c275d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
@@ -70,6 +70,11 @@ abstract class LocalNode(conf: SQLConf) extends 
TreeNode[LocalNode] with Logging
   def close(): Unit
 
   /**
+   * Returns the content through the [[Iterator]] interface.
+   */
+  final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this)
+
+  /**
    * Returns the content of the iterator from the beginning to the end in the 
form of a Scala Seq.
    */
   def collect(): Seq[Row] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/e626ac5f/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
new file mode 100644
index 0000000..abf3df1
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.local
+
+import java.util.Random
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
+
+/**
+ * Sample the dataset.
+ *
+ * @param conf the SQLConf
+ * @param lowerBound Lower-bound of the sampling probability (usually 0.0)
+ * @param upperBound Upper-bound of the sampling probability. The expected 
fraction sampled
+ *                   will be ub - lb.
+ * @param withReplacement Whether to sample with replacement.
+ * @param seed the random seed
+ * @param child the LocalNode
+ */
+case class SampleNode(
+    conf: SQLConf,
+    lowerBound: Double,
+    upperBound: Double,
+    withReplacement: Boolean,
+    seed: Long,
+    child: LocalNode) extends UnaryLocalNode(conf) {
+
+  override def output: Seq[Attribute] = child.output
+
+  private[this] var iterator: Iterator[InternalRow] = _
+
+  private[this] var currentRow: InternalRow = _
+
+  override def open(): Unit = {
+    child.open()
+    val (sampler, _seed) = if (withReplacement) {
+        val random = new Random(seed)
+        // Disable gap sampling since the gap sampling method buffers two rows 
internally,
+        // requiring us to copy the row, which is more expensive than the 
random number generator.
+        (new PoissonSampler[InternalRow](upperBound - lowerBound, 
useGapSamplingIfPossible = false),
+          // Use the seed for partition 0 like PartitionwiseSampledRDD to 
generate the same result
+          // of DataFrame
+          random.nextLong())
+      } else {
+        (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
+      }
+    sampler.setSeed(_seed)
+    iterator = sampler.sample(child.asIterator)
+  }
+
+  override def next(): Boolean = {
+    if (iterator.hasNext) {
+      currentRow = iterator.next()
+      true
+    } else {
+      false
+    }
+  }
+
+  override def fetch(): InternalRow = currentRow
+
+  override def close(): Unit = child.close()
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e626ac5f/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
new file mode 100644
index 0000000..53f1dcc
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.util.BoundedPriorityQueue
+
+case class TakeOrderedAndProjectNode(
+    conf: SQLConf,
+    limit: Int,
+    sortOrder: Seq[SortOrder],
+    projectList: Option[Seq[NamedExpression]],
+    child: LocalNode) extends UnaryLocalNode(conf) {
+
+  private[this] var projection: Option[Projection] = _
+  private[this] var ord: InterpretedOrdering = _
+  private[this] var iterator: Iterator[InternalRow] = _
+  private[this] var currentRow: InternalRow = _
+
+  override def output: Seq[Attribute] = {
+    val projectOutput = projectList.map(_.map(_.toAttribute))
+    projectOutput.getOrElse(child.output)
+  }
+
+  override def open(): Unit = {
+    child.open()
+    projection = projectList.map(new InterpretedProjection(_, child.output))
+    ord = new InterpretedOrdering(sortOrder, child.output)
+    // Priority keeps the largest elements, so let's reverse the ordering.
+    val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse)
+    while (child.next()) {
+      queue += child.fetch()
+    }
+    // Close it eagerly since we don't need it.
+    child.close()
+    iterator = queue.iterator
+  }
+
+  override def next(): Boolean = {
+    if (iterator.hasNext) {
+      val _currentRow = iterator.next()
+      currentRow = projection match {
+        case Some(p) => p(_currentRow)
+        case None => _currentRow
+      }
+      true
+    } else {
+      false
+    }
+  }
+
+  override def fetch(): InternalRow = currentRow
+
+  override def close(): Unit = child.close()
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e626ac5f/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
new file mode 100644
index 0000000..7deaa37
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
@@ -0,0 +1,35 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+class IntersectNodeSuite extends LocalNodeTest {
+
+  import testImplicits._
+
+  test("basic") {
+    val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
+    val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, 
i.toString)).toDF("key", "value")
+
+    checkAnswer2(
+      input1,
+      input2,
+      (node1, node2) => IntersectNode(conf, node1, node2),
+      input1.intersect(input2).collect()
+    )
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e626ac5f/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
new file mode 100644
index 0000000..87a7da4
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.local
+
+class SampleNodeSuite extends LocalNodeTest {
+
+  import testImplicits._
+
+  private def testSample(withReplacement: Boolean): Unit = {
+    test(s"withReplacement: $withReplacement") {
+      val seed = 0L
+      val input = sqlContext.sparkContext.
+        parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 
1 partition
+        toDF("key", "value")
+      checkAnswer(
+        input,
+        node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node),
+        input.sample(withReplacement, 0.3, seed).collect()
+      )
+    }
+  }
+
+  testSample(withReplacement = true)
+  testSample(withReplacement = false)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e626ac5f/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
new file mode 100644
index 0000000..ff28b24
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, 
SortOrder}
+
+class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
+
+  import testImplicits._
+
+  private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
+    val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
+      col.expr match {
+        case expr: SortOrder =>
+          expr
+        case expr: Expression =>
+          SortOrder(expr, Ascending)
+      }
+    }
+    sortOrder
+  }
+
+  private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = {
+    val testCaseName = if (desc) "desc" else "asc"
+    test(testCaseName) {
+      val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
+      val sortColumn = if (desc) input.col("key").desc else input.col("key")
+      checkAnswer(
+        input,
+        node => TakeOrderedAndProjectNode(conf, 5, 
columnToSortOrder(sortColumn), None, node),
+        input.sort(sortColumn).limit(5).collect()
+      )
+    }
+  }
+
+  testTakeOrderedAndProjectNode(desc = false)
+  testTakeOrderedAndProjectNode(desc = true)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to