This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 099a59c9f3cf [SPARK-51080][ML][PYTHON][CONNECT] Fix save/load for `PowerIterationClustering` 099a59c9f3cf is described below commit 099a59c9f3cfebabe43c1deddb824805bb67c79e Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sat Feb 8 08:43:27 2025 +0800 [SPARK-51080][ML][PYTHON][CONNECT] Fix save/load for `PowerIterationClustering` ### What changes were proposed in this pull request? Fix save/load for `PowerIterationClustering` ### Why are the changes needed? for feature parity ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? no Closes #49849 from zhengruifeng/ml_connect_pic_save_load. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../services/org.apache.spark.ml.Transformer | 1 + .../ml/clustering/PowerIterationClustering.scala | 30 ++++++++++++++++ python/pyspark/ml/connect/readwrite.py | 40 ++++++++++++++++++++++ python/pyspark/ml/tests/test_clustering.py | 11 +++++- 4 files changed, 81 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer index fc6a8166442a..4c0a68753b55 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer @@ -65,6 +65,7 @@ org.apache.spark.ml.clustering.BisectingKMeansModel org.apache.spark.ml.clustering.GaussianMixtureModel org.apache.spark.ml.clustering.DistributedLDAModel org.apache.spark.ml.clustering.LocalLDAModel +org.apache.spark.ml.clustering.PowerIterationClusteringWrapper # recommendation org.apache.spark.ml.recommendation.ALSModel diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index 8b2ee955d6a5..d14fc7fcc875 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.clustering import org.apache.spark.annotation.Since +import org.apache.spark.ml.Transformer import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -191,3 +192,32 @@ object PowerIterationClustering extends DefaultParamsReadable[PowerIterationClus @Since("2.4.0") override def load(path: String): PowerIterationClustering = super.load(path) } + +private[spark] class PowerIterationClusteringWrapper(override val uid: String) + extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("PowerIterationClusteringWrapper")) + + override def transform(dataset: Dataset[_]): DataFrame = + throw new UnsupportedOperationException("transform not supported") + + override def transformSchema(schema: StructType): StructType = + throw new UnsupportedOperationException("transformSchema not supported") + + override def copy(extra: ParamMap): PowerIterationClusteringWrapper = defaultCopy(extra) + + override def write: MLWriter = new MLWriter { + override protected def saveImpl(path: String): Unit = { + new PowerIterationClustering(uid).copy(paramMap).save(path) + } + } +} + +private[spark] object PowerIterationClusteringWrapper + extends DefaultParamsReadable[PowerIterationClusteringWrapper] { + + override def load(path: String): PowerIterationClusteringWrapper = { + val pic = PowerIterationClustering.load(path) + new PowerIterationClusteringWrapper(pic.uid).copy(pic.paramMap) + } +} diff --git a/python/pyspark/ml/connect/readwrite.py b/python/pyspark/ml/connect/readwrite.py index 5c2b850d51b3..584ff3237a0a 100644 --- a/python/pyspark/ml/connect/readwrite.py +++ b/python/pyspark/ml/connect/readwrite.py @@ -96,6 +96,7 @@ class RemoteMLWriter(MLWriter): from pyspark.ml.pipeline import Pipeline, PipelineModel from pyspark.ml.tuning import CrossValidator, TrainValidationSplit from pyspark.ml.classification import OneVsRest, OneVsRestModel + from pyspark.ml.clustering import PowerIterationClustering # Spark Connect ML is built on scala Spark.ML, that means we're only # supporting JavaModel or JavaEstimator or JavaEvaluator @@ -188,6 +189,21 @@ class RemoteMLWriter(MLWriter): ovrm_writer = OneVsRestModelWriter(instance) ovrm_writer.session(session) # type: ignore[arg-type] ovrm_writer.save(path) + + elif isinstance(instance, PowerIterationClustering): + transformer = JavaTransformer( + "org.apache.spark.ml.clustering.PowerIterationClusteringWrapper" + ) + transformer._resetUid(instance.uid) + transformer._paramMap = instance._paramMap + RemoteMLWriter.saveInstance( + transformer, # type: ignore[arg-type] + path, + session, + shouldOverwrite, + optionMap, + ) + else: raise NotImplementedError(f"Unsupported write for {instance.__class__}") @@ -224,6 +240,7 @@ class RemoteMLReader(MLReader[RL]): from pyspark.ml.pipeline import Pipeline, PipelineModel from pyspark.ml.tuning import CrossValidator, TrainValidationSplit from pyspark.ml.classification import OneVsRest, OneVsRestModel + from pyspark.ml.clustering import PowerIterationClustering if ( issubclass(clazz, JavaModel) @@ -332,5 +349,28 @@ class RemoteMLReader(MLReader[RL]): ovrm_reader.session(session) return ovrm_reader.load(path) + elif issubclass(clazz, PowerIterationClustering): + java_qualified_class_name = ( + "org.apache.spark.ml.clustering.PowerIterationClusteringWrapper" + ) + + command = pb2.Command() + command.ml_command.read.CopyFrom( + pb2.MlCommand.Read( + operator=pb2.MlOperator( + name=java_qualified_class_name, type=pb2.MlOperator.TRANSFORMER + ), + path=path, + ) + ) + (_, properties, _) = session.client.execute_command(command) + result = deserialize(properties) + + instance = PowerIterationClustering() + instance._resetUid(result.uid) + params = {k: deserialize_param(v) for k, v in result.params.params.items()} + instance._set(**params) + return instance # type: ignore[return-value] + else: raise RuntimeError(f"Unsupported read for {clazz}") diff --git a/python/pyspark/ml/tests/test_clustering.py b/python/pyspark/ml/tests/test_clustering.py index 136a166d2218..8ec1fcc48ca6 100644 --- a/python/pyspark/ml/tests/test_clustering.py +++ b/python/pyspark/ml/tests/test_clustering.py @@ -462,7 +462,6 @@ class ClusteringTestsMixin: model2 = DistributedLDAModel.load(d) self.assertEqual(str(model), str(model2)) - # TODO(SPARK-51080): Fix save/load for PowerIterationClustering def test_power_iteration_clustering(self): spark = self.spark @@ -496,6 +495,16 @@ class ClusteringTestsMixin: self.assertEqual(assignments.columns, ["id", "cluster"]) self.assertEqual(assignments.count(), 6) + # save & load + with tempfile.TemporaryDirectory(prefix="power_iteration_clustering") as d: + pic.write().overwrite().save(d) + pic2 = PowerIterationClustering.load(d) + self.assertEqual(str(pic), str(pic2)) + self.assertEqual(pic.uid, pic2.uid) + self.assertEqual(pic.getK(), pic2.getK()) + self.assertEqual(pic.getMaxIter(), pic2.getMaxIter()) + self.assertEqual(pic.getWeightCol(), pic2.getWeightCol()) + class ClusteringTests(ClusteringTestsMixin, unittest.TestCase): def setUp(self) -> None: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org