Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/19208#discussion_r141357565 --- Diff: mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala --- @@ -276,12 +315,32 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { ValidatorParams.validateParams(instance) + protected var shouldPersistSubModels: Boolean = false + + /** + * Set option for persist sub models. + */ + @Since("2.3.0") + def persistSubModels(persist: Boolean): this.type = { + shouldPersistSubModels = persist + this + } + override protected def saveImpl(path: String): Unit = { import org.json4s.JsonDSL._ - val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq + val extraMetadata = ("validationMetrics" -> instance.validationMetrics.toSeq) ~ + ("shouldPersistSubModels" -> shouldPersistSubModels) ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + if (shouldPersistSubModels) { + require(instance.subModels != null, "Cannot get sub models to persist.") + val subModelsPath = new Path(path, "subModels") + for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { + val modelPath = new Path(subModelsPath, paramIndex.toString).toString + instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath) --- End diff -- Ah, its a good point. But currently model saving code do not have some exception handling code. e.g, overwrite saving, when save failed, it do not recover the old directory. I think these things can be done in separated PRs. cc @jkbradley What' your opinion ?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org