This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b50690a89a81 [SPARK-52057][ML][CONNECT] Collect Tree size limit 
warning messages to client
b50690a89a81 is described below

commit b50690a89a8176ce48c7287f1b94f8e7841df0f8
Author: Weichen Xu <[email protected]>
AuthorDate: Mon May 12 10:08:35 2025 +0800

    [SPARK-52057][ML][CONNECT] Collect Tree size limit warning messages to 
client
    
    ### What changes were proposed in this pull request?
    
    Collect Tree size limit warning messages to client
    During tree model training, model size will be capped by the threshold,
    but the warning information is only printed in Spark driver.
    
    We need to send the warning message to Spark Connect client and print it in 
client logs
    ### Why are the changes needed?
    
    To show the warning message to user.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50848 from WeichenXu123/SPARK-52057.
    
    Authored-by: Weichen Xu <[email protected]>
    Signed-off-by: Weichen Xu <[email protected]>
---
 .../main/scala/org/apache/spark/ml/Estimator.scala   | 10 ++++++++++
 .../spark/ml/tree/impl/GradientBoostedTrees.scala    | 17 ++++++++++-------
 .../org/apache/spark/ml/tree/impl/RandomForest.scala | 12 +++++++++---
 python/pyspark/ml/util.py                            |  5 +++++
 python/pyspark/sql/connect/proto/ml_pb2.py           |  6 +++---
 python/pyspark/sql/connect/proto/ml_pb2.pyi          | 16 ++++++++++++++++
 .../common/src/main/protobuf/spark/connect/ml.proto  |  2 ++
 .../org/apache/spark/sql/connect/ml/MLHandler.scala  | 20 +++++++++++++++-----
 8 files changed, 70 insertions(+), 18 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index ead68b290fe4..0e1f64cc7b63 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml
 
 import scala.annotation.varargs
+import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.param.{ParamMap, ParamPair}
@@ -103,3 +104,12 @@ abstract class Estimator[M <: Model[M]] extends 
PipelineStage {
     throw new UnsupportedOperationException
   }
 }
+
+
+object EstimatorUtils {
+  // This warningMessagesBuffer is for collecting warning messages during 
`estimator.fit`
+  // execution in Spark Connect server.
+  private[spark] val warningMessagesBuffer = new 
java.lang.ThreadLocal[ArrayBuffer[String]]() {
+    override def initialValue: ArrayBuffer[String] = null
+  }
+}
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index 18273abbf0c7..d1cad44a15c8 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.tree.impl
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.TIMER
+import org.apache.spark.ml.EstimatorUtils
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
@@ -275,7 +276,6 @@ private[spark] object GradientBoostedTrees extends Logging {
 
   // This member is only for testing code.
   private[spark] var lastEarlyStoppedModelSize: Long = 0
-  private[spark] val modelSizeHistory = new 
scala.collection.mutable.ArrayBuffer[Long]()
 
   /**
    * Internal method for performing regression using trees as base learners.
@@ -413,8 +413,6 @@ private[spark] object GradientBoostedTrees extends Logging {
 
     var m = 1
     var earlyStop = false
-    modelSizeHistory.clear()
-    modelSizeHistory.append(accTreeSize)
     if (
         earlyStopModelSizeThresholdInBytes > 0
         && accTreeSize > earlyStopModelSizeThresholdInBytes
@@ -493,7 +491,6 @@ private[spark] object GradientBoostedTrees extends Logging {
         }
       }
       if (!earlyStop) {
-        modelSizeHistory.append(accTreeSize)
         if (
             earlyStopModelSizeThresholdInBytes > 0
             && accTreeSize > earlyStopModelSizeThresholdInBytes
@@ -526,9 +523,15 @@ private[spark] object GradientBoostedTrees extends Logging 
{
       //  - validation error increases
       //  - the accumulated size of trees exceeds the value of 
`earlyStopModelSizeThresholdInBytes`
       if (accTreeSize > earlyStopModelSizeThresholdInBytes) {
-        logWarning(
-          "The boosting tree training stops early because the model size 
exceeds threshold."
-        )
+        val warningMessage = "The boosting tree training stops early because 
the GBT accumulated " +
+          "tree models size " +
+          s"($accTreeSize bytes) exceeds threshold " +
+          s"($earlyStopModelSizeThresholdInBytes bytes)."
+        logWarning(warningMessage)
+        val msgBuffer = EstimatorUtils.warningMessagesBuffer.get()
+        if (msgBuffer != null) {
+          msgBuffer.append(warningMessage)
+        }
       }
       (baseLearners.slice(0, validM), baseLearnerWeights.slice(0, validM))
     } else {
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 79d0964b7eae..118d6d7a063a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -23,6 +23,7 @@ import scala.util.Random
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.{MAX_MEMORY_SIZE, MEMORY_SIZE, 
NUM_CLASSES, NUM_EXAMPLES, NUM_FEATURES, NUM_NODES, NUM_WEIGHTED_EXAMPLES, 
TIMER}
+import org.apache.spark.ml.EstimatorUtils
 import org.apache.spark.ml.classification.DecisionTreeClassificationModel
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.impl.Utils
@@ -228,9 +229,14 @@ private[spark] object RandomForest extends Logging with 
Serializable {
         val estimatedSize = SizeEstimator.estimate(nodes)
         if (estimatedSize > earlyStopModelSizeThresholdInBytes){
           earlyStop = true
-          logWarning(
-            "The random forest training stops early because the model size 
exceeds threshold."
-          )
+          val warningMessage = "The random forest training stops early because 
the model size " +
+            s"($estimatedSize bytes) exceeds threshold " +
+            s"($earlyStopModelSizeThresholdInBytes bytes)."
+          logWarning(warningMessage)
+          val msgBuffer = EstimatorUtils.warningMessagesBuffer.get()
+          if (msgBuffer != null) {
+            msgBuffer.append(warningMessage)
+          }
           lastEarlyStoppedModelSize = estimatedSize
         }
       }
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 0aac706ec25f..b86178a97c38 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -16,6 +16,7 @@
 #
 
 import json
+import logging
 import os
 import threading
 import time
@@ -68,6 +69,8 @@ FuncT = TypeVar("FuncT", bound=Callable[..., Any])
 
 ML_CONNECT_HELPER_ID = "______ML_CONNECT_HELPER______"
 
+_logger = logging.getLogger("pyspark.ml.util")
+
 
 def try_remote_intermediate_result(f: FuncT) -> FuncT:
     """Mark the function/property that returns the intermediate result of the 
remote call.
@@ -197,6 +200,8 @@ def try_remote_fit(f: FuncT) -> FuncT:
             )
             (_, properties, _) = client.execute_command(command)
             model_info = deserialize(properties)
+            if warning_msg := getattr(model_info, "warning_message", None):
+                _logger.warning(warning_msg)
             remote_model_ref = RemoteModelRef(model_info.obj_ref.id)
             model = self._create_model(remote_model_ref)
             if model.__class__.__name__ not in ["Bucketizer"]:
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py 
b/python/pyspark/sql/connect/proto/ml_pb2.py
index 31fa3dd5d0ec..46fc82131a9e 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import ml_common_pb2 as 
spark_dot_connect_dot_ml_
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb2\x0b\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
 
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
 
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
 
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\
 [...]
+    
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb2\x0b\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
 
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
 
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
 
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\
 [...]
 )
 
 _globals = globals()
@@ -72,7 +72,7 @@ if not _descriptor._USE_C_DESCRIPTORS:
     _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1401
     _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1584
     _globals["_MLCOMMANDRESULT"]._serialized_start = 1598
-    _globals["_MLCOMMANDRESULT"]._serialized_end = 2001
+    _globals["_MLCOMMANDRESULT"]._serialized_end = 2067
     _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1791
-    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1986
+    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2052
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi 
b/python/pyspark/sql/connect/proto/ml_pb2.pyi
index 9f6f4c1516d8..88cc6cb625de 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi
@@ -449,6 +449,7 @@ class MlCommandResult(google.protobuf.message.Message):
         NAME_FIELD_NUMBER: builtins.int
         UID_FIELD_NUMBER: builtins.int
         PARAMS_FIELD_NUMBER: builtins.int
+        WARNING_MESSAGE_FIELD_NUMBER: builtins.int
         @property
         def obj_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ObjectRef:
             """The cached object which could be a model or summary evaluated 
by a model"""
@@ -461,6 +462,8 @@ class MlCommandResult(google.protobuf.message.Message):
         @property
         def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams:
             """(Optional) parameters"""
+        warning_message: builtins.str
+        """(Optional) warning message generated during the ML command 
execution"""
         def __init__(
             self,
             *,
@@ -468,6 +471,7 @@ class MlCommandResult(google.protobuf.message.Message):
             name: builtins.str = ...,
             uid: builtins.str | None = ...,
             params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None = 
...,
+            warning_message: builtins.str | None = ...,
         ) -> None: ...
         def HasField(
             self,
@@ -476,6 +480,8 @@ class MlCommandResult(google.protobuf.message.Message):
                 b"_params",
                 "_uid",
                 b"_uid",
+                "_warning_message",
+                b"_warning_message",
                 "name",
                 b"name",
                 "obj_ref",
@@ -486,6 +492,8 @@ class MlCommandResult(google.protobuf.message.Message):
                 b"type",
                 "uid",
                 b"uid",
+                "warning_message",
+                b"warning_message",
             ],
         ) -> builtins.bool: ...
         def ClearField(
@@ -495,6 +503,8 @@ class MlCommandResult(google.protobuf.message.Message):
                 b"_params",
                 "_uid",
                 b"_uid",
+                "_warning_message",
+                b"_warning_message",
                 "name",
                 b"name",
                 "obj_ref",
@@ -505,6 +515,8 @@ class MlCommandResult(google.protobuf.message.Message):
                 b"type",
                 "uid",
                 b"uid",
+                "warning_message",
+                b"warning_message",
             ],
         ) -> None: ...
         @typing.overload
@@ -516,6 +528,10 @@ class MlCommandResult(google.protobuf.message.Message):
             self, oneof_group: typing_extensions.Literal["_uid", b"_uid"]
         ) -> typing_extensions.Literal["uid"] | None: ...
         @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_warning_message", 
b"_warning_message"]
+        ) -> typing_extensions.Literal["warning_message"] | None: ...
+        @typing.overload
         def WhichOneof(
             self, oneof_group: typing_extensions.Literal["type", b"type"]
         ) -> typing_extensions.Literal["obj_ref", "name"] | None: ...
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto 
b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
index 22c3ca7e6e90..b66c0a186df3 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
@@ -124,5 +124,7 @@ message MlCommandResult {
     optional string uid = 3;
     // (Optional) parameters
     optional MlParams params = 4;
+    // (Optional) warning message generated during the ML command execution
+    optional string warning_message = 5;
   }
 }
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index 204c874060cc..e5d16a8e783b 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -22,7 +22,7 @@ import scala.jdk.CollectionConverters.CollectionHasAsScala
 
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.Model
+import org.apache.spark.ml.{EstimatorUtils, Model}
 import org.apache.spark.ml.param.{ParamMap, Params}
 import org.apache.spark.ml.tree.TreeConfig
 import org.apache.spark.ml.util.{MLWritable, Summary}
@@ -163,14 +163,24 @@ private[connect] object MLHandler extends Logging {
                 "if Spark Connect model cache offloading is enabled.")
           }
         }
+
+        EstimatorUtils.warningMessagesBuffer.set(new 
mutable.ArrayBuffer[String]())
         val model = estimator.fit(dataset).asInstanceOf[Model[_]]
         val id = mlCache.register(model)
+
+        val fitWarningMessage = if 
(EstimatorUtils.warningMessagesBuffer.get().length > 0) {
+          EstimatorUtils.warningMessagesBuffer.get().mkString("\n")
+        } else { null }
+        EstimatorUtils.warningMessagesBuffer.set(null)
+        val opInfo = proto.MlCommandResult.MlOperatorInfo
+          .newBuilder()
+          .setObjRef(proto.ObjectRef.newBuilder().setId(id))
+        if (fitWarningMessage != null) {
+          opInfo.setWarningMessage(fitWarningMessage)
+        }
         proto.MlCommandResult
           .newBuilder()
-          .setOperatorInfo(
-            proto.MlCommandResult.MlOperatorInfo
-              .newBuilder()
-              .setObjRef(proto.ObjectRef.newBuilder().setId(id)))
+          .setOperatorInfo(opInfo)
           .build()
 
       case proto.MlCommand.CommandCase.FETCH =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to