Repository: spark
Updated Branches:
  refs/heads/master 5a1a1075a -> e0833c595


http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 0cc36c8..a82b86d 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -23,14 +23,15 @@ import java.util.List;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
+import static org.junit.Assert.assertEquals;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
-    .generateLogisticInputAsList;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.DataFrame;
 import org.apache.spark.sql.SQLContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+  .generateLogisticInputAsList;
 
 
 public class JavaLinearRegressionSuite implements Serializable {
@@ -65,8 +66,8 @@ public class JavaLinearRegressionSuite implements 
Serializable {
     DataFrame predictions = jsql.sql("SELECT label, prediction FROM 
prediction");
     predictions.collect();
     // Check defaults
-    assert(model.getFeaturesCol().equals("features"));
-    assert(model.getPredictionCol().equals("prediction"));
+    assertEquals("features", model.getFeaturesCol());
+    assertEquals("prediction", model.getPredictionCol());
   }
 
   @Test
@@ -76,14 +77,16 @@ public class JavaLinearRegressionSuite implements 
Serializable {
         .setMaxIter(10)
         .setRegParam(1.0);
     LinearRegressionModel model = lr.fit(dataset);
-    assert(model.fittingParamMap().apply(lr.maxIter()).equals(10));
-    assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
+    LinearRegression parent = model.parent();
+    assertEquals(10, parent.getMaxIter());
+    assertEquals(1.0, parent.getRegParam(), 0.0);
 
     // Call fit() with new params, and check as many params as we can.
     LinearRegressionModel model2 =
         lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), 
lr.predictionCol().w("thePred"));
-    assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5));
-    assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
-    assert(model2.getPredictionCol().equals("thePred"));
+    LinearRegression parent2 = model2.parent();
+    assertEquals(5, parent2.getMaxIter());
+    assertEquals(0.1, parent2.getRegParam(), 0.0);
+    assertEquals("thePred", model2.getPredictionCol());
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 0bb6b48..08eeca5 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -68,8 +68,8 @@ public class JavaCrossValidatorSuite implements Serializable {
       .setEvaluator(eval)
       .setNumFolds(3);
     CrossValidatorModel cvModel = cv.fit(dataset);
-    ParamMap bestParamMap = cvModel.bestModel().fittingParamMap();
-    Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam()));
-    Assert.assertEquals(10, bestParamMap.apply(lr.maxIter()));
+    LogisticRegression parent = (LogisticRegression) 
cvModel.bestModel().parent();
+    Assert.assertEquals(0.001, parent.getRegParam(), 0.0);
+    Assert.assertEquals(10, parent.getMaxIter());
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 2f175fb..2b04a30 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -42,30 +42,32 @@ class PipelineSuite extends FunSuite {
     val dataset3 = mock[DataFrame]
     val dataset4 = mock[DataFrame]
 
-    when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0)
-    when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1)
+    when(estimator0.copy(any[ParamMap])).thenReturn(estimator0)
+    when(model0.copy(any[ParamMap])).thenReturn(model0)
+    when(transformer1.copy(any[ParamMap])).thenReturn(transformer1)
+    when(estimator2.copy(any[ParamMap])).thenReturn(estimator2)
+    when(model2.copy(any[ParamMap])).thenReturn(model2)
+    when(transformer3.copy(any[ParamMap])).thenReturn(transformer3)
+
+    when(estimator0.fit(meq(dataset0))).thenReturn(model0)
+    when(model0.transform(meq(dataset0))).thenReturn(dataset1)
     when(model0.parent).thenReturn(estimator0)
-    when(transformer1.transform(meq(dataset1), 
any[ParamMap])).thenReturn(dataset2)
-    when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2)
-    when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3)
+    when(transformer1.transform(meq(dataset1))).thenReturn(dataset2)
+    when(estimator2.fit(meq(dataset2))).thenReturn(model2)
+    when(model2.transform(meq(dataset2))).thenReturn(dataset3)
     when(model2.parent).thenReturn(estimator2)
-    when(transformer3.transform(meq(dataset3), 
any[ParamMap]())).thenReturn(dataset4)
+    when(transformer3.transform(meq(dataset3))).thenReturn(dataset4)
 
     val pipeline = new Pipeline()
       .setStages(Array(estimator0, transformer1, estimator2, transformer3))
     val pipelineModel = pipeline.fit(dataset0)
 
-    assert(pipelineModel.stages.size === 4)
+    assert(pipelineModel.stages.length === 4)
     assert(pipelineModel.stages(0).eq(model0))
     assert(pipelineModel.stages(1).eq(transformer1))
     assert(pipelineModel.stages(2).eq(model2))
     assert(pipelineModel.stages(3).eq(transformer3))
 
-    assert(pipelineModel.getModel(estimator0).eq(model0))
-    assert(pipelineModel.getModel(estimator2).eq(model2))
-    intercept[NoSuchElementException] {
-      pipelineModel.getModel(mock[Estimator[MyModel]])
-    }
     val output = pipelineModel.transform(dataset0)
     assert(output.eq(dataset4))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 9b31ade..03af4ec 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -267,8 +267,8 @@ private[ml] object DecisionTreeClassifierSuite extends 
FunSuite {
     val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses)
     val newTree = dt.fit(newData)
     // Use parent, fittingParamMap from newTree since these are not checked 
anyways.
-    val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, 
newTree.parent,
-      newTree.fittingParamMap, categoricalFeatures)
+    val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
+      oldTree, newTree.parent, categoricalFeatures)
     TreeTests.checkEqual(oldTreeAsNew, newTree)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index e6ccc2c..16c758b 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -129,8 +129,8 @@ private object GBTClassifierSuite {
     val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses = 2)
     val newModel = gbt.fit(newData)
     // Use parent, fittingParamMap from newTree since these are not checked 
anyways.
-    val oldModelAsNew = GBTClassificationModel.fromOld(oldModel, 
newModel.parent,
-      newModel.fittingParamMap, categoricalFeatures)
+    val oldModelAsNew = GBTClassificationModel.fromOld(
+      oldModel, newModel.parent, categoricalFeatures)
     TreeTests.checkEqual(oldModelAsNew, newModel)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 35d8c2e..6dd1fdf 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -74,9 +74,10 @@ class LogisticRegressionSuite extends FunSuite with 
MLlibTestSparkContext {
       .setThreshold(0.6)
       .setProbabilityCol("myProbability")
     val model = lr.fit(dataset)
-    assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
-    assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))
-    assert(model.fittingParamMap.get(lr.threshold) === Some(0.6))
+    val parent = model.parent
+    assert(parent.getMaxIter === 10)
+    assert(parent.getRegParam === 1.0)
+    assert(parent.getThreshold === 0.6)
     assert(model.getThreshold === 0.6)
 
     // Modify model params, and check that the params worked.
@@ -99,9 +100,10 @@ class LogisticRegressionSuite extends FunSuite with 
MLlibTestSparkContext {
     // Call fit() with new params, and check as many params as we can.
     val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, 
lr.threshold -> 0.4,
       lr.probabilityCol -> "theProb")
-    assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
-    assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
-    assert(model2.fittingParamMap.get(lr.threshold).get === 0.4)
+    val parent2 = model2.parent
+    assert(parent2.getMaxIter === 5)
+    assert(parent2.getRegParam === 0.1)
+    assert(parent2.getThreshold === 0.4)
     assert(model2.getThreshold === 0.4)
     assert(model2.getProbabilityCol == "theProb")
   }
@@ -117,7 +119,7 @@ class LogisticRegressionSuite extends FunSuite with 
MLlibTestSparkContext {
     val results = model.transform(dataset)
 
     // Compare rawPrediction with probability
-    results.select("rawPrediction", "probability").collect().map {
+    results.select("rawPrediction", "probability").collect().foreach {
       case Row(raw: Vector, prob: Vector) =>
         assert(raw.size === 2)
         assert(prob.size === 2)
@@ -127,7 +129,7 @@ class LogisticRegressionSuite extends FunSuite with 
MLlibTestSparkContext {
     }
 
     // Compare prediction with probability
-    results.select("prediction", "probability").collect().map {
+    results.select("prediction", "probability").collect().foreach {
       case Row(pred: Double, prob: Vector) =>
         val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
         assert(pred == predFromProb)

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index ed41a96..c41def9 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -159,8 +159,8 @@ private object RandomForestClassifierSuite {
     val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses)
     val newModel = rf.fit(newData)
     // Use parent, fittingParamMap from newTree since these are not checked 
anyways.
-    val oldModelAsNew = RandomForestClassificationModel.fromOld(oldModel, 
newModel.parent,
-      newModel.fittingParamMap, categoricalFeatures)
+    val oldModelAsNew = RandomForestClassificationModel.fromOld(
+      oldModel, newModel.parent, categoricalFeatures)
     TreeTests.checkEqual(oldModelAsNew, newModel)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index f885260..6056e7d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -122,19 +122,21 @@ class ParamsSuite extends FunSuite {
 
     assert(solver.getParam("inputCol").eq(inputCol))
     assert(solver.getParam("maxIter").eq(maxIter))
+    assert(solver.hasParam("inputCol"))
+    assert(!solver.hasParam("abc"))
     intercept[NoSuchElementException] {
       solver.getParam("abc")
     }
 
     intercept[IllegalArgumentException] {
-      solver.validate()
+      solver.validateParams()
     }
-    solver.validate(ParamMap(inputCol -> "input"))
+    solver.validateParams(ParamMap(inputCol -> "input"))
     solver.setInputCol("input")
     assert(solver.isSet(inputCol))
     assert(solver.isDefined(inputCol))
     assert(solver.getInputCol === "input")
-    solver.validate()
+    solver.validateParams()
     intercept[IllegalArgumentException] {
       ParamMap(maxIter -> -10)
     }
@@ -144,6 +146,11 @@ class ParamsSuite extends FunSuite {
 
     solver.clearMaxIter()
     assert(!solver.isSet(maxIter))
+
+    val copied = solver.copy(ParamMap(solver.maxIter -> 50))
+    assert(copied.uid !== solver.uid)
+    assert(copied.getInputCol === solver.getInputCol)
+    assert(copied.getMaxIter === 50)
   }
 
   test("ParamValidate") {

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala 
b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index 6f9c9cb..dc16073 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -23,15 +23,19 @@ import org.apache.spark.ml.param.shared.{HasInputCol, 
HasMaxIter}
 class TestParams extends Params with HasMaxIter with HasInputCol {
 
   def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
+
   def setInputCol(value: String): this.type = { set(inputCol, value); this }
 
   setDefault(maxIter -> 10)
 
-  override def validate(paramMap: ParamMap): Unit = {
-    val m = extractParamMap(paramMap)
-    // Note: maxIter is validated when it is set.
-    require(m.contains(inputCol))
+  def clearMaxIter(): this.type = clear(maxIter)
+
+  override def validateParams(): Unit = {
+    super.validateParams()
+    require(isDefined(inputCol))
   }
 
-  def clearMaxIter(): this.type = clear(maxIter)
+  override def copy(extra: ParamMap): TestParams = {
+    super.copy(extra).asInstanceOf[TestParams]
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index c87a171..5aa81b4 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -84,8 +84,8 @@ private[ml] object DecisionTreeRegressorSuite extends 
FunSuite {
     val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses = 0)
     val newTree = dt.fit(newData)
     // Use parent, fittingParamMap from newTree since these are not checked 
anyways.
-    val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, 
newTree.parent,
-      newTree.fittingParamMap, categoricalFeatures)
+    val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
+      oldTree, newTree.parent, categoricalFeatures)
     TreeTests.checkEqual(oldTreeAsNew, newTree)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 4aec369..25b36ab 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -130,8 +130,7 @@ private object GBTRegressorSuite {
     val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses = 0)
     val newModel = gbt.fit(newData)
     // Use parent, fittingParamMap from newTree since these are not checked 
anyways.
-    val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent,
-      newModel.fittingParamMap, categoricalFeatures)
+    val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, 
categoricalFeatures)
     TreeTests.checkEqual(oldModelAsNew, newModel)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c6dc1cc..45f09f4 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -115,8 +115,8 @@ private object RandomForestRegressorSuite extends FunSuite {
     val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses = 0)
     val newModel = rf.fit(newData)
     // Use parent, fittingParamMap from newTree since these are not checked 
anyways.
-    val oldModelAsNew = RandomForestRegressionModel.fromOld(oldModel, 
newModel.parent,
-      newModel.fittingParamMap, categoricalFeatures)
+    val oldModelAsNew = RandomForestRegressionModel.fromOld(
+      oldModel, newModel.parent, categoricalFeatures)
     TreeTests.checkEqual(oldModelAsNew, newModel)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0833c59/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 761ea82..05313d4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -49,8 +49,8 @@ class CrossValidatorSuite extends FunSuite with 
MLlibTestSparkContext {
       .setEvaluator(eval)
       .setNumFolds(3)
     val cvModel = cv.fit(dataset)
-    val bestParamMap = cvModel.bestModel.fittingParamMap
-    assert(bestParamMap(lr.regParam) === 0.001)
-    assert(bestParamMap(lr.maxIter) === 10)
+    val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
+    assert(parent.getRegParam === 0.001)
+    assert(parent.getMaxIter === 10)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to