Repository: spark
Updated Branches:
  refs/heads/branch-2.1 99c293eea -> 51754d6df


[SPARK-18678][ML] Skewed reservoir sampling in SamplingUtils

## What changes were proposed in this pull request?

Fix reservoir sampling bias for small k. An off-by-one error meant that the 
probability of replacement was slightly too high -- k/(l-1) after l element 
instead of k/l, which matters for small k.

## How was this patch tested?

Existing test plus new test case.

Author: Sean Owen <so...@cloudera.com>

Closes #16129 from srowen/SPARK-18678.

(cherry picked from commit 79f5f281bb69cb2de9f64006180abd753e8ae427)
Signed-off-by: Sean Owen <so...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/51754d6d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/51754d6d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/51754d6d

Branch: refs/heads/branch-2.1
Commit: 51754d6df703c02ecb23ec1779889602ff8fb038
Parents: 99c293e
Author: Sean Owen <so...@cloudera.com>
Authored: Wed Dec 7 17:34:45 2016 +0800
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed Dec 7 17:34:57 2016 +0800

----------------------------------------------------------------------
 R/pkg/inst/tests/testthat/test_mllib.R                 |  9 +++++----
 .../org/apache/spark/util/random/SamplingUtils.scala   |  5 ++++-
 .../apache/spark/util/random/SamplingUtilsSuite.scala  | 13 +++++++++++++
 3 files changed, 22 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/51754d6d/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R 
b/R/pkg/inst/tests/testthat/test_mllib.R
index d7aa965..9f810be 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -1007,10 +1007,11 @@ test_that("spark.randomForest", {
   model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, 
maxBins = 16,
                               numTrees = 20, seed = 123)
   predictions <- collect(predict(model, data))
-  expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258,
-                                         63.736, 64.296, 64.868, 64.300,
-                                         66.709, 67.697, 67.966, 67.252,
-                                         68.866, 69.593, 69.195, 69.658),
+  expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 
62.11070,
+                                         63.53160, 64.05470, 65.12710, 
64.30450,
+                                         66.70910, 67.86125, 68.08700, 
67.21865,
+                                         68.89275, 69.53180, 69.39640, 
69.68250),
+
                tolerance = 1e-4)
   stats <- summary(model)
   expect_equal(stats$numTrees, 20)

http://git-wip-us.apache.org/repos/asf/spark/blob/51754d6d/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala 
b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
index 297524c..a7e0075 100644
--- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -56,11 +56,14 @@ private[spark] object SamplingUtils {
       val rand = new XORShiftRandom(seed)
       while (input.hasNext) {
         val item = input.next()
+        l += 1
+        // There are k elements in the reservoir, and the l-th element has been
+        // consumed. It should be chosen with probability k/l. The expression
+        // below is a random long chosen uniformly from [0,l)
         val replacementIndex = (rand.nextDouble() * l).toLong
         if (replacementIndex < k) {
           reservoir(replacementIndex.toInt) = item
         }
-        l += 1
       }
       (reservoir, l)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/51754d6d/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala 
b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
index 667a4db..55c5dd5 100644
--- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
@@ -44,6 +44,19 @@ class SamplingUtilsSuite extends SparkFunSuite {
     assert(sample3.length === 10)
   }
 
+  test("SPARK-18678 reservoirSampleAndCount with tiny input") {
+    val input = Seq(0, 1)
+    val counts = new Array[Int](input.size)
+    for (i <- 0 until 500) {
+      val (samples, inputSize) = 
SamplingUtils.reservoirSampleAndCount(input.iterator, 1)
+      assert(inputSize === 2)
+      assert(samples.length === 1)
+      counts(samples.head) += 1
+    }
+    // If correct, should be true with prob ~ 0.99999707
+    assert(math.abs(counts(0) - counts(1)) <= 100)
+  }
+
   test("computeFraction") {
     // test that the computed fraction guarantees enough data points
     // in the sample with a failure rate <= 0.0001


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to