This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a71f76f8aac [SPARK-51142][ML][CONNECT] ML protobufs clean up
6a71f76f8aac is described below

commit 6a71f76f8aacf34276fb3371c5bf242cfea4c0f7
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Feb 11 09:53:54 2025 +0800

    [SPARK-51142][ML][CONNECT] ML protobufs clean up
    
    ### What changes were proposed in this pull request?
    ML protobufs clean up
    
    ### Why are the changes needed?
    to follow the guide 
https://github.com/apache/spark/blob/ece14704cc083f17689d2e0b9ab8e31cf71a7a2d/sql/connect/docs/adding-proto-messages.md
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49862 from zhengruifeng/ml_connect_protos.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/ml/connect/proto.py                 |   4 +-
 python/pyspark/ml/connect/readwrite.py             |  19 +--
 python/pyspark/ml/util.py                          |   4 +-
 python/pyspark/sql/connect/proto/ml_common_pb2.py  |  12 +-
 python/pyspark/sql/connect/proto/ml_common_pb2.pyi |  38 +++---
 python/pyspark/sql/connect/proto/ml_pb2.py         |  34 +++---
 python/pyspark/sql/connect/proto/ml_pb2.pyi        | 133 +++++++++++++++++----
 .../src/main/protobuf/spark/connect/ml.proto       |  40 ++++---
 .../main/protobuf/spark/connect/ml_common.proto    |  26 ++--
 .../apache/spark/sql/connect/ml/MLHandler.scala    |  38 +++---
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |   2 +-
 .../spark/sql/connect/ml/MLBackendSuite.scala      |   2 +-
 .../org/apache/spark/sql/connect/ml/MLHelper.scala |   8 +-
 .../org/apache/spark/sql/connect/ml/MLSuite.scala  |   4 +-
 14 files changed, 240 insertions(+), 124 deletions(-)

diff --git a/python/pyspark/ml/connect/proto.py 
b/python/pyspark/ml/connect/proto.py
index 3a81e74b6aec..b0e012964fc4 100644
--- a/python/pyspark/ml/connect/proto.py
+++ b/python/pyspark/ml/connect/proto.py
@@ -50,7 +50,9 @@ class TransformerRelation(LogicalPlan):
             
plan.ml_relation.transform.obj_ref.CopyFrom(pb2.ObjectRef(id=self._name))
         else:
             plan.ml_relation.transform.transformer.CopyFrom(
-                pb2.MlOperator(name=self._name, uid=self._uid, 
type=pb2.MlOperator.TRANSFORMER)
+                pb2.MlOperator(
+                    name=self._name, uid=self._uid, 
type=pb2.MlOperator.OPERATOR_TYPE_TRANSFORMER
+                )
             )
 
         if self._ml_params is not None:
diff --git a/python/pyspark/ml/connect/readwrite.py 
b/python/pyspark/ml/connect/readwrite.py
index 584ff3237a0a..c2367282b7c4 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -118,13 +118,13 @@ class RemoteMLWriter(MLWriter):
         elif isinstance(instance, (JavaEstimator, JavaTransformer, 
JavaEvaluator)):
             operator: Union[JavaEstimator, JavaTransformer, JavaEvaluator]
             if isinstance(instance, JavaEstimator):
-                ml_type = pb2.MlOperator.ESTIMATOR
+                ml_type = pb2.MlOperator.OPERATOR_TYPE_ESTIMATOR
                 operator = cast("JavaEstimator", instance)
             elif isinstance(instance, JavaEvaluator):
-                ml_type = pb2.MlOperator.EVALUATOR
+                ml_type = pb2.MlOperator.OPERATOR_TYPE_EVALUATOR
                 operator = cast("JavaEvaluator", instance)
             else:
-                ml_type = pb2.MlOperator.TRANSFORMER
+                ml_type = pb2.MlOperator.OPERATOR_TYPE_TRANSFORMER
                 operator = cast("JavaTransformer", instance)
 
             params = serialize_ml_params(operator, session.client)
@@ -249,13 +249,13 @@ class RemoteMLReader(MLReader[RL]):
             or issubclass(clazz, JavaTransformer)
         ):
             if issubclass(clazz, JavaModel):
-                ml_type = pb2.MlOperator.MODEL
+                ml_type = pb2.MlOperator.OPERATOR_TYPE_MODEL
             elif issubclass(clazz, JavaEstimator):
-                ml_type = pb2.MlOperator.ESTIMATOR
+                ml_type = pb2.MlOperator.OPERATOR_TYPE_ESTIMATOR
             elif issubclass(clazz, JavaEvaluator):
-                ml_type = pb2.MlOperator.EVALUATOR
+                ml_type = pb2.MlOperator.OPERATOR_TYPE_EVALUATOR
             else:
-                ml_type = pb2.MlOperator.TRANSFORMER
+                ml_type = pb2.MlOperator.OPERATOR_TYPE_TRANSFORMER
 
             # to get the java corresponding qualified class name
             java_qualified_class_name = (
@@ -281,7 +281,7 @@ class RemoteMLReader(MLReader[RL]):
             py_type = _get_class()
             # It must be JavaWrapper, since we're passing the string to the 
_java_obj
             if issubclass(py_type, JavaWrapper):
-                if ml_type == pb2.MlOperator.MODEL:
+                if ml_type == pb2.MlOperator.OPERATOR_TYPE_MODEL:
                     session.client.add_ml_cache(result.obj_ref.id)
                     instance = py_type(result.obj_ref.id)
                 else:
@@ -358,7 +358,8 @@ class RemoteMLReader(MLReader[RL]):
             command.ml_command.read.CopyFrom(
                 pb2.MlCommand.Read(
                     operator=pb2.MlOperator(
-                        name=java_qualified_class_name, 
type=pb2.MlOperator.TRANSFORMER
+                        name=java_qualified_class_name,
+                        type=pb2.MlOperator.OPERATOR_TYPE_TRANSFORMER,
                     ),
                     path=path,
                 )
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 7b8ba57a1f8a..9eab45239b8f 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -136,7 +136,7 @@ def try_remote_fit(f: FuncT) -> FuncT:
             input = dataset._plan.plan(client)
             assert isinstance(self._java_obj, str)
             estimator = pb2.MlOperator(
-                name=self._java_obj, uid=self.uid, 
type=pb2.MlOperator.ESTIMATOR
+                name=self._java_obj, uid=self.uid, 
type=pb2.MlOperator.OPERATOR_TYPE_ESTIMATOR
             )
             command = pb2.Command()
             command.ml_command.fit.CopyFrom(
@@ -361,7 +361,7 @@ def try_remote_evaluate(f: FuncT) -> FuncT:
             input = dataset._plan.plan(client)
             assert isinstance(self._java_obj, str)
             evaluator = pb2.MlOperator(
-                name=self._java_obj, uid=self.uid, 
type=pb2.MlOperator.EVALUATOR
+                name=self._java_obj, uid=self.uid, 
type=pb2.MlOperator.OPERATOR_TYPE_EVALUATOR
             )
             command = pb2.Command()
             command.ml_command.evaluate.CopyFrom(
diff --git a/python/pyspark/sql/connect/proto/ml_common_pb2.py 
b/python/pyspark/sql/connect/proto/ml_common_pb2.py
index 43d6a512f48f..b61e1bcb205c 100644
--- a/python/pyspark/sql/connect/proto/ml_common_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_common_pb2.py
@@ -38,7 +38,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as 
spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/ml_common.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xa5\x01\n\x08MlParams\x12;\n\x06params\x18\x01
 
\x03(\x0b\x32#.spark.connect.MlParams.ParamsEntryR\x06params\x1a\\\n\x0bParamsEntry\x12\x10\n\x03key\x18\x01
 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 
\x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01"\xc9\x01\n\nMlOperator\x12\x12\n\x04name\x18\x01
 \x01(\tR\x04name\x12\x10\n\x03uid\x18\x02 \x01(\tR\x03uid\x12:\n\x04type [...]
+    
b'\n\x1dspark/connect/ml_common.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xa5\x01\n\x08MlParams\x12;\n\x06params\x18\x01
 
\x03(\x0b\x32#.spark.connect.MlParams.ParamsEntryR\x06params\x1a\\\n\x0bParamsEntry\x12\x10\n\x03key\x18\x01
 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 
\x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01"\x90\x02\n\nMlOperator\x12\x12\n\x04name\x18\x01
 \x01(\tR\x04name\x12\x10\n\x03uid\x18\x02 \x01(\tR\x03uid\x12:\n\x04type [...]
 )
 
 _globals = globals()
@@ -58,9 +58,9 @@ if not _descriptor._USE_C_DESCRIPTORS:
     _globals["_MLPARAMS_PARAMSENTRY"]._serialized_start = 155
     _globals["_MLPARAMS_PARAMSENTRY"]._serialized_end = 247
     _globals["_MLOPERATOR"]._serialized_start = 250
-    _globals["_MLOPERATOR"]._serialized_end = 451
-    _globals["_MLOPERATOR_OPERATORTYPE"]._serialized_start = 362
-    _globals["_MLOPERATOR_OPERATORTYPE"]._serialized_end = 451
-    _globals["_OBJECTREF"]._serialized_start = 453
-    _globals["_OBJECTREF"]._serialized_end = 480
+    _globals["_MLOPERATOR"]._serialized_end = 522
+    _globals["_MLOPERATOR_OPERATORTYPE"]._serialized_start = 363
+    _globals["_MLOPERATOR_OPERATORTYPE"]._serialized_end = 522
+    _globals["_OBJECTREF"]._serialized_start = 524
+    _globals["_OBJECTREF"]._serialized_end = 551
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/ml_common_pb2.pyi 
b/python/pyspark/sql/connect/proto/ml_common_pb2.pyi
index f4688e94c3d5..bc540028eb08 100644
--- a/python/pyspark/sql/connect/proto/ml_common_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/ml_common_pb2.pyi
@@ -112,28 +112,36 @@ class MlOperator(google.protobuf.message.Message):
         builtins.type,
     ):  # noqa: F821
         DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
-        UNSPECIFIED: MlOperator._OperatorType.ValueType  # 0
-        ESTIMATOR: MlOperator._OperatorType.ValueType  # 1
-        TRANSFORMER: MlOperator._OperatorType.ValueType  # 2
-        EVALUATOR: MlOperator._OperatorType.ValueType  # 3
-        MODEL: MlOperator._OperatorType.ValueType  # 4
+        OPERATOR_TYPE_UNSPECIFIED: MlOperator._OperatorType.ValueType  # 0
+        OPERATOR_TYPE_ESTIMATOR: MlOperator._OperatorType.ValueType  # 1
+        """ML estimator"""
+        OPERATOR_TYPE_TRANSFORMER: MlOperator._OperatorType.ValueType  # 2
+        """ML transformer (non-model)"""
+        OPERATOR_TYPE_EVALUATOR: MlOperator._OperatorType.ValueType  # 3
+        """ML evaluator"""
+        OPERATOR_TYPE_MODEL: MlOperator._OperatorType.ValueType  # 4
+        """ML model"""
 
     class OperatorType(_OperatorType, metaclass=_OperatorTypeEnumTypeWrapper): 
...
-    UNSPECIFIED: MlOperator.OperatorType.ValueType  # 0
-    ESTIMATOR: MlOperator.OperatorType.ValueType  # 1
-    TRANSFORMER: MlOperator.OperatorType.ValueType  # 2
-    EVALUATOR: MlOperator.OperatorType.ValueType  # 3
-    MODEL: MlOperator.OperatorType.ValueType  # 4
+    OPERATOR_TYPE_UNSPECIFIED: MlOperator.OperatorType.ValueType  # 0
+    OPERATOR_TYPE_ESTIMATOR: MlOperator.OperatorType.ValueType  # 1
+    """ML estimator"""
+    OPERATOR_TYPE_TRANSFORMER: MlOperator.OperatorType.ValueType  # 2
+    """ML transformer (non-model)"""
+    OPERATOR_TYPE_EVALUATOR: MlOperator.OperatorType.ValueType  # 3
+    """ML evaluator"""
+    OPERATOR_TYPE_MODEL: MlOperator.OperatorType.ValueType  # 4
+    """ML model"""
 
     NAME_FIELD_NUMBER: builtins.int
     UID_FIELD_NUMBER: builtins.int
     TYPE_FIELD_NUMBER: builtins.int
     name: builtins.str
-    """The qualified name of the ML operator."""
+    """(Required) The qualified name of the ML operator."""
     uid: builtins.str
-    """Unique id of the ML operator"""
+    """(Required) Unique id of the ML operator"""
     type: global___MlOperator.OperatorType.ValueType
-    """Represents what the ML operator is"""
+    """(Required) Represents what the ML operator is"""
     def __init__(
         self,
         *,
@@ -156,7 +164,9 @@ class ObjectRef(google.protobuf.message.Message):
 
     ID_FIELD_NUMBER: builtins.int
     id: builtins.str
-    """The ID is used to lookup the object on the server side."""
+    """(Required) The ID is used to lookup the object on the server side.
+    Note it is different from the 'uid' of a ML object.
+    """
     def __init__(
         self,
         *,
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py 
b/python/pyspark/sql/connect/proto/ml_pb2.py
index 8e8bc34a7a97..666cb1efdd2b 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import ml_common_pb2 as 
spark_dot_connect_dot_ml_
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb1\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
 
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
 
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
 
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x1
 [...]
+    
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xfb\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
 
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
 
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
 
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x1
 [...]
 )
 
 _globals = globals()
@@ -54,21 +54,21 @@ if not _descriptor._USE_C_DESCRIPTORS:
     _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None
     _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001"
     _globals["_MLCOMMAND"]._serialized_start = 137
-    _globals["_MLCOMMAND"]._serialized_end = 1338
+    _globals["_MLCOMMAND"]._serialized_end = 1412
     _globals["_MLCOMMAND_FIT"]._serialized_start = 480
-    _globals["_MLCOMMAND_FIT"]._serialized_end = 642
-    _globals["_MLCOMMAND_DELETE"]._serialized_start = 644
-    _globals["_MLCOMMAND_DELETE"]._serialized_end = 703
-    _globals["_MLCOMMAND_WRITE"]._serialized_start = 706
-    _globals["_MLCOMMAND_WRITE"]._serialized_end = 1074
-    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1008
-    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1066
-    _globals["_MLCOMMAND_READ"]._serialized_start = 1076
-    _globals["_MLCOMMAND_READ"]._serialized_end = 1157
-    _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1160
-    _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1327
-    _globals["_MLCOMMANDRESULT"]._serialized_start = 1341
-    _globals["_MLCOMMANDRESULT"]._serialized_end = 1715
-    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1534
-    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1700
+    _globals["_MLCOMMAND_FIT"]._serialized_end = 658
+    _globals["_MLCOMMAND_DELETE"]._serialized_start = 660
+    _globals["_MLCOMMAND_DELETE"]._serialized_end = 719
+    _globals["_MLCOMMAND_WRITE"]._serialized_start = 722
+    _globals["_MLCOMMAND_WRITE"]._serialized_end = 1132
+    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1034
+    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1092
+    _globals["_MLCOMMAND_READ"]._serialized_start = 1134
+    _globals["_MLCOMMAND_READ"]._serialized_end = 1215
+    _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1218
+    _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1401
+    _globals["_MLCOMMANDRESULT"]._serialized_start = 1415
+    _globals["_MLCOMMANDRESULT"]._serialized_end = 1818
+    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1608
+    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1803
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi 
b/python/pyspark/sql/connect/proto/ml_pb2.pyi
index e8ae0be8dded..3a1e9155d71d 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi
@@ -42,6 +42,7 @@ import pyspark.sql.connect.proto.expressions_pb2
 import pyspark.sql.connect.proto.ml_common_pb2
 import pyspark.sql.connect.proto.relations_pb2
 import sys
+import typing
 
 if sys.version_info >= (3, 8):
     import typing as typing_extensions
@@ -65,13 +66,13 @@ class MlCommand(google.protobuf.message.Message):
         DATASET_FIELD_NUMBER: builtins.int
         @property
         def estimator(self) -> 
pyspark.sql.connect.proto.ml_common_pb2.MlOperator:
-            """Estimator information"""
+            """(Required) Estimator information (its type should be 
OPERATOR_TYPE_ESTIMATOR)"""
         @property
         def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams:
-            """parameters of the Estimator"""
+            """(Optional) parameters of the Estimator"""
         @property
         def dataset(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
-            """the training dataset"""
+            """(Required) the training dataset"""
         def __init__(
             self,
             *,
@@ -82,15 +83,32 @@ class MlCommand(google.protobuf.message.Message):
         def HasField(
             self,
             field_name: typing_extensions.Literal[
-                "dataset", b"dataset", "estimator", b"estimator", "params", 
b"params"
+                "_params",
+                b"_params",
+                "dataset",
+                b"dataset",
+                "estimator",
+                b"estimator",
+                "params",
+                b"params",
             ],
         ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
-                "dataset", b"dataset", "estimator", b"estimator", "params", 
b"params"
+                "_params",
+                b"_params",
+                "dataset",
+                b"dataset",
+                "estimator",
+                b"estimator",
+                "params",
+                b"params",
             ],
         ) -> None: ...
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_params", b"_params"]
+        ) -> typing_extensions.Literal["params"] | None: ...
 
     class Delete(google.protobuf.message.Message):
         """Command to delete the cached object which could be a model
@@ -150,16 +168,16 @@ class MlCommand(google.protobuf.message.Message):
             """The cached model"""
         @property
         def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams:
-            """The parameters of operator which could be estimator/evaluator 
or a cached model"""
+            """(Optional) The parameters of operator which could be 
estimator/evaluator or a cached model"""
         path: builtins.str
-        """Save the ML instance to the path"""
+        """(Required) Save the ML instance to the path"""
         should_overwrite: builtins.bool
-        """Overwrites if the output path already exists."""
+        """(Optional) Overwrites if the output path already exists."""
         @property
         def options(
             self,
         ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, 
builtins.str]:
-            """The options of the writer"""
+            """(Optional) The options of the writer"""
         def __init__(
             self,
             *,
@@ -167,18 +185,35 @@ class MlCommand(google.protobuf.message.Message):
             obj_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None 
= ...,
             params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None = 
...,
             path: builtins.str = ...,
-            should_overwrite: builtins.bool = ...,
+            should_overwrite: builtins.bool | None = ...,
             options: collections.abc.Mapping[builtins.str, builtins.str] | 
None = ...,
         ) -> None: ...
         def HasField(
             self,
             field_name: typing_extensions.Literal[
-                "obj_ref", b"obj_ref", "operator", b"operator", "params", 
b"params", "type", b"type"
+                "_params",
+                b"_params",
+                "_should_overwrite",
+                b"_should_overwrite",
+                "obj_ref",
+                b"obj_ref",
+                "operator",
+                b"operator",
+                "params",
+                b"params",
+                "should_overwrite",
+                b"should_overwrite",
+                "type",
+                b"type",
             ],
         ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
+                "_params",
+                b"_params",
+                "_should_overwrite",
+                b"_should_overwrite",
                 "obj_ref",
                 b"obj_ref",
                 "operator",
@@ -195,6 +230,15 @@ class MlCommand(google.protobuf.message.Message):
                 b"type",
             ],
         ) -> None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_params", b"_params"]
+        ) -> typing_extensions.Literal["params"] | None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_should_overwrite", 
b"_should_overwrite"]
+        ) -> typing_extensions.Literal["should_overwrite"] | None: ...
+        @typing.overload
         def WhichOneof(
             self, oneof_group: typing_extensions.Literal["type", b"type"]
         ) -> typing_extensions.Literal["operator", "obj_ref"] | None: ...
@@ -208,9 +252,9 @@ class MlCommand(google.protobuf.message.Message):
         PATH_FIELD_NUMBER: builtins.int
         @property
         def operator(self) -> 
pyspark.sql.connect.proto.ml_common_pb2.MlOperator:
-            """ML operator information"""
+            """(Required) ML operator information"""
         path: builtins.str
-        """Load the ML instance from the input path"""
+        """(Required) Load the ML instance from the input path"""
         def __init__(
             self,
             *,
@@ -234,13 +278,13 @@ class MlCommand(google.protobuf.message.Message):
         DATASET_FIELD_NUMBER: builtins.int
         @property
         def evaluator(self) -> 
pyspark.sql.connect.proto.ml_common_pb2.MlOperator:
-            """Evaluator information"""
+            """(Required) Evaluator information (its type should be 
OPERATOR_TYPE_EVALUATOR)"""
         @property
         def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams:
-            """parameters of the Evaluator"""
+            """(Optional) parameters of the Evaluator"""
         @property
         def dataset(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
-            """the evaluating dataset"""
+            """(Required) the evaluating dataset"""
         def __init__(
             self,
             *,
@@ -251,15 +295,32 @@ class MlCommand(google.protobuf.message.Message):
         def HasField(
             self,
             field_name: typing_extensions.Literal[
-                "dataset", b"dataset", "evaluator", b"evaluator", "params", 
b"params"
+                "_params",
+                b"_params",
+                "dataset",
+                b"dataset",
+                "evaluator",
+                b"evaluator",
+                "params",
+                b"params",
             ],
         ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
-                "dataset", b"dataset", "evaluator", b"evaluator", "params", 
b"params"
+                "_params",
+                b"_params",
+                "dataset",
+                b"dataset",
+                "evaluator",
+                b"evaluator",
+                "params",
+                b"params",
             ],
         ) -> None: ...
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_params", b"_params"]
+        ) -> typing_extensions.Literal["params"] | None: ...
 
     FIT_FIELD_NUMBER: builtins.int
     FETCH_FIELD_NUMBER: builtins.int
@@ -355,25 +416,46 @@ class MlCommandResult(google.protobuf.message.Message):
         name: builtins.str
         """Operator name"""
         uid: builtins.str
+        """(Optional) the 'uid' of a ML object
+        Note it is different from the 'id' of a cached object.
+        """
         @property
-        def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams: 
...
+        def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams:
+            """(Optional) parameters"""
         def __init__(
             self,
             *,
             obj_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None 
= ...,
             name: builtins.str = ...,
-            uid: builtins.str = ...,
+            uid: builtins.str | None = ...,
             params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None = 
...,
         ) -> None: ...
         def HasField(
             self,
             field_name: typing_extensions.Literal[
-                "name", b"name", "obj_ref", b"obj_ref", "params", b"params", 
"type", b"type"
+                "_params",
+                b"_params",
+                "_uid",
+                b"_uid",
+                "name",
+                b"name",
+                "obj_ref",
+                b"obj_ref",
+                "params",
+                b"params",
+                "type",
+                b"type",
+                "uid",
+                b"uid",
             ],
         ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
+                "_params",
+                b"_params",
+                "_uid",
+                b"_uid",
                 "name",
                 b"name",
                 "obj_ref",
@@ -386,6 +468,15 @@ class MlCommandResult(google.protobuf.message.Message):
                 b"uid",
             ],
         ) -> None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_params", b"_params"]
+        ) -> typing_extensions.Literal["params"] | None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_uid", b"_uid"]
+        ) -> typing_extensions.Literal["uid"] | None: ...
+        @typing.overload
         def WhichOneof(
             self, oneof_group: typing_extensions.Literal["type", b"type"]
         ) -> typing_extensions.Literal["obj_ref", "name"] | None: ...
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto 
b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
index 20a5cafebb36..6e469bb9027e 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
@@ -40,11 +40,11 @@ message MlCommand {
 
   // Command for estimator.fit(dataset)
   message Fit {
-    // Estimator information
+    // (Required) Estimator information (its type should be 
OPERATOR_TYPE_ESTIMATOR)
     MlOperator estimator = 1;
-    // parameters of the Estimator
-    MlParams params = 2;
-    // the training dataset
+    // (Optional) parameters of the Estimator
+    optional MlParams params = 2;
+    // (Required) the training dataset
     Relation dataset = 3;
   }
 
@@ -63,31 +63,31 @@ message MlCommand {
       // The cached model
       ObjectRef obj_ref = 2;
     }
-    // The parameters of operator which could be estimator/evaluator or a 
cached model
-    MlParams params = 3;
-    // Save the ML instance to the path
+    // (Optional) The parameters of operator which could be 
estimator/evaluator or a cached model
+    optional MlParams params = 3;
+    // (Required) Save the ML instance to the path
     string path = 4;
-    // Overwrites if the output path already exists.
-    bool should_overwrite = 5;
-    // The options of the writer
+    // (Optional) Overwrites if the output path already exists.
+    optional bool should_overwrite = 5;
+    // (Optional) The options of the writer
     map<string, string> options = 6;
   }
 
   // Command to load ML operator.
   message Read {
-    // ML operator information
+    // (Required) ML operator information
     MlOperator operator = 1;
-    // Load the ML instance from the input path
+    // (Required) Load the ML instance from the input path
     string path = 2;
   }
 
   // Command for evaluator.evaluate(dataset)
   message Evaluate {
-    // Evaluator information
+    // (Required) Evaluator information (its type should be 
OPERATOR_TYPE_EVALUATOR)
     MlOperator evaluator = 1;
-    // parameters of the Evaluator
-    MlParams params = 2;
-    // the evaluating dataset
+    // (Optional) parameters of the Evaluator
+    optional MlParams params = 2;
+    // (Required) the evaluating dataset
     Relation dataset = 3;
   }
 }
@@ -111,8 +111,10 @@ message MlCommandResult {
       // Operator name
       string name = 2;
     }
-    string uid = 3;
-    MlParams params = 4;
+    // (Optional) the 'uid' of a ML object
+    // Note it is different from the 'id' of a cached object.
+    optional string uid = 3;
+    // (Optional) parameters
+    optional MlParams params = 4;
   }
-
 }
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml_common.proto 
b/sql/connect/common/src/main/protobuf/spark/connect/ml_common.proto
index 48b5fa8135cc..06ca4e5db697 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/ml_common.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/ml_common.proto
@@ -33,24 +33,32 @@ message MlParams {
 
 // MLOperator represents the ML operators like (Estimator, Transformer or 
Evaluator)
 message MlOperator {
-  // The qualified name of the ML operator.
+  // (Required) The qualified name of the ML operator.
   string name = 1;
-  // Unique id of the ML operator
+
+  // (Required) Unique id of the ML operator
   string uid = 2;
-  // Represents what the ML operator is
+
+  // (Required) Represents what the ML operator is
   OperatorType type = 3;
+
   enum OperatorType {
-    UNSPECIFIED = 0;
-    ESTIMATOR = 1;
-    TRANSFORMER = 2;
-    EVALUATOR = 3;
-    MODEL = 4;
+    OPERATOR_TYPE_UNSPECIFIED = 0;
+    // ML estimator
+    OPERATOR_TYPE_ESTIMATOR = 1;
+    // ML transformer (non-model)
+    OPERATOR_TYPE_TRANSFORMER = 2;
+    // ML evaluator
+    OPERATOR_TYPE_EVALUATOR = 3;
+    // ML model
+    OPERATOR_TYPE_MODEL = 4;
   }
 }
 
 // Represents a reference to the cached object which could be a model
 // or summary evaluated by a model
 message ObjectRef {
-  // The ID is used to lookup the object on the server side.
+  // (Required) The ID is used to lookup the object on the server side.
+  // Note it is different from the 'uid' of a ML object.
   string id = 1;
 }
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index d4ef1eee5c24..08080c099200 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -114,7 +114,7 @@ private[connect] object MLHandler extends Logging {
       case proto.MlCommand.CommandCase.FIT =>
         val fitCmd = mlCommand.getFit
         val estimatorProto = fitCmd.getEstimator
-        assert(estimatorProto.getType == 
proto.MlOperator.OperatorType.ESTIMATOR)
+        assert(estimatorProto.getType == 
proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR)
 
         val dataset = MLUtils.parseRelationProto(fitCmd.getDataset, 
sessionHolder)
         val estimator =
@@ -197,21 +197,21 @@ private[connect] object MLHandler extends Logging {
             val params = Some(writer.getParams)
 
             operatorType match {
-              case proto.MlOperator.OperatorType.ESTIMATOR =>
+              case proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR =>
                 val estimator = MLUtils.getEstimator(sessionHolder, 
writer.getOperator, params)
                 estimator match {
                   case writable: MLWritable => MLUtils.write(writable, 
mlCommand.getWrite)
                   case other => throw MlUnsupportedException(s"Estimator 
$other is not writable")
                 }
 
-              case proto.MlOperator.OperatorType.EVALUATOR =>
+              case proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR =>
                 val evaluator = MLUtils.getEvaluator(sessionHolder, 
writer.getOperator, params)
                 evaluator match {
                   case writable: MLWritable => MLUtils.write(writable, 
mlCommand.getWrite)
                   case other => throw MlUnsupportedException(s"Evaluator 
$other is not writable")
                 }
 
-              case proto.MlOperator.OperatorType.TRANSFORMER =>
+              case proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER =>
                 val transformer =
                   MLUtils.getTransformer(sessionHolder, writer.getOperator, 
params)
                 transformer match {
@@ -232,7 +232,7 @@ private[connect] object MLHandler extends Logging {
         val name = operator.getName
         val path = mlCommand.getRead.getPath
 
-        if (operator.getType == proto.MlOperator.OperatorType.MODEL) {
+        if (operator.getType == 
proto.MlOperator.OperatorType.OPERATOR_TYPE_MODEL) {
           val model = MLUtils.loadTransformer(sessionHolder, name, path)
           val id = mlCache.register(model)
           return proto.MlCommandResult
@@ -244,18 +244,21 @@ private[connect] object MLHandler extends Logging {
                 .setUid(model.uid)
                 .setParams(Serializer.serializeParams(model)))
             .build()
-
         }
 
-        val mlOperator = if (operator.getType == 
proto.MlOperator.OperatorType.ESTIMATOR) {
-          MLUtils.loadEstimator(sessionHolder, name, path).asInstanceOf[Params]
-        } else if (operator.getType == 
proto.MlOperator.OperatorType.EVALUATOR) {
-          MLUtils.loadEvaluator(sessionHolder, name, path).asInstanceOf[Params]
-        } else if (operator.getType == 
proto.MlOperator.OperatorType.TRANSFORMER) {
-          MLUtils.loadTransformer(sessionHolder, name, 
path).asInstanceOf[Params]
-        } else {
-          throw MlUnsupportedException(s"${operator.getType} read not 
supported")
-        }
+        val mlOperator =
+          if (operator.getType ==
+              proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) {
+            MLUtils.loadEstimator(sessionHolder, name, 
path).asInstanceOf[Params]
+          } else if (operator.getType ==
+              proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR) {
+            MLUtils.loadEvaluator(sessionHolder, name, 
path).asInstanceOf[Params]
+          } else if (operator.getType ==
+              proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER) {
+            MLUtils.loadTransformer(sessionHolder, name, 
path).asInstanceOf[Params]
+          } else {
+            throw MlUnsupportedException(s"${operator.getType} read not 
supported")
+          }
 
         proto.MlCommandResult
           .newBuilder()
@@ -270,7 +273,7 @@ private[connect] object MLHandler extends Logging {
       case proto.MlCommand.CommandCase.EVALUATE =>
         val evalCmd = mlCommand.getEvaluate
         val evalProto = evalCmd.getEvaluator
-        assert(evalProto.getType == proto.MlOperator.OperatorType.EVALUATOR)
+        assert(evalProto.getType == 
proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR)
 
         val dataset = MLUtils.parseRelationProto(evalCmd.getDataset, 
sessionHolder)
         val evaluator =
@@ -295,7 +298,7 @@ private[connect] object MLHandler extends Logging {
             val transformProto = relation.getTransform
             assert(
               transformProto.getTransformer.getType ==
-                proto.MlOperator.OperatorType.TRANSFORMER)
+                proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER)
             val dataset = MLUtils.parseRelationProto(transformProto.getInput, 
sessionHolder)
             val transformer = MLUtils.getTransformer(sessionHolder, 
transformProto)
             transformer.transform(dataset)
@@ -323,5 +326,4 @@ private[connect] object MLHandler extends Logging {
       case other => throw MlUnsupportedException(s"$other not supported")
     }
   }
-
 }
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index c999772b7d82..3647fa3d9dae 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -693,7 +693,7 @@ private[ml] object MLUtils {
   }
 
   def write(instance: MLWritable, writeProto: proto.MlCommand.Write): Unit = {
-    val writer = if (writeProto.getShouldOverwrite) {
+    val writer = if (writeProto.hasShouldOverwrite && 
writeProto.getShouldOverwrite) {
       instance.write.overwrite()
     } else {
       instance.write
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
index 5b2b5e6dd793..f7788fb3cd1a 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
@@ -42,7 +42,7 @@ class MLBackendSuite extends MLHelper {
       .newBuilder()
       .setName(name)
       .setUid(name)
-      .setType(proto.MlOperator.OperatorType.ESTIMATOR)
+      .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR)
   }
 
   private def getMaxIterBuilder: proto.MlParams.Builder = {
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
index 5a447189d870..5939b673501b 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
@@ -98,7 +98,7 @@ trait MLHelper extends SparkFunSuite with 
SparkConnectPlanTest {
       .newBuilder()
       .setName("org.apache.spark.ml.classification.LogisticRegression")
       .setUid("LogisticRegression")
-      .setType(proto.MlOperator.OperatorType.ESTIMATOR)
+      .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR)
 
   def getMaxIter: proto.MlParams.Builder =
     proto.MlParams
@@ -110,7 +110,7 @@ trait MLHelper extends SparkFunSuite with 
SparkConnectPlanTest {
       .newBuilder()
       .setName("org.apache.spark.ml.evaluation.RegressionEvaluator")
       .setUid("RegressionEvaluator")
-      .setType(proto.MlOperator.OperatorType.EVALUATOR)
+      .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR)
 
   def getMetricName: proto.MlParams.Builder =
     proto.MlParams
@@ -149,7 +149,7 @@ trait MLHelper extends SparkFunSuite with 
SparkConnectPlanTest {
       .newBuilder()
       .setUid("vec")
       .setName("org.apache.spark.ml.feature.VectorAssembler")
-      .setType(proto.MlOperator.OperatorType.TRANSFORMER)
+      .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_TRANSFORMER)
 
   def getVectorAssemblerParams: proto.MlParams.Builder =
     proto.MlParams
@@ -220,7 +220,7 @@ trait MLHelper extends SparkFunSuite with 
SparkConnectPlanTest {
               proto.MlOperator
                 .newBuilder()
                 .setName(clsName)
-                .setType(proto.MlOperator.OperatorType.MODEL))
+                .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_MODEL))
             .setPath(path))
         .build()
 
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index cc24a2a67439..0d0fbc4b1b7b 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -250,7 +250,7 @@ class MLSuite extends MLHelper {
                 .newBuilder()
                 .setName("org.apache.spark.ml.NotExistingML")
                 .setUid("FakedUid")
-                .setType(proto.MlOperator.OperatorType.ESTIMATOR)))
+                
.setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR)))
         .build()
       MLHandler.handleMlCommand(sessionHolder, command)
     }
@@ -280,7 +280,7 @@ class MLSuite extends MLHelper {
             .setOperator(proto.MlOperator
               .newBuilder()
               
.setName("org.apache.spark.sql.connect.ml.NotImplementingMLReadble")
-              .setType(proto.MlOperator.OperatorType.ESTIMATOR))
+              .setType(proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR))
             .setPath("/tmp/fake"))
         .build()
       MLHandler.handleMlCommand(sessionHolder, readCmd)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to