Repository: spark
Updated Branches:
  refs/heads/master 3e6a714c9 -> f180b6534


[SPARK-22060][ML] Fix CrossValidator/TrainValidationSplit param persist/load bug

## What changes were proposed in this pull request?

Currently the param of CrossValidator/TrainValidationSplit persist/loading is 
hardcoding, which is different with other ML estimators. This cause persist bug 
for new added `parallelism` param.

I refactor related code, avoid hardcoding persist/load param. And in the same 
time, it solve the `parallelism` persisting bug.

This refactoring is very useful because we will add more new params in #19208 , 
hardcoding param persisting/loading making the thing adding new params very 
troublesome.

## How was this patch tested?

Test added.

Author: WeichenXu <weichen...@databricks.com>

Closes #19278 from WeichenXu123/fix-tuning-param-bug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f180b653
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f180b653
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f180b653

Branch: refs/heads/master
Commit: f180b65343e706c60b995a3d46d0391612bda966
Parents: 3e6a714
Author: WeichenXu <weichen...@databricks.com>
Authored: Fri Sep 22 18:15:01 2017 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Fri Sep 22 18:15:01 2017 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/tuning/CrossValidator.scala | 17 +++++++--------
 .../spark/ml/tuning/TrainValidationSplit.scala  | 18 ++++++++--------
 .../spark/ml/tuning/ValidatorParams.scala       | 22 +++++++-------------
 .../org/apache/spark/ml/util/ReadWrite.scala    | 20 +++++++++++++-----
 .../spark/ml/tuning/CrossValidatorSuite.scala   |  3 +++
 .../ml/tuning/TrainValidationSplitSuite.scala   |  4 +++-
 6 files changed, 46 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index ce2a3a2..7c81cb9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -212,14 +212,13 @@ object CrossValidator extends MLReadable[CrossValidator] {
 
       val (metadata, estimator, evaluator, estimatorParamMaps) =
         ValidatorParams.loadImpl(path, sc, className)
-      val numFolds = (metadata.params \ "numFolds").extract[Int]
-      val seed = (metadata.params \ "seed").extract[Long]
-      new CrossValidator(metadata.uid)
+      val cv = new CrossValidator(metadata.uid)
         .setEstimator(estimator)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(estimatorParamMaps)
-        .setNumFolds(numFolds)
-        .setSeed(seed)
+      DefaultParamsReader.getAndSetParams(cv, metadata,
+        skipParams = Option(List("estimatorParamMaps")))
+      cv
     }
   }
 }
@@ -302,17 +301,17 @@ object CrossValidatorModel extends 
MLReadable[CrossValidatorModel] {
 
       val (metadata, estimator, evaluator, estimatorParamMaps) =
         ValidatorParams.loadImpl(path, sc, className)
-      val numFolds = (metadata.params \ "numFolds").extract[Int]
-      val seed = (metadata.params \ "seed").extract[Long]
       val bestModelPath = new Path(path, "bestModel").toString
       val bestModel = 
DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
       val avgMetrics = (metadata.metadata \ 
"avgMetrics").extract[Seq[Double]].toArray
+
       val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
       model.set(model.estimator, estimator)
         .set(model.evaluator, evaluator)
         .set(model.estimatorParamMaps, estimatorParamMaps)
-        .set(model.numFolds, numFolds)
-        .set(model.seed, seed)
+      DefaultParamsReader.getAndSetParams(model, metadata,
+        skipParams = Option(List("estimatorParamMaps")))
+      model
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 16db0f5..6e3ad40 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.ml.tuning
 
+import java.io.IOException
 import java.util.{List => JList}
 
 import scala.collection.JavaConverters._
@@ -207,14 +208,13 @@ object TrainValidationSplit extends 
MLReadable[TrainValidationSplit] {
 
       val (metadata, estimator, evaluator, estimatorParamMaps) =
         ValidatorParams.loadImpl(path, sc, className)
-      val trainRatio = (metadata.params \ "trainRatio").extract[Double]
-      val seed = (metadata.params \ "seed").extract[Long]
-      new TrainValidationSplit(metadata.uid)
+      val tvs = new TrainValidationSplit(metadata.uid)
         .setEstimator(estimator)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(estimatorParamMaps)
-        .setTrainRatio(trainRatio)
-        .setSeed(seed)
+      DefaultParamsReader.getAndSetParams(tvs, metadata,
+        skipParams = Option(List("estimatorParamMaps")))
+      tvs
     }
   }
 }
@@ -295,17 +295,17 @@ object TrainValidationSplitModel extends 
MLReadable[TrainValidationSplitModel] {
 
       val (metadata, estimator, evaluator, estimatorParamMaps) =
         ValidatorParams.loadImpl(path, sc, className)
-      val trainRatio = (metadata.params \ "trainRatio").extract[Double]
-      val seed = (metadata.params \ "seed").extract[Long]
       val bestModelPath = new Path(path, "bestModel").toString
       val bestModel = 
DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
       val validationMetrics = (metadata.metadata \ 
"validationMetrics").extract[Seq[Double]].toArray
+
       val model = new TrainValidationSplitModel(metadata.uid, bestModel, 
validationMetrics)
       model.set(model.estimator, estimator)
         .set(model.evaluator, evaluator)
         .set(model.estimatorParamMaps, estimatorParamMaps)
-        .set(model.trainRatio, trainRatio)
-        .set(model.seed, seed)
+      DefaultParamsReader.getAndSetParams(model, metadata,
+        skipParams = Option(List("estimatorParamMaps")))
+      model
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 0ab6eed..363304e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -150,20 +150,14 @@ private[ml] object ValidatorParams {
       }.toSeq
     ))
 
-    val validatorSpecificParams = instance match {
-      case cv: CrossValidatorParams =>
-        List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds)))
-      case tvs: TrainValidationSplitParams =>
-        List("trainRatio" -> 
parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio)))
-      case _ =>
-        // This should not happen.
-        throw new NotImplementedError("ValidatorParams.saveImpl does not 
handle type: " +
-          instance.getClass.getCanonicalName)
-    }
-
-    val jsonParams = validatorSpecificParams ++ List(
-      "estimatorParamMaps" -> parse(estimatorParamMapsJson),
-      "seed" -> parse(instance.seed.jsonEncode(instance.getSeed)))
+    val params = instance.extractParamMap().toSeq
+    val skipParams = List("estimator", "evaluator", "estimatorParamMaps")
+    val jsonParams = render(params
+      .filter { case ParamPair(p, v) => !skipParams.contains(p.name)}
+      .map { case ParamPair(p, v) =>
+        p.name -> parse(p.jsonEncode(v))
+      }.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson))
+    )
 
     DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, 
Some(jsonParams))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 65f142c..7188da3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -396,17 +396,27 @@ private[ml] object DefaultParamsReader {
 
   /**
    * Extract Params from metadata, and set them in the instance.
-   * This works if all Params implement 
[[org.apache.spark.ml.param.Param.jsonDecode()]].
+   * This works if all Params (except params included by `skipParams` list) 
implement
+   * [[org.apache.spark.ml.param.Param.jsonDecode()]].
+   *
+   * @param skipParams The params included in `skipParams` won't be set. This 
is useful if some
+   *                   params don't implement 
[[org.apache.spark.ml.param.Param.jsonDecode()]]
+   *                   and need special handling.
    * TODO: Move to [[Metadata]] method
    */
-  def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
+  def getAndSetParams(
+      instance: Params,
+      metadata: Metadata,
+      skipParams: Option[List[String]] = None): Unit = {
     implicit val format = DefaultFormats
     metadata.params match {
       case JObject(pairs) =>
         pairs.foreach { case (paramName, jsonValue) =>
-          val param = instance.getParam(paramName)
-          val value = param.jsonDecode(compact(render(jsonValue)))
-          instance.set(param, value)
+          if (skipParams == None || !skipParams.get.contains(paramName)) {
+            val param = instance.getParam(paramName)
+            val value = param.jsonDecode(compact(render(jsonValue)))
+            instance.set(param, value)
+          }
         }
       case _ =>
         throw new IllegalArgumentException(

http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index a8d4377..a01744f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -159,12 +159,15 @@ class CrossValidatorSuite
       .setEvaluator(evaluator)
       .setNumFolds(20)
       .setEstimatorParamMaps(paramMaps)
+      .setSeed(42L)
+      .setParallelism(2)
 
     val cv2 = testDefaultReadWrite(cv, testParams = false)
 
     assert(cv.uid === cv2.uid)
     assert(cv.getNumFolds === cv2.getNumFolds)
     assert(cv.getSeed === cv2.getSeed)
+    assert(cv.getParallelism === cv2.getParallelism)
 
     assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
     val evaluator2 = 
cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]

http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 7480173..2ed4fbb 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -23,7 +23,7 @@ import 
org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio
 import 
org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
 import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, 
Evaluator, RegressionEvaluator}
 import org.apache.spark.ml.linalg.Vectors
-import org.apache.spark.ml.param.{ParamMap}
+import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.regression.LinearRegression
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
@@ -160,11 +160,13 @@ class TrainValidationSplitSuite
       .setTrainRatio(0.5)
       .setEstimatorParamMaps(paramMaps)
       .setSeed(42L)
+      .setParallelism(2)
 
     val tvs2 = testDefaultReadWrite(tvs, testParams = false)
 
     assert(tvs.getTrainRatio === tvs2.getTrainRatio)
     assert(tvs.getSeed === tvs2.getSeed)
+    assert(tvs.getParallelism === tvs2.getParallelism)
 
     ValidatorParamsSuiteHelpers
       .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to