This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6a668cde87c3 [SPARK-51150][ML] Explicitly pass the session in meta algorithm writers 6a668cde87c3 is described below commit 6a668cde87c397871718fa38abb182e23af5dd49 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Feb 11 12:11:16 2025 +0800 [SPARK-51150][ML] Explicitly pass the session in meta algorithm writers ### What changes were proposed in this pull request? Explicitly pass the session to avoid recreating it ### Why are the changes needed? The overhead of get/create a session is non-trivial, in the writer of a meta algorithm, we should explicitly pass the session to the writers of the submodels ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests should cover this change ### Was this patch authored or co-authored using generative AI tooling? no Closes #49871 from zhengruifeng/ml_pass_session. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala | 2 +- .../main/scala/org/apache/spark/ml/classification/OneVsRest.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 5 +++-- .../main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala | 5 +++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 807648545fc6..518e00af306c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -266,7 +266,7 @@ object Pipeline extends MLReadable[Pipeline] { // Save stages val stagesDir = new Path(path, "stages").toString stages.zipWithIndex.foreach { case (stage, idx) => - val writer = stage.asInstanceOf[MLWritable].write + val writer = stage.asInstanceOf[MLWritable].write.session(spark) val stagePath = getStagePath(stage.uid, idx, stages.length, stagesDir) instr.withSaveInstanceEvent(writer, stagePath)(writer.save(stagePath)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 0f7b6485c770..fa543cc0dd82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -284,7 +284,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { OneVsRestParams.saveImpl(path, instance, sparkSession, Some(extraJson)) instance.models.map(_.asInstanceOf[MLWritable]).zipWithIndex.foreach { case (model, idx) => val modelPath = new Path(path, s"model_$idx").toString - model.save(modelPath) + model.write.session(sparkSession).save(modelPath) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5953afb7ba78..d023c8990e76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -405,7 +405,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { ("persistSubModels" -> persistSubModels) ValidatorParams.saveImpl(path, instance, sparkSession, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString - instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + instance.bestModel.asInstanceOf[MLWritable].write.session(sparkSession).save(bestModelPath) if (persistSubModels) { require(instance.hasSubModels, "When persisting tuning models, you can only set " + "persistSubModels to true if the tuning was done with collectSubModels set to true. " + @@ -415,7 +415,8 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}") for (paramIndex <- instance.getEstimatorParamMaps.indices) { val modelPath = new Path(splitPath, paramIndex.toString).toString - instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) + instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable] + .write.session(sparkSession).save(modelPath) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index baf14f11c424..ebfcac2e4952 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -370,7 +370,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { ("persistSubModels" -> persistSubModels) ValidatorParams.saveImpl(path, instance, sparkSession, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString - instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + instance.bestModel.asInstanceOf[MLWritable].write.session(sparkSession).save(bestModelPath) if (persistSubModels) { require(instance.hasSubModels, "When persisting tuning models, you can only set " + "persistSubModels to true if the tuning was done with collectSubModels set to true. " + @@ -378,7 +378,8 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val subModelsPath = new Path(path, "subModels") for (paramIndex <- instance.getEstimatorParamMaps.indices) { val modelPath = new Path(subModelsPath, paramIndex.toString).toString - instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath) + instance.subModels(paramIndex).asInstanceOf[MLWritable] + .write.session(sparkSession).save(modelPath) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org