This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new aeff6796020a [SPARK-51880][ML][PYTHON][CONNECT] Fix ML cache object 
python client references
aeff6796020a is described below

commit aeff6796020a52aa67f5a0111ed31b30fea8437e
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Mon Apr 28 08:28:17 2025 +0800

    [SPARK-51880][ML][PYTHON][CONNECT] Fix ML cache object python client 
references
    
    ### What changes were proposed in this pull request?
    
    Fix ML cache object python client references.
    
    When a model is copied from client, it results in multiple client model 
objects refer to the same server cached model.
    In this case, we need a reference count, only when reference count 
decreases to zero, we can release the server cached model.
    
    ### Why are the changes needed?
    
    Bugfix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50707 from WeichenXu123/ml-ref-id-fix.
    
    Lead-authored-by: Weichen Xu <weichen...@databricks.com>
    Co-authored-by: WeichenXu <weichen...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/ml/classification.py       | 14 ++++++-
 python/pyspark/ml/connect/readwrite.py    | 11 +++--
 python/pyspark/ml/feature.py              | 37 ++++++++++-------
 python/pyspark/ml/regression.py           | 14 ++++++-
 python/pyspark/ml/tests/test_tuning.py    |  1 -
 python/pyspark/ml/util.py                 | 67 +++++++++++++++++++++++++------
 python/pyspark/ml/wrapper.py              | 16 ++++++--
 python/pyspark/sql/connect/client/core.py |  3 +-
 8 files changed, 123 insertions(+), 40 deletions(-)

diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index 3f9e3fa37f72..a5fdaed0db2c 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -2306,7 +2306,12 @@ class RandomForestClassificationModel(
     def trees(self) -> List[DecisionTreeClassificationModel]:
         """Trees in this ensemble. Warning: These have null parent 
Estimators."""
         if is_remote():
-            return [DecisionTreeClassificationModel(m) for m in 
self._call_java("trees").split(",")]
+            from pyspark.ml.util import RemoteModelRef
+
+            return [
+                DecisionTreeClassificationModel(RemoteModelRef(m))
+                for m in self._call_java("trees").split(",")
+            ]
         return [DecisionTreeClassificationModel(m) for m in 
list(self._call_java("trees"))]
 
     @property
@@ -2805,7 +2810,12 @@ class GBTClassificationModel(
     def trees(self) -> List[DecisionTreeRegressionModel]:
         """Trees in this ensemble. Warning: These have null parent 
Estimators."""
         if is_remote():
-            return [DecisionTreeRegressionModel(m) for m in 
self._call_java("trees").split(",")]
+            from pyspark.ml.util import RemoteModelRef
+
+            return [
+                DecisionTreeRegressionModel(RemoteModelRef(m))
+                for m in self._call_java("trees").split(",")
+            ]
         return [DecisionTreeRegressionModel(m) for m in 
list(self._call_java("trees"))]
 
     def evaluateEachIteration(self, dataset: DataFrame) -> List[float]:
diff --git a/python/pyspark/ml/connect/readwrite.py 
b/python/pyspark/ml/connect/readwrite.py
index 0dc38e7275c1..ff53eb77d032 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -77,11 +77,13 @@ class RemoteMLWriter(MLWriter):
         # Spark Connect ML is built on scala Spark.ML, that means we're only
         # supporting JavaModel or JavaEstimator or JavaEvaluator
         if isinstance(instance, JavaModel):
+            from pyspark.ml.util import RemoteModelRef
+
             model = cast("JavaModel", instance)
             params = serialize_ml_params(model, session.client)
-            assert isinstance(model._java_obj, str)
+            assert isinstance(model._java_obj, RemoteModelRef)
             writer = pb2.MlCommand.Write(
-                obj_ref=pb2.ObjectRef(id=model._java_obj),
+                obj_ref=pb2.ObjectRef(id=model._java_obj.ref_id),
                 params=params,
                 path=path,
                 should_overwrite=shouldOverwrite,
@@ -270,9 +272,12 @@ class RemoteMLReader(MLReader[RL]):
             py_type = _get_class()
             # It must be JavaWrapper, since we're passing the string to the 
_java_obj
             if issubclass(py_type, JavaWrapper):
+                from pyspark.ml.util import RemoteModelRef
+
                 if ml_type == pb2.MlOperator.OPERATOR_TYPE_MODEL:
                     session.client.add_ml_cache(result.obj_ref.id)
-                    instance = py_type(result.obj_ref.id)
+                    remote_model_ref = RemoteModelRef(result.obj_ref.id)
+                    instance = py_type(remote_model_ref)
                 else:
                     instance = py_type()
                 instance._resetUid(result.uid)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index d669fab27d50..4d1551652028 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -64,6 +64,7 @@ from pyspark.ml.wrapper import (
     _jvm,
 )
 from pyspark.ml.common import inherit_doc
+from pyspark.ml.util import RemoteModelRef
 from pyspark.sql.types import ArrayType, StringType
 from pyspark.sql.utils import is_remote
 
@@ -1224,10 +1225,12 @@ class CountVectorizerModel(
 
         if is_remote():
             model = CountVectorizerModel()
-            model._java_obj = invoke_helper_attr(
-                "countVectorizerModelFromVocabulary",
-                model.uid,
-                list(vocabulary),
+            model._java_obj = RemoteModelRef(
+                invoke_helper_attr(
+                    "countVectorizerModelFromVocabulary",
+                    model.uid,
+                    list(vocabulary),
+                )
             )
 
         else:
@@ -4843,10 +4846,12 @@ class StringIndexerModel(
         """
         if is_remote():
             model = StringIndexerModel()
-            model._java_obj = invoke_helper_attr(
-                "stringIndexerModelFromLabels",
-                model.uid,
-                (list(labels), ArrayType(StringType())),
+            model._java_obj = RemoteModelRef(
+                invoke_helper_attr(
+                    "stringIndexerModelFromLabels",
+                    model.uid,
+                    (list(labels), ArrayType(StringType())),
+                )
             )
 
         else:
@@ -4882,13 +4887,15 @@ class StringIndexerModel(
         """
         if is_remote():
             model = StringIndexerModel()
-            model._java_obj = invoke_helper_attr(
-                "stringIndexerModelFromLabelsArray",
-                model.uid,
-                (
-                    [list(labels) for labels in arrayOfLabels],
-                    ArrayType(ArrayType(StringType())),
-                ),
+            model._java_obj = RemoteModelRef(
+                invoke_helper_attr(
+                    "stringIndexerModelFromLabelsArray",
+                    model.uid,
+                    (
+                        [list(labels) for labels in arrayOfLabels],
+                        ArrayType(ArrayType(StringType())),
+                    ),
+                )
             )
 
         else:
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index a7e793142233..66d6dbd6a267 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1614,7 +1614,12 @@ class RandomForestRegressionModel(
     def trees(self) -> List[DecisionTreeRegressionModel]:
         """Trees in this ensemble. Warning: These have null parent 
Estimators."""
         if is_remote():
-            return [DecisionTreeRegressionModel(m) for m in 
self._call_java("trees").split(",")]
+            from pyspark.ml.util import RemoteModelRef
+
+            return [
+                DecisionTreeRegressionModel(RemoteModelRef(m))
+                for m in self._call_java("trees").split(",")
+            ]
         return [DecisionTreeRegressionModel(m) for m in 
list(self._call_java("trees"))]
 
     @property
@@ -2005,7 +2010,12 @@ class GBTRegressionModel(
     def trees(self) -> List[DecisionTreeRegressionModel]:
         """Trees in this ensemble. Warning: These have null parent 
Estimators."""
         if is_remote():
-            return [DecisionTreeRegressionModel(m) for m in 
self._call_java("trees").split(",")]
+            from pyspark.ml.util import RemoteModelRef
+
+            return [
+                DecisionTreeRegressionModel(RemoteModelRef(m))
+                for m in self._call_java("trees").split(",")
+            ]
         return [DecisionTreeRegressionModel(m) for m in 
list(self._call_java("trees"))]
 
     def evaluateEachIteration(self, dataset: DataFrame, loss: str) -> 
List[float]:
diff --git a/python/pyspark/ml/tests/test_tuning.py 
b/python/pyspark/ml/tests/test_tuning.py
index 947c599b3cf2..ff9a26f71197 100644
--- a/python/pyspark/ml/tests/test_tuning.py
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -97,7 +97,6 @@ class TuningTestsMixin:
             self.assertEqual(str(tvs_model.getEstimator()), 
str(model2.getEstimator()))
             self.assertEqual(str(tvs_model.getEvaluator()), 
str(model2.getEvaluator()))
 
-    @unittest.skip("Disabled due to a Python side reference count issue in 
_parallelFitTasks.")
     def test_cross_validator(self):
         dataset = self.spark.createDataFrame(
             [
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 6abadec74e63..a5e0c847c173 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -17,6 +17,7 @@
 
 import json
 import os
+import threading
 import time
 import uuid
 import functools
@@ -75,7 +76,7 @@ def try_remote_intermediate_result(f: FuncT) -> FuncT:
     @functools.wraps(f)
     def wrapped(self: "JavaWrapper") -> Any:
         if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
-            return f"{self._java_obj}.{f.__name__}"
+            return f"{str(self._java_obj)}.{f.__name__}"
         else:
             return f(self)
 
@@ -108,13 +109,18 @@ def invoke_remote_attribute_relation(
     from pyspark.ml.connect.proto import AttributeRelation
     from pyspark.sql.connect.session import SparkSession
     from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+    from pyspark.ml.wrapper import JavaModel
 
     session = SparkSession.getActiveSession()
     assert session is not None
 
-    assert isinstance(instance._java_obj, str)
-
-    methods, obj_ref = _extract_id_methods(instance._java_obj)
+    if isinstance(instance, JavaModel):
+        assert isinstance(instance._java_obj, RemoteModelRef)
+        object_id = instance._java_obj.ref_id
+    else:
+        # model summary
+        object_id = instance._java_obj  # type: ignore
+    methods, obj_ref = _extract_id_methods(object_id)
     methods.append(pb2.Fetch.Method(method=method, 
args=serialize(session.client, *args)))
     plan = AttributeRelation(obj_ref, methods)
 
@@ -139,6 +145,33 @@ def try_remote_attribute_relation(f: FuncT) -> FuncT:
     return cast(FuncT, wrapped)
 
 
+class RemoteModelRef:
+    def __init__(self, ref_id: str) -> None:
+        self._ref_id = ref_id
+        self._ref_count = 1
+        self._lock = threading.Lock()
+
+    @property
+    def ref_id(self) -> str:
+        return self._ref_id
+
+    def add_ref(self) -> None:
+        with self._lock:
+            assert self._ref_count > 0
+            self._ref_count += 1
+
+    def release_ref(self) -> None:
+        with self._lock:
+            assert self._ref_count > 0
+            self._ref_count -= 1
+            if self._ref_count == 0:
+                # Delete the model if possible
+                del_remote_cache(self.ref_id)
+
+    def __str__(self) -> str:
+        return self.ref_id
+
+
 def try_remote_fit(f: FuncT) -> FuncT:
     """Mark the function that fits a model."""
 
@@ -165,7 +198,8 @@ def try_remote_fit(f: FuncT) -> FuncT:
             (_, properties, _) = client.execute_command(command)
             model_info = deserialize(properties)
             client.add_ml_cache(model_info.obj_ref.id)
-            model = self._create_model(model_info.obj_ref.id)
+            remote_model_ref = RemoteModelRef(model_info.obj_ref.id)
+            model = self._create_model(remote_model_ref)
             if model.__class__.__name__ not in ["Bucketizer"]:
                 model._resetUid(self.uid)
             return self._copyValues(model)
@@ -192,11 +226,11 @@ def try_remote_transform_relation(f: FuncT) -> FuncT:
             if isinstance(self, Model):
                 from pyspark.ml.connect.proto import TransformerRelation
 
-                assert isinstance(self._java_obj, str)
+                assert isinstance(self._java_obj, RemoteModelRef)
                 params = serialize_ml_params(self, session.client)
                 plan = TransformerRelation(
                     child=dataset._plan,
-                    name=self._java_obj,
+                    name=self._java_obj.ref_id,
                     ml_params=params,
                     is_model=True,
                 )
@@ -246,11 +280,20 @@ def try_remote_call(f: FuncT) -> FuncT:
             from pyspark.sql.connect.session import SparkSession
             from pyspark.ml.connect.util import _extract_id_methods
             from pyspark.ml.connect.serialize import serialize, deserialize
+            from pyspark.ml.wrapper import JavaModel
 
             session = SparkSession.getActiveSession()
             assert session is not None
-            assert isinstance(self._java_obj, str)
-            methods, obj_ref = _extract_id_methods(self._java_obj)
+            if self._java_obj == ML_CONNECT_HELPER_ID:
+                obj_id = ML_CONNECT_HELPER_ID
+            else:
+                if isinstance(self, JavaModel):
+                    assert isinstance(self._java_obj, RemoteModelRef)
+                    obj_id = self._java_obj.ref_id
+                else:
+                    # model summary
+                    obj_id = self._java_obj  # type: ignore
+            methods, obj_ref = _extract_id_methods(obj_id)
             methods.append(pb2.Fetch.Method(method=name, 
args=serialize(session.client, *args)))
             command = pb2.Command()
             command.ml_command.fetch.CopyFrom(
@@ -301,10 +344,8 @@ def try_remote_del(f: FuncT) -> FuncT:
         except Exception:
             return
 
-        if in_remote:
-            # Delete the model if possible
-            model_id = self._java_obj
-            del_remote_cache(cast(str, model_id))
+        if in_remote and isinstance(self._java_obj, RemoteModelRef):
+            self._java_obj.release_ref()
             return
         else:
             return f(self)
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index f88045e718a5..b8d86e9eab3b 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -356,9 +356,15 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
         if extra is None:
             extra = dict()
         that = super(JavaParams, self).copy(extra)
-        if self._java_obj is not None and not isinstance(self._java_obj, str):
-            that._java_obj = self._java_obj.copy(self._empty_java_param_map())
-            that._transfer_params_to_java()
+        if self._java_obj is not None:
+            from pyspark.ml.util import RemoteModelRef
+
+            if isinstance(self._java_obj, RemoteModelRef):
+                that._java_obj = self._java_obj
+                self._java_obj.add_ref()
+            elif not isinstance(self._java_obj, str):
+                that._java_obj = 
self._java_obj.copy(self._empty_java_param_map())
+                that._transfer_params_to_java()
         return that
 
     @try_remote_intercept
@@ -452,6 +458,10 @@ class JavaModel(JavaTransformer, Model, metaclass=ABCMeta):
         other ML classes).
         """
         super(JavaModel, self).__init__(java_model)
+        if is_remote() and java_model is not None:
+            from pyspark.ml.util import RemoteModelRef
+
+            assert isinstance(java_model, RemoteModelRef)
         if java_model is not None and not is_remote():
             # SPARK-10931: This is a temporary fix to allow models to own 
params
             # from estimators. Eventually, these params should be in models 
through
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index db7f5a135fb0..ca9bdd9b6f0c 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1981,9 +1981,10 @@ class SparkConnectClient(object):
         self.thread_local.ml_caches.add(cache_id)
 
     def remove_ml_cache(self, cache_id: str) -> None:
+        deleted = self._delete_ml_cache([cache_id])
+        # TODO: Fix the code: change thread-local `ml_caches` to global 
`ml_caches`.
         if hasattr(self.thread_local, "ml_caches"):
             if cache_id in self.thread_local.ml_caches:
-                deleted = self._delete_ml_cache([cache_id])
                 for obj_id in deleted:
                     self.thread_local.ml_caches.remove(obj_id)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to