This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-4.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push: new af90952169c0 [SPARK-51214][ML][PYTHON][CONNECT] Don't eagerly remove the cached models for `fit_transform` af90952169c0 is described below commit af90952169c04da4093ab1ef4a7c97aa68033ad1 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sat Feb 15 08:40:44 2025 +0800 [SPARK-51214][ML][PYTHON][CONNECT] Don't eagerly remove the cached models for `fit_transform` ### What changes were proposed in this pull request? Don't eagerly remove the cached models for `fit_transform`: 1, still keep the `Delete` ml command protobuf, but no longer call it in `__del__` in the python client side; 2, build the ml cache with guava CacheBuilder and soft references, and specify the maximum size and time out. ### Why are the changes needed? a common ml pipeline pattern is `fit_transform`: ``` def fit_transform(df): model = estimator.fit(df) return model.transform(df) df2 = fit_transform(df) df2.count() ``` existing implementation eagerly deletes the intermediate model from the ml cache, right after `fit_transform`, and thus causes NPE ``` pyspark.errors.exceptions.connect.SparkConnectGrpcException: (java.lang.NullPointerException) Cannot invoke "org.apache.spark.ml.Model.copy(org.apache.spark.ml.param.ParamMap)" because "model" is null JVM stacktrace: java.lang.NullPointerException at org.apache.spark.sql.connect.ml.ModelAttributeHelper.transform(MLHandler.scala:68) at org.apache.spark.sql.connect.ml.MLHandler$.transformMLRelation(MLHandler.scala:313) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.$anonfun$transformRelation$1(SparkConnectPlanner.scala:231) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$usePlanCache$3(SessionHolder.scala:477) at scala.Option.getOrElse(Option.scala:201) at org.apache.spark.sql.connect.service.SessionHolder.usePlanCache(SessionHolder.scala:476) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:147) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:133) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelationalGroupedAggregate(SparkConnectPlanner.scala:2318) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformAggregate(SparkConnectPlanner.scala:2299) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.$anonfun$transformRelation$1(SparkConnectPlanner.scala:165) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$usePlanCache$3(SessionHolder.scala:477) ``` ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49948 from zhengruifeng/ml_connect_del. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> (cherry picked from commit 09b93bd657aaf52cd64a040d1cb7ccede1616d7e) Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../src/main/resources/error/error-conditions.json | 5 +++ python/pyspark/ml/tests/test_pipeline.py | 20 +++++++++++ python/pyspark/ml/util.py | 40 +++++++++++++++------- .../org/apache/spark/sql/connect/ml/MLCache.scala | 21 ++++++++++-- .../apache/spark/sql/connect/ml/MLException.scala | 6 ++++ .../apache/spark/sql/connect/ml/MLHandler.scala | 11 +++++- .../org/apache/spark/sql/connect/ml/MLSuite.scala | 29 ++++++++++++++++ 7 files changed, 116 insertions(+), 16 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 39b5ddd0adff..b427774c8cc7 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -768,6 +768,11 @@ "<attribute> in <className> is not allowed to be accessed." ] }, + "CACHE_INVALID" : { + "message" : [ + "Cannot retrieve <objectName> from the ML cache. It is probably because the entry has been evicted." + ] + }, "UNSUPPORTED_EXCEPTION" : { "message" : [ "<message>" diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py index 8318f3bb71c9..ced1cda1948a 100644 --- a/python/pyspark/ml/tests/test_pipeline.py +++ b/python/pyspark/ml/tests/test_pipeline.py @@ -18,6 +18,7 @@ import tempfile import unittest +from pyspark.sql import Row from pyspark.ml.pipeline import Pipeline, PipelineModel from pyspark.ml.feature import ( VectorAssembler, @@ -26,6 +27,7 @@ from pyspark.ml.feature import ( MinMaxScaler, MinMaxScalerModel, ) +from pyspark.ml.linalg import Vectors from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel from pyspark.ml.clustering import KMeans, KMeansModel from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer @@ -172,6 +174,24 @@ class PipelineTestsMixin: self.assertEqual(str(model), str(model2)) self.assertEqual(str(model.stages), str(model2.stages)) + def test_model_gc(self): + spark = self.spark + df = spark.createDataFrame( + [ + Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])), + Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])), + Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])), + ] + ) + + def fit_transform(df): + lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight") + model = lr.fit(df) + return model.transform(df) + + output = fit_transform(df) + self.assertEqual(output.count(), 3) + class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 9eab45239b8f..67921d312d37 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -249,6 +249,21 @@ def try_remote_call(f: FuncT) -> FuncT: return cast(FuncT, wrapped) +# delete the object from the ml cache eagerly +def del_remote_cache(ref_id: str) -> None: + if ref_id is not None and "." not in ref_id: + try: + from pyspark.sql.connect.session import SparkSession + + session = SparkSession.getActiveSession() + if session is not None: + session.client.remove_ml_cache(ref_id) + return + except Exception: + # SparkSession's down. + return + + def try_remote_del(f: FuncT) -> FuncT: """Mark the function/property to delete a model on the server side.""" @@ -261,18 +276,19 @@ def try_remote_del(f: FuncT) -> FuncT: if in_remote: # Delete the model if possible - model_id = self._java_obj - if model_id is not None and "." not in model_id: - try: - from pyspark.sql.connect.session import SparkSession - - session = SparkSession.getActiveSession() - if session is not None: - session.client.remove_ml_cache(model_id) - return - except Exception: - # SparkSession's down. - return + # model_id = self._java_obj + # del_remote_cache(model_id) + # + # Above codes delete the model from the ml cache eagerly, and may cause + # NPE in the server side in the case of 'fit_transform': + # + # def fit_transform(df): + # model = estimator.fit(df) + # return model.transform(df) + # + # output = fit_transform(df) + # output.show() + return else: return f(self) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index beb06065d04a..e8d858502072 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.connect.ml import java.util.UUID -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentMap, TimeUnit} + +import com.google.common.cache.CacheBuilder import org.apache.spark.internal.Logging import org.apache.spark.ml.util.ConnectHelper @@ -29,8 +31,13 @@ private[connect] class MLCache extends Logging { private val helper = new ConnectHelper() private val helperID = "______ML_CONNECT_HELPER______" - private val cachedModel: ConcurrentHashMap[String, Object] = - new ConcurrentHashMap[String, Object]() + private val cachedModel: ConcurrentMap[String, Object] = CacheBuilder + .newBuilder() + .softValues() + .maximumSize(MLCache.MAX_CACHED_ITEMS) + .expireAfterAccess(MLCache.CACHE_TIMEOUT_MINUTE, TimeUnit.MINUTES) + .build[String, Object]() + .asMap() /** * Cache an object into a map of MLCache, and return its key @@ -76,3 +83,11 @@ private[connect] class MLCache extends Logging { cachedModel.clear() } } + +private[connect] object MLCache { + // The maximum number of distinct items in the cache. + private val MAX_CACHED_ITEMS = 100 + + // The maximum time for an item to stay in the cache. + private val CACHE_TIMEOUT_MINUTE = 60 +} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala index 7700eccf6553..d1a7f232edf7 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala @@ -30,3 +30,9 @@ private[spark] case class MLAttributeNotAllowedException(className: String, attr errorClass = "CONNECT_ML.ATTRIBUTE_NOT_ALLOWED", messageParameters = Map("className" -> className, "attribute" -> attribute), cause = null) + +private[spark] case class MLCacheInvalidException(objectName: String) + extends SparkException( + errorClass = "CONNECT_ML.CACHE_INVALID", + messageParameters = Map("objectName" -> objectName), + cause = null) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index 08080c099200..9a9e156f91cd 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -41,7 +41,13 @@ private class AttributeHelper( val sessionHolder: SessionHolder, val objRef: String, val methods: Array[Method]) { - protected lazy val instance = sessionHolder.mlCache.get(objRef) + protected lazy val instance = { + val obj = sessionHolder.mlCache.get(objRef) + if (obj == null) { + throw MLCacheInvalidException(s"object $objRef") + } + obj + } // Get the attribute by reflection def getAttribute: Any = { assert(methods.length >= 1) @@ -181,6 +187,9 @@ private[connect] object MLHandler extends Logging { case proto.MlCommand.Write.TypeCase.OBJ_REF => // save a model val objId = mlCommand.getWrite.getObjRef.getId val model = mlCache.get(objId).asInstanceOf[Model[_]] + if (model == null) { + throw MLCacheInvalidException(s"model $objId") + } val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]] MLUtils.setInstanceParams(copiedModel, mlCommand.getWrite.getParams) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index 0d0fbc4b1b7b..76ce34a67e74 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -256,6 +256,35 @@ class MLSuite extends MLHelper { } } + test("Exception: cannot retrieve object") { + val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) + val modelId = trainLogisticRegressionModel(sessionHolder) + + // Fetch summary attribute + val accuracyCommand = proto.MlCommand + .newBuilder() + .setFetch( + proto.Fetch + .newBuilder() + .setObjRef(proto.ObjectRef.newBuilder().setId(modelId)) + .addMethods(proto.Fetch.Method.newBuilder().setMethod("summary")) + .addMethods(proto.Fetch.Method.newBuilder().setMethod("accuracy"))) + .build() + + // Successfully fetch summary.accuracy from the cached model + MLHandler.handleMlCommand(sessionHolder, accuracyCommand) + + // Remove the model from cache + sessionHolder.mlCache.clear() + + // No longer able to retrieve the model from cache + val e = intercept[MLCacheInvalidException] { + MLHandler.handleMlCommand(sessionHolder, accuracyCommand) + } + val msg = e.getMessage + assert(msg.contains(s"$modelId from the ML cache")) + } + test("access the attribute which is not in allowed list") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) val modelId = trainLogisticRegressionModel(sessionHolder) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org