Repository: spark
Updated Branches:
  refs/heads/master e25ec0617 -> fda475987


[SPARK-2801][MLlib]: DistributionGenerator renamed to RandomDataGenerator. 
RandomRDD is now of generic type

The RandomRDDGenerators used to only output RDD[Double].
Now RandomRDDGenerators.randomRDD can be used to generate a random RDD[T] via a 
class that extends RandomDataGenerator, by supplying a type T and overriding 
the nextValue() function as they wish.

Author: Burak <brk...@gmail.com>

Closes #1732 from brkyvz/SPARK-2801 and squashes the following commits:

c94a694 [Burak] [SPARK-2801][MLlib] Missing ClassTags added
22d96fe [Burak] [SPARK-2801][MLlib]: DistributionGenerator renamed to 
RandomDataGenerator, generic types added for RandomRDD instead of Double


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fda47598
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fda47598
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fda47598

Branch: refs/heads/master
Commit: fda475987f3b8b37d563033b0e45706ce433824a
Parents: e25ec06
Author: Burak <brk...@gmail.com>
Authored: Fri Aug 1 22:32:12 2014 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Fri Aug 1 22:32:12 2014 -0700

----------------------------------------------------------------------
 .../mllib/random/DistributionGenerator.scala    | 101 -------------------
 .../mllib/random/RandomDataGenerator.scala      | 101 +++++++++++++++++++
 .../mllib/random/RandomRDDGenerators.scala      |  32 +++---
 .../org/apache/spark/mllib/rdd/RandomRDD.scala  |  34 ++++---
 .../random/DistributionGeneratorSuite.scala     |  90 -----------------
 .../mllib/random/RandomDataGeneratorSuite.scala |  90 +++++++++++++++++
 .../mllib/random/RandomRDDGeneratorsSuite.scala |   8 +-
 7 files changed, 231 insertions(+), 225 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala
deleted file mode 100644
index 7ecb409..0000000
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.random
-
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom}
-
-/**
- * :: Experimental ::
- * Trait for random number generators that generate i.i.d. values from a 
distribution.
- */
-@Experimental
-trait DistributionGenerator extends Pseudorandom with Serializable {
-
-  /**
-   * Returns an i.i.d. sample as a Double from an underlying distribution.
-   */
-  def nextValue(): Double
-
-  /**
-   * Returns a copy of the DistributionGenerator with a new instance of the 
rng object used in the
-   * class when applicable for non-locking concurrent usage.
-   */
-  def copy(): DistributionGenerator
-}
-
-/**
- * :: Experimental ::
- * Generates i.i.d. samples from U[0.0, 1.0]
- */
-@Experimental
-class UniformGenerator extends DistributionGenerator {
-
-  // XORShiftRandom for better performance. Thread safety isn't necessary here.
-  private val random = new XORShiftRandom()
-
-  override def nextValue(): Double = {
-    random.nextDouble()
-  }
-
-  override def setSeed(seed: Long) = random.setSeed(seed)
-
-  override def copy(): UniformGenerator = new UniformGenerator()
-}
-
-/**
- * :: Experimental ::
- * Generates i.i.d. samples from the standard normal distribution.
- */
-@Experimental
-class StandardNormalGenerator extends DistributionGenerator {
-
-  // XORShiftRandom for better performance. Thread safety isn't necessary here.
-  private val random = new XORShiftRandom()
-
-  override def nextValue(): Double = {
-      random.nextGaussian()
-  }
-
-  override def setSeed(seed: Long) = random.setSeed(seed)
-
-  override def copy(): StandardNormalGenerator = new StandardNormalGenerator()
-}
-
-/**
- * :: Experimental ::
- * Generates i.i.d. samples from the Poisson distribution with the given mean.
- *
- * @param mean mean for the Poisson distribution.
- */
-@Experimental
-class PoissonGenerator(val mean: Double) extends DistributionGenerator {
-
-  private var rng = new Poisson(mean, new DRand)
-
-  override def nextValue(): Double = rng.nextDouble()
-
-  override def setSeed(seed: Long) {
-    rng = new Poisson(mean, new DRand(seed.toInt))
-  }
-
-  override def copy(): PoissonGenerator = new PoissonGenerator(mean)
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
new file mode 100644
index 0000000..9cab49f
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.random
+
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom}
+
+/**
+ * :: Experimental ::
+ * Trait for random data generators that generate i.i.d. data.
+ */
+@Experimental
+trait RandomDataGenerator[T] extends Pseudorandom with Serializable {
+
+  /**
+   * Returns an i.i.d. sample as a generic type from an underlying 
distribution.
+   */
+  def nextValue(): T
+
+  /**
+   * Returns a copy of the RandomDataGenerator with a new instance of the rng 
object used in the
+   * class when applicable for non-locking concurrent usage.
+   */
+  def copy(): RandomDataGenerator[T]
+}
+
+/**
+ * :: Experimental ::
+ * Generates i.i.d. samples from U[0.0, 1.0]
+ */
+@Experimental
+class UniformGenerator extends RandomDataGenerator[Double] {
+
+  // XORShiftRandom for better performance. Thread safety isn't necessary here.
+  private val random = new XORShiftRandom()
+
+  override def nextValue(): Double = {
+    random.nextDouble()
+  }
+
+  override def setSeed(seed: Long) = random.setSeed(seed)
+
+  override def copy(): UniformGenerator = new UniformGenerator()
+}
+
+/**
+ * :: Experimental ::
+ * Generates i.i.d. samples from the standard normal distribution.
+ */
+@Experimental
+class StandardNormalGenerator extends RandomDataGenerator[Double] {
+
+  // XORShiftRandom for better performance. Thread safety isn't necessary here.
+  private val random = new XORShiftRandom()
+
+  override def nextValue(): Double = {
+      random.nextGaussian()
+  }
+
+  override def setSeed(seed: Long) = random.setSeed(seed)
+
+  override def copy(): StandardNormalGenerator = new StandardNormalGenerator()
+}
+
+/**
+ * :: Experimental ::
+ * Generates i.i.d. samples from the Poisson distribution with the given mean.
+ *
+ * @param mean mean for the Poisson distribution.
+ */
+@Experimental
+class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] {
+
+  private var rng = new Poisson(mean, new DRand)
+
+  override def nextValue(): Double = rng.nextDouble()
+
+  override def setSeed(seed: Long) {
+    rng = new Poisson(mean, new DRand(seed.toInt))
+  }
+
+  override def copy(): PoissonGenerator = new PoissonGenerator(mean)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala
index 021d651..b0a0593 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala
@@ -24,6 +24,8 @@ import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.Utils
 
+import scala.reflect.ClassTag
+
 /**
  * :: Experimental ::
  * Generator methods for creating RDDs comprised of i.i.d. samples from some 
distribution.
@@ -200,12 +202,12 @@ object RandomRDDGenerators {
    * @return RDD[Double] comprised of i.i.d. samples produced by generator.
    */
   @Experimental
-  def randomRDD(sc: SparkContext,
-      generator: DistributionGenerator,
+  def randomRDD[T: ClassTag](sc: SparkContext,
+      generator: RandomDataGenerator[T],
       size: Long,
       numPartitions: Int,
-      seed: Long): RDD[Double] = {
-    new RandomRDD(sc, size, numPartitions, generator, seed)
+      seed: Long): RDD[T] = {
+    new RandomRDD[T](sc, size, numPartitions, generator, seed)
   }
 
   /**
@@ -219,11 +221,11 @@ object RandomRDDGenerators {
    * @return RDD[Double] comprised of i.i.d. samples produced by generator.
    */
   @Experimental
-  def randomRDD(sc: SparkContext,
-      generator: DistributionGenerator,
+  def randomRDD[T: ClassTag](sc: SparkContext,
+      generator: RandomDataGenerator[T],
       size: Long,
-      numPartitions: Int): RDD[Double] = {
-    randomRDD(sc, generator, size, numPartitions, Utils.random.nextLong)
+      numPartitions: Int): RDD[T] = {
+    randomRDD[T](sc, generator, size, numPartitions, Utils.random.nextLong)
   }
 
   /**
@@ -237,10 +239,10 @@ object RandomRDDGenerators {
    * @return RDD[Double] comprised of i.i.d. samples produced by generator.
    */
   @Experimental
-  def randomRDD(sc: SparkContext,
-      generator: DistributionGenerator,
-      size: Long): RDD[Double] = {
-    randomRDD(sc, generator, size, sc.defaultParallelism, 
Utils.random.nextLong)
+  def randomRDD[T: ClassTag](sc: SparkContext,
+      generator: RandomDataGenerator[T],
+      size: Long): RDD[T] = {
+    randomRDD[T](sc, generator, size, sc.defaultParallelism, 
Utils.random.nextLong)
   }
 
   // TODO Generate RDD[Vector] from multivariate distributions.
@@ -439,7 +441,7 @@ object RandomRDDGenerators {
    */
   @Experimental
   def randomVectorRDD(sc: SparkContext,
-      generator: DistributionGenerator,
+      generator: RandomDataGenerator[Double],
       numRows: Long,
       numCols: Int,
       numPartitions: Int,
@@ -461,7 +463,7 @@ object RandomRDDGenerators {
    */
   @Experimental
   def randomVectorRDD(sc: SparkContext,
-      generator: DistributionGenerator,
+      generator: RandomDataGenerator[Double],
       numRows: Long,
       numCols: Int,
       numPartitions: Int): RDD[Vector] = {
@@ -482,7 +484,7 @@ object RandomRDDGenerators {
    */
   @Experimental
   def randomVectorRDD(sc: SparkContext,
-      generator: DistributionGenerator,
+      generator: RandomDataGenerator[Double],
       numRows: Long,
       numCols: Int): RDD[Vector] = {
     randomVectorRDD(sc, generator, numRows, numCols,

http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
index f13282d..c8db391 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
@@ -19,35 +19,36 @@ package org.apache.spark.mllib.rdd
 
 import org.apache.spark.{Partition, SparkContext, TaskContext}
 import org.apache.spark.mllib.linalg.{DenseVector, Vector}
-import org.apache.spark.mllib.random.DistributionGenerator
+import org.apache.spark.mllib.random.RandomDataGenerator
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.Utils
 
+import scala.reflect.ClassTag
 import scala.util.Random
 
-private[mllib] class RandomRDDPartition(override val index: Int,
+private[mllib] class RandomRDDPartition[T](override val index: Int,
     val size: Int,
-    val generator: DistributionGenerator,
+    val generator: RandomDataGenerator[T],
     val seed: Long) extends Partition {
 
   require(size >= 0, "Non-negative partition size required.")
 }
 
 // These two classes are necessary since Range objects in Scala cannot have 
size > Int.MaxValue
-private[mllib] class RandomRDD(@transient sc: SparkContext,
+private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext,
     size: Long,
     numPartitions: Int,
-    @transient rng: DistributionGenerator,
-    @transient seed: Long = Utils.random.nextLong) extends RDD[Double](sc, 
Nil) {
+    @transient rng: RandomDataGenerator[T],
+    @transient seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) {
 
   require(size > 0, "Positive RDD size required.")
   require(numPartitions > 0, "Positive number of partitions required")
   require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue,
     "Partition size cannot exceed Int.MaxValue")
 
-  override def compute(splitIn: Partition, context: TaskContext): 
Iterator[Double] = {
-    val split = splitIn.asInstanceOf[RandomRDDPartition]
-    RandomRDD.getPointIterator(split)
+  override def compute(splitIn: Partition, context: TaskContext): Iterator[T] 
= {
+    val split = splitIn.asInstanceOf[RandomRDDPartition[T]]
+    RandomRDD.getPointIterator[T](split)
   }
 
   override def getPartitions: Array[Partition] = {
@@ -59,7 +60,7 @@ private[mllib] class RandomVectorRDD(@transient sc: 
SparkContext,
     size: Long,
     vectorSize: Int,
     numPartitions: Int,
-    @transient rng: DistributionGenerator,
+    @transient rng: RandomDataGenerator[Double],
     @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, 
Nil) {
 
   require(size > 0, "Positive RDD size required.")
@@ -69,7 +70,7 @@ private[mllib] class RandomVectorRDD(@transient sc: 
SparkContext,
     "Partition size cannot exceed Int.MaxValue")
 
   override def compute(splitIn: Partition, context: TaskContext): 
Iterator[Vector] = {
-    val split = splitIn.asInstanceOf[RandomRDDPartition]
+    val split = splitIn.asInstanceOf[RandomRDDPartition[Double]]
     RandomRDD.getVectorIterator(split, vectorSize)
   }
 
@@ -80,12 +81,12 @@ private[mllib] class RandomVectorRDD(@transient sc: 
SparkContext,
 
 private[mllib] object RandomRDD {
 
-  def getPartitions(size: Long,
+  def getPartitions[T](size: Long,
       numPartitions: Int,
-      rng: DistributionGenerator,
+      rng: RandomDataGenerator[T],
       seed: Long): Array[Partition] = {
 
-    val partitions = new Array[RandomRDDPartition](numPartitions)
+    val partitions = new Array[RandomRDDPartition[T]](numPartitions)
     var i = 0
     var start: Long = 0
     var end: Long = 0
@@ -101,7 +102,7 @@ private[mllib] object RandomRDD {
 
   // The RNG has to be reset every time the iterator is requested to guarantee 
same data
   // every time the content of the RDD is examined.
-  def getPointIterator(partition: RandomRDDPartition): Iterator[Double] = {
+  def getPointIterator[T: ClassTag](partition: RandomRDDPartition[T]): 
Iterator[T] = {
     val generator = partition.generator.copy()
     generator.setSeed(partition.seed)
     Array.fill(partition.size)(generator.nextValue()).toIterator
@@ -109,7 +110,8 @@ private[mllib] object RandomRDD {
 
   // The RNG has to be reset every time the iterator is requested to guarantee 
same data
   // every time the content of the RDD is examined.
-  def getVectorIterator(partition: RandomRDDPartition, vectorSize: Int): 
Iterator[Vector] = {
+  def getVectorIterator(partition: RandomRDDPartition[Double],
+                        vectorSize: Int): Iterator[Vector] = {
     val generator = partition.generator.copy()
     generator.setSeed(partition.seed)
     Array.fill(partition.size)(new DenseVector(

http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala
deleted file mode 100644
index 974dec4..0000000
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala
+++ /dev/null
@@ -1,90 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.random
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.util.StatCounter
-
-// TODO update tests to use TestingUtils for floating point comparison after 
PR 1367 is merged
-class DistributionGeneratorSuite extends FunSuite {
-
-  def apiChecks(gen: DistributionGenerator) {
-
-    // resetting seed should generate the same sequence of random numbers
-    gen.setSeed(42L)
-    val array1 = (0 until 1000).map(_ => gen.nextValue())
-    gen.setSeed(42L)
-    val array2 = (0 until 1000).map(_ => gen.nextValue())
-    assert(array1.equals(array2))
-
-    // newInstance should contain a difference instance of the rng
-    // i.e. setting difference seeds for difference instances produces 
different sequences of
-    // random numbers.
-    val gen2 = gen.copy()
-    gen.setSeed(0L)
-    val array3 = (0 until 1000).map(_ => gen.nextValue())
-    gen2.setSeed(1L)
-    val array4 = (0 until 1000).map(_ => gen2.nextValue())
-    // Compare arrays instead of elements since individual elements can 
coincide by chance but the
-    // sequences should differ given two different seeds.
-    assert(!array3.equals(array4))
-
-    // test that setting the same seed in the copied instance produces the 
same sequence of numbers
-    gen.setSeed(0L)
-    val array5 = (0 until 1000).map(_ => gen.nextValue())
-    gen2.setSeed(0L)
-    val array6 = (0 until 1000).map(_ => gen2.nextValue())
-    assert(array5.equals(array6))
-  }
-
-  def distributionChecks(gen: DistributionGenerator,
-      mean: Double = 0.0,
-      stddev: Double = 1.0,
-      epsilon: Double = 0.01) {
-    for (seed <- 0 until 5) {
-      gen.setSeed(seed.toLong)
-      val sample = (0 until 100000).map { _ => gen.nextValue()}
-      val stats = new StatCounter(sample)
-      assert(math.abs(stats.mean - mean) < epsilon)
-      assert(math.abs(stats.stdev - stddev) < epsilon)
-    }
-  }
-
-  test("UniformGenerator") {
-    val uniform = new UniformGenerator()
-    apiChecks(uniform)
-    // Stddev of uniform distribution = (ub - lb) / math.sqrt(12)
-    distributionChecks(uniform, 0.5, 1 / math.sqrt(12))
-  }
-
-  test("StandardNormalGenerator") {
-    val normal = new StandardNormalGenerator()
-    apiChecks(normal)
-    distributionChecks(normal, 0.0, 1.0)
-  }
-
-  test("PoissonGenerator") {
-    // mean = 0.0 will not pass the API checks since 0.0 is always 
deterministically produced.
-    for (mean <- List(1.0, 5.0, 100.0)) {
-      val poisson = new PoissonGenerator(mean)
-      apiChecks(poisson)
-      distributionChecks(poisson, mean, math.sqrt(mean), 0.1)
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
new file mode 100644
index 0000000..3df7c12
--- /dev/null
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.random
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.util.StatCounter
+
+// TODO update tests to use TestingUtils for floating point comparison after 
PR 1367 is merged
+class RandomDataGeneratorSuite extends FunSuite {
+
+  def apiChecks(gen: RandomDataGenerator[Double]) {
+
+    // resetting seed should generate the same sequence of random numbers
+    gen.setSeed(42L)
+    val array1 = (0 until 1000).map(_ => gen.nextValue())
+    gen.setSeed(42L)
+    val array2 = (0 until 1000).map(_ => gen.nextValue())
+    assert(array1.equals(array2))
+
+    // newInstance should contain a difference instance of the rng
+    // i.e. setting difference seeds for difference instances produces 
different sequences of
+    // random numbers.
+    val gen2 = gen.copy()
+    gen.setSeed(0L)
+    val array3 = (0 until 1000).map(_ => gen.nextValue())
+    gen2.setSeed(1L)
+    val array4 = (0 until 1000).map(_ => gen2.nextValue())
+    // Compare arrays instead of elements since individual elements can 
coincide by chance but the
+    // sequences should differ given two different seeds.
+    assert(!array3.equals(array4))
+
+    // test that setting the same seed in the copied instance produces the 
same sequence of numbers
+    gen.setSeed(0L)
+    val array5 = (0 until 1000).map(_ => gen.nextValue())
+    gen2.setSeed(0L)
+    val array6 = (0 until 1000).map(_ => gen2.nextValue())
+    assert(array5.equals(array6))
+  }
+
+  def distributionChecks(gen: RandomDataGenerator[Double],
+      mean: Double = 0.0,
+      stddev: Double = 1.0,
+      epsilon: Double = 0.01) {
+    for (seed <- 0 until 5) {
+      gen.setSeed(seed.toLong)
+      val sample = (0 until 100000).map { _ => gen.nextValue()}
+      val stats = new StatCounter(sample)
+      assert(math.abs(stats.mean - mean) < epsilon)
+      assert(math.abs(stats.stdev - stddev) < epsilon)
+    }
+  }
+
+  test("UniformGenerator") {
+    val uniform = new UniformGenerator()
+    apiChecks(uniform)
+    // Stddev of uniform distribution = (ub - lb) / math.sqrt(12)
+    distributionChecks(uniform, 0.5, 1 / math.sqrt(12))
+  }
+
+  test("StandardNormalGenerator") {
+    val normal = new StandardNormalGenerator()
+    apiChecks(normal)
+    distributionChecks(normal, 0.0, 1.0)
+  }
+
+  test("PoissonGenerator") {
+    // mean = 0.0 will not pass the API checks since 0.0 is always 
deterministically produced.
+    for (mean <- List(1.0, 5.0, 100.0)) {
+      val poisson = new PoissonGenerator(mean)
+      apiChecks(poisson)
+      distributionChecks(poisson, mean, math.sqrt(mean), 0.1)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
index 6aa4f80..96e0bc6 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
@@ -78,7 +78,9 @@ class RandomRDDGeneratorsSuite extends FunSuite with 
LocalSparkContext with Seri
       assert(rdd.partitions.size === numPartitions)
 
       // check that partition sizes are balanced
-      val partSizes = rdd.partitions.map(p => 
p.asInstanceOf[RandomRDDPartition].size.toDouble)
+      val partSizes = rdd.partitions.map(p =>
+        p.asInstanceOf[RandomRDDPartition[Double]].size.toDouble)
+
       val partStats = new StatCounter(partSizes)
       assert(partStats.max - partStats.min <= 1)
     }
@@ -89,7 +91,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with 
LocalSparkContext with Seri
     val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L)
     assert(rdd.partitions.size === numPartitions)
     val count = rdd.partitions.foldLeft(0L) { (count, part) =>
-      count + part.asInstanceOf[RandomRDDPartition].size
+      count + part.asInstanceOf[RandomRDDPartition[Double]].size
     }
     assert(count === size)
 
@@ -145,7 +147,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with 
LocalSparkContext with Seri
   }
 }
 
-private[random] class MockDistro extends DistributionGenerator {
+private[random] class MockDistro extends RandomDataGenerator[Double] {
 
   var seed = 0L
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to