This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new af90952169c0 [SPARK-51214][ML][PYTHON][CONNECT] Don't eagerly remove 
the cached models for `fit_transform`
af90952169c0 is described below

commit af90952169c04da4093ab1ef4a7c97aa68033ad1
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sat Feb 15 08:40:44 2025 +0800

    [SPARK-51214][ML][PYTHON][CONNECT] Don't eagerly remove the cached models 
for `fit_transform`
    
    ### What changes were proposed in this pull request?
    Don't eagerly remove the cached models for `fit_transform`:
    1, still keep the `Delete` ml command protobuf, but no longer call it in 
`__del__` in the  python client side;
    2, build the ml cache with guava CacheBuilder and soft references, and 
specify the maximum size and time out.
    
    ### Why are the changes needed?
    a common ml pipeline pattern is `fit_transform`:
    ```
    def fit_transform(df):
        model = estimator.fit(df)
        return model.transform(df)
    
    df2 = fit_transform(df)
    df2.count()
    ```
    
    existing implementation eagerly deletes the intermediate model from the ml 
cache, right after `fit_transform`, and thus causes NPE
    
    ```
    pyspark.errors.exceptions.connect.SparkConnectGrpcException: 
(java.lang.NullPointerException) Cannot invoke 
"org.apache.spark.ml.Model.copy(org.apache.spark.ml.param.ParamMap)" because 
"model" is null
    
    JVM stacktrace:
    java.lang.NullPointerException
            at 
org.apache.spark.sql.connect.ml.ModelAttributeHelper.transform(MLHandler.scala:68)
            at 
org.apache.spark.sql.connect.ml.MLHandler$.transformMLRelation(MLHandler.scala:313)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.$anonfun$transformRelation$1(SparkConnectPlanner.scala:231)
            at 
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$usePlanCache$3(SessionHolder.scala:477)
            at scala.Option.getOrElse(Option.scala:201)
            at 
org.apache.spark.sql.connect.service.SessionHolder.usePlanCache(SessionHolder.scala:476)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:147)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:133)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelationalGroupedAggregate(SparkConnectPlanner.scala:2318)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformAggregate(SparkConnectPlanner.scala:2299)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.$anonfun$transformRelation$1(SparkConnectPlanner.scala:165)
            at 
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$usePlanCache$3(SessionHolder.scala:477)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49948 from zhengruifeng/ml_connect_del.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
    (cherry picked from commit 09b93bd657aaf52cd64a040d1cb7ccede1616d7e)
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../src/main/resources/error/error-conditions.json |  5 +++
 python/pyspark/ml/tests/test_pipeline.py           | 20 +++++++++++
 python/pyspark/ml/util.py                          | 40 +++++++++++++++-------
 .../org/apache/spark/sql/connect/ml/MLCache.scala  | 21 ++++++++++--
 .../apache/spark/sql/connect/ml/MLException.scala  |  6 ++++
 .../apache/spark/sql/connect/ml/MLHandler.scala    | 11 +++++-
 .../org/apache/spark/sql/connect/ml/MLSuite.scala  | 29 ++++++++++++++++
 7 files changed, 116 insertions(+), 16 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 39b5ddd0adff..b427774c8cc7 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -768,6 +768,11 @@
           "<attribute> in <className> is not allowed to be accessed."
         ]
       },
+      "CACHE_INVALID" : {
+        "message" : [
+          "Cannot retrieve <objectName> from the ML cache. It is probably 
because the entry has been evicted."
+        ]
+      },
       "UNSUPPORTED_EXCEPTION" : {
         "message" : [
           "<message>"
diff --git a/python/pyspark/ml/tests/test_pipeline.py 
b/python/pyspark/ml/tests/test_pipeline.py
index 8318f3bb71c9..ced1cda1948a 100644
--- a/python/pyspark/ml/tests/test_pipeline.py
+++ b/python/pyspark/ml/tests/test_pipeline.py
@@ -18,6 +18,7 @@
 import tempfile
 import unittest
 
+from pyspark.sql import Row
 from pyspark.ml.pipeline import Pipeline, PipelineModel
 from pyspark.ml.feature import (
     VectorAssembler,
@@ -26,6 +27,7 @@ from pyspark.ml.feature import (
     MinMaxScaler,
     MinMaxScalerModel,
 )
+from pyspark.ml.linalg import Vectors
 from pyspark.ml.classification import LogisticRegression, 
LogisticRegressionModel
 from pyspark.ml.clustering import KMeans, KMeansModel
 from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer
@@ -172,6 +174,24 @@ class PipelineTestsMixin:
             self.assertEqual(str(model), str(model2))
             self.assertEqual(str(model.stages), str(model2.stages))
 
+    def test_model_gc(self):
+        spark = self.spark
+        df = spark.createDataFrame(
+            [
+                Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
+                Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
+                Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
+            ]
+        )
+
+        def fit_transform(df):
+            lr = LogisticRegression(maxIter=1, regParam=0.01, 
weightCol="weight")
+            model = lr.fit(df)
+            return model.transform(df)
+
+        output = fit_transform(df)
+        self.assertEqual(output.count(), 3)
+
 
 class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 9eab45239b8f..67921d312d37 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -249,6 +249,21 @@ def try_remote_call(f: FuncT) -> FuncT:
     return cast(FuncT, wrapped)
 
 
+# delete the object from the ml cache eagerly
+def del_remote_cache(ref_id: str) -> None:
+    if ref_id is not None and "." not in ref_id:
+        try:
+            from pyspark.sql.connect.session import SparkSession
+
+            session = SparkSession.getActiveSession()
+            if session is not None:
+                session.client.remove_ml_cache(ref_id)
+                return
+        except Exception:
+            # SparkSession's down.
+            return
+
+
 def try_remote_del(f: FuncT) -> FuncT:
     """Mark the function/property to delete a model on the server side."""
 
@@ -261,18 +276,19 @@ def try_remote_del(f: FuncT) -> FuncT:
 
         if in_remote:
             # Delete the model if possible
-            model_id = self._java_obj
-            if model_id is not None and "." not in model_id:
-                try:
-                    from pyspark.sql.connect.session import SparkSession
-
-                    session = SparkSession.getActiveSession()
-                    if session is not None:
-                        session.client.remove_ml_cache(model_id)
-                        return
-                except Exception:
-                    # SparkSession's down.
-                    return
+            # model_id = self._java_obj
+            # del_remote_cache(model_id)
+            #
+            # Above codes delete the model from the ml cache eagerly, and may 
cause
+            # NPE in the server side in the case of 'fit_transform':
+            #
+            # def fit_transform(df):
+            #     model = estimator.fit(df)
+            #     return model.transform(df)
+            #
+            # output = fit_transform(df)
+            # output.show()
+            return
         else:
             return f(self)
 
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
index beb06065d04a..e8d858502072 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
@@ -17,7 +17,9 @@
 package org.apache.spark.sql.connect.ml
 
 import java.util.UUID
-import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.{ConcurrentMap, TimeUnit}
+
+import com.google.common.cache.CacheBuilder
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.util.ConnectHelper
@@ -29,8 +31,13 @@ private[connect] class MLCache extends Logging {
   private val helper = new ConnectHelper()
   private val helperID = "______ML_CONNECT_HELPER______"
 
-  private val cachedModel: ConcurrentHashMap[String, Object] =
-    new ConcurrentHashMap[String, Object]()
+  private val cachedModel: ConcurrentMap[String, Object] = CacheBuilder
+    .newBuilder()
+    .softValues()
+    .maximumSize(MLCache.MAX_CACHED_ITEMS)
+    .expireAfterAccess(MLCache.CACHE_TIMEOUT_MINUTE, TimeUnit.MINUTES)
+    .build[String, Object]()
+    .asMap()
 
   /**
    * Cache an object into a map of MLCache, and return its key
@@ -76,3 +83,11 @@ private[connect] class MLCache extends Logging {
     cachedModel.clear()
   }
 }
+
+private[connect] object MLCache {
+  // The maximum number of distinct items in the cache.
+  private val MAX_CACHED_ITEMS = 100
+
+  // The maximum time for an item to stay in the cache.
+  private val CACHE_TIMEOUT_MINUTE = 60
+}
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala
index 7700eccf6553..d1a7f232edf7 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala
@@ -30,3 +30,9 @@ private[spark] case class 
MLAttributeNotAllowedException(className: String, attr
       errorClass = "CONNECT_ML.ATTRIBUTE_NOT_ALLOWED",
       messageParameters = Map("className" -> className, "attribute" -> 
attribute),
       cause = null)
+
+private[spark] case class MLCacheInvalidException(objectName: String)
+    extends SparkException(
+      errorClass = "CONNECT_ML.CACHE_INVALID",
+      messageParameters = Map("objectName" -> objectName),
+      cause = null)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index 08080c099200..9a9e156f91cd 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -41,7 +41,13 @@ private class AttributeHelper(
     val sessionHolder: SessionHolder,
     val objRef: String,
     val methods: Array[Method]) {
-  protected lazy val instance = sessionHolder.mlCache.get(objRef)
+  protected lazy val instance = {
+    val obj = sessionHolder.mlCache.get(objRef)
+    if (obj == null) {
+      throw MLCacheInvalidException(s"object $objRef")
+    }
+    obj
+  }
   // Get the attribute by reflection
   def getAttribute: Any = {
     assert(methods.length >= 1)
@@ -181,6 +187,9 @@ private[connect] object MLHandler extends Logging {
           case proto.MlCommand.Write.TypeCase.OBJ_REF => // save a model
             val objId = mlCommand.getWrite.getObjRef.getId
             val model = mlCache.get(objId).asInstanceOf[Model[_]]
+            if (model == null) {
+              throw MLCacheInvalidException(s"model $objId")
+            }
             val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
             MLUtils.setInstanceParams(copiedModel, 
mlCommand.getWrite.getParams)
 
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index 0d0fbc4b1b7b..76ce34a67e74 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -256,6 +256,35 @@ class MLSuite extends MLHelper {
     }
   }
 
+  test("Exception: cannot retrieve object") {
+    val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+    val modelId = trainLogisticRegressionModel(sessionHolder)
+
+    // Fetch summary attribute
+    val accuracyCommand = proto.MlCommand
+      .newBuilder()
+      .setFetch(
+        proto.Fetch
+          .newBuilder()
+          .setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
+          .addMethods(proto.Fetch.Method.newBuilder().setMethod("summary"))
+          .addMethods(proto.Fetch.Method.newBuilder().setMethod("accuracy")))
+      .build()
+
+    // Successfully fetch summary.accuracy from the cached model
+    MLHandler.handleMlCommand(sessionHolder, accuracyCommand)
+
+    // Remove the model from cache
+    sessionHolder.mlCache.clear()
+
+    // No longer able to retrieve the model from cache
+    val e = intercept[MLCacheInvalidException] {
+      MLHandler.handleMlCommand(sessionHolder, accuracyCommand)
+    }
+    val msg = e.getMessage
+    assert(msg.contains(s"$modelId from the ML cache"))
+  }
+
   test("access the attribute which is not in allowed list") {
     val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
     val modelId = trainLogisticRegressionModel(sessionHolder)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to