This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a668cde87c3 [SPARK-51150][ML] Explicitly pass the session in meta 
algorithm writers
6a668cde87c3 is described below

commit 6a668cde87c397871718fa38abb182e23af5dd49
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Feb 11 12:11:16 2025 +0800

    [SPARK-51150][ML] Explicitly pass the session in meta algorithm writers
    
    ### What changes were proposed in this pull request?
    Explicitly pass the session to avoid recreating it
    
    ### Why are the changes needed?
    The overhead of get/create a session is non-trivial, in the writer of a 
meta algorithm, we should explicitly pass the session to the writers of the 
submodels
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    existing tests should cover this change
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49871 from zhengruifeng/ml_pass_session.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala              | 2 +-
 .../main/scala/org/apache/spark/ml/classification/OneVsRest.scala    | 2 +-
 mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 5 +++--
 .../main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala | 5 +++--
 4 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 807648545fc6..518e00af306c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -266,7 +266,7 @@ object Pipeline extends MLReadable[Pipeline] {
       // Save stages
       val stagesDir = new Path(path, "stages").toString
       stages.zipWithIndex.foreach { case (stage, idx) =>
-        val writer = stage.asInstanceOf[MLWritable].write
+        val writer = stage.asInstanceOf[MLWritable].write.session(spark)
         val stagePath = getStagePath(stage.uid, idx, stages.length, stagesDir)
         instr.withSaveInstanceEvent(writer, stagePath)(writer.save(stagePath))
       }
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 0f7b6485c770..fa543cc0dd82 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -284,7 +284,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
       OneVsRestParams.saveImpl(path, instance, sparkSession, Some(extraJson))
       instance.models.map(_.asInstanceOf[MLWritable]).zipWithIndex.foreach { 
case (model, idx) =>
         val modelPath = new Path(path, s"model_$idx").toString
-        model.save(modelPath)
+        model.write.session(sparkSession).save(modelPath)
       }
     }
   }
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 5953afb7ba78..d023c8990e76 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -405,7 +405,7 @@ object CrossValidatorModel extends 
MLReadable[CrossValidatorModel] {
         ("persistSubModels" -> persistSubModels)
       ValidatorParams.saveImpl(path, instance, sparkSession, 
Some(extraMetadata))
       val bestModelPath = new Path(path, "bestModel").toString
-      instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+      
instance.bestModel.asInstanceOf[MLWritable].write.session(sparkSession).save(bestModelPath)
       if (persistSubModels) {
         require(instance.hasSubModels, "When persisting tuning models, you can 
only set " +
           "persistSubModels to true if the tuning was done with 
collectSubModels set to true. " +
@@ -415,7 +415,8 @@ object CrossValidatorModel extends 
MLReadable[CrossValidatorModel] {
           val splitPath = new Path(subModelsPath, 
s"fold${splitIndex.toString}")
           for (paramIndex <- instance.getEstimatorParamMaps.indices) {
             val modelPath = new Path(splitPath, paramIndex.toString).toString
-            
instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath)
+            instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable]
+              .write.session(sparkSession).save(modelPath)
           }
         }
       }
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index baf14f11c424..ebfcac2e4952 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -370,7 +370,7 @@ object TrainValidationSplitModel extends 
MLReadable[TrainValidationSplitModel] {
         ("persistSubModels" -> persistSubModels)
       ValidatorParams.saveImpl(path, instance, sparkSession, 
Some(extraMetadata))
       val bestModelPath = new Path(path, "bestModel").toString
-      instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+      
instance.bestModel.asInstanceOf[MLWritable].write.session(sparkSession).save(bestModelPath)
       if (persistSubModels) {
         require(instance.hasSubModels, "When persisting tuning models, you can 
only set " +
           "persistSubModels to true if the tuning was done with 
collectSubModels set to true. " +
@@ -378,7 +378,8 @@ object TrainValidationSplitModel extends 
MLReadable[TrainValidationSplitModel] {
         val subModelsPath = new Path(path, "subModels")
         for (paramIndex <- instance.getEstimatorParamMaps.indices) {
           val modelPath = new Path(subModelsPath, paramIndex.toString).toString
-          
instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath)
+          instance.subModels(paramIndex).asInstanceOf[MLWritable]
+            .write.session(sparkSession).save(modelPath)
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to