Repository: spark
Updated Branches:
  refs/heads/master 3ca995b78 -> e9c36938b


[SPARK-9752][SQL] Support UnsafeRow in Sample operator.

In order for this to work, I had to disable gap sampling.

Author: Reynold Xin <r...@databricks.com>

Closes #8040 from rxin/SPARK-9752 and squashes the following commits:

f9e248c [Reynold Xin] Fix the test case for real this time.
adbccb3 [Reynold Xin] Fixed test case.
589fb23 [Reynold Xin] Merge branch 'SPARK-9752' of github.com:rxin/spark into 
SPARK-9752
55ccddc [Reynold Xin] Fixed core test.
78fa895 [Reynold Xin] [SPARK-9752][SQL] Support UnsafeRow in Sample operator.
c9e7112 [Reynold Xin] [SPARK-9752][SQL] Support UnsafeRow in Sample operator.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e9c36938
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e9c36938
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e9c36938

Branch: refs/heads/master
Commit: e9c36938ba972b6fe3c9f6228508e3c9f1c876b2
Parents: 3ca995b
Author: Reynold Xin <r...@databricks.com>
Authored: Sun Aug 9 10:58:36 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Sun Aug 9 10:58:36 2015 -0700

----------------------------------------------------------------------
 .../spark/util/random/RandomSampler.scala       | 18 ++++++----
 .../spark/sql/execution/basicOperators.scala    | 18 +++++++---
 .../apache/spark/sql/DataFrameStatSuite.scala   | 35 ++++++++++++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala   | 17 ----------
 4 files changed, 61 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e9c36938/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala 
b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 786b97a..c156b03 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -176,10 +176,15 @@ class BernoulliSampler[T: ClassTag](fraction: Double) 
extends RandomSampler[T, T
  * A sampler for sampling with replacement, based on values drawn from Poisson 
distribution.
  *
  * @param fraction the sampling fraction (with replacement)
+ * @param useGapSamplingIfPossible if true, use gap sampling when sampling 
ratio is low.
  * @tparam T item type
  */
 @DeveloperApi
-class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, 
T] {
+class PoissonSampler[T: ClassTag](
+    fraction: Double,
+    useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {
+
+  def this(fraction: Double) = this(fraction, useGapSamplingIfPossible = true)
 
   /** Epsilon slop to avoid failure from floating point jitter. */
   require(
@@ -199,17 +204,18 @@ class PoissonSampler[T: ClassTag](fraction: Double) 
extends RandomSampler[T, T]
   override def sample(items: Iterator[T]): Iterator[T] = {
     if (fraction <= 0.0) {
       Iterator.empty
-    } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
-        new GapSamplingReplacementIterator(items, fraction, rngGap, 
RandomSampler.rngEpsilon)
+    } else if (useGapSamplingIfPossible &&
+               fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
+      new GapSamplingReplacementIterator(items, fraction, rngGap, 
RandomSampler.rngEpsilon)
     } else {
-      items.flatMap { item => {
+      items.flatMap { item =>
         val count = rng.sample()
         if (count == 0) Iterator.empty else Iterator.fill(count)(item)
-      }}
+      }
     }
   }
 
-  override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction)
+  override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction, 
useGapSamplingIfPossible)
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e9c36938/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 0680f31..c5d1ed0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
@@ -30,6 +30,7 @@ import org.apache.spark.sql.metric.SQLMetrics
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.collection.ExternalSorter
 import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
+import org.apache.spark.util.random.PoissonSampler
 import org.apache.spark.util.{CompletionIterator, MutablePair}
 import org.apache.spark.{HashPartitioner, SparkEnv}
 
@@ -130,12 +131,21 @@ case class Sample(
 {
   override def output: Seq[Attribute] = child.output
 
-  // TODO: How to pick seed?
+  override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+  override def canProcessUnsafeRows: Boolean = true
+  override def canProcessSafeRows: Boolean = true
+
   protected override def doExecute(): RDD[InternalRow] = {
     if (withReplacement) {
-      child.execute().map(_.copy()).sample(withReplacement, upperBound - 
lowerBound, seed)
+      // Disable gap sampling since the gap sampling method buffers two rows 
internally,
+      // requiring us to copy the row, which is more expensive than the random 
number generator.
+      new PartitionwiseSampledRDD[InternalRow, InternalRow](
+        child.execute(),
+        new PoissonSampler[InternalRow](upperBound - lowerBound, 
useGapSamplingIfPossible = false),
+        preservesPartitioning = true,
+        seed)
     } else {
-      child.execute().map(_.copy()).randomSampleWithRange(lowerBound, 
upperBound, seed)
+      child.execute().randomSampleWithRange(lowerBound, upperBound, seed)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9c36938/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 0e7659f..8f5984e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -30,6 +30,41 @@ class DataFrameStatSuite extends QueryTest {
 
   private def toLetter(i: Int): String = (i + 97).toChar.toString
 
+  test("sample with replacement") {
+    val n = 100
+    val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+    checkAnswer(
+      data.sample(withReplacement = true, 0.05, seed = 13),
+      Seq(5, 10, 52, 73).map(Row(_))
+    )
+  }
+
+  test("sample without replacement") {
+    val n = 100
+    val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+    checkAnswer(
+      data.sample(withReplacement = false, 0.05, seed = 13),
+      Seq(16, 23, 88, 100).map(Row(_))
+    )
+  }
+
+  test("randomSplit") {
+    val n = 600
+    val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+    for (seed <- 1 to 5) {
+      val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
+      assert(splits.length == 3, "wrong number of splits")
+
+      assert(splits.reduce((a, b) => 
a.unionAll(b)).sort("id").collect().toList ==
+        data.collect().toList, "incomplete or wrong split")
+
+      val s = splits.map(_.count())
+      assert(math.abs(s(0) - 100) < 50) // std =  9.13
+      assert(math.abs(s(1) - 200) < 50) // std = 11.55
+      assert(math.abs(s(2) - 300) < 50) // std = 12.25
+    }
+  }
+
   test("pearson correlation") {
     val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
     val corr1 = df.stat.corr("a", "b", "pearson")

http://git-wip-us.apache.org/repos/asf/spark/blob/e9c36938/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f9cc6d1..0212637 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -415,23 +415,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
     assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
   }
 
-  test("randomSplit") {
-    val n = 600
-    val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id")
-    for (seed <- 1 to 5) {
-      val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
-      assert(splits.length == 3, "wrong number of splits")
-
-      assert(splits.reduce((a, b) => 
a.unionAll(b)).sort("id").collect().toList ==
-        data.collect().toList, "incomplete or wrong split")
-
-      val s = splits.map(_.count())
-      assert(math.abs(s(0) - 100) < 50) // std =  9.13
-      assert(math.abs(s(1) - 200) < 50) // std = 11.55
-      assert(math.abs(s(2) - 300) < 50) // std = 12.25
-    }
-  }
-
   test("describe") {
     val describeTestData = Seq(
       ("Bob", 16, 176),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to