This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 099a59c9f3cf [SPARK-51080][ML][PYTHON][CONNECT] Fix save/load for 
`PowerIterationClustering`
099a59c9f3cf is described below

commit 099a59c9f3cfebabe43c1deddb824805bb67c79e
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sat Feb 8 08:43:27 2025 +0800

    [SPARK-51080][ML][PYTHON][CONNECT] Fix save/load for 
`PowerIterationClustering`
    
    ### What changes were proposed in this pull request?
    Fix save/load for `PowerIterationClustering`
    
    ### Why are the changes needed?
    for feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    new test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49849 from zhengruifeng/ml_connect_pic_save_load.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../services/org.apache.spark.ml.Transformer       |  1 +
 .../ml/clustering/PowerIterationClustering.scala   | 30 ++++++++++++++++
 python/pyspark/ml/connect/readwrite.py             | 40 ++++++++++++++++++++++
 python/pyspark/ml/tests/test_clustering.py         | 11 +++++-
 4 files changed, 81 insertions(+), 1 deletion(-)

diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index fc6a8166442a..4c0a68753b55 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -65,6 +65,7 @@ org.apache.spark.ml.clustering.BisectingKMeansModel
 org.apache.spark.ml.clustering.GaussianMixtureModel
 org.apache.spark.ml.clustering.DistributedLDAModel
 org.apache.spark.ml.clustering.LocalLDAModel
+org.apache.spark.ml.clustering.PowerIterationClusteringWrapper
 
 # recommendation
 org.apache.spark.ml.recommendation.ALSModel
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
index 8b2ee955d6a5..d14fc7fcc875 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.clustering
 
 import org.apache.spark.annotation.Since
+import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
@@ -191,3 +192,32 @@ object PowerIterationClustering extends 
DefaultParamsReadable[PowerIterationClus
   @Since("2.4.0")
   override def load(path: String): PowerIterationClustering = super.load(path)
 }
+
+private[spark] class PowerIterationClusteringWrapper(override val uid: String)
+  extends Transformer with PowerIterationClusteringParams with 
DefaultParamsWritable {
+
+  def this() = this(Identifiable.randomUID("PowerIterationClusteringWrapper"))
+
+  override def transform(dataset: Dataset[_]): DataFrame =
+    throw new UnsupportedOperationException("transform not supported")
+
+  override def transformSchema(schema: StructType): StructType =
+    throw new UnsupportedOperationException("transformSchema not supported")
+
+  override def copy(extra: ParamMap): PowerIterationClusteringWrapper = 
defaultCopy(extra)
+
+  override def write: MLWriter = new MLWriter {
+    override protected def saveImpl(path: String): Unit = {
+      new PowerIterationClustering(uid).copy(paramMap).save(path)
+    }
+  }
+}
+
+private[spark] object PowerIterationClusteringWrapper
+  extends DefaultParamsReadable[PowerIterationClusteringWrapper] {
+
+  override def load(path: String): PowerIterationClusteringWrapper = {
+    val pic = PowerIterationClustering.load(path)
+    new PowerIterationClusteringWrapper(pic.uid).copy(pic.paramMap)
+  }
+}
diff --git a/python/pyspark/ml/connect/readwrite.py 
b/python/pyspark/ml/connect/readwrite.py
index 5c2b850d51b3..584ff3237a0a 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -96,6 +96,7 @@ class RemoteMLWriter(MLWriter):
         from pyspark.ml.pipeline import Pipeline, PipelineModel
         from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
         from pyspark.ml.classification import OneVsRest, OneVsRestModel
+        from pyspark.ml.clustering import PowerIterationClustering
 
         # Spark Connect ML is built on scala Spark.ML, that means we're only
         # supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -188,6 +189,21 @@ class RemoteMLWriter(MLWriter):
             ovrm_writer = OneVsRestModelWriter(instance)
             ovrm_writer.session(session)  # type: ignore[arg-type]
             ovrm_writer.save(path)
+
+        elif isinstance(instance, PowerIterationClustering):
+            transformer = JavaTransformer(
+                
"org.apache.spark.ml.clustering.PowerIterationClusteringWrapper"
+            )
+            transformer._resetUid(instance.uid)
+            transformer._paramMap = instance._paramMap
+            RemoteMLWriter.saveInstance(
+                transformer,  # type: ignore[arg-type]
+                path,
+                session,
+                shouldOverwrite,
+                optionMap,
+            )
+
         else:
             raise NotImplementedError(f"Unsupported write for 
{instance.__class__}")
 
@@ -224,6 +240,7 @@ class RemoteMLReader(MLReader[RL]):
         from pyspark.ml.pipeline import Pipeline, PipelineModel
         from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
         from pyspark.ml.classification import OneVsRest, OneVsRestModel
+        from pyspark.ml.clustering import PowerIterationClustering
 
         if (
             issubclass(clazz, JavaModel)
@@ -332,5 +349,28 @@ class RemoteMLReader(MLReader[RL]):
             ovrm_reader.session(session)
             return ovrm_reader.load(path)
 
+        elif issubclass(clazz, PowerIterationClustering):
+            java_qualified_class_name = (
+                
"org.apache.spark.ml.clustering.PowerIterationClusteringWrapper"
+            )
+
+            command = pb2.Command()
+            command.ml_command.read.CopyFrom(
+                pb2.MlCommand.Read(
+                    operator=pb2.MlOperator(
+                        name=java_qualified_class_name, 
type=pb2.MlOperator.TRANSFORMER
+                    ),
+                    path=path,
+                )
+            )
+            (_, properties, _) = session.client.execute_command(command)
+            result = deserialize(properties)
+
+            instance = PowerIterationClustering()
+            instance._resetUid(result.uid)
+            params = {k: deserialize_param(v) for k, v in 
result.params.params.items()}
+            instance._set(**params)
+            return instance  # type: ignore[return-value]
+
         else:
             raise RuntimeError(f"Unsupported read for {clazz}")
diff --git a/python/pyspark/ml/tests/test_clustering.py 
b/python/pyspark/ml/tests/test_clustering.py
index 136a166d2218..8ec1fcc48ca6 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -462,7 +462,6 @@ class ClusteringTestsMixin:
             model2 = DistributedLDAModel.load(d)
             self.assertEqual(str(model), str(model2))
 
-    # TODO(SPARK-51080): Fix save/load for PowerIterationClustering
     def test_power_iteration_clustering(self):
         spark = self.spark
 
@@ -496,6 +495,16 @@ class ClusteringTestsMixin:
         self.assertEqual(assignments.columns, ["id", "cluster"])
         self.assertEqual(assignments.count(), 6)
 
+        # save & load
+        with tempfile.TemporaryDirectory(prefix="power_iteration_clustering") 
as d:
+            pic.write().overwrite().save(d)
+            pic2 = PowerIterationClustering.load(d)
+            self.assertEqual(str(pic), str(pic2))
+            self.assertEqual(pic.uid, pic2.uid)
+            self.assertEqual(pic.getK(), pic2.getK())
+            self.assertEqual(pic.getMaxIter(), pic2.getMaxIter())
+            self.assertEqual(pic.getWeightCol(), pic2.getWeightCol())
+
 
 class ClusteringTests(ClusteringTestsMixin, unittest.TestCase):
     def setUp(self) -> None:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to