Repository: spark Updated Branches: refs/heads/master 84324fbcb -> 4b736dbab
http://git-wip-us.apache.org/repos/asf/spark/blob/4b736dba/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java new file mode 100644 index 0000000..4284667 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +/** + * Test Pipeline construction and fitting in Java. + */ +public class JavaPipelineSuite { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaPipelineSuite"); + jsql = new JavaSQLContext(jsc); + JavaRDD<LabeledPoint> points = + jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); + dataset = jsql.applySchema(points, LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void pipeline() { + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + LogisticRegression lr = new LogisticRegression() + .setFeaturesCol("scaledFeatures"); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {scaler, lr}); + PipelineModel model = pipeline.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/4b736dba/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java new file mode 100644 index 0000000..76eb7f0 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +public class JavaLogisticRegressionSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + jsql = new JavaSQLContext(jsc); + List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void logisticRegression() { + LogisticRegression lr = new LogisticRegression(); + LogisticRegressionModel model = lr.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } + + @Test + public void logisticRegressionWithSetters() { + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0); + LogisticRegressionModel model = lr.fit(dataset); + model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold + .registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } + + @Test + public void logisticRegressionFitWithVarargs() { + LogisticRegression lr = new LogisticRegression(); + lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0)); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/4b736dba/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java new file mode 100644 index 0000000..a266ebd --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +public class JavaCrossValidatorSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); + jsql = new JavaSQLContext(jsc); + List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void crossValidationWithLogisticRegression() { + LogisticRegression lr = new LogisticRegression(); + ParamMap[] lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.001, 1000.0}) + .addGrid(lr.maxIter(), new int[] {0, 10}) + .build(); + BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator(); + CrossValidator cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3); + CrossValidatorModel cvModel = cv.fit(dataset); + ParamMap bestParamMap = cvModel.bestModel().fittingParamMap(); + Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam())); + Assert.assertEquals(10, bestParamMap.apply(lr.maxIter())); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/4b736dba/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala new file mode 100644 index 0000000..4515084 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.when +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.SchemaRDD + +class PipelineSuite extends FunSuite { + + abstract class MyModel extends Model[MyModel] + + test("pipeline") { + val estimator0 = mock[Estimator[MyModel]] + val model0 = mock[MyModel] + val transformer1 = mock[Transformer] + val estimator2 = mock[Estimator[MyModel]] + val model2 = mock[MyModel] + val transformer3 = mock[Transformer] + val dataset0 = mock[SchemaRDD] + val dataset1 = mock[SchemaRDD] + val dataset2 = mock[SchemaRDD] + val dataset3 = mock[SchemaRDD] + val dataset4 = mock[SchemaRDD] + + when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) + when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) + when(model0.parent).thenReturn(estimator0) + when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2) + when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2) + when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3) + when(model2.parent).thenReturn(estimator2) + when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4) + + val pipeline = new Pipeline() + .setStages(Array(estimator0, transformer1, estimator2, transformer3)) + val pipelineModel = pipeline.fit(dataset0) + + assert(pipelineModel.stages.size === 4) + assert(pipelineModel.stages(0).eq(model0)) + assert(pipelineModel.stages(1).eq(transformer1)) + assert(pipelineModel.stages(2).eq(model2)) + assert(pipelineModel.stages(3).eq(transformer3)) + + assert(pipelineModel.getModel(estimator0).eq(model0)) + assert(pipelineModel.getModel(estimator2).eq(model2)) + intercept[NoSuchElementException] { + pipelineModel.getModel(mock[Estimator[MyModel]]) + } + val output = pipelineModel.transform(dataset0) + assert(output.eq(dataset4)) + } + + test("pipeline with duplicate stages") { + val estimator = mock[Estimator[MyModel]] + val pipeline = new Pipeline() + .setStages(Array(estimator, estimator)) + val dataset = mock[SchemaRDD] + intercept[IllegalArgumentException] { + pipeline.fit(dataset) + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/4b736dba/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala new file mode 100644 index 0000000..625af29 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.sql.SchemaRDD + +class LogisticRegressionSuite extends FunSuite with LocalSparkContext { + + import sqlContext._ + + val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2) + + test("logistic regression") { + val lr = new LogisticRegression + val model = lr.fit(dataset) + model.transform(dataset) + .select('label, 'prediction) + .collect() + } + + test("logistic regression with setters") { + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + val model = lr.fit(dataset) + model.transform(dataset, model.threshold -> 0.8) // overwrite threshold + .select('label, 'score, 'prediction) + .collect() + } + + test("logistic regression fit and transform with varargs") { + val lr = new LogisticRegression + val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) + model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") + .select('label, 'probability, 'prediction) + .collect() + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/4b736dba/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala new file mode 100644 index 0000000..1ce2987 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +import org.scalatest.FunSuite + +class ParamsSuite extends FunSuite { + + val solver = new TestParams() + import solver.{inputCol, maxIter} + + test("param") { + assert(maxIter.name === "maxIter") + assert(maxIter.doc === "max number of iterations") + assert(maxIter.defaultValue.get === 100) + assert(maxIter.parent.eq(solver)) + assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") + assert(inputCol.defaultValue === None) + } + + test("param pair") { + val pair0 = maxIter -> 5 + val pair1 = maxIter.w(5) + val pair2 = ParamPair(maxIter, 5) + for (pair <- Seq(pair0, pair1, pair2)) { + assert(pair.param.eq(maxIter)) + assert(pair.value === 5) + } + } + + test("param map") { + val map0 = ParamMap.empty + + assert(!map0.contains(maxIter)) + assert(map0(maxIter) === maxIter.defaultValue.get) + map0.put(maxIter, 10) + assert(map0.contains(maxIter)) + assert(map0(maxIter) === 10) + + assert(!map0.contains(inputCol)) + intercept[NoSuchElementException] { + map0(inputCol) + } + map0.put(inputCol -> "input") + assert(map0.contains(inputCol)) + assert(map0(inputCol) === "input") + + val map1 = map0.copy + val map2 = ParamMap(maxIter -> 10, inputCol -> "input") + val map3 = new ParamMap() + .put(maxIter, 10) + .put(inputCol, "input") + val map4 = ParamMap.empty ++ map0 + val map5 = ParamMap.empty + map5 ++= map0 + + for (m <- Seq(map1, map2, map3, map4, map5)) { + assert(m.contains(maxIter)) + assert(m(maxIter) === 10) + assert(m.contains(inputCol)) + assert(m(inputCol) === "input") + } + } + + test("params") { + val params = solver.params + assert(params.size === 2) + assert(params(0).eq(inputCol), "params must be ordered by name") + assert(params(1).eq(maxIter)) + assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) + assert(solver.getParam("maxIter").eq(maxIter)) + intercept[NoSuchMethodException] { + solver.getParam("abc") + } + assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { + solver.validate() + } + solver.validate(ParamMap(inputCol -> "input")) + solver.setInputCol("input") + assert(solver.isSet(inputCol)) + assert(solver.getInputCol === "input") + solver.validate() + intercept[IllegalArgumentException] { + solver.validate(ParamMap(maxIter -> -10)) + } + solver.setMaxIter(-10) + intercept[IllegalArgumentException] { + solver.validate() + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/4b736dba/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala new file mode 100644 index 0000000..1a65883 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +/** A subclass of Params for testing. */ +class TestParams extends Params { + + val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) + def setMaxIter(value: Int): this.type = { set(maxIter, value); this } + def getMaxIter: Int = get(maxIter) + + val inputCol = new Param[String](this, "inputCol", "input column name") + def setInputCol(value: String): this.type = { set(inputCol, value); this } + def getInputCol: String = get(inputCol) + + override def validate(paramMap: ParamMap) = { + val m = this.paramMap ++ paramMap + require(m(maxIter) >= 0) + require(m.contains(inputCol)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/4b736dba/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala new file mode 100644 index 0000000..72a334a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.scalatest.FunSuite + +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.sql.SchemaRDD + +class CrossValidatorSuite extends FunSuite with LocalSparkContext { + + import sqlContext._ + + val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2) + + test("cross validation with logistic regression") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 10)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3) + val cvModel = cv.fit(dataset) + val bestParamMap = cvModel.bestModel.fittingParamMap + assert(bestParamMap(lr.regParam) === 0.001) + assert(bestParamMap(lr.maxIter) === 10) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/4b736dba/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala new file mode 100644 index 0000000..20aa100 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark.ml.param.{ParamMap, TestParams} + +class ParamGridBuilderSuite extends FunSuite { + + val solver = new TestParams() + import solver.{inputCol, maxIter} + + test("param grid builder") { + def validateGrid(maps: Array[ParamMap], expected: mutable.Set[(Int, String)]): Unit = { + assert(maps.size === expected.size) + maps.foreach { m => + val tuple = (m(maxIter), m(inputCol)) + assert(expected.contains(tuple)) + expected.remove(tuple) + } + assert(expected.isEmpty) + } + + val maps0 = new ParamGridBuilder() + .baseOn(maxIter -> 10) + .addGrid(inputCol, Array("input0", "input1")) + .build() + val expected0 = mutable.Set( + (10, "input0"), + (10, "input1")) + validateGrid(maps0, expected0) + + val maps1 = new ParamGridBuilder() + .baseOn(ParamMap(maxIter -> 5, inputCol -> "input")) // will be overwritten + .addGrid(maxIter, Array(10, 20)) + .addGrid(inputCol, Array("input0", "input1")) + .build() + val expected1 = mutable.Set( + (10, "input0"), + (20, "input0"), + (10, "input1"), + (20, "input1")) + validateGrid(maps1, expected1) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/4b736dba/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala index 7857d9e..4417d66 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala @@ -17,26 +17,17 @@ package org.apache.spark.mllib.util -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, Suite} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext trait LocalSparkContext extends BeforeAndAfterAll { self: Suite => - @transient var sc: SparkContext = _ - - override def beforeAll() { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - sc = new SparkContext(conf) - super.beforeAll() - } + @transient val sc = new SparkContext("local", "test") + @transient lazy val sqlContext = new SQLContext(sc) override def afterAll() { - if (sc != null) { - sc.stop() - } + sc.stop() super.afterAll() } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org