Repository: spark
Updated Branches:
  refs/heads/master 758c9d25e -> d91967e15


[SPARK-10763] [ML] [JAVA] [TEST] Update Java MLLIB/ML tests to use simplified 
dataframe construction

As introduced in https://issues.apache.org/jira/browse/SPARK-10630 we now have 
an easier way to create dataframes from local Java lists. Lets update the tests 
to use those.

Author: Holden Karau <hol...@pigscanfly.ca>

Closes #8886 from 
holdenk/SPARK-10763-update-java-mllib-ml-tests-to-use-simplified-dataframe-construction.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d91967e1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d91967e1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d91967e1

Branch: refs/heads/master
Commit: d91967e159f416924bbd7f0db25156588d4bd7b1
Parents: 758c9d2
Author: Holden Karau <hol...@pigscanfly.ca>
Authored: Wed Sep 23 22:49:08 2015 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Wed Sep 23 22:49:08 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/JavaNaiveBayesSuite.java  |  8 ++++----
 .../apache/spark/ml/feature/JavaBucketizerSuite.java  | 14 +++++++-------
 .../org/apache/spark/ml/feature/JavaDCTSuite.java     | 11 +++++------
 .../apache/spark/ml/feature/JavaHashingTFSuite.java   |  7 ++++---
 .../ml/feature/JavaPolynomialExpansionSuite.java      |  5 +++--
 .../spark/ml/feature/JavaStopWordsRemoverSuite.java   |  7 ++++---
 .../spark/ml/feature/JavaStringIndexerSuite.java      |  7 ++++---
 .../spark/ml/feature/JavaVectorAssemblerSuite.java    |  3 +--
 .../spark/ml/feature/JavaVectorSlicerSuite.java       |  7 ++++---
 .../apache/spark/ml/feature/JavaWord2VecSuite.java    | 12 ++++++------
 10 files changed, 42 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d91967e1/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 075a62c..f5f690e 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification;
 
 import java.io.Serializable;
 import java.util.Arrays;
+import java.util.List;
 
 import org.junit.After;
 import org.junit.Before;
@@ -75,21 +76,20 @@ public class JavaNaiveBayesSuite implements Serializable {
 
   @Test
   public void testNaiveBayes() {
-    JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
+    List<Row> data = Arrays.asList(
       RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
       RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
       RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
       RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
       RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
-      RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))
-    ));
+      RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0)));
 
     StructType schema = new StructType(new StructField[]{
       new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
       new StructField("features", new VectorUDT(), false, Metadata.empty())
     });
 
-    DataFrame dataset = jsql.createDataFrame(jrdd, schema);
+    DataFrame dataset = jsql.createDataFrame(data, schema);
     NaiveBayes nb = new 
NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
     NaiveBayesModel model = nb.fit(dataset);
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d91967e1/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index 47d68de..8a1e5ef 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -55,16 +55,16 @@ public class JavaBucketizerSuite {
   public void bucketizerTest() {
     double[] splits = {-0.5, 0.0, 0.5};
 
-    JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
-      RowFactory.create(-0.5),
-      RowFactory.create(-0.3),
-      RowFactory.create(0.0),
-      RowFactory.create(0.2)
-    ));
     StructType schema = new StructType(new StructField[] {
       new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
     });
-    DataFrame dataset = jsql.createDataFrame(data, schema);
+    DataFrame dataset = jsql.createDataFrame(
+      Arrays.asList(
+        RowFactory.create(-0.5),
+        RowFactory.create(-0.3),
+        RowFactory.create(0.0),
+        RowFactory.create(0.2)),
+      schema);
 
     Bucketizer bucketizer = new Bucketizer()
       .setInputCol("feature")

http://git-wip-us.apache.org/repos/asf/spark/blob/d91967e1/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
index 0f6ec64..39da473 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -57,12 +57,11 @@ public class JavaDCTSuite {
   @Test
   public void javaCompatibilityTest() {
     double[] input = new double[] {1D, 2D, 3D, 4D};
-    JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
-      RowFactory.create(Vectors.dense(input))
-    ));
-    DataFrame dataset = jsql.createDataFrame(data, new StructType(new 
StructField[]{
-      new StructField("vec", (new VectorUDT()), false, Metadata.empty())
-    }));
+    DataFrame dataset = jsql.createDataFrame(
+      Arrays.asList(RowFactory.create(Vectors.dense(input))),
+      new StructType(new StructField[]{
+        new StructField("vec", (new VectorUDT()), false, Metadata.empty())
+      }));
 
     double[] expectedResult = input.clone();
     (new DoubleDCT_1D(input.length)).forward(expectedResult, true);

http://git-wip-us.apache.org/repos/asf/spark/blob/d91967e1/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 03dd536..d12332c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.feature;
 
 import java.util.Arrays;
+import java.util.List;
 
 import org.junit.After;
 import org.junit.Assert;
@@ -55,17 +56,17 @@ public class JavaHashingTFSuite {
 
   @Test
   public void hashingTF() {
-    JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
+    List<Row> data = Arrays.asList(
       RowFactory.create(0.0, "Hi I heard about Spark"),
       RowFactory.create(0.0, "I wish Java could use case classes"),
       RowFactory.create(1.0, "Logistic regression models are neat")
-    ));
+    );
     StructType schema = new StructType(new StructField[]{
       new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
       new StructField("sentence", DataTypes.StringType, false, 
Metadata.empty())
     });
 
-    DataFrame sentenceData = jsql.createDataFrame(jrdd, schema);
+    DataFrame sentenceData = jsql.createDataFrame(data, schema);
     Tokenizer tokenizer = new Tokenizer()
       .setInputCol("sentence")
       .setOutputCol("words");

http://git-wip-us.apache.org/repos/asf/spark/blob/d91967e1/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
index 834fedb..bf8eefd 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.feature;
 
 import java.util.Arrays;
+import java.util.List;
 
 import org.junit.After;
 import org.junit.Assert;
@@ -60,7 +61,7 @@ public class JavaPolynomialExpansionSuite {
       .setOutputCol("polyFeatures")
       .setDegree(3);
 
-    JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
+    List<Row> data = Arrays.asList(
       RowFactory.create(
         Vectors.dense(-2.0, 2.3),
         Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)
@@ -70,7 +71,7 @@ public class JavaPolynomialExpansionSuite {
         Vectors.dense(0.6, -1.1),
         Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, 
-1.331)
       )
-    ));
+    );
 
     StructType schema = new StructType(new StructField[] {
       new StructField("features", new VectorUDT(), false, Metadata.empty()),

http://git-wip-us.apache.org/repos/asf/spark/blob/d91967e1/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
index 76cdd0f..848d9f8 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.feature;
 
 import java.util.Arrays;
+import java.util.List;
 
 import org.junit.After;
 import org.junit.Before;
@@ -58,14 +59,14 @@ public class JavaStopWordsRemoverSuite {
       .setInputCol("raw")
       .setOutputCol("filtered");
 
-    JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(
+    List<Row> data = Arrays.asList(
       RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
       RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
-    ));
+    );
     StructType schema = new StructType(new StructField[] {
       new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), 
false, Metadata.empty())
     });
-    DataFrame dataset = jsql.createDataFrame(rdd, schema);
+    DataFrame dataset = jsql.createDataFrame(data, schema);
 
     remover.transform(dataset).collect();
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d91967e1/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
index 35b18c5..6b2c48e 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.feature;
 
 import java.util.Arrays;
+import java.util.List;
 
 import org.junit.After;
 import org.junit.Assert;
@@ -56,9 +57,9 @@ public class JavaStringIndexerSuite {
       createStructField("id", IntegerType, false),
       createStructField("label", StringType, false)
     });
-    JavaRDD<Row> rdd = jsc.parallelize(
-      Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), 
c(5, "c")));
-    DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+    List<Row> data = Arrays.asList(
+      c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"));
+    DataFrame dataset = sqlContext.createDataFrame(data, schema);
 
     StringIndexer indexer = new StringIndexer()
       .setInputCol("label")

http://git-wip-us.apache.org/repos/asf/spark/blob/d91967e1/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
index b7c564c..e283777 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -65,8 +65,7 @@ public class JavaVectorAssemblerSuite {
     Row row = RowFactory.create(
       0, 0.0, Vectors.dense(1.0, 2.0), "a",
       Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
-    JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(row));
-    DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+    DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
     VectorAssembler assembler = new VectorAssembler()
       .setInputCols(new String[] {"x", "y", "z", "n"})
       .setOutputCol("features");

http://git-wip-us.apache.org/repos/asf/spark/blob/d91967e1/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
index f953361..00174e6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.feature;
 
 import java.util.Arrays;
+import java.util.List;
 
 import org.junit.After;
 import org.junit.Assert;
@@ -63,12 +64,12 @@ public class JavaVectorSlicerSuite {
     };
     AttributeGroup group = new AttributeGroup("userFeatures", attrs);
 
-    JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
+    List<Row> data = Arrays.asList(
       RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 
2.3})),
       RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
-    ));
+    );
 
-    DataFrame dataset = jsql.createDataFrame(jrdd, (new 
StructType()).add(group.toStructField()));
+    DataFrame dataset = jsql.createDataFrame(data, (new 
StructType()).add(group.toStructField()));
 
     VectorSlicer vectorSlicer = new VectorSlicer()
       .setInputCol("userFeatures").setOutputCol("features");

http://git-wip-us.apache.org/repos/asf/spark/blob/d91967e1/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
index 70f5ad9..0c0c1c4 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
@@ -51,15 +51,15 @@ public class JavaWord2VecSuite {
 
   @Test
   public void testJavaWord2Vec() {
-    JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
-      RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
-      RowFactory.create(Arrays.asList("I wish Java could use case 
classes".split(" "))),
-      RowFactory.create(Arrays.asList("Logistic regression models are 
neat".split(" ")))
-    ));
     StructType schema = new StructType(new StructField[]{
       new StructField("text", new ArrayType(DataTypes.StringType, true), 
false, Metadata.empty())
     });
-    DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema);
+    DataFrame documentDF = sqlContext.createDataFrame(
+      Arrays.asList(
+        RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
+        RowFactory.create(Arrays.asList("I wish Java could use case 
classes".split(" "))),
+        RowFactory.create(Arrays.asList("Logistic regression models are 
neat".split(" ")))),
+      schema);
 
     Word2Vec word2Vec = new Word2Vec()
       .setInputCol("text")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to