This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cdd52963281a [SPARK-51856][ML][CONNECT] Update model size API to count 
distributed DataFrame size
cdd52963281a is described below

commit cdd52963281abb62792ba51491a98fa9f87f968a
Author: Weichen Xu <[email protected]>
AuthorDate: Wed Apr 23 08:21:06 2025 +0800

    [SPARK-51856][ML][CONNECT] Update model size API to count distributed 
DataFrame size
    
    ### What changes were proposed in this pull request?
    
    Update model size API to count distributed DataFrame size
    
    ### Why are the changes needed?
    
    For Spark server ML cache management.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50652 from WeichenXu123/get-model-ser-size-api.
    
    Lead-authored-by: Weichen Xu <[email protected]>
    Co-authored-by: WeichenXu <[email protected]>
    Signed-off-by: Weichen Xu <[email protected]>
---
 .../src/main/scala/org/apache/spark/ml/Estimator.scala |  4 ++--
 mllib/src/main/scala/org/apache/spark/ml/Model.scala   |  4 ++--
 .../scala/org/apache/spark/ml/clustering/LDA.scala     |  5 +++++
 .../main/scala/org/apache/spark/ml/fpm/FPGrowth.scala  |  5 +++++
 .../scala/org/apache/spark/ml/recommendation/ALS.scala | 12 ++++++++++++
 .../org/apache/spark/ml/recommendation/ALSSuite.scala  | 18 ++++++++++++++++++
 python/pyspark/ml/tests/test_clustering.py             |  3 +++
 python/pyspark/ml/tests/test_fpm.py                    |  4 +++-
 8 files changed, 50 insertions(+), 5 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index 686afc115436..ead68b290fe4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -87,8 +87,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage 
{
    * Estimate an upper-bound size of the model to be fitted in bytes, based on 
the
    * parameters and the dataset, e.g., using $(k) and numFeatures to estimate a
    * k-means model size.
-   * 1, Only driver side memory usage is counted, distributed objects (like 
DataFrame,
-   * RDD, Graph, Summary) are ignored.
+   * 1, Both driver side memory usage and distributed objects size (like 
DataFrame,
+   * RDD, Graph, Summary) are counted.
    * 2, Lazy vals are not counted, e.g., an auxiliary object used in 
prediction.
    * 3, If there is no enough information to get an accurate size, try to 
estimate the
    * upper-bound size, e.g.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index 7e0297515fa2..6321e5f88f74 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -49,8 +49,8 @@ abstract class Model[M <: Model[M]] extends Transformer { 
self =>
    * For ml connect only.
    * Estimate the size of this model in bytes.
    * This is an approximation, the real size might be different.
-   * 1, Only driver side memory usage is counted, distributed objects (like 
DataFrame,
-   * RDD, Graph, Summary) are ignored.
+   * 1, Both driver side memory usage and distributed objects size (like 
DataFrame,
+   * RDD, Graph, Summary) are counted.
    * 2, Lazy vals are not counted, e.g., an auxiliary object used in 
prediction.
    * 3, The default implementation uses 
`org.apache.spark.util.SizeEstimator.estimate`,
    *    some models override the default implementation to achieve more 
precise estimation.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 3ea1c8594e1f..0c5211864385 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -805,6 +805,11 @@ class DistributedLDAModel private[ml] (
   override def toString: String = {
     s"DistributedLDAModel: uid=$uid, k=${$(k)}, numFeatures=$vocabSize"
   }
+
+  override def estimatedSize: Long = {
+    // TODO: Implement this method.
+    throw new UnsupportedOperationException
+  }
 }
 
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala 
b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 0b75753695fd..7a932d250cee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -322,6 +322,11 @@ class FPGrowthModel private[ml] (
   override def toString: String = {
     s"FPGrowthModel: uid=$uid, numTrainingRecords=$numTrainingRecords"
   }
+
+  override def estimatedSize: Long = {
+    // TODO: Implement this method.
+    throw new UnsupportedOperationException
+  }
 }
 
 @Since("2.2.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 95c47531720d..36255d3df0f1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -540,6 +540,11 @@ class ALSModel private[ml] (
     }
   }
 
+  override def estimatedSize: Long = {
+    val userCount = userFactors.count()
+    val itemCount = itemFactors.count()
+    (userCount + itemCount) * (rank + 1) * 4
+  }
 }
 
 @Since("1.6.0")
@@ -771,6 +776,13 @@ class ALS(@Since("1.4.0") override val uid: String) 
extends Estimator[ALSModel]
 
   @Since("1.5.0")
   override def copy(extra: ParamMap): ALS = defaultCopy(extra)
+
+  override def estimateModelSize(dataset: Dataset[_]): Long = {
+    val userCount = dataset.select(getUserCol).distinct().count()
+    val itemCount = dataset.select(getItemCol).distinct().count()
+    val rank = getRank
+    (userCount + itemCount) * (rank + 1) * 4
+  }
 }
 
 
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 94abeaf0804e..4da67a92d707 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -1128,6 +1128,24 @@ class ALSStorageSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defa
     levels.foreach(level => assert(level == StorageLevel.MEMORY_ONLY))
     nonDefaultListener.storageLevels.foreach(level => assert(level == 
StorageLevel.DISK_ONLY))
   }
+
+  test("saved model size estimation") {
+    import testImplicits._
+
+    val als = new ALS().setMaxIter(1).setRank(8)
+    val estimatedDFSize = (3 + 2) * (8 + 1) * 4
+    val df = sc.parallelize(Seq(
+      (123, 1, 0.5),
+      (123, 2, 0.7),
+      (123, 3, 0.6),
+      (111, 2, 1.0),
+      (111, 1, 0.1)
+    )).toDF("item", "user", "rating")
+    assert(als.estimateModelSize(df) === estimatedDFSize)
+
+    val model = als.fit(df)
+    assert(model.estimatedSize == estimatedDFSize)
+  }
 }
 
 private class IntermediateRDDStorageListener extends SparkListener {
diff --git a/python/pyspark/ml/tests/test_clustering.py 
b/python/pyspark/ml/tests/test_clustering.py
index a35eaac10a7e..1b8eb73135a9 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -37,6 +37,7 @@ from pyspark.ml.clustering import (
     DistributedLDAModel,
     PowerIterationClustering,
 )
+from pyspark.sql import is_remote
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
@@ -377,6 +378,8 @@ class ClusteringTestsMixin:
             self.assertEqual(str(model), str(model2))
 
     def test_distributed_lda(self):
+        if is_remote():
+            self.skipTest("Do not support Spark Connect.")
         spark = self.spark
         df = (
             spark.createDataFrame(
diff --git a/python/pyspark/ml/tests/test_fpm.py 
b/python/pyspark/ml/tests/test_fpm.py
index ea94216c9860..7b949763c398 100644
--- a/python/pyspark/ml/tests/test_fpm.py
+++ b/python/pyspark/ml/tests/test_fpm.py
@@ -18,7 +18,7 @@
 import tempfile
 import unittest
 
-from pyspark.sql import Row
+from pyspark.sql import is_remote, Row
 import pyspark.sql.functions as sf
 from pyspark.ml.fpm import (
     FPGrowth,
@@ -30,6 +30,8 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 class FPMTestsMixin:
     def test_fp_growth(self):
+        if is_remote():
+            self.skipTest("Do not support Spark Connect.")
         df = self.spark.createDataFrame(
             [
                 ["r z h k p"],


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to