Repository: spark Updated Branches: refs/heads/master b00972259 -> 774398045
[SPARK-21087][ML] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala ## What changes were proposed in this pull request? We add a parameter whether to collect the full model list when CrossValidator/TrainValidationSplit training (Default is NOT), avoid the change cause OOM) - Add a method in CrossValidatorModel/TrainValidationSplitModel, allow user to get the model list - CrossValidatorModelWriter add a âoptionâ, allow user to control whether to persist the model list to disk (will persist by default). - Note: when persisting the model list, use indices as the sub-model path ## How was this patch tested? Test cases added. Author: WeichenXu <weichen...@databricks.com> Closes #19208 from WeichenXu123/expose-model-list. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/77439804 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/77439804 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/77439804 Branch: refs/heads/master Commit: 774398045b7b0cde4afb3f3c1a19ad491cf71ed1 Parents: b009722 Author: WeichenXu <weichen...@databricks.com> Authored: Tue Nov 14 16:48:26 2017 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Tue Nov 14 16:48:26 2017 -0800 ---------------------------------------------------------------------- .../ml/param/shared/SharedParamsCodeGen.scala | 8 +- .../spark/ml/param/shared/sharedParams.scala | 17 +++ .../apache/spark/ml/tuning/CrossValidator.scala | 137 +++++++++++++++++-- .../spark/ml/tuning/TrainValidationSplit.scala | 128 +++++++++++++++-- .../org/apache/spark/ml/util/ReadWrite.scala | 19 +++ .../spark/ml/tuning/CrossValidatorSuite.scala | 54 +++++++- .../ml/tuning/TrainValidationSplitSuite.scala | 48 ++++++- project/MimaExcludes.scala | 6 +- 8 files changed, 388 insertions(+), 29 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 20a1db8..c540629 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -83,7 +83,13 @@ private[shared] object SharedParamsCodeGen { "all instance weights as 1.0"), ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), - isValid = "ParamValidators.gtEq(2)", isExpertParam = true)) + isValid = "ParamValidators.gtEq(2)", isExpertParam = true), + ParamDesc[Boolean]("collectSubModels", "If set to false, then only the single best " + + "sub-model will be available after fitting. If set to true, then all sub-models will be " + + "available. Warning: For large models, collecting all sub-models can cause OOMs on the " + + "Spark driver.", + Some("false"), isExpertParam = true) + ) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 0d5fb28..34aa38a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -468,4 +468,21 @@ trait HasAggregationDepth extends Params { /** @group expertGetParam */ final def getAggregationDepth: Int = $(aggregationDepth) } + +/** + * Trait for shared param collectSubModels (default: false). + */ +private[ml] trait HasCollectSubModels extends Params { + + /** + * Param for whether to collect a list of sub-models trained during tuning. + * @group expertParam + */ + final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect a list of sub-models trained during tuning") + + setDefault(collectSubModels, false) + + /** @group expertGetParam */ + final def getCollectSubModels: Boolean = $(collectSubModels) +} // scalastyle:on http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 7c81cb9..1682ca9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tuning -import java.util.{List => JList} +import java.util.{List => JList, Locale} import scala.collection.JavaConverters._ import scala.concurrent.Future @@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.HasParallelism +import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -67,7 +67,8 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { @Since("1.2.0") class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with HasParallelism with MLWritable with Logging { + with CrossValidatorParams with HasParallelism with HasCollectSubModels + with MLWritable with Logging { @Since("1.2.0") def this() = this(Identifiable.randomUID("cv")) @@ -101,6 +102,21 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.3.0") def setParallelism(value: Int): this.type = set(parallelism, value) + /** + * Whether to collect submodels when fitting. If set, we can get submodels from + * the returned model. + * + * Note: If set this param, when you save the returned model, you can set an option + * "persistSubModels" to be "true" before saving, in order to save these submodels. + * You can check documents of + * {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter} + * for more information. + * + * @group expertSetParam + */ + @Since("2.3.0") + def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema @@ -117,6 +133,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) instr.logParams(numFolds, seed, parallelism) logTuningParams(instr) + val collectSubModelsParam = $(collectSubModels) + + var subModels: Option[Array[Array[Model[_]]]] = if (collectSubModelsParam) { + Some(Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null))) + } else None + // Compute metrics for each model over each split val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => @@ -125,10 +147,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel - val modelFutures = epm.map { paramMap => + val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Model[_]] { - val model = est.fit(trainingDataset, paramMap) - model.asInstanceOf[Model[_]] + val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] + + if (collectSubModelsParam) { + subModels.get(splitIndex)(paramIndex) = model + } + model } (executionContext) } @@ -160,7 +186,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) - copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel, metrics) + .setSubModels(subModels).setParent(this)) } @Since("1.4.0") @@ -244,6 +271,31 @@ class CrossValidatorModel private[ml] ( this(uid, bestModel, avgMetrics.asScala.toArray) } + private var _subModels: Option[Array[Array[Model[_]]]] = None + + private[tuning] def setSubModels(subModels: Option[Array[Array[Model[_]]]]) + : CrossValidatorModel = { + _subModels = subModels + this + } + + /** + * @return submodels represented in two dimension array. The index of outer array is the + * fold index, and the index of inner array corresponds to the ordering of + * estimatorParamMaps + * @throws IllegalArgumentException if subModels are not available. To retrieve subModels, + * make sure to set collectSubModels to true before fitting. + */ + @Since("2.3.0") + def subModels: Array[Array[Model[_]]] = { + require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " + + "to set collectSubModels to true before fitting.") + _subModels.get + } + + @Since("2.3.0") + def hasSubModels: Boolean = _subModels.isDefined + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -260,34 +312,76 @@ class CrossValidatorModel private[ml] ( val copied = new CrossValidatorModel( uid, bestModel.copy(extra).asInstanceOf[Model[_]], - avgMetrics.clone()) + avgMetrics.clone() + ).setSubModels(CrossValidatorModel.copySubModels(_subModels)) copyValues(copied, extra).setParent(parent) } @Since("1.6.0") - override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) + override def write: CrossValidatorModel.CrossValidatorModelWriter = { + new CrossValidatorModel.CrossValidatorModelWriter(this) + } } @Since("1.6.0") object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]]) + : Option[Array[Array[Model[_]]]] = { + subModels.map(_.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]]))) + } + @Since("1.6.0") override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader @Since("1.6.0") override def load(path: String): CrossValidatorModel = super.load(path) - private[CrossValidatorModel] - class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + /** + * Writer for CrossValidatorModel. + * @param instance CrossValidatorModel instance used to construct the writer + * + * CrossValidatorModelWriter supports an option "persistSubModels", with possible values + * "true" or "false". If you set the collectSubModels Param before fitting, then you can + * set "persistSubModels" to "true" in order to persist the subModels. By default, + * "persistSubModels" will be "true" when subModels are available and "false" otherwise. + * If subModels are not available, then setting "persistSubModels" to "true" will cause + * an exception. + */ + @Since("2.3.0") + final class CrossValidatorModelWriter private[tuning] ( + instance: CrossValidatorModel) extends MLWriter { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { + val persistSubModelsParam = optionMap.getOrElse("persistsubmodels", + if (instance.hasSubModels) "true" else "false") + + require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)), + s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " + + "values are \"true\" or \"false\"") + val persistSubModels = persistSubModelsParam.toBoolean + import org.json4s.JsonDSL._ - val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~ + ("persistSubModels" -> persistSubModels) ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + if (persistSubModels) { + require(instance.hasSubModels, "When persisting tuning models, you can only set " + + "persistSubModels to true if the tuning was done with collectSubModels set to true. " + + "To save the sub-models, try rerunning fitting with collectSubModels set to true.") + val subModelsPath = new Path(path, "subModels") + for (splitIndex <- 0 until instance.getNumFolds) { + val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}") + for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { + val modelPath = new Path(splitPath, paramIndex.toString).toString + instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) + } + } + } } } @@ -301,11 +395,30 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) + val numFolds = (metadata.params \ "numFolds").extract[Int] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val persistSubModels = (metadata.metadata \ "persistSubModels") + .extractOrElse[Boolean](false) + + val subModels: Option[Array[Array[Model[_]]]] = if (persistSubModels) { + val subModelsPath = new Path(path, "subModels") + val _subModels = Array.fill(numFolds)(Array.fill[Model[_]]( + estimatorParamMaps.length)(null)) + for (splitIndex <- 0 until numFolds) { + val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}") + for (paramIndex <- 0 until estimatorParamMaps.length) { + val modelPath = new Path(splitPath, paramIndex.toString).toString + _subModels(splitIndex)(paramIndex) = + DefaultParamsReader.loadParamsInstance(modelPath, sc) + } + } + Some(_subModels) + } else None val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + .setSubModels(subModels) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 6e3ad40..c73bd18 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.tuning -import java.io.IOException -import java.util.{List => JList} +import java.util.{List => JList, Locale} import scala.collection.JavaConverters._ import scala.concurrent.Future @@ -33,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.HasParallelism +import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -67,7 +66,8 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { @Since("1.5.0") class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Estimator[TrainValidationSplitModel] - with TrainValidationSplitParams with HasParallelism with MLWritable with Logging { + with TrainValidationSplitParams with HasParallelism with HasCollectSubModels + with MLWritable with Logging { @Since("1.5.0") def this() = this(Identifiable.randomUID("tvs")) @@ -101,6 +101,20 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.3.0") def setParallelism(value: Int): this.type = set(parallelism, value) + /** + * Whether to collect submodels when fitting. If set, we can get submodels from + * the returned model. + * + * Note: If set this param, when you save the returned model, you can set an option + * "persistSubModels" to be "true" before saving, in order to save these submodels. + * You can check documents of + * {@link org.apache.spark.ml.tuning.TrainValidationSplitModel.TrainValidationSplitModelWriter} + * for more information. + * + * @group expertSetParam + */@Since("2.3.0") + def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema @@ -121,12 +135,22 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St trainingDataset.cache() validationDataset.cache() + val collectSubModelsParam = $(collectSubModels) + + var subModels: Option[Array[Model[_]]] = if (collectSubModelsParam) { + Some(Array.fill[Model[_]](epm.length)(null)) + } else None + // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.map { paramMap => + val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Model[_]] { - val model = est.fit(trainingDataset, paramMap) - model.asInstanceOf[Model[_]] + val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] + + if (collectSubModelsParam) { + subModels.get(paramIndex) = model + } + model } (executionContext) } @@ -158,7 +182,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) - copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) + copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) + .setSubModels(subModels).setParent(this)) } @Since("1.5.0") @@ -238,6 +263,30 @@ class TrainValidationSplitModel private[ml] ( this(uid, bestModel, validationMetrics.asScala.toArray) } + private var _subModels: Option[Array[Model[_]]] = None + + private[tuning] def setSubModels(subModels: Option[Array[Model[_]]]) + : TrainValidationSplitModel = { + _subModels = subModels + this + } + + /** + * @return submodels represented in array. The index of array corresponds to the ordering of + * estimatorParamMaps + * @throws IllegalArgumentException if subModels are not available. To retrieve subModels, + * make sure to set collectSubModels to true before fitting. + */ + @Since("2.3.0") + def subModels: Array[Model[_]] = { + require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " + + "to set collectSubModels to true before fitting.") + _subModels.get + } + + @Since("2.3.0") + def hasSubModels: Boolean = _subModels.isDefined + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -254,34 +303,73 @@ class TrainValidationSplitModel private[ml] ( val copied = new TrainValidationSplitModel ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], - validationMetrics.clone()) + validationMetrics.clone() + ).setSubModels(TrainValidationSplitModel.copySubModels(_subModels)) copyValues(copied, extra).setParent(parent) } @Since("2.0.0") - override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) + override def write: TrainValidationSplitModel.TrainValidationSplitModelWriter = { + new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) + } } @Since("2.0.0") object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { + private[TrainValidationSplitModel] def copySubModels(subModels: Option[Array[Model[_]]]) + : Option[Array[Model[_]]] = { + subModels.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]])) + } + @Since("2.0.0") override def read: MLReader[TrainValidationSplitModel] = new TrainValidationSplitModelReader @Since("2.0.0") override def load(path: String): TrainValidationSplitModel = super.load(path) - private[TrainValidationSplitModel] - class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter { + /** + * Writer for TrainValidationSplitModel. + * @param instance TrainValidationSplitModel instance used to construct the writer + * + * TrainValidationSplitModel supports an option "persistSubModels", with possible values + * "true" or "false". If you set the collectSubModels Param before fitting, then you can + * set "persistSubModels" to "true" in order to persist the subModels. By default, + * "persistSubModels" will be "true" when subModels are available and "false" otherwise. + * If subModels are not available, then setting "persistSubModels" to "true" will cause + * an exception. + */ + @Since("2.3.0") + final class TrainValidationSplitModelWriter private[tuning] ( + instance: TrainValidationSplitModel) extends MLWriter { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { + val persistSubModelsParam = optionMap.getOrElse("persistsubmodels", + if (instance.hasSubModels) "true" else "false") + + require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)), + s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " + + "values are \"true\" or \"false\"") + val persistSubModels = persistSubModelsParam.toBoolean + import org.json4s.JsonDSL._ - val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq + val extraMetadata = ("validationMetrics" -> instance.validationMetrics.toSeq) ~ + ("persistSubModels" -> persistSubModels) ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + if (persistSubModels) { + require(instance.hasSubModels, "When persisting tuning models, you can only set " + + "persistSubModels to true if the tuning was done with collectSubModels set to true. " + + "To save the sub-models, try rerunning fitting with collectSubModels set to true.") + val subModelsPath = new Path(path, "subModels") + for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { + val modelPath = new Path(subModelsPath, paramIndex.toString).toString + instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath) + } + } } } @@ -298,8 +386,22 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray + val persistSubModels = (metadata.metadata \ "persistSubModels") + .extractOrElse[Boolean](false) + + val subModels: Option[Array[Model[_]]] = if (persistSubModels) { + val subModelsPath = new Path(path, "subModels") + val _subModels = Array.fill[Model[_]](estimatorParamMaps.length)(null) + for (paramIndex <- 0 until estimatorParamMaps.length) { + val modelPath = new Path(subModelsPath, paramIndex.toString).toString + _subModels(paramIndex) = + DefaultParamsReader.loadParamsInstance(modelPath, sc) + } + Some(_subModels) + } else None val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) + .setSubModels(subModels) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7188da3..a616907 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -18,6 +18,9 @@ package org.apache.spark.ml.util import java.io.IOException +import java.util.Locale + +import scala.collection.mutable import org.apache.hadoop.fs.Path import org.json4s._ @@ -108,6 +111,22 @@ abstract class MLWriter extends BaseReadWrite with Logging { protected def saveImpl(path: String): Unit /** + * Map to store extra options for this writer. + */ + protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]() + + /** + * Adds an option to the underlying MLWriter. See the documentation for the specific model's + * writer for possible options. The option name (key) is case-insensitive. + */ + @Since("2.3.0") + def option(key: String, value: String): this.type = { + require(key != null && !key.isEmpty) + optionMap.put(key.toLowerCase(Locale.ROOT), value) + this + } + + /** * Overwrites if the output path already exists. */ @Since("1.6.0") http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 853eeb3..15dade2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tuning +import java.io.File + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model, Pipeline} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} @@ -27,7 +29,7 @@ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -161,6 +163,7 @@ class CrossValidatorSuite .setEstimatorParamMaps(paramMaps) .setSeed(42L) .setParallelism(2) + .setCollectSubModels(true) val cv2 = testDefaultReadWrite(cv, testParams = false) @@ -168,6 +171,7 @@ class CrossValidatorSuite assert(cv.getNumFolds === cv2.getNumFolds) assert(cv.getSeed === cv2.getSeed) assert(cv.getParallelism === cv2.getParallelism) + assert(cv.getCollectSubModels === cv2.getCollectSubModels) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] @@ -187,6 +191,54 @@ class CrossValidatorSuite .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) } + test("CrossValidator expose sub models") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 3)) + .build() + val eval = new BinaryClassificationEvaluator + val numFolds = 3 + val subPath = new File(tempDir, "testCrossValidatorSubModels") + + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(numFolds) + .setParallelism(1) + .setCollectSubModels(true) + + val cvModel = cv.fit(dataset) + + assert(cvModel.hasSubModels && cvModel.subModels.length == numFolds) + cvModel.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + + // Test the default value for option "persistSubModel" to be "true" + val savingPathWithSubModels = new File(subPath, "cvModel3").getPath + cvModel.save(savingPathWithSubModels) + val cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + assert(cvModel3.hasSubModels && cvModel3.subModels.length == numFolds) + cvModel3.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + + val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath + cvModel.write.option("persistSubModels", "false").save(savingPathWithoutSubModels) + val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + assert(!cvModel2.hasSubModels) + + for (i <- 0 until numFolds) { + for (j <- 0 until lrParamMaps.length) { + assert(cvModel.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid === + cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid) + } + } + + val savingPathTestingIllegalParam = new File(subPath, "cvModel4").getPath + intercept[IllegalArgumentException] { + cvModel2.write.option("persistSubModels", "true").save(savingPathTestingIllegalParam) + } + } + test("read/write: CrossValidator with nested estimator") { val ova = new OneVsRest().setClassifier(new LogisticRegression) val evaluator = new MulticlassClassificationEvaluator() http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index f8d9c66..9024342 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tuning +import java.io.File + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} @@ -26,7 +28,7 @@ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -161,12 +163,14 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(paramMaps) .setSeed(42L) .setParallelism(2) + .setCollectSubModels(true) val tvs2 = testDefaultReadWrite(tvs, testParams = false) assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.getSeed === tvs2.getSeed) assert(tvs.getParallelism === tvs2.getParallelism) + assert(tvs.getCollectSubModels === tvs2.getCollectSubModels) ValidatorParamsSuiteHelpers .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) @@ -181,6 +185,48 @@ class TrainValidationSplitSuite } } + test("TrainValidationSplit expose sub models") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 3)) + .build() + val eval = new BinaryClassificationEvaluator + val subPath = new File(tempDir, "testTrainValidationSplitSubModels") + + val tvs = new TrainValidationSplit() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setParallelism(1) + .setCollectSubModels(true) + + val tvsModel = tvs.fit(dataset) + + assert(tvsModel.hasSubModels && tvsModel.subModels.length == lrParamMaps.length) + + // Test the default value for option "persistSubModel" to be "true" + val savingPathWithSubModels = new File(subPath, "tvsModel3").getPath + tvsModel.save(savingPathWithSubModels) + val tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + assert(tvsModel3.hasSubModels && tvsModel3.subModels.length == lrParamMaps.length) + + val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath + tvsModel.write.option("persistSubModels", "false").save(savingPathWithoutSubModels) + val tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + assert(!tvsModel2.hasSubModels) + + for (i <- 0 until lrParamMaps.length) { + assert(tvsModel.subModels(i).asInstanceOf[LogisticRegressionModel].uid === + tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid) + } + + val savingPathTestingIllegalParam = new File(subPath, "tvsModel4").getPath + intercept[IllegalArgumentException] { + tvsModel2.write.option("persistSubModels", "true").save(savingPathTestingIllegalParam) + } + } + test("read/write: TrainValidationSplit with nested estimator") { val ova = new OneVsRest() .setClassifier(new LogisticRegression) http://git-wip-us.apache.org/repos/asf/spark/blob/77439804/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7f18b40..915c7e2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -78,7 +78,11 @@ object MimaExcludes { // [SPARK-14280] Support Scala 2.12 ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transformWith"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform"), + + // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") ) // Exclude rules for 2.2.x --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org