This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 01918bb9017 [SPARK-43983][PYTHON][ML][CONNECT] Implement cross 
validator estimator
01918bb9017 is described below

commit 01918bb90170c13abd6c0f0f5c47f5d9bcc02adc
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Tue Jul 11 07:00:19 2023 +0800

    [SPARK-43983][PYTHON][ML][CONNECT] Implement cross validator estimator
    
    ### What changes were proposed in this pull request?
    
    Implement cross validator estimator for spark connect.
    
    ### Why are the changes needed?
    
    Distributed ML on spark connect project.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    New class `pyspark.ml.connect.tuning.CrossValidator` and 
`pyspark.ml.connect.tuning.CrossValidatorModel` are added.
    
    ### How was this patch tested?
    
    Unit tests.
    
    Closes #41881 from WeichenXu123/SPARK-43983-cross-val.
    
    Authored-by: Weichen Xu <weichen...@databricks.com>
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
---
 dev/sparktestsupport/modules.py                    |   2 +
 python/pyspark/ml/connect/__init__.py              |   2 +
 python/pyspark/ml/connect/base.py                  |  18 +-
 python/pyspark/ml/connect/evaluation.py            |  83 ++-
 python/pyspark/ml/connect/io_utils.py              |  76 ++-
 python/pyspark/ml/connect/pipeline.py              |  47 +-
 python/pyspark/ml/connect/tuning.py                | 566 +++++++++++++++++++++
 .../ml/tests/connect/test_connect_tuning.py        |  45 ++
 .../tests/connect/test_legacy_mode_evaluation.py   |  31 ++
 .../ml/tests/connect/test_legacy_mode_tuning.py    | 267 ++++++++++
 10 files changed, 1080 insertions(+), 57 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 439ae40a0f8..2090546512f 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -620,6 +620,7 @@ pyspark_ml = Module(
         "pyspark.ml.tests.connect.test_legacy_mode_feature",
         "pyspark.ml.tests.connect.test_legacy_mode_classification",
         "pyspark.ml.tests.connect.test_legacy_mode_pipeline",
+        "pyspark.ml.tests.connect.test_legacy_mode_tuning",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy and it 
isn't available there
@@ -866,6 +867,7 @@ pyspark_connect = Module(
         "pyspark.ml.tests.connect.test_connect_feature",
         "pyspark.ml.tests.connect.test_connect_classification",
         "pyspark.ml.tests.connect.test_connect_pipeline",
+        "pyspark.ml.tests.connect.test_connect_tuning",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy, 
pandas, and pyarrow and
diff --git a/python/pyspark/ml/connect/__init__.py 
b/python/pyspark/ml/connect/__init__.py
index 2e048355d74..2ee152f6a38 100644
--- a/python/pyspark/ml/connect/__init__.py
+++ b/python/pyspark/ml/connect/__init__.py
@@ -26,6 +26,7 @@ from pyspark.ml.connect.base import (
 from pyspark.ml.connect import (
     feature,
     evaluation,
+    tuning,
 )
 
 from pyspark.ml.connect.pipeline import Pipeline, PipelineModel
@@ -39,4 +40,5 @@ __all__ = [
     "evaluation",
     "Pipeline",
     "PipelineModel",
+    "tuning",
 ]
diff --git a/python/pyspark/ml/connect/base.py 
b/python/pyspark/ml/connect/base.py
index f86b1e928c2..f8ce0cb6962 100644
--- a/python/pyspark/ml/connect/base.py
+++ b/python/pyspark/ml/connect/base.py
@@ -146,7 +146,9 @@ class Transformer(Params, metaclass=ABCMeta):
         """
         raise NotImplementedError()
 
-    def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> 
Union[DataFrame, pd.DataFrame]:
+    def transform(
+        self, dataset: Union[DataFrame, pd.DataFrame], params: 
Optional["ParamMap"] = None
+    ) -> Union[DataFrame, pd.DataFrame]:
         """
         Transforms the input dataset.
         The dataset can be either pandas dataframe or spark dataframe,
@@ -163,12 +165,24 @@ class Transformer(Params, metaclass=ABCMeta):
         dataset : :py:class:`pyspark.sql.DataFrame` or 
py:class:`pandas.DataFrame`
             input dataset.
 
+        params : dict, optional
+            an optional param map that overrides embedded params.
+
         Returns
         -------
         :py:class:`pyspark.sql.DataFrame` or py:class:`pandas.DataFrame`
             transformed dataset, the type of output dataframe is consistent 
with
             input dataframe.
         """
+        if params is None:
+            params = dict()
+        if isinstance(params, dict):
+            if params:
+                return self.copy(params)._transform(dataset)
+            else:
+                return self._transform(dataset)
+
+    def _transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> 
Union[DataFrame, pd.DataFrame]:
         input_cols = self._input_columns()
         transform_fn = self._get_transform_fn()
         output_cols = self._output_columns()
@@ -249,7 +263,7 @@ class Evaluator(Params, metaclass=ABCMeta):
         (True, default) or minimized (False).
         A given evaluator may support multiple metrics which may be maximized 
or minimized.
         """
-        return True
+        raise NotImplementedError()
 
 
 @inherit_doc
diff --git a/python/pyspark/ml/connect/evaluation.py 
b/python/pyspark/ml/connect/evaluation.py
index 0606c7cad7d..88e4e5ab006 100644
--- a/python/pyspark/ml/connect/evaluation.py
+++ b/python/pyspark/ml/connect/evaluation.py
@@ -15,10 +15,10 @@
 # limitations under the License.
 #
 import numpy as np
-
 import pandas as pd
 from typing import Any, Union, List, Tuple
 
+from pyspark import keyword_only
 from pyspark.ml.param import Param, Params, TypeConverters
 from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, 
HasProbabilityCol
 from pyspark.ml.connect.base import Evaluator
@@ -36,6 +36,14 @@ class _TorchMetricEvaluator(Evaluator):
         typeConverter=TypeConverters.toString,
     )
 
+    def getMetricName(self) -> str:
+        """
+        Gets the value of metricName or its default value.
+
+        .. versionadded:: 3.5.0
+        """
+        return self.getOrDefault(self.metricName)
+
     def _get_torch_metric(self) -> Any:
         raise NotImplementedError()
 
@@ -68,15 +76,36 @@ class _TorchMetricEvaluator(Evaluator):
         )
 
 
+def _get_rmse_torchmetric() -> Any:
+    import torch
+    import torcheval.metrics as torchmetrics
+
+    class _RootMeanSquaredError(torchmetrics.MeanSquaredError):
+        def compute(self: Any) -> torch.Tensor:
+            return torch.sqrt(super().compute())
+
+    return _RootMeanSquaredError()
+
+
 class RegressionEvaluator(_TorchMetricEvaluator, HasLabelCol, 
HasPredictionCol, ParamsReadWrite):
     """
     Evaluator for Regression, which expects input columns prediction and label.
-    Supported metrics are 'mse' and 'r2'.
+    Supported metrics are 'rmse', 'mse' and 'r2'.
 
     .. versionadded:: 3.5.0
     """
 
-    def __init__(self, metricName: str, labelCol: str, predictionCol: str) -> 
None:
+    @keyword_only
+    def __init__(
+        self,
+        *,
+        metricName: str = "rmse",
+        labelCol: str = "label",
+        predictionCol: str = "prediction",
+    ) -> None:
+        """
+        __init__(self, *, metricName='rmse', labelCol='label', 
predictionCol='prediction') -> None:
+        """
         super().__init__()
         self._set(metricName=metricName, labelCol=labelCol, 
predictionCol=predictionCol)
 
@@ -89,6 +118,8 @@ class RegressionEvaluator(_TorchMetricEvaluator, 
HasLabelCol, HasPredictionCol,
             return torchmetrics.MeanSquaredError()
         if metric_name == "r2":
             return torchmetrics.R2Score()
+        if metric_name == "rmse":
+            return _get_rmse_torchmetric()
 
         raise ValueError(f"Unsupported regressor evaluator metric name: 
{metric_name}")
 
@@ -102,6 +133,12 @@ class RegressionEvaluator(_TorchMetricEvaluator, 
HasLabelCol, HasPredictionCol,
         labels_tensor = torch.tensor(dataset[self.getLabelCol()].values)
         return preds_tensor, labels_tensor
 
+    def isLargerBetter(self) -> bool:
+        if self.getOrDefault(self.metricName) == "r2":
+            return True
+
+        return False
+
 
 class BinaryClassificationEvaluator(
     _TorchMetricEvaluator, HasLabelCol, HasProbabilityCol, ParamsReadWrite
@@ -113,7 +150,23 @@ class BinaryClassificationEvaluator(
     .. versionadded:: 3.5.0
     """
 
-    def __init__(self, metricName: str, labelCol: str, probabilityCol: str) -> 
None:
+    @keyword_only
+    def __init__(
+        self,
+        *,
+        metricName: str = "areaUnderROC",
+        labelCol: str = "label",
+        probabilityCol: str = "probability",
+    ) -> None:
+        """
+        __init__(
+            self,
+            *,
+            metricName='rmse',
+            labelCol='label',
+            probabilityCol='probability'
+        ) -> None:
+        """
         super().__init__()
         self._set(metricName=metricName, labelCol=labelCol, 
probabilityCol=probabilityCol)
 
@@ -142,6 +195,9 @@ class BinaryClassificationEvaluator(
         labels_tensor = torch.tensor(dataset[self.getLabelCol()].values)
         return preds_tensor, labels_tensor
 
+    def isLargerBetter(self) -> bool:
+        return True
+
 
 class MulticlassClassificationEvaluator(
     _TorchMetricEvaluator, HasLabelCol, HasPredictionCol, ParamsReadWrite
@@ -153,7 +209,21 @@ class MulticlassClassificationEvaluator(
     .. versionadded:: 3.5.0
     """
 
-    def __init__(self, metricName: str, labelCol: str, predictionCol: str) -> 
None:
+    def __init__(
+        self,
+        metricName: str = "accuracy",
+        labelCol: str = "label",
+        predictionCol: str = "prediction",
+    ) -> None:
+        """
+        __init__(
+            self,
+            *,
+            metricName='accuracy',
+            labelCol='label',
+            predictionCol='prediction'
+        ) -> None:
+        """
         super().__init__()
         self._set(metricName=metricName, labelCol=labelCol, 
predictionCol=predictionCol)
 
@@ -178,3 +248,6 @@ class MulticlassClassificationEvaluator(
         preds_tensor = torch.tensor(dataset[self.getPredictionCol()].values)
         labels_tensor = torch.tensor(dataset[self.getLabelCol()].values)
         return preds_tensor, labels_tensor
+
+    def isLargerBetter(self) -> bool:
+        return True
diff --git a/python/pyspark/ml/connect/io_utils.py 
b/python/pyspark/ml/connect/io_utils.py
index 7c3025849da..9a963086aaf 100644
--- a/python/pyspark/ml/connect/io_utils.py
+++ b/python/pyspark/ml/connect/io_utils.py
@@ -130,12 +130,7 @@ class ParamsReadWrite(Params):
         pass
 
     def _save_to_local(self, path: str) -> None:
-        metadata = self._get_metadata_to_save()
-        if isinstance(self, CoreModelReadWrite):
-            core_model_path = self._get_core_model_filename()
-            self._save_core_model(os.path.join(path, core_model_path))
-            metadata["core_model_path"] = core_model_path
-
+        metadata = self._save_to_node_path(path, [])
         with open(os.path.join(path, _META_DATA_FILE_NAME), "w") as fp:
             json.dump(metadata, fp)
 
@@ -158,7 +153,7 @@ class ParamsReadWrite(Params):
         self._save_to_local(path)
 
     @classmethod
-    def _load_from_metadata(cls, metadata: Dict[str, Any]) -> "Params":
+    def _load_metadata(cls, metadata: Dict[str, Any]) -> "Params":
         if "type" not in metadata or metadata["type"] != "spark_connect":
             raise RuntimeError(
                 "The saved data is not saved by ML algorithm implemented in 
'pyspark.ml.connect' "
@@ -184,18 +179,25 @@ class ParamsReadWrite(Params):
         return instance
 
     @classmethod
-    def _load_from_local(cls, path: str) -> "Params":
-        with open(os.path.join(path, _META_DATA_FILE_NAME), "r") as fp:
-            metadata = json.load(fp)
-
-        instance = cls._load_from_metadata(metadata)
+    def _load_instance_from_metadata(cls, metadata: Dict[str, Any], path: str) 
-> Any:
+        instance = cls._load_metadata(metadata)
 
         if isinstance(instance, CoreModelReadWrite):
             core_model_path = metadata["core_model_path"]
             instance._load_core_model(os.path.join(path, core_model_path))
 
+        if isinstance(instance, MetaAlgorithmReadWrite):
+            instance._load_meta_algorithm(path, metadata)
+
         return instance
 
+    @classmethod
+    def _load_from_local(cls, path: str) -> "Params":
+        with open(os.path.join(path, _META_DATA_FILE_NAME), "r") as fp:
+            metadata = json.load(fp)
+
+        return cls._load_instance_from_metadata(metadata, path)
+
     @classmethod
     def loadFromLocal(cls, path: str) -> "Params":
         """
@@ -205,6 +207,21 @@ class ParamsReadWrite(Params):
         """
         return cls._load_from_local(path)
 
+    def _save_to_node_path(self, root_path: str, node_path: List[str]) -> Any:
+        """
+        Save the instance to provided node path, and return the node metadata.
+        """
+        if isinstance(self, MetaAlgorithmReadWrite):
+            metadata = self._save_meta_algorithm(root_path, node_path)
+        else:
+            metadata = self._get_metadata_to_save()
+            if isinstance(self, CoreModelReadWrite):
+                core_model_path = ".".join(node_path + 
[self._get_core_model_filename()])
+                self._save_core_model(os.path.join(root_path, core_model_path))
+                metadata["core_model_path"] = core_model_path
+
+        return metadata
+
     def save(self, path: str, *, overwrite: bool = False) -> None:
         """
         Save Estimator / Transformer / Model / Evaluator to provided cloud 
storage path.
@@ -283,23 +300,36 @@ class MetaAlgorithmReadWrite(ParamsReadWrite):
     Meta-algorithm such as pipeline and cross validator must implement this 
interface.
     """
 
+    def _get_child_stages(self) -> List[Any]:
+        raise NotImplementedError()
+
     def _save_meta_algorithm(self, root_path: str, node_path: List[str]) -> 
Dict[str, Any]:
         raise NotImplementedError()
 
     def _load_meta_algorithm(self, root_path: str, node_metadata: Dict[str, 
Any]) -> None:
         raise NotImplementedError()
 
-    def _save_to_local(self, path: str) -> None:
-        metadata = self._save_meta_algorithm(path, [])
-        with open(os.path.join(path, _META_DATA_FILE_NAME), "w") as fp:
-            json.dump(metadata, fp)
+    @staticmethod
+    def _get_all_nested_stages(instance: Any) -> List[Any]:
+        if isinstance(instance, MetaAlgorithmReadWrite):
+            child_stages = instance._get_child_stages()
+        else:
+            child_stages = []
 
-    @classmethod
-    def _load_from_local(cls, path: str) -> Any:
-        with open(os.path.join(path, _META_DATA_FILE_NAME), "r") as fp:
-            metadata = json.load(fp)
+        nested_stages = []
+        for stage in child_stages:
+            
nested_stages.extend(MetaAlgorithmReadWrite._get_all_nested_stages(stage))
 
-        instance = cls._load_from_metadata(metadata)
-        instance._load_meta_algorithm(path, metadata)  # type: 
ignore[attr-defined]
+        return [instance] + nested_stages
 
-        return instance
+    @staticmethod
+    def get_uid_map(instance: Any) -> Dict[str, Any]:
+        all_nested_stages = 
MetaAlgorithmReadWrite._get_all_nested_stages(instance)
+        uid_map = {stage.uid: stage for stage in all_nested_stages}
+        if len(all_nested_stages) != len(uid_map):
+            raise RuntimeError(
+                
f"{instance.__class__.__module__}.{instance.__class__.__name__}"
+                f"is a compound estimator with stages with duplicate "
+                f"UIDs. List of UIDs: {list(uid_map.keys())}."
+            )
+        return uid_map
diff --git a/python/pyspark/ml/connect/pipeline.py 
b/python/pyspark/ml/connect/pipeline.py
index 90eca01d378..64232db1f09 100644
--- a/python/pyspark/ml/connect/pipeline.py
+++ b/python/pyspark/ml/connect/pipeline.py
@@ -14,14 +14,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-import os
-
 import pandas as pd
 from typing import Any, Dict, List, Optional, Union, cast, TYPE_CHECKING
 
 from pyspark import keyword_only, since
 from pyspark.ml.connect.base import Estimator, Model, Transformer
-from pyspark.ml.connect.io_utils import ParamsReadWrite, 
MetaAlgorithmReadWrite, CoreModelReadWrite
+from pyspark.ml.connect.io_utils import (
+    ParamsReadWrite,
+    MetaAlgorithmReadWrite,
+)
 from pyspark.ml.param import Param, Params
 from pyspark.ml.common import inherit_doc
 from pyspark.sql.dataframe import DataFrame
@@ -32,6 +33,14 @@ if TYPE_CHECKING:
 
 
 class _PipelineReadWrite(MetaAlgorithmReadWrite):
+    def _get_child_stages(self) -> List[Any]:
+        if isinstance(self, Pipeline):
+            return list(self.getStages())
+        elif isinstance(self, PipelineModel):
+            return list(self.stages)
+        else:
+            raise ValueError(f"Unknown type {self.__class__}")
+
     def _get_skip_saving_params(self) -> List[str]:
         """
         Returns params to be skipped when saving metadata.
@@ -47,36 +56,20 @@ class _PipelineReadWrite(MetaAlgorithmReadWrite):
         elif isinstance(self, PipelineModel):
             stages = self.stages
         else:
-            raise ValueError()
+            raise ValueError(f"Unknown type {self.__class__}")
 
         for stage_index, stage in enumerate(stages):
-            stage_name = f"pipeline_stage_{stage_index}"
-            node_path.append(stage_name)
-            if isinstance(stage, MetaAlgorithmReadWrite):
-                stage_metadata = stage._save_meta_algorithm(root_path, 
node_path)
-            else:
-                stage_metadata = stage._get_metadata_to_save()  # type: 
ignore[attr-defined]
-                if isinstance(stage, CoreModelReadWrite):
-                    core_model_path = ".".join(node_path + 
[stage._get_core_model_filename()])
-                    stage._save_core_model(os.path.join(root_path, 
core_model_path))
-                    stage_metadata["core_model_path"] = core_model_path
-
+            stage_node_path = node_path + [f"pipeline_stage_{stage_index}"]
+            stage_metadata = stage._save_to_node_path(  # type: 
ignore[attr-defined]
+                root_path, stage_node_path
+            )
             metadata["stages"].append(stage_metadata)
-            node_path.pop()
         return metadata
 
     def _load_meta_algorithm(self, root_path: str, node_metadata: Dict[str, 
Any]) -> None:
         stages = []
         for stage_meta in node_metadata["stages"]:
-            stage = ParamsReadWrite._load_from_metadata(stage_meta)
-
-            if isinstance(stage, MetaAlgorithmReadWrite):
-                stage._load_meta_algorithm(root_path, stage_meta)
-
-            if isinstance(stage, CoreModelReadWrite):
-                core_model_path = stage_meta["core_model_path"]
-                stage._load_core_model(os.path.join(root_path, 
core_model_path))
-
+            stage = ParamsReadWrite._load_instance_from_metadata(stage_meta, 
root_path)
             stages.append(stage)
 
         if isinstance(self, Pipeline):
@@ -221,9 +214,9 @@ class PipelineModel(Model, _PipelineReadWrite):
         super(PipelineModel, self).__init__()
         self.stages = stages  # type: ignore[assignment]
 
-    def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> 
Union[DataFrame, pd.DataFrame]:
+    def _transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> 
Union[DataFrame, pd.DataFrame]:
         for t in self.stages:
-            dataset = t.transform(dataset)  # type: ignore[attr-defined]
+            dataset = t.transform(dataset)
         return dataset
 
     def copy(self, extra: Optional["ParamMap"] = None) -> "PipelineModel":
diff --git a/python/pyspark/ml/connect/tuning.py 
b/python/pyspark/ml/connect/tuning.py
new file mode 100644
index 00000000000..2c34f4d57ff
--- /dev/null
+++ b/python/pyspark/ml/connect/tuning.py
@@ -0,0 +1,566 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from multiprocessing.pool import ThreadPool
+
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+    cast,
+    TYPE_CHECKING,
+)
+
+import numpy as np
+import pandas as pd
+
+from pyspark import keyword_only, since, inheritable_thread_target
+from pyspark.ml.connect import Estimator, Model
+from pyspark.ml.connect.base import Evaluator
+from pyspark.ml.connect.io_utils import (
+    MetaAlgorithmReadWrite,
+    ParamsReadWrite,
+)
+from pyspark.ml.param import Params, Param, TypeConverters
+from pyspark.ml.param.shared import HasParallelism, HasSeed
+from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
+from pyspark.sql.types import BooleanType
+from pyspark.sql.dataframe import DataFrame
+from pyspark.sql import SparkSession
+
+from pyspark.sql.utils import is_remote
+
+
+if TYPE_CHECKING:
+    from pyspark.ml._typing import ParamMap
+
+
+class _ValidatorParams(HasSeed):
+    """
+    Common params for TrainValidationSplit and CrossValidator.
+    """
+
+    estimator: Param[Estimator] = Param(
+        Params._dummy(), "estimator", "estimator to be cross-validated"
+    )
+    estimatorParamMaps: Param[List["ParamMap"]] = Param(
+        Params._dummy(), "estimatorParamMaps", "estimator param maps"
+    )
+    evaluator: Param[Evaluator] = Param(
+        Params._dummy(),
+        "evaluator",
+        "evaluator used to select hyper-parameters that maximize the validator 
metric",
+    )
+
+    @since("2.0.0")
+    def getEstimator(self) -> Estimator:
+        """
+        Gets the value of estimator or its default value.
+        """
+        return self.getOrDefault(self.estimator)
+
+    @since("2.0.0")
+    def getEstimatorParamMaps(self) -> List["ParamMap"]:
+        """
+        Gets the value of estimatorParamMaps or its default value.
+        """
+        return self.getOrDefault(self.estimatorParamMaps)
+
+    @since("2.0.0")
+    def getEvaluator(self) -> Evaluator:
+        """
+        Gets the value of evaluator or its default value.
+        """
+        return self.getOrDefault(self.evaluator)
+
+
+class _CrossValidatorParams(_ValidatorParams):
+    """
+    Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`.
+
+    .. versionadded:: 3.5.0
+    """
+
+    numFolds: Param[int] = Param(
+        Params._dummy(),
+        "numFolds",
+        "number of folds for cross validation",
+        typeConverter=TypeConverters.toInt,
+    )
+
+    foldCol: Param[str] = Param(
+        Params._dummy(),
+        "foldCol",
+        "Param for the column name of user "
+        + "specified fold number. Once this is specified, 
:py:class:`CrossValidator` "
+        + "won't do random k-fold split. Note that this column should be 
integer type "
+        + "with range [0, numFolds) and Spark will throw exception on 
out-of-range "
+        + "fold numbers.",
+        typeConverter=TypeConverters.toString,
+    )
+
+    def __init__(self, *args: Any):
+        super(_CrossValidatorParams, self).__init__(*args)
+        self._setDefault(numFolds=3, foldCol="")
+
+    @since("1.4.0")
+    def getNumFolds(self) -> int:
+        """
+        Gets the value of numFolds or its default value.
+        """
+        return self.getOrDefault(self.numFolds)
+
+    @since("3.1.0")
+    def getFoldCol(self) -> str:
+        """
+        Gets the value of foldCol or its default value.
+        """
+        return self.getOrDefault(self.foldCol)
+
+
+def _parallelFitTasks(
+    estimator: Estimator,
+    train: DataFrame,
+    evaluator: Evaluator,
+    validation: DataFrame,
+    epm: Sequence["ParamMap"],
+) -> List[Callable[[], Tuple[int, float]]]:
+    """
+    Creates a list of callables which can be called from different threads to 
fit and evaluate
+    an estimator in parallel. Each callable returns an `(index, metric)` pair.
+
+    Parameters
+    ----------
+    est : :py:class:`pyspark.ml.baseEstimator`
+        he estimator to be fit.
+    train : :py:class:`pyspark.sql.DataFrame`
+        DataFrame, training data set, used for fitting.
+    eva : :py:class:`pyspark.ml.evaluation.Evaluator`
+        used to compute `metric`
+    validation : :py:class:`pyspark.sql.DataFrame`
+        DataFrame, validation data set, used for evaluation.
+    epm : :py:class:`collections.abc.Sequence`
+        Sequence of ParamMap, params maps to be used during fitting & 
evaluation.
+    collectSubModel : bool
+        Whether to collect sub model.
+
+    Returns
+    -------
+    tuple
+        (int, float), an index into `epm` and the associated metric value.
+    """
+
+    active_session = SparkSession.getActiveSession()
+
+    if active_session is None:
+        raise RuntimeError(
+            "An active SparkSession is required for running cross valiator fit 
tasks."
+        )
+
+    def get_single_task(index: int, param_map: Any) -> Callable[[], Tuple[int, 
float]]:
+        def single_task() -> Tuple[int, float]:
+            # Active session is thread-local variable, in background thread 
the active session
+            # is not set, the following line sets it as the main thread active 
session.
+            active_session._jvm.SparkSession.setActiveSession(  # type: 
ignore[union-attr]
+                active_session._jsparkSession  # type: ignore[union-attr]
+            )
+
+            model = estimator.fit(train, param_map)
+            metric = evaluator.evaluate(
+                model.transform(validation, param_map)  # type: 
ignore[union-attr]
+            )
+            return index, metric
+
+        return single_task
+
+    return [get_single_task(index, param_map) for index, param_map in 
enumerate(epm)]
+
+
+class _CrossValidatorReadWrite(MetaAlgorithmReadWrite):
+    def _get_skip_saving_params(self) -> List[str]:
+        """
+        Returns params to be skipped when saving metadata.
+        """
+        return ["estimator", "estimatorParamMaps", "evaluator"]
+
+    def _save_meta_algorithm(self, root_path: str, node_path: List[str]) -> 
Dict[str, Any]:
+        metadata = self._get_metadata_to_save()
+        metadata[
+            "estimator"
+        ] = self.getEstimator()._save_to_node_path(  # type: 
ignore[attr-defined]
+            root_path, node_path + ["crossvalidator_estimator"]
+        )
+        metadata[
+            "evaluator"
+        ] = self.getEvaluator()._save_to_node_path(  # type: 
ignore[attr-defined]
+            root_path, node_path + ["crossvalidator_evaluator"]
+        )
+        metadata["estimator_param_maps"] = [
+            [
+                {"parent": param.parent, "name": param.name, "value": value}
+                for param, value in param_map.items()
+            ]
+            for param_map in self.getEstimatorParamMaps()  # type: 
ignore[attr-defined]
+        ]
+
+        if isinstance(self, CrossValidatorModel):
+            metadata["avg_metrics"] = self.avgMetrics
+            metadata["std_metrics"] = self.stdMetrics
+
+            metadata["best_model"] = self.bestModel._save_to_node_path(
+                root_path, node_path + ["crossvalidator_best_model"]
+            )
+        return metadata
+
+    def _load_meta_algorithm(self, root_path: str, node_metadata: Dict[str, 
Any]) -> None:
+        estimator = ParamsReadWrite._load_instance_from_metadata(
+            node_metadata["estimator"], root_path
+        )
+        self.set(self.estimator, estimator)  # type: ignore[attr-defined]
+
+        evaluator = ParamsReadWrite._load_instance_from_metadata(
+            node_metadata["evaluator"], root_path
+        )
+        self.set(self.evaluator, evaluator)  # type: ignore[attr-defined]
+
+        json_epm = node_metadata["estimator_param_maps"]
+
+        uid_to_instances = MetaAlgorithmReadWrite.get_uid_map(estimator)
+
+        epm = []
+        for json_param_map in json_epm:
+            param_map = {}
+            for json_param in json_param_map:
+                est = uid_to_instances[json_param["parent"]]
+                param = getattr(est, json_param["name"])
+                value = json_param["value"]
+                param_map[param] = value
+            epm.append(param_map)
+
+        self.set(self.estimatorParamMaps, epm)  # type: ignore[attr-defined]
+
+        if isinstance(self, CrossValidatorModel):
+            self.avgMetrics = node_metadata["avg_metrics"]
+            self.stdMetrics = node_metadata["std_metrics"]
+
+            self.bestModel = ParamsReadWrite._load_instance_from_metadata(
+                node_metadata["best_model"], root_path
+            )
+
+
+class CrossValidator(
+    Estimator["CrossValidatorModel"],
+    _CrossValidatorParams,
+    HasParallelism,
+    _CrossValidatorReadWrite,
+):
+    """
+
+    K-fold cross validation performs model selection by splitting the dataset 
into a set of
+    non-overlapping randomly partitioned folds which are used as separate 
training and test datasets
+    e.g., with k=3 folds, K-fold cross validation will generate 3 (training, 
test) dataset pairs,
+    each of which uses 2/3 of the data for training and 1/3 for testing. Each 
fold is used as the
+    test set exactly once.
+
+    .. versionadded:: 3.5.0
+    """
+
+    _input_kwargs: Dict[str, Any]
+
+    @keyword_only
+    def __init__(
+        self,
+        *,
+        estimator: Optional[Estimator] = None,
+        estimatorParamMaps: Optional[List["ParamMap"]] = None,
+        evaluator: Optional[Evaluator] = None,
+        numFolds: int = 3,
+        seed: Optional[int] = None,
+        parallelism: int = 1,
+        foldCol: str = "",
+    ) -> None:
+        """
+        __init__(self, \\*, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3,\
+                 seed=None, parallelism=1, foldCol="")
+        """
+        super(CrossValidator, self).__init__()
+        self._setDefault(parallelism=1)
+        kwargs = self._input_kwargs
+        self._set(**kwargs)
+
+    @keyword_only
+    @since("3.5.0")
+    def setParams(
+        self,
+        *,
+        estimator: Optional[Estimator] = None,
+        estimatorParamMaps: Optional[List["ParamMap"]] = None,
+        evaluator: Optional[Evaluator] = None,
+        numFolds: int = 3,
+        seed: Optional[int] = None,
+        parallelism: int = 1,
+        foldCol: str = "",
+    ) -> "CrossValidator":
+        """
+        setParams(self, \\*, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3,\
+                  seed=None, parallelism=1, collectSubModels=False, 
foldCol=""):
+        Sets params for cross validator.
+        """
+        kwargs = self._input_kwargs
+        return self._set(**kwargs)
+
+    @since("3.5.0")
+    def setEstimator(self, value: Estimator) -> "CrossValidator":
+        """
+        Sets the value of :py:attr:`estimator`.
+        """
+        return self._set(estimator=value)
+
+    @since("3.5.0")
+    def setEstimatorParamMaps(self, value: List["ParamMap"]) -> 
"CrossValidator":
+        """
+        Sets the value of :py:attr:`estimatorParamMaps`.
+        """
+        return self._set(estimatorParamMaps=value)
+
+    @since("3.5.0")
+    def setEvaluator(self, value: Evaluator) -> "CrossValidator":
+        """
+        Sets the value of :py:attr:`evaluator`.
+        """
+        return self._set(evaluator=value)
+
+    @since("3.5.0")
+    def setNumFolds(self, value: int) -> "CrossValidator":
+        """
+        Sets the value of :py:attr:`numFolds`.
+        """
+        return self._set(numFolds=value)
+
+    @since("3.5.0")
+    def setFoldCol(self, value: str) -> "CrossValidator":
+        """
+        Sets the value of :py:attr:`foldCol`.
+        """
+        return self._set(foldCol=value)
+
+    def setSeed(self, value: int) -> "CrossValidator":
+        """
+        Sets the value of :py:attr:`seed`.
+        """
+        return self._set(seed=value)
+
+    def setParallelism(self, value: int) -> "CrossValidator":
+        """
+        Sets the value of :py:attr:`parallelism`.
+        """
+        return self._set(parallelism=value)
+
+    def setCollectSubModels(self, value: bool) -> "CrossValidator":
+        """
+        Sets the value of :py:attr:`collectSubModels`.
+        """
+        return self._set(collectSubModels=value)
+
+    @staticmethod
+    def _gen_avg_and_std_metrics(
+        metrics_all: List[List[float]],
+    ) -> Tuple[List[float], List[float]]:
+        avg_metrics = np.mean(metrics_all, axis=0)
+        std_metrics = np.std(metrics_all, axis=0)
+        return list(avg_metrics), list(std_metrics)
+
+    def _fit(self, dataset: Union[pd.DataFrame, DataFrame]) -> 
"CrossValidatorModel":
+        if isinstance(dataset, pd.DataFrame):
+            # TODO: support pandas dataframe fitting
+            raise NotImplementedError("Fitting pandas dataframe is not 
supported yet.")
+
+        est = self.getOrDefault(self.estimator)
+        epm = self.getOrDefault(self.estimatorParamMaps)
+        numModels = len(epm)
+        eva = self.getOrDefault(self.evaluator)
+        nFolds = self.getOrDefault(self.numFolds)
+        metrics_all = [[0.0] * numModels for i in range(nFolds)]
+
+        pool = ThreadPool(processes=min(self.getParallelism(), numModels))
+
+        datasets = self._kFold(dataset)
+        for i in range(nFolds):
+            validation = datasets[i][1].cache()
+            train = datasets[i][0].cache()
+
+            tasks = _parallelFitTasks(est, train, eva, validation, epm)
+            if not is_remote():
+                tasks = list(map(inheritable_thread_target, tasks))
+
+            for j, metric in pool.imap_unordered(lambda f: f(), tasks):
+                metrics_all[i][j] = metric
+
+            validation.unpersist()
+            train.unpersist()
+
+        metrics, std_metrics = 
CrossValidator._gen_avg_and_std_metrics(metrics_all)
+
+        if eva.isLargerBetter():
+            bestIndex = np.argmax(metrics)
+        else:
+            bestIndex = np.argmin(metrics)
+        bestModel = cast(Model, est.fit(dataset, epm[bestIndex]))
+        cv_model = self._copyValues(
+            CrossValidatorModel(
+                bestModel,
+                avgMetrics=metrics,
+                stdMetrics=std_metrics,
+            )
+        )
+        cv_model._resetUid(self.uid)
+        return cv_model
+
+    def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]:
+        nFolds = self.getOrDefault(self.numFolds)
+        foldCol = self.getOrDefault(self.foldCol)
+
+        datasets = []
+        if not foldCol:
+            # Do random k-fold split.
+            seed = self.getOrDefault(self.seed)
+            h = 1.0 / nFolds
+            randCol = self.uid + "_rand"
+            df = dataset.select("*", rand(seed).alias(randCol))
+            for i in range(nFolds):
+                validateLB = i * h
+                validateUB = (i + 1) * h
+                condition = (df[randCol] >= validateLB) & (df[randCol] < 
validateUB)
+                validation = df.filter(condition)
+                train = df.filter(~condition)
+                datasets.append((train, validation))
+        else:
+            # Use user-specified fold numbers.
+            def checker(foldNum: int) -> bool:
+                if foldNum < 0 or foldNum >= nFolds:
+                    raise ValueError(
+                        "Fold number must be in range [0, %s), but got %s." % 
(nFolds, foldNum)
+                    )
+                return True
+
+            checker_udf = UserDefinedFunction(checker, BooleanType())
+            for i in range(nFolds):
+                training = dataset.filter(checker_udf(dataset[foldCol]) & 
(col(foldCol) != lit(i)))
+                validation = dataset.filter(
+                    checker_udf(dataset[foldCol]) & (col(foldCol) == lit(i))
+                )
+                if training.rdd.getNumPartitions() == 0 or 
len(training.take(1)) == 0:
+                    raise ValueError("The training data at fold %s is empty." 
% i)
+                if validation.rdd.getNumPartitions() == 0 or 
len(validation.take(1)) == 0:
+                    raise ValueError("The validation data at fold %s is 
empty." % i)
+                datasets.append((training, validation))
+
+        return datasets
+
+    def copy(self, extra: Optional["ParamMap"] = None) -> "CrossValidator":
+        """
+        Creates a copy of this instance with a randomly generated uid
+        and some extra params. This copies creates a deep copy of
+        the embedded paramMap, and copies the embedded and extra parameters 
over.
+
+
+        .. versionadded:: 3.5.0
+
+        Parameters
+        ----------
+        extra : dict, optional
+            Extra parameters to copy to the new instance
+
+        Returns
+        -------
+        :py:class:`CrossValidator`
+            Copy of this instance
+        """
+        if extra is None:
+            extra = dict()
+        newCV = Params.copy(self, extra)
+        if self.isSet(self.estimator):
+            newCV.setEstimator(self.getEstimator().copy(extra))
+        # estimatorParamMaps remain the same
+        if self.isSet(self.evaluator):
+            newCV.setEvaluator(self.getEvaluator().copy(extra))
+        return newCV
+
+
+class CrossValidatorModel(Model, _CrossValidatorParams, 
_CrossValidatorReadWrite):
+    """
+    CrossValidatorModel contains the model with the highest average 
cross-validation
+    metric across folds and uses this model to transform input data. 
CrossValidatorModel
+    also tracks the metrics for each param map evaluated.
+
+    .. versionadded:: 3.5.0
+    """
+
+    def __init__(
+        self,
+        bestModel: Optional[Model] = None,
+        avgMetrics: Optional[List[float]] = None,
+        stdMetrics: Optional[List[float]] = None,
+    ) -> None:
+        super(CrossValidatorModel, self).__init__()
+        #: best model from cross validation
+        self.bestModel = bestModel
+        #: Average cross-validation metrics for each paramMap in
+        #: CrossValidator.estimatorParamMaps, in the corresponding order.
+        self.avgMetrics = avgMetrics or []
+        #: standard deviation of metrics for each paramMap in
+        #: CrossValidator.estimatorParamMaps, in the corresponding order.
+        self.stdMetrics = stdMetrics or []
+
+    def _transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> 
Union[DataFrame, pd.DataFrame]:
+        return self.bestModel.transform(dataset)
+
+    def copy(self, extra: Optional["ParamMap"] = None) -> 
"CrossValidatorModel":
+        """
+        Creates a copy of this instance with a randomly generated uid
+        and some extra params. This copies the underlying bestModel,
+        creates a deep copy of the embedded paramMap, and
+        copies the embedded and extra parameters over.
+        It does not copy the extra Params into the subModels.
+
+        .. versionadded:: 3.5.0
+
+        Parameters
+        ----------
+        extra : dict, optional
+            Extra parameters to copy to the new instance
+
+        Returns
+        -------
+        :py:class:`CrossValidatorModel`
+            Copy of this instance
+        """
+        if extra is None:
+            extra = dict()
+        bestModel = self.bestModel.copy(extra)
+        avgMetrics = list(self.avgMetrics)
+        stdMetrics = list(self.stdMetrics)
+
+        return self._copyValues(
+            CrossValidatorModel(bestModel, avgMetrics=avgMetrics, 
stdMetrics=stdMetrics),
+            extra=extra,
+        )
diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py 
b/python/pyspark/ml/tests/connect/test_connect_tuning.py
new file mode 100644
index 00000000000..18673d4b26b
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from pyspark.sql import SparkSession
+from pyspark.ml.tests.connect.test_legacy_mode_tuning import 
CrossValidatorTestsMixin
+
+
+class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, 
unittest.TestCase):
+    def setUp(self) -> None:
+        self.spark = (
+            SparkSession.builder.remote("local[2]")
+            .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
+            .getOrCreate()
+        )
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.connect.test_connect_tuning import *  # noqa: 
F401,F403
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
index 51c3bb26db8..9ff26c1f450 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
@@ -17,6 +17,7 @@
 #
 import unittest
 import numpy as np
+import tempfile
 
 from pyspark.ml.connect.evaluation import (
     RegressionEvaluator,
@@ -58,6 +59,18 @@ class EvaluationTestsMixin:
         np.testing.assert_almost_equal(mse, expected_mse)
         np.testing.assert_almost_equal(mse_local, expected_mse)
 
+        rmse_evaluator = RegressionEvaluator(
+            metricName="rmse",
+            labelCol="label",
+            predictionCol="prediction",
+        )
+
+        expected_rmse = 0.6683312709480042
+        rmse = rmse_evaluator.evaluate(df1)
+        rmse_local = rmse_evaluator.evaluate(local_df1)
+        np.testing.assert_almost_equal(rmse, expected_rmse)
+        np.testing.assert_almost_equal(rmse_local, expected_rmse)
+
         r2_evaluator = RegressionEvaluator(
             metricName="r2",
             labelCol="label",
@@ -70,6 +83,12 @@ class EvaluationTestsMixin:
         np.testing.assert_almost_equal(r2, expected_r2)
         np.testing.assert_almost_equal(r2_local, expected_r2)
 
+        # Test save / load
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            r2_evaluator.saveToLocal(f"{tmp_dir}/ev")
+            loaded_evaluator = 
RegressionEvaluator.loadFromLocal(f"{tmp_dir}/ev")
+            assert loaded_evaluator.getMetricName() == "r2"
+
     def test_binary_classifier_evaluator(self):
         df1 = self.spark.createDataFrame(
             [
@@ -110,6 +129,12 @@ class EvaluationTestsMixin:
             np.testing.assert_almost_equal(auprc, expected_auprc, decimal=2)
             np.testing.assert_almost_equal(auprc_local, expected_auprc, 
decimal=2)
 
+        # Test save / load
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            auprc_evaluator.saveToLocal(f"{tmp_dir}/ev")
+            loaded_evaluator = 
RegressionEvaluator.loadFromLocal(f"{tmp_dir}/ev")
+            assert loaded_evaluator.getMetricName() == "areaUnderPR"
+
     def test_multiclass_classifier_evaluator(self):
         df1 = self.spark.createDataFrame(
             [
@@ -141,6 +166,12 @@ class EvaluationTestsMixin:
         np.testing.assert_almost_equal(accuracy, expected_accuracy, decimal=2)
         np.testing.assert_almost_equal(accuracy_local, expected_accuracy, 
decimal=2)
 
+        # Test save / load
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            accuracy_evaluator.saveToLocal(f"{tmp_dir}/ev")
+            loaded_evaluator = 
RegressionEvaluator.loadFromLocal(f"{tmp_dir}/ev")
+            assert loaded_evaluator.getMetricName() == "accuracy"
+
 
 @unittest.skipIf(not have_torcheval, "torcheval is required")
 class EvaluationTests(EvaluationTestsMixin, unittest.TestCase):
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
new file mode 100644
index 00000000000..d6c813533d6
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
@@ -0,0 +1,267 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import unittest
+import numpy as np
+import pandas as pd
+from pyspark.ml.param import Param, Params
+from pyspark.ml.connect import Model, Estimator
+from pyspark.ml.connect.feature import StandardScaler
+from pyspark.ml.connect.classification import LogisticRegression as LORV2
+from pyspark.ml.connect.pipeline import Pipeline
+from pyspark.ml.connect.tuning import CrossValidator, CrossValidatorModel
+from pyspark.ml.connect.evaluation import BinaryClassificationEvaluator, 
RegressionEvaluator
+from pyspark.ml.tuning import ParamGridBuilder
+from pyspark.sql import SparkSession
+from pyspark.sql.functions import rand
+
+from sklearn.datasets import load_breast_cancer
+
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+
+class HasInducedError(Params):
+    def __init__(self):
+        super(HasInducedError, self).__init__()
+        self.inducedError = Param(
+            self, "inducedError", "Uniformly-distributed error added to 
feature"
+        )
+
+    def getInducedError(self):
+        return self.getOrDefault(self.inducedError)
+
+
+class InducedErrorModel(Model, HasInducedError):
+    def __init__(self):
+        super(InducedErrorModel, self).__init__()
+
+    def _transform(self, dataset):
+        return dataset.withColumn(
+            "prediction", dataset.feature + (rand(0) * self.getInducedError())
+        )
+
+
+class InducedErrorEstimator(Estimator, HasInducedError):
+    def __init__(self, inducedError=1.0):
+        super(InducedErrorEstimator, self).__init__()
+        self._set(inducedError=inducedError)
+
+    def _fit(self, dataset):
+        model = InducedErrorModel()
+        self._copyValues(model)
+        return model
+
+
+class CrossValidatorTestsMixin:
+    def test_gen_avg_and_std_metrics(self):
+        metrics_all = [
+            [1.0, 3.0, 2.0, 4.0],
+            [3.0, 2.0, 2.0, 4.0],
+            [3.0, 2.5, 2.1, 8.0],
+        ]
+        avg_metrics, std_metrics = 
CrossValidator._gen_avg_and_std_metrics(metrics_all)
+        assert np.allclose(avg_metrics, [2.33333333, 2.5, 2.03333333, 
5.33333333])
+        assert np.allclose(std_metrics, [0.94280904, 0.40824829, 0.04714045, 
1.88561808])
+        assert isinstance(avg_metrics, list)
+        assert isinstance(std_metrics, list)
+
+    def test_copy(self):
+        dataset = self.spark.createDataFrame(
+            [(10, 10.0), (50, 50.0), (100, 100.0), (500, 500.0)] * 10, 
["feature", "label"]
+        )
+
+        iee = InducedErrorEstimator()
+        evaluator = RegressionEvaluator(metricName="rmse")
+
+        grid = ParamGridBuilder().addGrid(iee.inducedError, [100.0, 0.0, 
10000.0]).build()
+        cv = CrossValidator(
+            estimator=iee,
+            estimatorParamMaps=grid,
+            evaluator=evaluator,
+            numFolds=2,
+        )
+        cvCopied = cv.copy()
+        for param in [
+            lambda x: x.getEstimator().uid,
+            # SPARK-32092: CrossValidator.copy() needs to copy all existing 
params
+            lambda x: x.getNumFolds(),
+            lambda x: x.getFoldCol(),
+            lambda x: x.getParallelism(),
+            lambda x: x.getSeed(),
+        ]:
+            self.assertEqual(param(cv), param(cvCopied))
+
+        cvModel = cv.fit(dataset)
+        cvModelCopied = cvModel.copy()
+        for index in range(len(cvModel.avgMetrics)):
+            self.assertTrue(
+                abs(cvModel.avgMetrics[index] - 
cvModelCopied.avgMetrics[index]) < 0.0001
+            )
+        self.assertTrue(np.allclose(cvModel.stdMetrics, 
cvModelCopied.stdMetrics))
+        # SPARK-32092: CrossValidatorModel.copy() needs to copy all existing 
params
+        for param in [lambda x: x.getNumFolds(), lambda x: x.getFoldCol(), 
lambda x: x.getSeed()]:
+            self.assertEqual(param(cvModel), param(cvModelCopied))
+
+        cvModel.avgMetrics[0] = "foo"
+        self.assertNotEqual(
+            cvModelCopied.avgMetrics[0],
+            "foo",
+            "Changing the original avgMetrics should not affect the copied 
model",
+        )
+        cvModel.stdMetrics[0] = "foo"
+        self.assertNotEqual(
+            cvModelCopied.stdMetrics[0],
+            "foo",
+            "Changing the original stdMetrics should not affect the copied 
model",
+        )
+
+    def test_fit_minimize_metric(self):
+        dataset = self.spark.createDataFrame(
+            [(10, 10.0), (50, 50.0), (100, 100.0), (500, 500.0)] * 10, 
["feature", "label"]
+        )
+
+        iee = InducedErrorEstimator()
+        evaluator = RegressionEvaluator(metricName="rmse")
+
+        grid = ParamGridBuilder().addGrid(iee.inducedError, [100.0, 0.0, 
10000.0]).build()
+        cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, 
evaluator=evaluator)
+        cvModel = cv.fit(dataset)
+        bestModel = cvModel.bestModel
+        bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+
+        self.assertEqual(
+            0.0, bestModel.getOrDefault("inducedError"), "Best model should 
have zero induced error"
+        )
+        self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
+
+    def test_fit_maximize_metric(self):
+        dataset = self.spark.createDataFrame(
+            [(10, 10.0), (50, 50.0), (100, 100.0), (500, 500.0)] * 10, 
["feature", "label"]
+        )
+
+        iee = InducedErrorEstimator()
+        evaluator = RegressionEvaluator(metricName="r2")
+
+        grid = ParamGridBuilder().addGrid(iee.inducedError, [100.0, 0.0, 
10000.0]).build()
+        cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, 
evaluator=evaluator)
+        cvModel = cv.fit(dataset)
+        bestModel = cvModel.bestModel
+        bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+
+        self.assertEqual(
+            0.0, bestModel.getOrDefault("inducedError"), "Best model should 
have zero induced error"
+        )
+        self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+
+    @staticmethod
+    def _check_result(result_dataframe, expected_predictions, 
expected_probabilities=None):
+        np.testing.assert_array_equal(list(result_dataframe.prediction), 
expected_predictions)
+        if "probability" in result_dataframe.columns:
+            np.testing.assert_allclose(
+                list(result_dataframe.probability),
+                expected_probabilities,
+                rtol=1e-1,
+            )
+
+    def test_crossvalidator_on_pipeline(self):
+        sk_dataset = load_breast_cancer()
+
+        train_dataset = self.spark.createDataFrame(
+            zip(sk_dataset.data.tolist(), [int(t) for t in sk_dataset.target]),
+            schema="features: array<double>, label: long",
+        )
+
+        scaler = StandardScaler(inputCol="features", 
outputCol="scaled_features")
+        lorv2 = LORV2(numTrainWorkers=2, featuresCol="scaled_features")
+        pipeline = Pipeline(stages=[scaler, lorv2])
+
+        grid2 = ParamGridBuilder().addGrid(lorv2.maxIter, [2, 200]).build()
+        cv = CrossValidator(
+            estimator=pipeline,
+            estimatorParamMaps=grid2,
+            parallelism=2,
+            evaluator=BinaryClassificationEvaluator(),
+        )
+        cv_model = cv.fit(train_dataset)
+        transformed_result = (
+            cv_model.transform(train_dataset).select("prediction", 
"probability").toPandas()
+        )
+        expected_transformed_result = (
+            cv_model.bestModel.transform(train_dataset)
+            .select("prediction", "probability")
+            .toPandas()
+        )
+        pd.testing.assert_frame_equal(transformed_result, 
expected_transformed_result)
+
+        assert cv_model.bestModel.stages[1].getMaxIter() == 200
+
+        # trial of index 2 should have better metric value
+        # because it sets higher `maxIter` param.
+        assert cv_model.avgMetrics[1] > cv_model.avgMetrics[0]
+
+        def _verify_cv_saved_params(instance, loaded_instance):
+            assert instance.getEstimator().uid == 
loaded_instance.getEstimator().uid
+            assert instance.getEvaluator().uid == 
loaded_instance.getEvaluator().uid
+            assert instance.getEstimatorParamMaps() == 
loaded_instance.getEstimatorParamMaps()
+
+        # Test save / load
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            cv.saveToLocal(f"{tmp_dir}/cv")
+            loaded_cv = CrossValidator.loadFromLocal(f"{tmp_dir}/cv")
+
+            _verify_cv_saved_params(cv, loaded_cv)
+
+            cv_model.saveToLocal(f"{tmp_dir}/cv_model")
+            loaded_cv_model = 
CrossValidatorModel.loadFromLocal(f"{tmp_dir}/cv_model")
+
+            _verify_cv_saved_params(cv_model, loaded_cv_model)
+
+            assert cv_model.uid == loaded_cv_model.uid
+            assert cv_model.bestModel.uid == loaded_cv_model.bestModel.uid
+            assert cv_model.bestModel.stages[0].uid == 
loaded_cv_model.bestModel.stages[0].uid
+            assert cv_model.bestModel.stages[1].uid == 
loaded_cv_model.bestModel.stages[1].uid
+            assert loaded_cv_model.bestModel.stages[1].getMaxIter() == 200
+
+            np.testing.assert_allclose(cv_model.avgMetrics, 
loaded_cv_model.avgMetrics)
+            np.testing.assert_allclose(cv_model.stdMetrics, 
loaded_cv_model.stdMetrics)
+
+
+class CrossValidatorTests(CrossValidatorTestsMixin, unittest.TestCase):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.master("local[2]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.connect.test_legacy_mode_tuning import *  # noqa: 
F401,F403
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to