http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 9f81751..00c59f0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -30,25 +30,26 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; - +import org.apache.spark.sql.SparkSession; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; public class JavaLinearRegressionSuite implements Serializable { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaLinearRegressionSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); dataset.registerTempTable("dataset"); } @@ -65,7 +66,7 @@ public class JavaLinearRegressionSuite implements Serializable { assertEquals("auto", lr.getSolver()); LinearRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction"); + Dataset<Row> predictions = spark.sql("SELECT label, prediction FROM prediction"); predictions.collect(); // Check defaults assertEquals("features", model.getFeaturesCol()); @@ -76,8 +77,8 @@ public class JavaLinearRegressionSuite implements Serializable { public void linearRegressionWithSetters() { // Set params, train, and check as many params as we can. LinearRegression lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(1.0).setSolver("l-bfgs"); + .setMaxIter(10) + .setRegParam(1.0).setSolver("l-bfgs"); LinearRegressionModel model = lr.fit(dataset); LinearRegression parent = (LinearRegression) model.parent(); assertEquals(10, parent.getMaxIter()); @@ -85,7 +86,7 @@ public class JavaLinearRegressionSuite implements Serializable { // Call fit() with new params, and check as many params as we can. LinearRegressionModel model2 = - lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); + lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); LinearRegression parent2 = (LinearRegression) model2.parent(); assertEquals(5, parent2.getMaxIter()); assertEquals(0.1, parent2.getRegParam(), 0.0);
http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index 38b895f..fdb41ff 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -28,27 +28,33 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.ml.tree.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaRandomForestRegressorSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaRandomForestRegressorSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -57,7 +63,7 @@ public class JavaRandomForestRegressorSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD<LabeledPoint> data = sc.parallelize( + JavaRDD<LabeledPoint> data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); @@ -75,22 +81,22 @@ public class JavaRandomForestRegressorSuite implements Serializable { .setSeed(1234) .setNumTrees(3) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: RandomForestRegressor.supportedImpurities()) { + for (String impurity : RandomForestRegressor.supportedImpurities()) { rf.setImpurity(impurity); } - for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { + for (String featureSubsetStrategy : RandomForestRegressor.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; - for (String strategy: realStrategies) { + for (String strategy : realStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] integerStrategies = {"1", "10", "100", "1000", "10000"}; - for (String strategy: integerStrategies) { + for (String strategy : integerStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; - for (String strategy: invalidStrategies) { + for (String strategy : invalidStrategies) { try { rf.setFeatureSubsetStrategy(strategy); Assert.fail("Expected exception to be thrown for invalid strategies"); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index 1c18b2b..058f2dd 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -28,12 +28,11 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.util.Utils; @@ -41,16 +40,17 @@ import org.apache.spark.util.Utils; * Test LibSVMRelation in Java. */ public class JavaLibSVMRelationSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; + private transient SparkSession spark; private File tempDir; private String path; @Before public void setUp() throws IOException { - jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); - sqlContext = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaLibSVMRelationSuite") + .getOrCreate(); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); File file = new File(tempDir, "part-00000"); @@ -61,14 +61,14 @@ public class JavaLibSVMRelationSuite { @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; Utils.deleteRecursively(tempDir); } @Test public void verifyLibSVMDF() { - Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + Dataset<Row> dataset = spark.read().format("libsvm").option("vectorType", "dense") .load(path); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 24b0097..8b4d034 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -32,21 +32,25 @@ import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; public class JavaCrossValidatorSuite implements Serializable { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; private transient Dataset<Row> dataset; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaCrossValidatorSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); + dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); } @After @@ -59,8 +63,8 @@ public class JavaCrossValidatorSuite implements Serializable { public void crossValidationWithLogisticRegression() { LogisticRegression lr = new LogisticRegression(); ParamMap[] lrParamMaps = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.001, 1000.0}) - .addGrid(lr.maxIter(), new int[] {0, 10}) + .addGrid(lr.regParam(), new double[]{0.001, 1000.0}) + .addGrid(lr.maxIter(), new int[]{0, 10}) .build(); BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator(); CrossValidator cv = new CrossValidator() http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala index 9283015..878bc66 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala +++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala @@ -37,4 +37,5 @@ object IdentifiableSuite { class Test(override val uid: String) extends Identifiable { def this() = this(Identifiable.randomUID("test")) } + } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index 01ff1ea..7151e27 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -27,31 +27,34 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.util.Utils; public class JavaDefaultReadWriteSuite { JavaSparkContext jsc = null; - SQLContext sqlContext = null; + SparkSession spark = null; File tempDir = null; @Before public void setUp() { - jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite"); SQLContext.clearActive(); - sqlContext = new SQLContext(jsc); - SQLContext.setActive(sqlContext); + spark = SparkSession.builder() + .master("local[2]") + .appName("JavaDefaultReadWriteSuite") + .getOrCreate(); + SQLContext.setActive(spark.wrapped()); + tempDir = Utils.createTempDir( System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); } @After public void tearDown() { - sqlContext = null; SQLContext.clearActive(); - if (jsc != null) { - jsc.stop(); - jsc = null; + if (spark != null) { + spark.stop(); + spark = null; } Utils.deleteRecursively(tempDir); } @@ -70,7 +73,7 @@ public class JavaDefaultReadWriteSuite { } catch (IOException e) { // expected } - instance.write().context(sqlContext).overwrite().save(outputPath); + instance.write().context(spark.wrapped()).overwrite().save(outputPath); MyParams newInstance = MyParams.load(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); Assert.assertEquals("Params should be preserved.", http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java index 862221d..2f10d14 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java @@ -27,26 +27,31 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; - import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.SparkSession; public class JavaLogisticRegressionSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaLogisticRegressionSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } int validatePrediction(List<LabeledPoint> validationData, LogisticRegressionModel model) { int numAccurate = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); if (prediction == point.label()) { numAccurate++; @@ -61,16 +66,16 @@ public class JavaLogisticRegressionSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD<LabeledPoint> testRDD = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + JavaRDD<LabeledPoint> testRDD = jsc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); List<LabeledPoint> validationData = - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(); lrImpl.setIntercept(true); lrImpl.optimizer().setStepSize(1.0) - .setRegParam(1.0) - .setNumIterations(100); + .setRegParam(1.0) + .setNumIterations(100); LogisticRegressionModel model = lrImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); @@ -83,13 +88,13 @@ public class JavaLogisticRegressionSuite implements Serializable { double A = 0.0; double B = -2.5; - JavaRDD<LabeledPoint> testRDD = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + JavaRDD<LabeledPoint> testRDD = jsc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); List<LabeledPoint> validationData = - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); LogisticRegressionModel model = LogisticRegressionWithSGD.train( - testRDD.rdd(), 100, 1.0, 1.0); + testRDD.rdd(), 100, 1.0, 1.0); int numAccurate = validatePrediction(validationData, model); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 3771c0e..5e212e2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -32,20 +32,26 @@ import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.SparkSession; public class JavaNaiveBayesSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaNaiveBayesSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaNaiveBayesSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } private static final List<LabeledPoint> POINTS = Arrays.asList( @@ -59,7 +65,7 @@ public class JavaNaiveBayesSuite implements Serializable { private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) { int correct = 0; - for (LabeledPoint p: points) { + for (LabeledPoint p : points) { if (model.predict(p.features()) == p.label()) { correct += 1; } @@ -69,7 +75,7 @@ public class JavaNaiveBayesSuite implements Serializable { @Test public void runUsingConstructor() { - JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache(); + JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache(); NaiveBayes nb = new NaiveBayes().setLambda(1.0); NaiveBayesModel model = nb.run(testRDD.rdd()); @@ -80,7 +86,7 @@ public class JavaNaiveBayesSuite implements Serializable { @Test public void runUsingStaticMethods() { - JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache(); + JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache(); NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd()); int numAccurate1 = validatePrediction(POINTS, model1); @@ -93,13 +99,14 @@ public class JavaNaiveBayesSuite implements Serializable { @Test public void testPredictJavaRDD() { - JavaRDD<LabeledPoint> examples = sc.parallelize(POINTS, 2).cache(); + JavaRDD<LabeledPoint> examples = jsc.parallelize(POINTS, 2).cache(); NaiveBayesModel model = NaiveBayes.train(examples.rdd()); JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, Vector>() { @Override public Vector call(LabeledPoint v) throws Exception { return v.features(); - }}); + } + }); JavaRDD<Double> predictions = model.predict(vectors); // Should be able to get the first prediction. predictions.first(); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java index 31b9f3e..2a090c0 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java @@ -28,24 +28,30 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.SparkSession; public class JavaSVMSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaSVMSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaSVMSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } int validatePrediction(List<LabeledPoint> validationData, SVMModel model) { int numAccurate = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); if (prediction == point.label()) { numAccurate++; @@ -60,16 +66,16 @@ public class JavaSVMSuite implements Serializable { double A = 2.0; double[] weights = {-1.5, 1.0}; - JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, - weights, nPoints, 42), 2).cache(); + JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A, + weights, nPoints, 42), 2).cache(); List<LabeledPoint> validationData = - SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); + SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); SVMWithSGD svmSGDImpl = new SVMWithSGD(); svmSGDImpl.setIntercept(true); svmSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(1.0) - .setNumIterations(100); + .setRegParam(1.0) + .setNumIterations(100); SVMModel model = svmSGDImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); @@ -82,10 +88,10 @@ public class JavaSVMSuite implements Serializable { double A = 0.0; double[] weights = {-1.5, 1.0}; - JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, - weights, nPoints, 42), 2).cache(); + JavaRDD<LabeledPoint> testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A, + weights, nPoints, 42), 2).cache(); List<LabeledPoint> validationData = - SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); + SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java index a714620..7f29b05 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering; import java.io.Serializable; import com.google.common.collect.Lists; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -29,27 +30,33 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.SparkSession; public class JavaBisectingKMeansSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", this.getClass().getSimpleName()); + spark = SparkSession.builder() + .master("local") + .appName("JavaBisectingKMeansSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test public void twoDimensionalData() { - JavaRDD<Vector> points = sc.parallelize(Lists.newArrayList( + JavaRDD<Vector> points = jsc.parallelize(Lists.newArrayList( Vectors.dense(4, -1), Vectors.dense(4, 1), - Vectors.sparse(2, new int[] {0}, new double[] {1.0}) + Vectors.sparse(2, new int[]{0}, new double[]{1.0}) ), 2); BisectingKMeans bkm = new BisectingKMeans() @@ -58,15 +65,15 @@ public class JavaBisectingKMeansSuite implements Serializable { .setSeed(1L); BisectingKMeansModel model = bkm.run(points); Assert.assertEquals(3, model.k()); - Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12); - for (ClusteringTreeNode child: model.root().children()) { + Assert.assertArrayEquals(new double[]{3.0, 0.0}, model.root().center().toArray(), 1e-12); + for (ClusteringTreeNode child : model.root().children()) { double[] center = child.center().toArray(); if (center[0] > 2) { Assert.assertEquals(2, child.size()); - Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12); + Assert.assertArrayEquals(new double[]{4.0, 0.0}, center, 1e-12); } else { Assert.assertEquals(1, child.size()); - Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12); + Assert.assertArrayEquals(new double[]{1.0, 0.0}, center, 1e-12); } } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java index 123f78d..20edd08 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -21,29 +21,35 @@ import java.io.Serializable; import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertEquals; + import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.assertEquals; - import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.SparkSession; public class JavaGaussianMixtureSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaGaussianMixture"); + spark = SparkSession.builder() + .master("local") + .appName("JavaGaussianMixture") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -54,7 +60,7 @@ public class JavaGaussianMixtureSuite implements Serializable { Vectors.dense(1.0, 4.0, 6.0) ); - JavaRDD<Vector> data = sc.parallelize(points, 2); + JavaRDD<Vector> data = jsc.parallelize(points, 2); GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234) .run(data); assertEquals(model.gaussians().length, 2); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java index ad06676..4e5b87f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -21,28 +21,35 @@ import java.io.Serializable; import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertEquals; + import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.SparkSession; public class JavaKMeansSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaKMeans"); + spark = SparkSession.builder() + .master("local") + .appName("JavaKMeans") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -55,7 +62,7 @@ public class JavaKMeansSuite implements Serializable { Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0); - JavaRDD<Vector> data = sc.parallelize(points, 2); + JavaRDD<Vector> data = jsc.parallelize(points, 2); KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL()); assertEquals(1, model.clusterCenters().length); assertEquals(expectedCenter, model.clusterCenters()[0]); @@ -74,7 +81,7 @@ public class JavaKMeansSuite implements Serializable { Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0); - JavaRDD<Vector> data = sc.parallelize(points, 2); + JavaRDD<Vector> data = jsc.parallelize(points, 2); KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); assertEquals(1, model.clusterCenters().length); assertEquals(expectedCenter, model.clusterCenters()[0]); @@ -94,7 +101,7 @@ public class JavaKMeansSuite implements Serializable { Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) ); - JavaRDD<Vector> data = sc.parallelize(points, 2); + JavaRDD<Vector> data = jsc.parallelize(points, 2); KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); JavaRDD<Integer> predictions = model.predict(data); // Should be able to get the first prediction. http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index db19b30..f16585a 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -27,37 +27,42 @@ import scala.Tuple3; import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.SparkSession; public class JavaLDASuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaLDA"); + spark = SparkSession.builder() + .master("local") + .appName("JavaLDASuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>(); for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { - tinyCorpus.add(new Tuple2<>((Long)LDASuite.tinyCorpus()[i]._1(), - LDASuite.tinyCorpus()[i]._2())); + tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(), + LDASuite.tinyCorpus()[i]._2())); } - JavaRDD<Tuple2<Long, Vector>> tmpCorpus = sc.parallelize(tinyCorpus, 2); + JavaRDD<Tuple2<Long, Vector>> tmpCorpus = jsc.parallelize(tinyCorpus, 2); corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -95,7 +100,7 @@ public class JavaLDASuite implements Serializable { .setMaxIterations(5) .setSeed(12345); - DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); + DistributedLDAModel model = (DistributedLDAModel) lda.run(corpus); // Check: basic parameters LocalLDAModel localModel = model.toLocal(); @@ -124,7 +129,7 @@ public class JavaLDASuite implements Serializable { public Boolean call(Tuple2<Long, Vector> tuple2) { return Vectors.norm(tuple2._2(), 1.0) != 0.0; } - }); + }); assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); // Check: javaTopTopicsPerDocuments @@ -179,7 +184,7 @@ public class JavaLDASuite implements Serializable { @Test public void localLdaMethods() { - JavaRDD<Tuple2<Long, Vector>> docs = sc.parallelize(toyData, 2); + JavaRDD<Tuple2<Long, Vector>> docs = jsc.parallelize(toyData, 2); JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs); // check: topicDistributions @@ -191,7 +196,7 @@ public class JavaLDASuite implements Serializable { // check: logLikelihood. ArrayList<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<>(); docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0))); - JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); + JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(jsc.parallelize(docsSingleWord)); double logLikelihood = toyModel.logLikelihood(single); } @@ -199,7 +204,7 @@ public class JavaLDASuite implements Serializable { private static int tinyVocabSize = LDASuite.tinyVocabSize(); private static Matrix tinyTopics = LDASuite.tinyTopics(); private static Tuple2<int[], double[]>[] tinyTopicDescription = - LDASuite.tinyTopicDescription(); + LDASuite.tinyTopicDescription(); private JavaPairRDD<Long, Vector> corpus; private LocalLDAModel toyModel = LDASuite.toyModel(); private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData(); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java index 62edbd3..d1d618f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -27,8 +27,6 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.apache.spark.streaming.JavaTestUtils.*; - import org.apache.spark.SparkConf; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; @@ -36,6 +34,7 @@ import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; public class JavaStreamingKMeansSuite implements Serializable { http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java index fa4d334..6a096d6 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java @@ -31,27 +31,34 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; public class JavaRankingMetricsSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaRankingMetricsSuite"); - predictionAndLabels = sc.parallelize(Arrays.asList( + spark = SparkSession.builder() + .master("local") + .appName("JavaPCASuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + + predictionAndLabels = jsc.parallelize(Arrays.asList( Tuple2$.MODULE$.apply( Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)), Tuple2$.MODULE$.apply( - Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)), + Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)), Tuple2$.MODULE$.apply( - Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2); + Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java index 8a320af..de50fb8 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -29,19 +29,25 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.SparkSession; public class JavaTfIdfSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaTfIdfSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaPCASuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -49,7 +55,7 @@ public class JavaTfIdfSuite implements Serializable { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList( + JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList( Arrays.asList("this is a sentence".split(" ")), Arrays.asList("this is another sentence".split(" ")), Arrays.asList("this is still a sentence".split(" "))), 2); @@ -59,7 +65,7 @@ public class JavaTfIdfSuite implements Serializable { JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs); List<Vector> localTfIdfs = tfIdfs.collect(); int indexOfThis = tf.indexOf("this"); - for (Vector v: localTfIdfs) { + for (Vector v : localTfIdfs) { Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); } } @@ -69,7 +75,7 @@ public class JavaTfIdfSuite implements Serializable { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList( + JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList( Arrays.asList("this is a sentence".split(" ")), Arrays.asList("this is another sentence".split(" ")), Arrays.asList("this is still a sentence".split(" "))), 2); @@ -79,7 +85,7 @@ public class JavaTfIdfSuite implements Serializable { JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs); List<Vector> localTfIdfs = tfIdfs.collect(); int indexOfThis = tf.indexOf("this"); - for (Vector v: localTfIdfs) { + for (Vector v : localTfIdfs) { Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java index e13ed07..64885cc 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java @@ -21,9 +21,10 @@ import java.io.Serializable; import java.util.Arrays; import java.util.List; +import com.google.common.base.Strings; + import scala.Tuple2; -import com.google.common.base.Strings; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -31,19 +32,25 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; public class JavaWord2VecSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaWord2VecSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaPCASuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -53,7 +60,7 @@ public class JavaWord2VecSuite implements Serializable { String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10); List<String> words = Arrays.asList(sentence.split(" ")); List<List<String>> localDoc = Arrays.asList(words, words); - JavaRDD<List<String>> doc = sc.parallelize(localDoc); + JavaRDD<List<String>> doc = jsc.parallelize(localDoc); Word2Vec word2vec = new Word2Vec() .setVectorSize(10) .setSeed(42L); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index 2bef7a8..fdc19a5 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -26,32 +26,37 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; +import org.apache.spark.sql.SparkSession; public class JavaAssociationRulesSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaFPGrowth"); + spark = SparkSession.builder() + .master("local") + .appName("JavaAssociationRulesSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test public void runAssociationRules() { @SuppressWarnings("unchecked") - JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = sc.parallelize(Arrays.asList( - new FreqItemset<String>(new String[] {"a"}, 15L), - new FreqItemset<String>(new String[] {"b"}, 35L), - new FreqItemset<String>(new String[] {"a", "b"}, 12L) + JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = jsc.parallelize(Arrays.asList( + new FreqItemset<String>(new String[]{"a"}, 15L), + new FreqItemset<String>(new String[]{"b"}, 35L), + new FreqItemset<String>(new String[]{"a", "b"}, 12L) )); JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets); } } - http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index 916fff1..f235251 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -22,34 +22,41 @@ import java.io.Serializable; import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertEquals; + import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.util.Utils; public class JavaFPGrowthSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaFPGrowth"); + spark = SparkSession.builder() + .master("local") + .appName("JavaFPGrowth") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test public void runFPGrowth() { @SuppressWarnings("unchecked") - JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList( + JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList( Arrays.asList("r z h k p".split(" ")), Arrays.asList("z y x w v u t s".split(" ")), Arrays.asList("s x o n r".split(" ")), @@ -65,7 +72,7 @@ public class JavaFPGrowthSuite implements Serializable { List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect(); assertEquals(18, freqItemsets.size()); - for (FPGrowth.FreqItemset<String> itemset: freqItemsets) { + for (FPGrowth.FreqItemset<String> itemset : freqItemsets) { // Test return types. List<String> items = itemset.javaItems(); long freq = itemset.freq(); @@ -76,7 +83,7 @@ public class JavaFPGrowthSuite implements Serializable { public void runFPGrowthSaveLoad() { @SuppressWarnings("unchecked") - JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList( + JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList( Arrays.asList("r z h k p".split(" ")), Arrays.asList("z y x w v u t s".split(" ")), Arrays.asList("s x o n r".split(" ")), @@ -94,15 +101,15 @@ public class JavaFPGrowthSuite implements Serializable { String outputPath = tempDir.getPath(); try { - model.save(sc.sc(), outputPath); + model.save(spark.sparkContext(), outputPath); @SuppressWarnings("unchecked") FPGrowthModel<String> newModel = - (FPGrowthModel<String>) FPGrowthModel.load(sc.sc(), outputPath); + (FPGrowthModel<String>) FPGrowthModel.load(spark.sparkContext(), outputPath); List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD() .collect(); assertEquals(18, freqItemsets.size()); - for (FPGrowth.FreqItemset<String> itemset: freqItemsets) { + for (FPGrowth.FreqItemset<String> itemset : freqItemsets) { // Test return types. List<String> items = itemset.javaItems(); long freq = itemset.freq(); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java index 8a67793..bf7f1fc 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java @@ -29,25 +29,31 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence; +import org.apache.spark.sql.SparkSession; import org.apache.spark.util.Utils; public class JavaPrefixSpanSuite { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaPrefixSpan"); + spark = SparkSession.builder() + .master("local") + .appName("JavaPrefixSpan") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test public void runPrefixSpan() { - JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList( + JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList( Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), @@ -61,7 +67,7 @@ public class JavaPrefixSpanSuite { List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect(); Assert.assertEquals(5, localFreqSeqs.size()); // Check that each frequent sequence could be materialized. - for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) { + for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) { List<List<Integer>> seq = freqSeq.javaSequence(); long freq = freqSeq.freq(); } @@ -69,7 +75,7 @@ public class JavaPrefixSpanSuite { @Test public void runPrefixSpanSaveLoad() { - JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList( + JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList( Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), @@ -85,13 +91,13 @@ public class JavaPrefixSpanSuite { String outputPath = tempDir.getPath(); try { - model.save(sc.sc(), outputPath); - PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath); + model.save(spark.sparkContext(), outputPath); + PrefixSpanModel newModel = PrefixSpanModel.load(spark.sparkContext(), outputPath); JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD(); List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect(); Assert.assertEquals(5, localFreqSeqs.size()); // Check that each frequent sequence could be materialized. - for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) { + for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) { List<List<Integer>> seq = freqSeq.javaSequence(); long freq = freqSeq.freq(); } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java index 8beea10..92fc578 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -17,147 +17,149 @@ package org.apache.spark.mllib.linalg; -import static org.junit.Assert.*; -import org.junit.Test; - import java.io.Serializable; import java.util.Random; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + public class JavaMatricesSuite implements Serializable { - @Test - public void randMatrixConstruction() { - Random rng = new Random(24); - Matrix r = Matrices.rand(3, 4, rng); - rng.setSeed(24); - DenseMatrix dr = DenseMatrix.rand(3, 4, rng); - assertArrayEquals(r.toArray(), dr.toArray(), 0.0); - - rng.setSeed(24); - Matrix rn = Matrices.randn(3, 4, rng); - rng.setSeed(24); - DenseMatrix drn = DenseMatrix.randn(3, 4, rng); - assertArrayEquals(rn.toArray(), drn.toArray(), 0.0); - - rng.setSeed(24); - Matrix s = Matrices.sprand(3, 4, 0.5, rng); - rng.setSeed(24); - SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng); - assertArrayEquals(s.toArray(), sr.toArray(), 0.0); - - rng.setSeed(24); - Matrix sn = Matrices.sprandn(3, 4, 0.5, rng); - rng.setSeed(24); - SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng); - assertArrayEquals(sn.toArray(), srn.toArray(), 0.0); - } - - @Test - public void identityMatrixConstruction() { - Matrix r = Matrices.eye(2); - DenseMatrix dr = DenseMatrix.eye(2); - SparseMatrix sr = SparseMatrix.speye(2); - assertArrayEquals(r.toArray(), dr.toArray(), 0.0); - assertArrayEquals(sr.toArray(), dr.toArray(), 0.0); - assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0); - } - - @Test - public void diagonalMatrixConstruction() { - Vector v = Vectors.dense(1.0, 0.0, 2.0); - Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0}); - - Matrix m = Matrices.diag(v); - Matrix sm = Matrices.diag(sv); - DenseMatrix d = DenseMatrix.diag(v); - DenseMatrix sd = DenseMatrix.diag(sv); - SparseMatrix s = SparseMatrix.spdiag(v); - SparseMatrix ss = SparseMatrix.spdiag(sv); - - assertArrayEquals(m.toArray(), sm.toArray(), 0.0); - assertArrayEquals(d.toArray(), sm.toArray(), 0.0); - assertArrayEquals(d.toArray(), sd.toArray(), 0.0); - assertArrayEquals(sd.toArray(), s.toArray(), 0.0); - assertArrayEquals(s.toArray(), ss.toArray(), 0.0); - assertArrayEquals(s.values(), ss.values(), 0.0); - assertEquals(2, s.values().length); - assertEquals(2, ss.values().length); - assertEquals(4, s.colPtrs().length); - assertEquals(4, ss.colPtrs().length); - } - - @Test - public void zerosMatrixConstruction() { - Matrix z = Matrices.zeros(2, 2); - Matrix one = Matrices.ones(2, 2); - DenseMatrix dz = DenseMatrix.zeros(2, 2); - DenseMatrix done = DenseMatrix.ones(2, 2); - - assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); - assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); - assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); - assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); - } - - @Test - public void sparseDenseConversion() { - int m = 3; - int n = 2; - double[] values = new double[]{1.0, 2.0, 4.0, 5.0}; - double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0}; - int[] colPtrs = new int[]{0, 2, 4}; - int[] rowIndices = new int[]{0, 1, 1, 2}; - - SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values); - DenseMatrix deMat1 = new DenseMatrix(m, n, allValues); - - SparseMatrix spMat2 = deMat1.toSparse(); - DenseMatrix deMat2 = spMat1.toDense(); - - assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0); - assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0); - } - - @Test - public void concatenateMatrices() { - int m = 3; - int n = 2; - - Random rng = new Random(42); - SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng); - rng.setSeed(42); - DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng); - Matrix deMat2 = Matrices.eye(3); - Matrix spMat2 = Matrices.speye(3); - Matrix deMat3 = Matrices.eye(2); - Matrix spMat3 = Matrices.speye(2); - - Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2}); - Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2}); - Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); - Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); - - assertEquals(3, deHorz1.numRows()); - assertEquals(3, deHorz2.numRows()); - assertEquals(3, deHorz3.numRows()); - assertEquals(3, spHorz.numRows()); - assertEquals(5, deHorz1.numCols()); - assertEquals(5, deHorz2.numCols()); - assertEquals(5, deHorz3.numCols()); - assertEquals(5, spHorz.numCols()); - - Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); - Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); - Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); - Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); - - assertEquals(5, deVert1.numRows()); - assertEquals(5, deVert2.numRows()); - assertEquals(5, deVert3.numRows()); - assertEquals(5, spVert.numRows()); - assertEquals(2, deVert1.numCols()); - assertEquals(2, deVert2.numCols()); - assertEquals(2, deVert3.numCols()); - assertEquals(2, spVert.numCols()); - } + @Test + public void randMatrixConstruction() { + Random rng = new Random(24); + Matrix r = Matrices.rand(3, 4, rng); + rng.setSeed(24); + DenseMatrix dr = DenseMatrix.rand(3, 4, rng); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + + rng.setSeed(24); + Matrix rn = Matrices.randn(3, 4, rng); + rng.setSeed(24); + DenseMatrix drn = DenseMatrix.randn(3, 4, rng); + assertArrayEquals(rn.toArray(), drn.toArray(), 0.0); + + rng.setSeed(24); + Matrix s = Matrices.sprand(3, 4, 0.5, rng); + rng.setSeed(24); + SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng); + assertArrayEquals(s.toArray(), sr.toArray(), 0.0); + + rng.setSeed(24); + Matrix sn = Matrices.sprandn(3, 4, 0.5, rng); + rng.setSeed(24); + SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng); + assertArrayEquals(sn.toArray(), srn.toArray(), 0.0); + } + + @Test + public void identityMatrixConstruction() { + Matrix r = Matrices.eye(2); + DenseMatrix dr = DenseMatrix.eye(2); + SparseMatrix sr = SparseMatrix.speye(2); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + assertArrayEquals(sr.toArray(), dr.toArray(), 0.0); + assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0); + } + + @Test + public void diagonalMatrixConstruction() { + Vector v = Vectors.dense(1.0, 0.0, 2.0); + Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0}); + + Matrix m = Matrices.diag(v); + Matrix sm = Matrices.diag(sv); + DenseMatrix d = DenseMatrix.diag(v); + DenseMatrix sd = DenseMatrix.diag(sv); + SparseMatrix s = SparseMatrix.spdiag(v); + SparseMatrix ss = SparseMatrix.spdiag(sv); + + assertArrayEquals(m.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sd.toArray(), 0.0); + assertArrayEquals(sd.toArray(), s.toArray(), 0.0); + assertArrayEquals(s.toArray(), ss.toArray(), 0.0); + assertArrayEquals(s.values(), ss.values(), 0.0); + assertEquals(2, s.values().length); + assertEquals(2, ss.values().length); + assertEquals(4, s.colPtrs().length); + assertEquals(4, ss.colPtrs().length); + } + + @Test + public void zerosMatrixConstruction() { + Matrix z = Matrices.zeros(2, 2); + Matrix one = Matrices.ones(2, 2); + DenseMatrix dz = DenseMatrix.zeros(2, 2); + DenseMatrix done = DenseMatrix.ones(2, 2); + + assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + } + + @Test + public void sparseDenseConversion() { + int m = 3; + int n = 2; + double[] values = new double[]{1.0, 2.0, 4.0, 5.0}; + double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0}; + int[] colPtrs = new int[]{0, 2, 4}; + int[] rowIndices = new int[]{0, 1, 1, 2}; + + SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values); + DenseMatrix deMat1 = new DenseMatrix(m, n, allValues); + + SparseMatrix spMat2 = deMat1.toSparse(); + DenseMatrix deMat2 = spMat1.toDense(); + + assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0); + assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0); + } + + @Test + public void concatenateMatrices() { + int m = 3; + int n = 2; + + Random rng = new Random(42); + SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng); + rng.setSeed(42); + DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng); + Matrix deMat2 = Matrices.eye(3); + Matrix spMat2 = Matrices.speye(3); + Matrix deMat3 = Matrices.eye(2); + Matrix spMat3 = Matrices.speye(2); + + Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2}); + Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2}); + Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); + Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); + + assertEquals(3, deHorz1.numRows()); + assertEquals(3, deHorz2.numRows()); + assertEquals(3, deHorz3.numRows()); + assertEquals(3, spHorz.numRows()); + assertEquals(5, deHorz1.numCols()); + assertEquals(5, deHorz2.numCols()); + assertEquals(5, deHorz3.numCols()); + assertEquals(5, spHorz.numCols()); + + Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); + Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); + Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); + Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); + + assertEquals(5, deVert1.numRows()); + assertEquals(5, deVert2.numRows()); + assertEquals(5, deVert3.numRows()); + assertEquals(5, spVert.numRows()); + assertEquals(2, deVert1.numCols()); + assertEquals(2, deVert2.numCols()); + assertEquals(2, deVert3.numCols()); + assertEquals(2, spVert.numCols()); + } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 4ba8e54..817b962 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -20,10 +20,11 @@ package org.apache.spark.mllib.linalg; import java.io.Serializable; import java.util.Arrays; +import static org.junit.Assert.assertArrayEquals; + import scala.Tuple2; import org.junit.Test; -import static org.junit.Assert.*; public class JavaVectorsSuite implements Serializable { @@ -37,8 +38,8 @@ public class JavaVectorsSuite implements Serializable { public void sparseArrayConstruction() { @SuppressWarnings("unchecked") Vector v = Vectors.sparse(3, Arrays.asList( - new Tuple2<>(0, 2.0), - new Tuple2<>(2, 3.0))); + new Tuple2<>(0, 2.0), + new Tuple2<>(2, 3.0))); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index be58691..b449108 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -20,29 +20,35 @@ package org.apache.spark.mllib.random; import java.io.Serializable; import java.util.Arrays; -import org.apache.spark.api.java.JavaRDD; -import org.junit.Assert; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.SparkSession; import static org.apache.spark.mllib.random.RandomRDDs.*; public class JavaRandomRDDsSuite { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaRandomRDDsSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaRandomRDDsSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -50,10 +56,10 @@ public class JavaRandomRDDsSuite { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m); - JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p); - JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = uniformJavaRDD(jsc, m); + JavaDoubleRDD rdd2 = uniformJavaRDD(jsc, m, p); + JavaDoubleRDD rdd3 = uniformJavaRDD(jsc, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -63,10 +69,10 @@ public class JavaRandomRDDsSuite { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = normalJavaRDD(sc, m); - JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p); - JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = normalJavaRDD(jsc, m); + JavaDoubleRDD rdd2 = normalJavaRDD(jsc, m, p); + JavaDoubleRDD rdd3 = normalJavaRDD(jsc, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -78,10 +84,10 @@ public class JavaRandomRDDsSuite { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m); - JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p); - JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = logNormalJavaRDD(jsc, mean, std, m); + JavaDoubleRDD rdd2 = logNormalJavaRDD(jsc, mean, std, m, p); + JavaDoubleRDD rdd3 = logNormalJavaRDD(jsc, mean, std, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -92,10 +98,10 @@ public class JavaRandomRDDsSuite { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m); - JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p); - JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = poissonJavaRDD(jsc, mean, m); + JavaDoubleRDD rdd2 = poissonJavaRDD(jsc, mean, m, p); + JavaDoubleRDD rdd3 = poissonJavaRDD(jsc, mean, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -106,10 +112,10 @@ public class JavaRandomRDDsSuite { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m); - JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p); - JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = exponentialJavaRDD(jsc, mean, m); + JavaDoubleRDD rdd2 = exponentialJavaRDD(jsc, mean, m, p); + JavaDoubleRDD rdd3 = exponentialJavaRDD(jsc, mean, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -117,14 +123,14 @@ public class JavaRandomRDDsSuite { @Test public void testGammaRDD() { double shape = 1.0; - double scale = 2.0; + double jscale = 2.0; long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m); - JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p); - JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = gammaJavaRDD(jsc, shape, jscale, m); + JavaDoubleRDD rdd2 = gammaJavaRDD(jsc, shape, jscale, m, p); + JavaDoubleRDD rdd3 = gammaJavaRDD(jsc, shape, jscale, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -137,10 +143,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(sc, m, n); - JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(sc, m, n, p); - JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(jsc, m, n); + JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(jsc, m, n, p); + JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(jsc, m, n, p, seed); + for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -153,10 +159,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD<Vector> rdd1 = normalJavaVectorRDD(sc, m, n); - JavaRDD<Vector> rdd2 = normalJavaVectorRDD(sc, m, n, p); - JavaRDD<Vector> rdd3 = normalJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD<Vector> rdd1 = normalJavaVectorRDD(jsc, m, n); + JavaRDD<Vector> rdd2 = normalJavaVectorRDD(jsc, m, n, p); + JavaRDD<Vector> rdd3 = normalJavaVectorRDD(jsc, m, n, p, seed); + for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -171,10 +177,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n); - JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p); - JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed); - for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(jsc, mean, std, m, n); + JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p); + JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p, seed); + for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -188,10 +194,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(sc, mean, m, n); - JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p); - JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(jsc, mean, m, n); + JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(jsc, mean, m, n, p); + JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(jsc, mean, m, n, p, seed); + for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -205,10 +211,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(sc, mean, m, n); - JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p); - JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(jsc, mean, m, n); + JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(jsc, mean, m, n, p); + JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(jsc, mean, m, n, p, seed); + for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -218,15 +224,15 @@ public class JavaRandomRDDsSuite { @SuppressWarnings("unchecked") public void testGammaVectorRDD() { double shape = 1.0; - double scale = 2.0; + double jscale = 2.0; long m = 100L; int n = 10; int p = 2; long seed = 1L; - JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n); - JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p); - JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed); - for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(jsc, shape, jscale, m, n); + JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p); + JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p, seed); + for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -238,10 +244,10 @@ public class JavaRandomRDDsSuite { long seed = 1L; int numPartitions = 0; StringGenerator gen = new StringGenerator(); - JavaRDD<String> rdd1 = randomJavaRDD(sc, gen, size); - JavaRDD<String> rdd2 = randomJavaRDD(sc, gen, size, numPartitions); - JavaRDD<String> rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed); - for (JavaRDD<String> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD<String> rdd1 = randomJavaRDD(jsc, gen, size); + JavaRDD<String> rdd2 = randomJavaRDD(jsc, gen, size, numPartitions); + JavaRDD<String> rdd3 = randomJavaRDD(jsc, gen, size, numPartitions, seed); + for (JavaRDD<String> rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(size, rdd.count()); Assert.assertEquals(2, rdd.first().length()); } @@ -255,10 +261,10 @@ public class JavaRandomRDDsSuite { int n = 10; int p = 2; long seed = 1L; - JavaRDD<Vector> rdd1 = randomJavaVectorRDD(sc, generator, m, n); - JavaRDD<Vector> rdd2 = randomJavaVectorRDD(sc, generator, m, n, p); - JavaRDD<Vector> rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed); - for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD<Vector> rdd1 = randomJavaVectorRDD(jsc, generator, m, n); + JavaRDD<Vector> rdd2 = randomJavaVectorRDD(jsc, generator, m, n, p); + JavaRDD<Vector> rdd3 = randomJavaVectorRDD(jsc, generator, m, n, p, seed); + for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -271,10 +277,12 @@ class StringGenerator implements RandomDataGenerator<String>, Serializable { public String nextValue() { return "42"; } + @Override public StringGenerator copy() { return new StringGenerator(); } + @Override public void setSeed(long seed) { } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index d0bf7f5..aa78405 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -32,40 +32,46 @@ import org.junit.Test; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; public class JavaALSSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaALS"); + spark = SparkSession.builder() + .master("local") + .appName("JavaALS") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } private void validatePrediction( - MatrixFactorizationModel model, - int users, - int products, - double[] trueRatings, - double matchThreshold, - boolean implicitPrefs, - double[] truePrefs) { + MatrixFactorizationModel model, + int users, + int products, + double[] trueRatings, + double matchThreshold, + boolean implicitPrefs, + double[] truePrefs) { List<Tuple2<Integer, Integer>> localUsersProducts = new ArrayList<>(users * products); - for (int u=0; u < users; ++u) { - for (int p=0; p < products; ++p) { + for (int u = 0; u < users; ++u) { + for (int p = 0; p < products; ++p) { localUsersProducts.add(new Tuple2<>(u, p)); } } - JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts); + JavaPairRDD<Integer, Integer> usersProducts = jsc.parallelizePairs(localUsersProducts); List<Rating> predictedRatings = model.predict(usersProducts).collect(); Assert.assertEquals(users * products, predictedRatings.size()); if (!implicitPrefs) { - for (Rating r: predictedRatings) { + for (Rating r : predictedRatings) { double prediction = r.rating(); double correct = trueRatings[r.product() * users + r.user()]; Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", @@ -76,7 +82,7 @@ public class JavaALSSuite implements Serializable { // (ref Mahout's implicit ALS tests) double sqErr = 0.0; double denom = 0.0; - for (Rating r: predictedRatings) { + for (Rating r : predictedRatings) { double prediction = r.rating(); double truePref = truePrefs[r.product() * users + r.user()]; double confidence = 1.0 + @@ -98,9 +104,9 @@ public class JavaALSSuite implements Serializable { int users = 50; int products = 100; Tuple3<List<Rating>, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); - JavaRDD<Rating> data = sc.parallelize(testData._1()); + JavaRDD<Rating> data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); } @@ -112,9 +118,9 @@ public class JavaALSSuite implements Serializable { int users = 100; int products = 200; Tuple3<List<Rating>, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); - JavaRDD<Rating> data = sc.parallelize(testData._1()); + JavaRDD<Rating> data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) @@ -129,9 +135,9 @@ public class JavaALSSuite implements Serializable { int users = 80; int products = 160; Tuple3<List<Rating>, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); - JavaRDD<Rating> data = sc.parallelize(testData._1()); + JavaRDD<Rating> data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @@ -143,9 +149,9 @@ public class JavaALSSuite implements Serializable { int users = 100; int products = 200; Tuple3<List<Rating>, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); - JavaRDD<Rating> data = sc.parallelize(testData._1()); + JavaRDD<Rating> data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) @@ -161,9 +167,9 @@ public class JavaALSSuite implements Serializable { int users = 80; int products = 160; Tuple3<List<Rating>, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true); - JavaRDD<Rating> data = sc.parallelize(testData._1()); + JavaRDD<Rating> data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) .setImplicitPrefs(true) @@ -179,8 +185,8 @@ public class JavaALSSuite implements Serializable { int users = 200; int products = 50; List<Rating> testData = ALSSuite.generateRatingsAsJava( - users, products, features, 0.7, true, false)._1(); - JavaRDD<Rating> data = sc.parallelize(testData); + users, products, features, 0.7, true, false)._1(); + JavaRDD<Rating> data = jsc.parallelize(testData); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) .setImplicitPrefs(true) @@ -193,7 +199,7 @@ public class JavaALSSuite implements Serializable { private static void validateRecommendations(Rating[] recommendations, int howMany) { Assert.assertEquals(howMany, recommendations.length); for (int i = 1; i < recommendations.length; i++) { - Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating()); + Assert.assertTrue(recommendations[i - 1].rating() >= recommendations[i].rating()); } Assert.assertTrue(recommendations[0].rating() > 0.7); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org