Repository: spark Updated Branches: refs/heads/branch-1.4 343d3bfaf -> 893b3103f
http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 0cc36c8..a82b86d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -23,14 +23,15 @@ import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; +import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; public class JavaLinearRegressionSuite implements Serializable { @@ -65,8 +66,8 @@ public class JavaLinearRegressionSuite implements Serializable { DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); predictions.collect(); // Check defaults - assert(model.getFeaturesCol().equals("features")); - assert(model.getPredictionCol().equals("prediction")); + assertEquals("features", model.getFeaturesCol()); + assertEquals("prediction", model.getPredictionCol()); } @Test @@ -76,14 +77,16 @@ public class JavaLinearRegressionSuite implements Serializable { .setMaxIter(10) .setRegParam(1.0); LinearRegressionModel model = lr.fit(dataset); - assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); - assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); + LinearRegression parent = model.parent(); + assertEquals(10, parent.getMaxIter()); + assertEquals(1.0, parent.getRegParam(), 0.0); // Call fit() with new params, and check as many params as we can. LinearRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); - assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); - assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); - assert(model2.getPredictionCol().equals("thePred")); + LinearRegression parent2 = model2.parent(); + assertEquals(5, parent2.getMaxIter()); + assertEquals(0.1, parent2.getRegParam(), 0.0); + assertEquals("thePred", model2.getPredictionCol()); } } http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 0bb6b48..08eeca5 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -68,8 +68,8 @@ public class JavaCrossValidatorSuite implements Serializable { .setEvaluator(eval) .setNumFolds(3); CrossValidatorModel cvModel = cv.fit(dataset); - ParamMap bestParamMap = cvModel.bestModel().fittingParamMap(); - Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam())); - Assert.assertEquals(10, bestParamMap.apply(lr.maxIter())); + LogisticRegression parent = (LogisticRegression) cvModel.bestModel().parent(); + Assert.assertEquals(0.001, parent.getRegParam(), 0.0); + Assert.assertEquals(10, parent.getMaxIter()); } } http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 2f175fb..2b04a30 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -42,30 +42,32 @@ class PipelineSuite extends FunSuite { val dataset3 = mock[DataFrame] val dataset4 = mock[DataFrame] - when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) - when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) + when(estimator0.copy(any[ParamMap])).thenReturn(estimator0) + when(model0.copy(any[ParamMap])).thenReturn(model0) + when(transformer1.copy(any[ParamMap])).thenReturn(transformer1) + when(estimator2.copy(any[ParamMap])).thenReturn(estimator2) + when(model2.copy(any[ParamMap])).thenReturn(model2) + when(transformer3.copy(any[ParamMap])).thenReturn(transformer3) + + when(estimator0.fit(meq(dataset0))).thenReturn(model0) + when(model0.transform(meq(dataset0))).thenReturn(dataset1) when(model0.parent).thenReturn(estimator0) - when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2) - when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2) - when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3) + when(transformer1.transform(meq(dataset1))).thenReturn(dataset2) + when(estimator2.fit(meq(dataset2))).thenReturn(model2) + when(model2.transform(meq(dataset2))).thenReturn(dataset3) when(model2.parent).thenReturn(estimator2) - when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4) + when(transformer3.transform(meq(dataset3))).thenReturn(dataset4) val pipeline = new Pipeline() .setStages(Array(estimator0, transformer1, estimator2, transformer3)) val pipelineModel = pipeline.fit(dataset0) - assert(pipelineModel.stages.size === 4) + assert(pipelineModel.stages.length === 4) assert(pipelineModel.stages(0).eq(model0)) assert(pipelineModel.stages(1).eq(transformer1)) assert(pipelineModel.stages(2).eq(model2)) assert(pipelineModel.stages(3).eq(transformer3)) - assert(pipelineModel.getModel(estimator0).eq(model0)) - assert(pipelineModel.getModel(estimator2).eq(model2)) - intercept[NoSuchElementException] { - pipelineModel.getModel(mock[Estimator[MyModel]]) - } val output = pipelineModel.transform(dataset0) assert(output.eq(dataset4)) } http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 9b31ade..03af4ec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -267,8 +267,8 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newTree = dt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, newTree.parent, - newTree.fittingParamMap, categoricalFeatures) + val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( + oldTree, newTree.parent, categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) } } http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index e6ccc2c..16c758b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -129,8 +129,8 @@ private object GBTClassifierSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldModelAsNew = GBTClassificationModel.fromOld(oldModel, newModel.parent, - newModel.fittingParamMap, categoricalFeatures) + val oldModelAsNew = GBTClassificationModel.fromOld( + oldModel, newModel.parent, categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 35d8c2e..6dd1fdf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -74,9 +74,10 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { .setThreshold(0.6) .setProbabilityCol("myProbability") val model = lr.fit(dataset) - assert(model.fittingParamMap.get(lr.maxIter) === Some(10)) - assert(model.fittingParamMap.get(lr.regParam) === Some(1.0)) - assert(model.fittingParamMap.get(lr.threshold) === Some(0.6)) + val parent = model.parent + assert(parent.getMaxIter === 10) + assert(parent.getRegParam === 1.0) + assert(parent.getThreshold === 0.6) assert(model.getThreshold === 0.6) // Modify model params, and check that the params worked. @@ -99,9 +100,10 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { // Call fit() with new params, and check as many params as we can. val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, lr.probabilityCol -> "theProb") - assert(model2.fittingParamMap.get(lr.maxIter).get === 5) - assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) - assert(model2.fittingParamMap.get(lr.threshold).get === 0.4) + val parent2 = model2.parent + assert(parent2.getMaxIter === 5) + assert(parent2.getRegParam === 0.1) + assert(parent2.getThreshold === 0.4) assert(model2.getThreshold === 0.4) assert(model2.getProbabilityCol == "theProb") } @@ -117,7 +119,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { val results = model.transform(dataset) // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().map { + results.select("rawPrediction", "probability").collect().foreach { case Row(raw: Vector, prob: Vector) => assert(raw.size === 2) assert(prob.size === 2) @@ -127,7 +129,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { } // Compare prediction with probability - results.select("prediction", "probability").collect().map { + results.select("prediction", "probability").collect().foreach { case Row(pred: Double, prob: Vector) => val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ed41a96..c41def9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -159,8 +159,8 @@ private object RandomForestClassifierSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newModel = rf.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldModelAsNew = RandomForestClassificationModel.fromOld(oldModel, newModel.parent, - newModel.fittingParamMap, categoricalFeatures) + val oldModelAsNew = RandomForestClassificationModel.fromOld( + oldModel, newModel.parent, categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index f885260..6056e7d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -122,19 +122,21 @@ class ParamsSuite extends FunSuite { assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) + assert(solver.hasParam("inputCol")) + assert(!solver.hasParam("abc")) intercept[NoSuchElementException] { solver.getParam("abc") } intercept[IllegalArgumentException] { - solver.validate() + solver.validateParams() } - solver.validate(ParamMap(inputCol -> "input")) + solver.validateParams(ParamMap(inputCol -> "input")) solver.setInputCol("input") assert(solver.isSet(inputCol)) assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") - solver.validate() + solver.validateParams() intercept[IllegalArgumentException] { ParamMap(maxIter -> -10) } @@ -144,6 +146,11 @@ class ParamsSuite extends FunSuite { solver.clearMaxIter() assert(!solver.isSet(maxIter)) + + val copied = solver.copy(ParamMap(solver.maxIter -> 50)) + assert(copied.uid !== solver.uid) + assert(copied.getInputCol === solver.getInputCol) + assert(copied.getMaxIter === 50) } test("ParamValidate") { http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 6f9c9cb..dc16073 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -23,15 +23,19 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} class TestParams extends Params with HasMaxIter with HasInputCol { def setMaxIter(value: Int): this.type = { set(maxIter, value); this } + def setInputCol(value: String): this.type = { set(inputCol, value); this } setDefault(maxIter -> 10) - override def validate(paramMap: ParamMap): Unit = { - val m = extractParamMap(paramMap) - // Note: maxIter is validated when it is set. - require(m.contains(inputCol)) + def clearMaxIter(): this.type = clear(maxIter) + + override def validateParams(): Unit = { + super.validateParams() + require(isDefined(inputCol)) } - def clearMaxIter(): this.type = clear(maxIter) + override def copy(extra: ParamMap): TestParams = { + super.copy(extra).asInstanceOf[TestParams] + } } http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index c87a171..5aa81b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -84,8 +84,8 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newTree = dt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, newTree.parent, - newTree.fittingParamMap, categoricalFeatures) + val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( + oldTree, newTree.parent, categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) } } http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 4aec369..25b36ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -130,8 +130,7 @@ private object GBTRegressorSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, - newModel.fittingParamMap, categoricalFeatures) + val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index c6dc1cc..45f09f4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -115,8 +115,8 @@ private object RandomForestRegressorSuite extends FunSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = rf.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldModelAsNew = RandomForestRegressionModel.fromOld(oldModel, newModel.parent, - newModel.fittingParamMap, categoricalFeatures) + val oldModelAsNew = RandomForestRegressionModel.fromOld( + oldModel, newModel.parent, categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } http://git-wip-us.apache.org/repos/asf/spark/blob/893b3103/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 761ea82..05313d4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -49,8 +49,8 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { .setEvaluator(eval) .setNumFolds(3) val cvModel = cv.fit(dataset) - val bestParamMap = cvModel.bestModel.fittingParamMap - assert(bestParamMap(lr.regParam) === 0.001) - assert(bestParamMap(lr.maxIter) === 10) + val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org