Repository: spark
Updated Branches:
  refs/heads/master 014403951 -> 20a61dbd9


[SPARK-10626] [MLLIB] create java friendly method for random rdd

SPARK-3136 added a large number of functions for creating Java RandomRDDs, but 
for people that want to use custom RandomDataGenerators we should make a Java 
friendly method.

Author: Holden Karau <hol...@pigscanfly.ca>

Closes #8782 from holdenk/SPARK-10626-create-java-friendly-method-for-randomRDD.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/20a61dbd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/20a61dbd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/20a61dbd

Branch: refs/heads/master
Commit: 20a61dbd9b57957fcc5b58ef8935533914172b07
Parents: 0144039
Author: Holden Karau <hol...@pigscanfly.ca>
Authored: Mon Sep 21 18:53:28 2015 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Mon Sep 21 18:53:28 2015 +0100

----------------------------------------------------------------------
 .../apache/spark/mllib/random/RandomRDDs.scala  | 52 +++++++++++++++++++-
 .../spark/mllib/random/JavaRandomRDDsSuite.java | 30 +++++++++++
 2 files changed, 81 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/20a61dbd/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
index 4dd5ea2..f8ff26b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
@@ -22,6 +22,7 @@ import scala.reflect.ClassTag
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
 import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext}
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD}
 import org.apache.spark.rdd.RDD
@@ -381,7 +382,7 @@ object RandomRDDs {
    * @param size Size of the RDD.
    * @param numPartitions Number of partitions in the RDD (default: 
`sc.defaultParallelism`).
    * @param seed Random seed (default: a random long integer).
-   * @return RDD[Double] comprised of `i.i.d.` samples produced by generator.
+   * @return RDD[T] comprised of `i.i.d.` samples produced by generator.
    */
   @DeveloperApi
   @Since("1.1.0")
@@ -394,6 +395,55 @@ object RandomRDDs {
     new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), 
generator, seed)
   }
 
+  /**
+   * :: DeveloperApi ::
+   * Generates an RDD comprised of `i.i.d.` samples produced by the input 
RandomDataGenerator.
+   *
+   * @param jsc JavaSparkContext used to create the RDD.
+   * @param generator RandomDataGenerator used to populate the RDD.
+   * @param size Size of the RDD.
+   * @param numPartitions Number of partitions in the RDD (default: 
`sc.defaultParallelism`).
+   * @param seed Random seed (default: a random long integer).
+   * @return RDD[T] comprised of `i.i.d.` samples produced by generator.
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaRDD[T](
+      jsc: JavaSparkContext,
+      generator: RandomDataGenerator[T],
+      size: Long,
+      numPartitions: Int,
+      seed: Long): JavaRDD[T] = {
+    implicit val ctag: ClassTag[T] = fakeClassTag
+    val rdd = randomRDD(jsc.sc, generator, size, numPartitions, seed)
+    JavaRDD.fromRDD(rdd)
+  }
+
+  /**
+   * [[RandomRDDs#randomJavaRDD]] with the default seed.
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaRDD[T](
+    jsc: JavaSparkContext,
+    generator: RandomDataGenerator[T],
+    size: Long,
+    numPartitions: Int): JavaRDD[T] = {
+    randomJavaRDD(jsc, generator, size, numPartitions, Utils.random.nextLong())
+  }
+
+  /**
+   * [[RandomRDDs#randomJavaRDD]] with the default seed & numPartitions
+   */
+  @DeveloperApi
+  @Since("1.6.0")
+  def randomJavaRDD[T](
+    jsc: JavaSparkContext,
+    generator: RandomDataGenerator[T],
+    size: Long): JavaRDD[T] = {
+    randomJavaRDD(jsc, generator, size, 0);
+  }
+
   // TODO Generate RDD[Vector] from multivariate distributions.
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/20a61dbd/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
index 33d81b1..fce5f67 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.mllib.random;
 
+import java.io.Serializable;
 import java.util.Arrays;
 
 import org.apache.spark.api.java.JavaRDD;
@@ -231,4 +232,33 @@ public class JavaRandomRDDsSuite {
     }
   }
 
+  @Test
+  public void testArbitrary() {
+    long size = 10;
+    long seed = 1L;
+    int numPartitions = 0;
+    StringGenerator gen = new StringGenerator();
+    JavaRDD<String> rdd1 = randomJavaRDD(sc, gen, size);
+    JavaRDD<String> rdd2 = randomJavaRDD(sc, gen, size, numPartitions);
+    JavaRDD<String> rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed);
+    for (JavaRDD<String> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+      Assert.assertEquals(size, rdd.count());
+      Assert.assertEquals(2, rdd.first().length());
+    }
+  }
+}
+
+// This is just a test generator, it always returns a string of 42
+class StringGenerator implements RandomDataGenerator<String>, Serializable {
+  @Override
+  public String nextValue() {
+    return "42";
+  }
+  @Override
+  public StringGenerator copy() {
+    return new StringGenerator();
+  }
+  @Override
+  public void setSeed(long seed) {
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to