Repository: spark Updated Branches: refs/heads/master e25ec0617 -> fda475987
[SPARK-2801][MLlib]: DistributionGenerator renamed to RandomDataGenerator. RandomRDD is now of generic type The RandomRDDGenerators used to only output RDD[Double]. Now RandomRDDGenerators.randomRDD can be used to generate a random RDD[T] via a class that extends RandomDataGenerator, by supplying a type T and overriding the nextValue() function as they wish. Author: Burak <brk...@gmail.com> Closes #1732 from brkyvz/SPARK-2801 and squashes the following commits: c94a694 [Burak] [SPARK-2801][MLlib] Missing ClassTags added 22d96fe [Burak] [SPARK-2801][MLlib]: DistributionGenerator renamed to RandomDataGenerator, generic types added for RandomRDD instead of Double Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fda47598 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fda47598 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fda47598 Branch: refs/heads/master Commit: fda475987f3b8b37d563033b0e45706ce433824a Parents: e25ec06 Author: Burak <brk...@gmail.com> Authored: Fri Aug 1 22:32:12 2014 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Fri Aug 1 22:32:12 2014 -0700 ---------------------------------------------------------------------- .../mllib/random/DistributionGenerator.scala | 101 ------------------- .../mllib/random/RandomDataGenerator.scala | 101 +++++++++++++++++++ .../mllib/random/RandomRDDGenerators.scala | 32 +++--- .../org/apache/spark/mllib/rdd/RandomRDD.scala | 34 ++++--- .../random/DistributionGeneratorSuite.scala | 90 ----------------- .../mllib/random/RandomDataGeneratorSuite.scala | 90 +++++++++++++++++ .../mllib/random/RandomRDDGeneratorsSuite.scala | 8 +- 7 files changed, 231 insertions(+), 225 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala deleted file mode 100644 index 7ecb409..0000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.random - -import cern.jet.random.Poisson -import cern.jet.random.engine.DRand - -import org.apache.spark.annotation.Experimental -import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} - -/** - * :: Experimental :: - * Trait for random number generators that generate i.i.d. values from a distribution. - */ -@Experimental -trait DistributionGenerator extends Pseudorandom with Serializable { - - /** - * Returns an i.i.d. sample as a Double from an underlying distribution. - */ - def nextValue(): Double - - /** - * Returns a copy of the DistributionGenerator with a new instance of the rng object used in the - * class when applicable for non-locking concurrent usage. - */ - def copy(): DistributionGenerator -} - -/** - * :: Experimental :: - * Generates i.i.d. samples from U[0.0, 1.0] - */ -@Experimental -class UniformGenerator extends DistributionGenerator { - - // XORShiftRandom for better performance. Thread safety isn't necessary here. - private val random = new XORShiftRandom() - - override def nextValue(): Double = { - random.nextDouble() - } - - override def setSeed(seed: Long) = random.setSeed(seed) - - override def copy(): UniformGenerator = new UniformGenerator() -} - -/** - * :: Experimental :: - * Generates i.i.d. samples from the standard normal distribution. - */ -@Experimental -class StandardNormalGenerator extends DistributionGenerator { - - // XORShiftRandom for better performance. Thread safety isn't necessary here. - private val random = new XORShiftRandom() - - override def nextValue(): Double = { - random.nextGaussian() - } - - override def setSeed(seed: Long) = random.setSeed(seed) - - override def copy(): StandardNormalGenerator = new StandardNormalGenerator() -} - -/** - * :: Experimental :: - * Generates i.i.d. samples from the Poisson distribution with the given mean. - * - * @param mean mean for the Poisson distribution. - */ -@Experimental -class PoissonGenerator(val mean: Double) extends DistributionGenerator { - - private var rng = new Poisson(mean, new DRand) - - override def nextValue(): Double = rng.nextDouble() - - override def setSeed(seed: Long) { - rng = new Poisson(mean, new DRand(seed.toInt)) - } - - override def copy(): PoissonGenerator = new PoissonGenerator(mean) -} http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala new file mode 100644 index 0000000..9cab49f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.random + +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand + +import org.apache.spark.annotation.Experimental +import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} + +/** + * :: Experimental :: + * Trait for random data generators that generate i.i.d. data. + */ +@Experimental +trait RandomDataGenerator[T] extends Pseudorandom with Serializable { + + /** + * Returns an i.i.d. sample as a generic type from an underlying distribution. + */ + def nextValue(): T + + /** + * Returns a copy of the RandomDataGenerator with a new instance of the rng object used in the + * class when applicable for non-locking concurrent usage. + */ + def copy(): RandomDataGenerator[T] +} + +/** + * :: Experimental :: + * Generates i.i.d. samples from U[0.0, 1.0] + */ +@Experimental +class UniformGenerator extends RandomDataGenerator[Double] { + + // XORShiftRandom for better performance. Thread safety isn't necessary here. + private val random = new XORShiftRandom() + + override def nextValue(): Double = { + random.nextDouble() + } + + override def setSeed(seed: Long) = random.setSeed(seed) + + override def copy(): UniformGenerator = new UniformGenerator() +} + +/** + * :: Experimental :: + * Generates i.i.d. samples from the standard normal distribution. + */ +@Experimental +class StandardNormalGenerator extends RandomDataGenerator[Double] { + + // XORShiftRandom for better performance. Thread safety isn't necessary here. + private val random = new XORShiftRandom() + + override def nextValue(): Double = { + random.nextGaussian() + } + + override def setSeed(seed: Long) = random.setSeed(seed) + + override def copy(): StandardNormalGenerator = new StandardNormalGenerator() +} + +/** + * :: Experimental :: + * Generates i.i.d. samples from the Poisson distribution with the given mean. + * + * @param mean mean for the Poisson distribution. + */ +@Experimental +class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { + + private var rng = new Poisson(mean, new DRand) + + override def nextValue(): Double = rng.nextDouble() + + override def setSeed(seed: Long) { + rng = new Poisson(mean, new DRand(seed.toInt)) + } + + override def copy(): PoissonGenerator = new PoissonGenerator(mean) +} http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala index 021d651..b0a0593 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala @@ -24,6 +24,8 @@ import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import scala.reflect.ClassTag + /** * :: Experimental :: * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. @@ -200,12 +202,12 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], size: Long, numPartitions: Int, - seed: Long): RDD[Double] = { - new RandomRDD(sc, size, numPartitions, generator, seed) + seed: Long): RDD[T] = { + new RandomRDD[T](sc, size, numPartitions, generator, seed) } /** @@ -219,11 +221,11 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], size: Long, - numPartitions: Int): RDD[Double] = { - randomRDD(sc, generator, size, numPartitions, Utils.random.nextLong) + numPartitions: Int): RDD[T] = { + randomRDD[T](sc, generator, size, numPartitions, Utils.random.nextLong) } /** @@ -237,10 +239,10 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, - size: Long): RDD[Double] = { - randomRDD(sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], + size: Long): RDD[T] = { + randomRDD[T](sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) } // TODO Generate RDD[Vector] from multivariate distributions. @@ -439,7 +441,7 @@ object RandomRDDGenerators { */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int, numPartitions: Int, @@ -461,7 +463,7 @@ object RandomRDDGenerators { */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int, numPartitions: Int): RDD[Vector] = { @@ -482,7 +484,7 @@ object RandomRDDGenerators { */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int): RDD[Vector] = { randomVectorRDD(sc, generator, numRows, numCols, http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index f13282d..c8db391 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -19,35 +19,36 @@ package org.apache.spark.mllib.rdd import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.mllib.linalg.{DenseVector, Vector} -import org.apache.spark.mllib.random.DistributionGenerator +import org.apache.spark.mllib.random.RandomDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import scala.reflect.ClassTag import scala.util.Random -private[mllib] class RandomRDDPartition(override val index: Int, +private[mllib] class RandomRDDPartition[T](override val index: Int, val size: Int, - val generator: DistributionGenerator, + val generator: RandomDataGenerator[T], val seed: Long) extends Partition { require(size >= 0, "Non-negative partition size required.") } // These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue -private[mllib] class RandomRDD(@transient sc: SparkContext, +private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext, size: Long, numPartitions: Int, - @transient rng: DistributionGenerator, - @transient seed: Long = Utils.random.nextLong) extends RDD[Double](sc, Nil) { + @transient rng: RandomDataGenerator[T], + @transient seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) { require(size > 0, "Positive RDD size required.") require(numPartitions > 0, "Positive number of partitions required") require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue, "Partition size cannot exceed Int.MaxValue") - override def compute(splitIn: Partition, context: TaskContext): Iterator[Double] = { - val split = splitIn.asInstanceOf[RandomRDDPartition] - RandomRDD.getPointIterator(split) + override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { + val split = splitIn.asInstanceOf[RandomRDDPartition[T]] + RandomRDD.getPointIterator[T](split) } override def getPartitions: Array[Partition] = { @@ -59,7 +60,7 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, size: Long, vectorSize: Int, numPartitions: Int, - @transient rng: DistributionGenerator, + @transient rng: RandomDataGenerator[Double], @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) { require(size > 0, "Positive RDD size required.") @@ -69,7 +70,7 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, "Partition size cannot exceed Int.MaxValue") override def compute(splitIn: Partition, context: TaskContext): Iterator[Vector] = { - val split = splitIn.asInstanceOf[RandomRDDPartition] + val split = splitIn.asInstanceOf[RandomRDDPartition[Double]] RandomRDD.getVectorIterator(split, vectorSize) } @@ -80,12 +81,12 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, private[mllib] object RandomRDD { - def getPartitions(size: Long, + def getPartitions[T](size: Long, numPartitions: Int, - rng: DistributionGenerator, + rng: RandomDataGenerator[T], seed: Long): Array[Partition] = { - val partitions = new Array[RandomRDDPartition](numPartitions) + val partitions = new Array[RandomRDDPartition[T]](numPartitions) var i = 0 var start: Long = 0 var end: Long = 0 @@ -101,7 +102,7 @@ private[mllib] object RandomRDD { // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getPointIterator(partition: RandomRDDPartition): Iterator[Double] = { + def getPointIterator[T: ClassTag](partition: RandomRDDPartition[T]): Iterator[T] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) Array.fill(partition.size)(generator.nextValue()).toIterator @@ -109,7 +110,8 @@ private[mllib] object RandomRDD { // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getVectorIterator(partition: RandomRDDPartition, vectorSize: Int): Iterator[Vector] = { + def getVectorIterator(partition: RandomRDDPartition[Double], + vectorSize: Int): Iterator[Vector] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) Array.fill(partition.size)(new DenseVector( http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala deleted file mode 100644 index 974dec4..0000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.random - -import org.scalatest.FunSuite - -import org.apache.spark.util.StatCounter - -// TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged -class DistributionGeneratorSuite extends FunSuite { - - def apiChecks(gen: DistributionGenerator) { - - // resetting seed should generate the same sequence of random numbers - gen.setSeed(42L) - val array1 = (0 until 1000).map(_ => gen.nextValue()) - gen.setSeed(42L) - val array2 = (0 until 1000).map(_ => gen.nextValue()) - assert(array1.equals(array2)) - - // newInstance should contain a difference instance of the rng - // i.e. setting difference seeds for difference instances produces different sequences of - // random numbers. - val gen2 = gen.copy() - gen.setSeed(0L) - val array3 = (0 until 1000).map(_ => gen.nextValue()) - gen2.setSeed(1L) - val array4 = (0 until 1000).map(_ => gen2.nextValue()) - // Compare arrays instead of elements since individual elements can coincide by chance but the - // sequences should differ given two different seeds. - assert(!array3.equals(array4)) - - // test that setting the same seed in the copied instance produces the same sequence of numbers - gen.setSeed(0L) - val array5 = (0 until 1000).map(_ => gen.nextValue()) - gen2.setSeed(0L) - val array6 = (0 until 1000).map(_ => gen2.nextValue()) - assert(array5.equals(array6)) - } - - def distributionChecks(gen: DistributionGenerator, - mean: Double = 0.0, - stddev: Double = 1.0, - epsilon: Double = 0.01) { - for (seed <- 0 until 5) { - gen.setSeed(seed.toLong) - val sample = (0 until 100000).map { _ => gen.nextValue()} - val stats = new StatCounter(sample) - assert(math.abs(stats.mean - mean) < epsilon) - assert(math.abs(stats.stdev - stddev) < epsilon) - } - } - - test("UniformGenerator") { - val uniform = new UniformGenerator() - apiChecks(uniform) - // Stddev of uniform distribution = (ub - lb) / math.sqrt(12) - distributionChecks(uniform, 0.5, 1 / math.sqrt(12)) - } - - test("StandardNormalGenerator") { - val normal = new StandardNormalGenerator() - apiChecks(normal) - distributionChecks(normal, 0.0, 1.0) - } - - test("PoissonGenerator") { - // mean = 0.0 will not pass the API checks since 0.0 is always deterministically produced. - for (mean <- List(1.0, 5.0, 100.0)) { - val poisson = new PoissonGenerator(mean) - apiChecks(poisson) - distributionChecks(poisson, mean, math.sqrt(mean), 0.1) - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala new file mode 100644 index 0000000..3df7c12 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.random + +import org.scalatest.FunSuite + +import org.apache.spark.util.StatCounter + +// TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged +class RandomDataGeneratorSuite extends FunSuite { + + def apiChecks(gen: RandomDataGenerator[Double]) { + + // resetting seed should generate the same sequence of random numbers + gen.setSeed(42L) + val array1 = (0 until 1000).map(_ => gen.nextValue()) + gen.setSeed(42L) + val array2 = (0 until 1000).map(_ => gen.nextValue()) + assert(array1.equals(array2)) + + // newInstance should contain a difference instance of the rng + // i.e. setting difference seeds for difference instances produces different sequences of + // random numbers. + val gen2 = gen.copy() + gen.setSeed(0L) + val array3 = (0 until 1000).map(_ => gen.nextValue()) + gen2.setSeed(1L) + val array4 = (0 until 1000).map(_ => gen2.nextValue()) + // Compare arrays instead of elements since individual elements can coincide by chance but the + // sequences should differ given two different seeds. + assert(!array3.equals(array4)) + + // test that setting the same seed in the copied instance produces the same sequence of numbers + gen.setSeed(0L) + val array5 = (0 until 1000).map(_ => gen.nextValue()) + gen2.setSeed(0L) + val array6 = (0 until 1000).map(_ => gen2.nextValue()) + assert(array5.equals(array6)) + } + + def distributionChecks(gen: RandomDataGenerator[Double], + mean: Double = 0.0, + stddev: Double = 1.0, + epsilon: Double = 0.01) { + for (seed <- 0 until 5) { + gen.setSeed(seed.toLong) + val sample = (0 until 100000).map { _ => gen.nextValue()} + val stats = new StatCounter(sample) + assert(math.abs(stats.mean - mean) < epsilon) + assert(math.abs(stats.stdev - stddev) < epsilon) + } + } + + test("UniformGenerator") { + val uniform = new UniformGenerator() + apiChecks(uniform) + // Stddev of uniform distribution = (ub - lb) / math.sqrt(12) + distributionChecks(uniform, 0.5, 1 / math.sqrt(12)) + } + + test("StandardNormalGenerator") { + val normal = new StandardNormalGenerator() + apiChecks(normal) + distributionChecks(normal, 0.0, 1.0) + } + + test("PoissonGenerator") { + // mean = 0.0 will not pass the API checks since 0.0 is always deterministically produced. + for (mean <- List(1.0, 5.0, 100.0)) { + val poisson = new PoissonGenerator(mean) + apiChecks(poisson) + distributionChecks(poisson, mean, math.sqrt(mean), 0.1) + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/fda47598/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala index 6aa4f80..96e0bc6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala @@ -78,7 +78,9 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri assert(rdd.partitions.size === numPartitions) // check that partition sizes are balanced - val partSizes = rdd.partitions.map(p => p.asInstanceOf[RandomRDDPartition].size.toDouble) + val partSizes = rdd.partitions.map(p => + p.asInstanceOf[RandomRDDPartition[Double]].size.toDouble) + val partStats = new StatCounter(partSizes) assert(partStats.max - partStats.min <= 1) } @@ -89,7 +91,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L) assert(rdd.partitions.size === numPartitions) val count = rdd.partitions.foldLeft(0L) { (count, part) => - count + part.asInstanceOf[RandomRDDPartition].size + count + part.asInstanceOf[RandomRDDPartition[Double]].size } assert(count === size) @@ -145,7 +147,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri } } -private[random] class MockDistro extends DistributionGenerator { +private[random] class MockDistro extends RandomDataGenerator[Double] { var seed = 0L --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org