This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b50690a89a81 [SPARK-52057][ML][CONNECT] Collect Tree size limit
warning messages to client
b50690a89a81 is described below
commit b50690a89a8176ce48c7287f1b94f8e7841df0f8
Author: Weichen Xu <[email protected]>
AuthorDate: Mon May 12 10:08:35 2025 +0800
[SPARK-52057][ML][CONNECT] Collect Tree size limit warning messages to
client
### What changes were proposed in this pull request?
Collect Tree size limit warning messages to client
During tree model training, model size will be capped by the threshold,
but the warning information is only printed in Spark driver.
We need to send the warning message to Spark Connect client and print it in
client logs
### Why are the changes needed?
To show the warning message to user.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #50848 from WeichenXu123/SPARK-52057.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
.../main/scala/org/apache/spark/ml/Estimator.scala | 10 ++++++++++
.../spark/ml/tree/impl/GradientBoostedTrees.scala | 17 ++++++++++-------
.../org/apache/spark/ml/tree/impl/RandomForest.scala | 12 +++++++++---
python/pyspark/ml/util.py | 5 +++++
python/pyspark/sql/connect/proto/ml_pb2.py | 6 +++---
python/pyspark/sql/connect/proto/ml_pb2.pyi | 16 ++++++++++++++++
.../common/src/main/protobuf/spark/connect/ml.proto | 2 ++
.../org/apache/spark/sql/connect/ml/MLHandler.scala | 20 +++++++++++++++-----
8 files changed, 70 insertions(+), 18 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index ead68b290fe4..0e1f64cc7b63 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml
import scala.annotation.varargs
+import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.Since
import org.apache.spark.ml.param.{ParamMap, ParamPair}
@@ -103,3 +104,12 @@ abstract class Estimator[M <: Model[M]] extends
PipelineStage {
throw new UnsupportedOperationException
}
}
+
+
+object EstimatorUtils {
+ // This warningMessagesBuffer is for collecting warning messages during
`estimator.fit`
+ // execution in Spark Connect server.
+ private[spark] val warningMessagesBuffer = new
java.lang.ThreadLocal[ArrayBuffer[String]]() {
+ override def initialValue: ArrayBuffer[String] = null
+ }
+}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index 18273abbf0c7..d1cad44a15c8 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.TIMER
+import org.apache.spark.ml.EstimatorUtils
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
@@ -275,7 +276,6 @@ private[spark] object GradientBoostedTrees extends Logging {
// This member is only for testing code.
private[spark] var lastEarlyStoppedModelSize: Long = 0
- private[spark] val modelSizeHistory = new
scala.collection.mutable.ArrayBuffer[Long]()
/**
* Internal method for performing regression using trees as base learners.
@@ -413,8 +413,6 @@ private[spark] object GradientBoostedTrees extends Logging {
var m = 1
var earlyStop = false
- modelSizeHistory.clear()
- modelSizeHistory.append(accTreeSize)
if (
earlyStopModelSizeThresholdInBytes > 0
&& accTreeSize > earlyStopModelSizeThresholdInBytes
@@ -493,7 +491,6 @@ private[spark] object GradientBoostedTrees extends Logging {
}
}
if (!earlyStop) {
- modelSizeHistory.append(accTreeSize)
if (
earlyStopModelSizeThresholdInBytes > 0
&& accTreeSize > earlyStopModelSizeThresholdInBytes
@@ -526,9 +523,15 @@ private[spark] object GradientBoostedTrees extends Logging
{
// - validation error increases
// - the accumulated size of trees exceeds the value of
`earlyStopModelSizeThresholdInBytes`
if (accTreeSize > earlyStopModelSizeThresholdInBytes) {
- logWarning(
- "The boosting tree training stops early because the model size
exceeds threshold."
- )
+ val warningMessage = "The boosting tree training stops early because
the GBT accumulated " +
+ "tree models size " +
+ s"($accTreeSize bytes) exceeds threshold " +
+ s"($earlyStopModelSizeThresholdInBytes bytes)."
+ logWarning(warningMessage)
+ val msgBuffer = EstimatorUtils.warningMessagesBuffer.get()
+ if (msgBuffer != null) {
+ msgBuffer.append(warningMessage)
+ }
}
(baseLearners.slice(0, validM), baseLearnerWeights.slice(0, validM))
} else {
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 79d0964b7eae..118d6d7a063a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -23,6 +23,7 @@ import scala.util.Random
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{MAX_MEMORY_SIZE, MEMORY_SIZE,
NUM_CLASSES, NUM_EXAMPLES, NUM_FEATURES, NUM_NODES, NUM_WEIGHTED_EXAMPLES,
TIMER}
+import org.apache.spark.ml.EstimatorUtils
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.impl.Utils
@@ -228,9 +229,14 @@ private[spark] object RandomForest extends Logging with
Serializable {
val estimatedSize = SizeEstimator.estimate(nodes)
if (estimatedSize > earlyStopModelSizeThresholdInBytes){
earlyStop = true
- logWarning(
- "The random forest training stops early because the model size
exceeds threshold."
- )
+ val warningMessage = "The random forest training stops early because
the model size " +
+ s"($estimatedSize bytes) exceeds threshold " +
+ s"($earlyStopModelSizeThresholdInBytes bytes)."
+ logWarning(warningMessage)
+ val msgBuffer = EstimatorUtils.warningMessagesBuffer.get()
+ if (msgBuffer != null) {
+ msgBuffer.append(warningMessage)
+ }
lastEarlyStoppedModelSize = estimatedSize
}
}
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 0aac706ec25f..b86178a97c38 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -16,6 +16,7 @@
#
import json
+import logging
import os
import threading
import time
@@ -68,6 +69,8 @@ FuncT = TypeVar("FuncT", bound=Callable[..., Any])
ML_CONNECT_HELPER_ID = "______ML_CONNECT_HELPER______"
+_logger = logging.getLogger("pyspark.ml.util")
+
def try_remote_intermediate_result(f: FuncT) -> FuncT:
"""Mark the function/property that returns the intermediate result of the
remote call.
@@ -197,6 +200,8 @@ def try_remote_fit(f: FuncT) -> FuncT:
)
(_, properties, _) = client.execute_command(command)
model_info = deserialize(properties)
+ if warning_msg := getattr(model_info, "warning_message", None):
+ _logger.warning(warning_msg)
remote_model_ref = RemoteModelRef(model_info.obj_ref.id)
model = self._create_model(remote_model_ref)
if model.__class__.__name__ not in ["Bucketizer"]:
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py
b/python/pyspark/sql/connect/proto/ml_pb2.py
index 31fa3dd5d0ec..46fc82131a9e 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import ml_common_pb2 as
spark_dot_connect_dot_ml_
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb2\x0b\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\
[...]
+
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb2\x0b\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\
[...]
)
_globals = globals()
@@ -72,7 +72,7 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1401
_globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1584
_globals["_MLCOMMANDRESULT"]._serialized_start = 1598
- _globals["_MLCOMMANDRESULT"]._serialized_end = 2001
+ _globals["_MLCOMMANDRESULT"]._serialized_end = 2067
_globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1791
- _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1986
+ _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2052
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi
b/python/pyspark/sql/connect/proto/ml_pb2.pyi
index 9f6f4c1516d8..88cc6cb625de 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi
@@ -449,6 +449,7 @@ class MlCommandResult(google.protobuf.message.Message):
NAME_FIELD_NUMBER: builtins.int
UID_FIELD_NUMBER: builtins.int
PARAMS_FIELD_NUMBER: builtins.int
+ WARNING_MESSAGE_FIELD_NUMBER: builtins.int
@property
def obj_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ObjectRef:
"""The cached object which could be a model or summary evaluated
by a model"""
@@ -461,6 +462,8 @@ class MlCommandResult(google.protobuf.message.Message):
@property
def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams:
"""(Optional) parameters"""
+ warning_message: builtins.str
+ """(Optional) warning message generated during the ML command
execution"""
def __init__(
self,
*,
@@ -468,6 +471,7 @@ class MlCommandResult(google.protobuf.message.Message):
name: builtins.str = ...,
uid: builtins.str | None = ...,
params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None =
...,
+ warning_message: builtins.str | None = ...,
) -> None: ...
def HasField(
self,
@@ -476,6 +480,8 @@ class MlCommandResult(google.protobuf.message.Message):
b"_params",
"_uid",
b"_uid",
+ "_warning_message",
+ b"_warning_message",
"name",
b"name",
"obj_ref",
@@ -486,6 +492,8 @@ class MlCommandResult(google.protobuf.message.Message):
b"type",
"uid",
b"uid",
+ "warning_message",
+ b"warning_message",
],
) -> builtins.bool: ...
def ClearField(
@@ -495,6 +503,8 @@ class MlCommandResult(google.protobuf.message.Message):
b"_params",
"_uid",
b"_uid",
+ "_warning_message",
+ b"_warning_message",
"name",
b"name",
"obj_ref",
@@ -505,6 +515,8 @@ class MlCommandResult(google.protobuf.message.Message):
b"type",
"uid",
b"uid",
+ "warning_message",
+ b"warning_message",
],
) -> None: ...
@typing.overload
@@ -516,6 +528,10 @@ class MlCommandResult(google.protobuf.message.Message):
self, oneof_group: typing_extensions.Literal["_uid", b"_uid"]
) -> typing_extensions.Literal["uid"] | None: ...
@typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_warning_message",
b"_warning_message"]
+ ) -> typing_extensions.Literal["warning_message"] | None: ...
+ @typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["type", b"type"]
) -> typing_extensions.Literal["obj_ref", "name"] | None: ...
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
index 22c3ca7e6e90..b66c0a186df3 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
@@ -124,5 +124,7 @@ message MlCommandResult {
optional string uid = 3;
// (Optional) parameters
optional MlParams params = 4;
+ // (Optional) warning message generated during the ML command execution
+ optional string warning_message = 5;
}
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index 204c874060cc..e5d16a8e783b 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -22,7 +22,7 @@ import scala.jdk.CollectionConverters.CollectionHasAsScala
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.Model
+import org.apache.spark.ml.{EstimatorUtils, Model}
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.tree.TreeConfig
import org.apache.spark.ml.util.{MLWritable, Summary}
@@ -163,14 +163,24 @@ private[connect] object MLHandler extends Logging {
"if Spark Connect model cache offloading is enabled.")
}
}
+
+ EstimatorUtils.warningMessagesBuffer.set(new
mutable.ArrayBuffer[String]())
val model = estimator.fit(dataset).asInstanceOf[Model[_]]
val id = mlCache.register(model)
+
+ val fitWarningMessage = if
(EstimatorUtils.warningMessagesBuffer.get().length > 0) {
+ EstimatorUtils.warningMessagesBuffer.get().mkString("\n")
+ } else { null }
+ EstimatorUtils.warningMessagesBuffer.set(null)
+ val opInfo = proto.MlCommandResult.MlOperatorInfo
+ .newBuilder()
+ .setObjRef(proto.ObjectRef.newBuilder().setId(id))
+ if (fitWarningMessage != null) {
+ opInfo.setWarningMessage(fitWarningMessage)
+ }
proto.MlCommandResult
.newBuilder()
- .setOperatorInfo(
- proto.MlCommandResult.MlOperatorInfo
- .newBuilder()
- .setObjRef(proto.ObjectRef.newBuilder().setId(id)))
+ .setOperatorInfo(opInfo)
.build()
case proto.MlCommand.CommandCase.FETCH =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]