Repository: spark
Updated Branches:
  refs/heads/master b00972259 -> 774398045


[SPARK-21087][ML] CrossValidator, TrainValidationSplit expose sub models after 
fitting: Scala

## What changes were proposed in this pull request?

We add a parameter whether to collect the full model list when 
CrossValidator/TrainValidationSplit training (Default is NOT), avoid the change 
cause OOM)

- Add a method in CrossValidatorModel/TrainValidationSplitModel, allow user to 
get the model list

- CrossValidatorModelWriter add a “option”, allow user to control whether 
to persist the model list to disk (will persist by default).

- Note: when persisting the model list, use indices as the sub-model path

## How was this patch tested?

Test cases added.

Author: WeichenXu <weichen...@databricks.com>

Closes #19208 from WeichenXu123/expose-model-list.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/77439804
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/77439804
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/77439804

Branch: refs/heads/master
Commit: 774398045b7b0cde4afb3f3c1a19ad491cf71ed1
Parents: b009722
Author: WeichenXu <weichen...@databricks.com>
Authored: Tue Nov 14 16:48:26 2017 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Nov 14 16:48:26 2017 -0800

----------------------------------------------------------------------
 .../ml/param/shared/SharedParamsCodeGen.scala   |   8 +-
 .../spark/ml/param/shared/sharedParams.scala    |  17 +++
 .../apache/spark/ml/tuning/CrossValidator.scala | 137 +++++++++++++++++--
 .../spark/ml/tuning/TrainValidationSplit.scala  | 128 +++++++++++++++--
 .../org/apache/spark/ml/util/ReadWrite.scala    |  19 +++
 .../spark/ml/tuning/CrossValidatorSuite.scala   |  54 +++++++-
 .../ml/tuning/TrainValidationSplitSuite.scala   |  48 ++++++-
 project/MimaExcludes.scala                      |   6 +-
 8 files changed, 388 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 20a1db8..c540629 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -83,7 +83,13 @@ private[shared] object SharedParamsCodeGen {
         "all instance weights as 1.0"),
       ParamDesc[String]("solver", "the solver algorithm for optimization", 
finalFields = false),
       ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate 
(>= 2)", Some("2"),
-        isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
+        isValid = "ParamValidators.gtEq(2)", isExpertParam = true),
+      ParamDesc[Boolean]("collectSubModels", "If set to false, then only the 
single best " +
+        "sub-model will be available after fitting. If set to true, then all 
sub-models will be " +
+        "available. Warning: For large models, collecting all sub-models can 
cause OOMs on the " +
+        "Spark driver.",
+        Some("false"), isExpertParam = true)
+    )
 
     val code = genSharedParams(params)
     val file = 
"src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"

http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 0d5fb28..34aa38a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -468,4 +468,21 @@ trait HasAggregationDepth extends Params {
   /** @group expertGetParam */
   final def getAggregationDepth: Int = $(aggregationDepth)
 }
+
+/**
+ * Trait for shared param collectSubModels (default: false).
+ */
+private[ml] trait HasCollectSubModels extends Params {
+
+  /**
+   * Param for whether to collect a list of sub-models trained during tuning.
+   * @group expertParam
+   */
+  final val collectSubModels: BooleanParam = new BooleanParam(this, 
"collectSubModels", "whether to collect a list of sub-models trained during 
tuning")
+
+  setDefault(collectSubModels, false)
+
+  /** @group expertGetParam */
+  final def getCollectSubModels: Boolean = $(collectSubModels)
+}
 // scalastyle:on

http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 7c81cb9..1682ca9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.ml.tuning
 
-import java.util.{List => JList}
+import java.util.{List => JList, Locale}
 
 import scala.collection.JavaConverters._
 import scala.concurrent.Future
@@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
-import org.apache.spark.ml.param.shared.HasParallelism
+import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism}
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.sql.{DataFrame, Dataset}
@@ -67,7 +67,8 @@ private[ml] trait CrossValidatorParams extends 
ValidatorParams {
 @Since("1.2.0")
 class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
   extends Estimator[CrossValidatorModel]
-  with CrossValidatorParams with HasParallelism with MLWritable with Logging {
+  with CrossValidatorParams with HasParallelism with HasCollectSubModels
+  with MLWritable with Logging {
 
   @Since("1.2.0")
   def this() = this(Identifiable.randomUID("cv"))
@@ -101,6 +102,21 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") 
override val uid: String)
   @Since("2.3.0")
   def setParallelism(value: Int): this.type = set(parallelism, value)
 
+  /**
+   * Whether to collect submodels when fitting. If set, we can get submodels 
from
+   * the returned model.
+   *
+   * Note: If set this param, when you save the returned model, you can set an 
option
+   * "persistSubModels" to be "true" before saving, in order to save these 
submodels.
+   * You can check documents of
+   * {@link 
org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter}
+   * for more information.
+   *
+   * @group expertSetParam
+   */
+  @Since("2.3.0")
+  def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, 
value)
+
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): CrossValidatorModel = {
     val schema = dataset.schema
@@ -117,6 +133,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") 
override val uid: String)
     instr.logParams(numFolds, seed, parallelism)
     logTuningParams(instr)
 
+    val collectSubModelsParam = $(collectSubModels)
+
+    var subModels: Option[Array[Array[Model[_]]]] = if (collectSubModelsParam) 
{
+      Some(Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null)))
+    } else None
+
     // Compute metrics for each model over each split
     val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
     val metrics = splits.zipWithIndex.map { case ((training, validation), 
splitIndex) =>
@@ -125,10 +147,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") 
override val uid: String)
       logDebug(s"Train split $splitIndex with multiple sets of parameters.")
 
       // Fit models in a Future for training in parallel
-      val modelFutures = epm.map { paramMap =>
+      val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
         Future[Model[_]] {
-          val model = est.fit(trainingDataset, paramMap)
-          model.asInstanceOf[Model[_]]
+          val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
+
+          if (collectSubModelsParam) {
+            subModels.get(splitIndex)(paramIndex) = model
+          }
+          model
         } (executionContext)
       }
 
@@ -160,7 +186,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") 
override val uid: String)
     logInfo(s"Best cross-validation metric: $bestMetric.")
     val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
     instr.logSuccess(bestModel)
-    copyValues(new CrossValidatorModel(uid, bestModel, 
metrics).setParent(this))
+    copyValues(new CrossValidatorModel(uid, bestModel, metrics)
+      .setSubModels(subModels).setParent(this))
   }
 
   @Since("1.4.0")
@@ -244,6 +271,31 @@ class CrossValidatorModel private[ml] (
     this(uid, bestModel, avgMetrics.asScala.toArray)
   }
 
+  private var _subModels: Option[Array[Array[Model[_]]]] = None
+
+  private[tuning] def setSubModels(subModels: Option[Array[Array[Model[_]]]])
+    : CrossValidatorModel = {
+    _subModels = subModels
+    this
+  }
+
+  /**
+   * @return submodels represented in two dimension array. The index of outer 
array is the
+   *         fold index, and the index of inner array corresponds to the 
ordering of
+   *         estimatorParamMaps
+   * @throws IllegalArgumentException if subModels are not available. To 
retrieve subModels,
+   *         make sure to set collectSubModels to true before fitting.
+   */
+  @Since("2.3.0")
+  def subModels: Array[Array[Model[_]]] = {
+    require(_subModels.isDefined, "subModels not available, To retrieve 
subModels, make sure " +
+      "to set collectSubModels to true before fitting.")
+    _subModels.get
+  }
+
+  @Since("2.3.0")
+  def hasSubModels: Boolean = _subModels.isDefined
+
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
@@ -260,34 +312,76 @@ class CrossValidatorModel private[ml] (
     val copied = new CrossValidatorModel(
       uid,
       bestModel.copy(extra).asInstanceOf[Model[_]],
-      avgMetrics.clone())
+      avgMetrics.clone()
+    ).setSubModels(CrossValidatorModel.copySubModels(_subModels))
     copyValues(copied, extra).setParent(parent)
   }
 
   @Since("1.6.0")
-  override def write: MLWriter = new 
CrossValidatorModel.CrossValidatorModelWriter(this)
+  override def write: CrossValidatorModel.CrossValidatorModelWriter = {
+    new CrossValidatorModel.CrossValidatorModelWriter(this)
+  }
 }
 
 @Since("1.6.0")
 object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
 
+  private[CrossValidatorModel] def copySubModels(subModels: 
Option[Array[Array[Model[_]]]])
+    : Option[Array[Array[Model[_]]]] = {
+    subModels.map(_.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]])))
+  }
+
   @Since("1.6.0")
   override def read: MLReader[CrossValidatorModel] = new 
CrossValidatorModelReader
 
   @Since("1.6.0")
   override def load(path: String): CrossValidatorModel = super.load(path)
 
-  private[CrossValidatorModel]
-  class CrossValidatorModelWriter(instance: CrossValidatorModel) extends 
MLWriter {
+  /**
+   * Writer for CrossValidatorModel.
+   * @param instance CrossValidatorModel instance used to construct the writer
+   *
+   * CrossValidatorModelWriter supports an option "persistSubModels", with 
possible values
+   * "true" or "false". If you set the collectSubModels Param before fitting, 
then you can
+   * set "persistSubModels" to "true" in order to persist the subModels. By 
default,
+   * "persistSubModels" will be "true" when subModels are available and 
"false" otherwise.
+   * If subModels are not available, then setting "persistSubModels" to "true" 
will cause
+   * an exception.
+   */
+  @Since("2.3.0")
+  final class CrossValidatorModelWriter private[tuning] (
+      instance: CrossValidatorModel) extends MLWriter {
 
     ValidatorParams.validateParams(instance)
 
     override protected def saveImpl(path: String): Unit = {
+      val persistSubModelsParam = optionMap.getOrElse("persistsubmodels",
+        if (instance.hasSubModels) "true" else "false")
+
+      require(Array("true", 
"false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)),
+        s"persistSubModels option value ${persistSubModelsParam} is invalid, 
the possible " +
+        "values are \"true\" or \"false\"")
+      val persistSubModels = persistSubModelsParam.toBoolean
+
       import org.json4s.JsonDSL._
-      val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
+      val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~
+        ("persistSubModels" -> persistSubModels)
       ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
       val bestModelPath = new Path(path, "bestModel").toString
       instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+      if (persistSubModels) {
+        require(instance.hasSubModels, "When persisting tuning models, you can 
only set " +
+          "persistSubModels to true if the tuning was done with 
collectSubModels set to true. " +
+          "To save the sub-models, try rerunning fitting with collectSubModels 
set to true.")
+        val subModelsPath = new Path(path, "subModels")
+        for (splitIndex <- 0 until instance.getNumFolds) {
+          val splitPath = new Path(subModelsPath, 
s"fold${splitIndex.toString}")
+          for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
+            val modelPath = new Path(splitPath, paramIndex.toString).toString
+            
instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath)
+          }
+        }
+      }
     }
   }
 
@@ -301,11 +395,30 @@ object CrossValidatorModel extends 
MLReadable[CrossValidatorModel] {
 
       val (metadata, estimator, evaluator, estimatorParamMaps) =
         ValidatorParams.loadImpl(path, sc, className)
+      val numFolds = (metadata.params \ "numFolds").extract[Int]
       val bestModelPath = new Path(path, "bestModel").toString
       val bestModel = 
DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
       val avgMetrics = (metadata.metadata \ 
"avgMetrics").extract[Seq[Double]].toArray
+      val persistSubModels = (metadata.metadata \ "persistSubModels")
+        .extractOrElse[Boolean](false)
+
+      val subModels: Option[Array[Array[Model[_]]]] = if (persistSubModels) {
+        val subModelsPath = new Path(path, "subModels")
+        val _subModels = Array.fill(numFolds)(Array.fill[Model[_]](
+          estimatorParamMaps.length)(null))
+        for (splitIndex <- 0 until numFolds) {
+          val splitPath = new Path(subModelsPath, 
s"fold${splitIndex.toString}")
+          for (paramIndex <- 0 until estimatorParamMaps.length) {
+            val modelPath = new Path(splitPath, paramIndex.toString).toString
+            _subModels(splitIndex)(paramIndex) =
+              DefaultParamsReader.loadParamsInstance(modelPath, sc)
+          }
+        }
+        Some(_subModels)
+      } else None
 
       val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
+        .setSubModels(subModels)
       model.set(model.estimator, estimator)
         .set(model.evaluator, evaluator)
         .set(model.estimatorParamMaps, estimatorParamMaps)

http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 6e3ad40..c73bd18 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -17,8 +17,7 @@
 
 package org.apache.spark.ml.tuning
 
-import java.io.IOException
-import java.util.{List => JList}
+import java.util.{List => JList, Locale}
 
 import scala.collection.JavaConverters._
 import scala.concurrent.Future
@@ -33,7 +32,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
-import org.apache.spark.ml.param.shared.HasParallelism
+import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism}
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.StructType
@@ -67,7 +66,8 @@ private[ml] trait TrainValidationSplitParams extends 
ValidatorParams {
 @Since("1.5.0")
 class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: 
String)
   extends Estimator[TrainValidationSplitModel]
-  with TrainValidationSplitParams with HasParallelism with MLWritable with 
Logging {
+  with TrainValidationSplitParams with HasParallelism with HasCollectSubModels
+  with MLWritable with Logging {
 
   @Since("1.5.0")
   def this() = this(Identifiable.randomUID("tvs"))
@@ -101,6 +101,20 @@ class TrainValidationSplit @Since("1.5.0") 
(@Since("1.5.0") override val uid: St
   @Since("2.3.0")
   def setParallelism(value: Int): this.type = set(parallelism, value)
 
+  /**
+   * Whether to collect submodels when fitting. If set, we can get submodels 
from
+   * the returned model.
+   *
+   * Note: If set this param, when you save the returned model, you can set an 
option
+   * "persistSubModels" to be "true" before saving, in order to save these 
submodels.
+   * You can check documents of
+   * {@link 
org.apache.spark.ml.tuning.TrainValidationSplitModel.TrainValidationSplitModelWriter}
+   * for more information.
+   *
+   * @group expertSetParam
+   */@Since("2.3.0")
+  def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, 
value)
+
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
     val schema = dataset.schema
@@ -121,12 +135,22 @@ class TrainValidationSplit @Since("1.5.0") 
(@Since("1.5.0") override val uid: St
     trainingDataset.cache()
     validationDataset.cache()
 
+    val collectSubModelsParam = $(collectSubModels)
+
+    var subModels: Option[Array[Model[_]]] = if (collectSubModelsParam) {
+      Some(Array.fill[Model[_]](epm.length)(null))
+    } else None
+
     // Fit models in a Future for training in parallel
     logDebug(s"Train split with multiple sets of parameters.")
-    val modelFutures = epm.map { paramMap =>
+    val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
       Future[Model[_]] {
-        val model = est.fit(trainingDataset, paramMap)
-        model.asInstanceOf[Model[_]]
+        val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
+
+        if (collectSubModelsParam) {
+          subModels.get(paramIndex) = model
+        }
+        model
       } (executionContext)
     }
 
@@ -158,7 +182,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") 
override val uid: St
     logInfo(s"Best train validation split metric: $bestMetric.")
     val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
     instr.logSuccess(bestModel)
-    copyValues(new TrainValidationSplitModel(uid, bestModel, 
metrics).setParent(this))
+    copyValues(new TrainValidationSplitModel(uid, bestModel, metrics)
+      .setSubModels(subModels).setParent(this))
   }
 
   @Since("1.5.0")
@@ -238,6 +263,30 @@ class TrainValidationSplitModel private[ml] (
     this(uid, bestModel, validationMetrics.asScala.toArray)
   }
 
+  private var _subModels: Option[Array[Model[_]]] = None
+
+  private[tuning] def setSubModels(subModels: Option[Array[Model[_]]])
+    : TrainValidationSplitModel = {
+    _subModels = subModels
+    this
+  }
+
+  /**
+   * @return submodels represented in array. The index of array corresponds to 
the ordering of
+   *         estimatorParamMaps
+   * @throws IllegalArgumentException if subModels are not available. To 
retrieve subModels,
+   *         make sure to set collectSubModels to true before fitting.
+   */
+  @Since("2.3.0")
+  def subModels: Array[Model[_]] = {
+    require(_subModels.isDefined, "subModels not available, To retrieve 
subModels, make sure " +
+      "to set collectSubModels to true before fitting.")
+    _subModels.get
+  }
+
+  @Since("2.3.0")
+  def hasSubModels: Boolean = _subModels.isDefined
+
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
@@ -254,34 +303,73 @@ class TrainValidationSplitModel private[ml] (
     val copied = new TrainValidationSplitModel (
       uid,
       bestModel.copy(extra).asInstanceOf[Model[_]],
-      validationMetrics.clone())
+      validationMetrics.clone()
+    ).setSubModels(TrainValidationSplitModel.copySubModels(_subModels))
     copyValues(copied, extra).setParent(parent)
   }
 
   @Since("2.0.0")
-  override def write: MLWriter = new 
TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
+  override def write: 
TrainValidationSplitModel.TrainValidationSplitModelWriter = {
+    new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
+  }
 }
 
 @Since("2.0.0")
 object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] 
{
 
+  private[TrainValidationSplitModel] def copySubModels(subModels: 
Option[Array[Model[_]]])
+    : Option[Array[Model[_]]] = {
+    subModels.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]]))
+  }
+
   @Since("2.0.0")
   override def read: MLReader[TrainValidationSplitModel] = new 
TrainValidationSplitModelReader
 
   @Since("2.0.0")
   override def load(path: String): TrainValidationSplitModel = super.load(path)
 
-  private[TrainValidationSplitModel]
-  class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) 
extends MLWriter {
+  /**
+   * Writer for TrainValidationSplitModel.
+   * @param instance TrainValidationSplitModel instance used to construct the 
writer
+   *
+   * TrainValidationSplitModel supports an option "persistSubModels", with 
possible values
+   * "true" or "false". If you set the collectSubModels Param before fitting, 
then you can
+   * set "persistSubModels" to "true" in order to persist the subModels. By 
default,
+   * "persistSubModels" will be "true" when subModels are available and 
"false" otherwise.
+   * If subModels are not available, then setting "persistSubModels" to "true" 
will cause
+   * an exception.
+   */
+  @Since("2.3.0")
+  final class TrainValidationSplitModelWriter private[tuning] (
+      instance: TrainValidationSplitModel) extends MLWriter {
 
     ValidatorParams.validateParams(instance)
 
     override protected def saveImpl(path: String): Unit = {
+      val persistSubModelsParam = optionMap.getOrElse("persistsubmodels",
+        if (instance.hasSubModels) "true" else "false")
+
+      require(Array("true", 
"false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)),
+        s"persistSubModels option value ${persistSubModelsParam} is invalid, 
the possible " +
+        "values are \"true\" or \"false\"")
+      val persistSubModels = persistSubModelsParam.toBoolean
+
       import org.json4s.JsonDSL._
-      val extraMetadata = "validationMetrics" -> 
instance.validationMetrics.toSeq
+      val extraMetadata = ("validationMetrics" -> 
instance.validationMetrics.toSeq) ~
+        ("persistSubModels" -> persistSubModels)
       ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
       val bestModelPath = new Path(path, "bestModel").toString
       instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+      if (persistSubModels) {
+        require(instance.hasSubModels, "When persisting tuning models, you can 
only set " +
+          "persistSubModels to true if the tuning was done with 
collectSubModels set to true. " +
+          "To save the sub-models, try rerunning fitting with collectSubModels 
set to true.")
+        val subModelsPath = new Path(path, "subModels")
+        for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
+          val modelPath = new Path(subModelsPath, paramIndex.toString).toString
+          
instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath)
+        }
+      }
     }
   }
 
@@ -298,8 +386,22 @@ object TrainValidationSplitModel extends 
MLReadable[TrainValidationSplitModel] {
       val bestModelPath = new Path(path, "bestModel").toString
       val bestModel = 
DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
       val validationMetrics = (metadata.metadata \ 
"validationMetrics").extract[Seq[Double]].toArray
+      val persistSubModels = (metadata.metadata \ "persistSubModels")
+        .extractOrElse[Boolean](false)
+
+      val subModels: Option[Array[Model[_]]] = if (persistSubModels) {
+        val subModelsPath = new Path(path, "subModels")
+        val _subModels = Array.fill[Model[_]](estimatorParamMaps.length)(null)
+        for (paramIndex <- 0 until estimatorParamMaps.length) {
+          val modelPath = new Path(subModelsPath, paramIndex.toString).toString
+          _subModels(paramIndex) =
+            DefaultParamsReader.loadParamsInstance(modelPath, sc)
+        }
+        Some(_subModels)
+      } else None
 
       val model = new TrainValidationSplitModel(metadata.uid, bestModel, 
validationMetrics)
+        .setSubModels(subModels)
       model.set(model.estimator, estimator)
         .set(model.evaluator, evaluator)
         .set(model.estimatorParamMaps, estimatorParamMaps)

http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 7188da3..a616907 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -18,6 +18,9 @@
 package org.apache.spark.ml.util
 
 import java.io.IOException
+import java.util.Locale
+
+import scala.collection.mutable
 
 import org.apache.hadoop.fs.Path
 import org.json4s._
@@ -108,6 +111,22 @@ abstract class MLWriter extends BaseReadWrite with Logging 
{
   protected def saveImpl(path: String): Unit
 
   /**
+   * Map to store extra options for this writer.
+   */
+  protected val optionMap: mutable.Map[String, String] = new 
mutable.HashMap[String, String]()
+
+  /**
+   * Adds an option to the underlying MLWriter. See the documentation for the 
specific model's
+   * writer for possible options. The option name (key) is case-insensitive.
+   */
+  @Since("2.3.0")
+  def option(key: String, value: String): this.type = {
+    require(key != null && !key.isEmpty)
+    optionMap.put(key.toLowerCase(Locale.ROOT), value)
+    this
+  }
+
+  /**
    * Overwrites if the output path already exists.
    */
   @Since("1.6.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 853eeb3..15dade2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.ml.tuning
 
+import java.io.File
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.{Estimator, Model, Pipeline}
 import org.apache.spark.ml.classification.{LogisticRegression, 
LogisticRegressionModel, OneVsRest}
@@ -27,7 +29,7 @@ import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, 
MLTestingUtils}
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.types.StructType
@@ -161,6 +163,7 @@ class CrossValidatorSuite
       .setEstimatorParamMaps(paramMaps)
       .setSeed(42L)
       .setParallelism(2)
+      .setCollectSubModels(true)
 
     val cv2 = testDefaultReadWrite(cv, testParams = false)
 
@@ -168,6 +171,7 @@ class CrossValidatorSuite
     assert(cv.getNumFolds === cv2.getNumFolds)
     assert(cv.getSeed === cv2.getSeed)
     assert(cv.getParallelism === cv2.getParallelism)
+    assert(cv.getCollectSubModels === cv2.getCollectSubModels)
 
     assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
     val evaluator2 = 
cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
@@ -187,6 +191,54 @@ class CrossValidatorSuite
       .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
   }
 
+  test("CrossValidator expose sub models") {
+    val lr = new LogisticRegression
+    val lrParamMaps = new ParamGridBuilder()
+      .addGrid(lr.regParam, Array(0.001, 1000.0))
+      .addGrid(lr.maxIter, Array(0, 3))
+      .build()
+    val eval = new BinaryClassificationEvaluator
+    val numFolds = 3
+    val subPath = new File(tempDir, "testCrossValidatorSubModels")
+
+    val cv = new CrossValidator()
+      .setEstimator(lr)
+      .setEstimatorParamMaps(lrParamMaps)
+      .setEvaluator(eval)
+      .setNumFolds(numFolds)
+      .setParallelism(1)
+      .setCollectSubModels(true)
+
+    val cvModel = cv.fit(dataset)
+
+    assert(cvModel.hasSubModels && cvModel.subModels.length == numFolds)
+    cvModel.subModels.foreach(array => assert(array.length == 
lrParamMaps.length))
+
+    // Test the default value for option "persistSubModel" to be "true"
+    val savingPathWithSubModels = new File(subPath, "cvModel3").getPath
+    cvModel.save(savingPathWithSubModels)
+    val cvModel3 = CrossValidatorModel.load(savingPathWithSubModels)
+    assert(cvModel3.hasSubModels && cvModel3.subModels.length == numFolds)
+    cvModel3.subModels.foreach(array => assert(array.length == 
lrParamMaps.length))
+
+    val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath
+    cvModel.write.option("persistSubModels", 
"false").save(savingPathWithoutSubModels)
+    val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
+    assert(!cvModel2.hasSubModels)
+
+    for (i <- 0 until numFolds) {
+      for (j <- 0 until lrParamMaps.length) {
+        
assert(cvModel.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid ===
+          cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid)
+      }
+    }
+
+    val savingPathTestingIllegalParam = new File(subPath, "cvModel4").getPath
+    intercept[IllegalArgumentException] {
+      cvModel2.write.option("persistSubModels", 
"true").save(savingPathTestingIllegalParam)
+    }
+  }
+
   test("read/write: CrossValidator with nested estimator") {
     val ova = new OneVsRest().setClassifier(new LogisticRegression)
     val evaluator = new MulticlassClassificationEvaluator()

http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index f8d9c66..9024342 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.ml.tuning
 
+import java.io.File
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.classification.{LogisticRegression, 
LogisticRegressionModel, OneVsRest}
@@ -26,7 +28,7 @@ import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, 
MLTestingUtils}
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.types.StructType
@@ -161,12 +163,14 @@ class TrainValidationSplitSuite
       .setEstimatorParamMaps(paramMaps)
       .setSeed(42L)
       .setParallelism(2)
+      .setCollectSubModels(true)
 
     val tvs2 = testDefaultReadWrite(tvs, testParams = false)
 
     assert(tvs.getTrainRatio === tvs2.getTrainRatio)
     assert(tvs.getSeed === tvs2.getSeed)
     assert(tvs.getParallelism === tvs2.getParallelism)
+    assert(tvs.getCollectSubModels === tvs2.getCollectSubModels)
 
     ValidatorParamsSuiteHelpers
       .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)
@@ -181,6 +185,48 @@ class TrainValidationSplitSuite
     }
   }
 
+  test("TrainValidationSplit expose sub models") {
+    val lr = new LogisticRegression
+    val lrParamMaps = new ParamGridBuilder()
+      .addGrid(lr.regParam, Array(0.001, 1000.0))
+      .addGrid(lr.maxIter, Array(0, 3))
+      .build()
+    val eval = new BinaryClassificationEvaluator
+    val subPath = new File(tempDir, "testTrainValidationSplitSubModels")
+
+    val tvs = new TrainValidationSplit()
+      .setEstimator(lr)
+      .setEstimatorParamMaps(lrParamMaps)
+      .setEvaluator(eval)
+      .setParallelism(1)
+      .setCollectSubModels(true)
+
+    val tvsModel = tvs.fit(dataset)
+
+    assert(tvsModel.hasSubModels && tvsModel.subModels.length == 
lrParamMaps.length)
+
+    // Test the default value for option "persistSubModel" to be "true"
+    val savingPathWithSubModels = new File(subPath, "tvsModel3").getPath
+    tvsModel.save(savingPathWithSubModels)
+    val tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels)
+    assert(tvsModel3.hasSubModels && tvsModel3.subModels.length == 
lrParamMaps.length)
+
+    val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath
+    tvsModel.write.option("persistSubModels", 
"false").save(savingPathWithoutSubModels)
+    val tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels)
+    assert(!tvsModel2.hasSubModels)
+
+    for (i <- 0 until lrParamMaps.length) {
+      assert(tvsModel.subModels(i).asInstanceOf[LogisticRegressionModel].uid 
===
+        tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid)
+    }
+
+    val savingPathTestingIllegalParam = new File(subPath, "tvsModel4").getPath
+    intercept[IllegalArgumentException] {
+      tvsModel2.write.option("persistSubModels", 
"true").save(savingPathTestingIllegalParam)
+    }
+  }
+
   test("read/write: TrainValidationSplit with nested estimator") {
     val ova = new OneVsRest()
       .setClassifier(new LogisticRegression)

http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 7f18b40..915c7e2 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -78,7 +78,11 @@ object MimaExcludes {
 
     // [SPARK-14280] Support Scala 2.12
     
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transformWith"),
-    
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform")
+    
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform"),
+
+    // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models 
after fitting: Scala
+    
ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"),
+    
ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter")
   )
 
   // Exclude rules for 2.2.x


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to