This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3c7ef7d6135 [SPARK-42800][CONNECT][PYTHON][ML] Implement ml function 
`{array_to_vector, vector_to_array}`
3c7ef7d6135 is described below

commit 3c7ef7d6135a33448e9b08902f4b5582ae2d60c4
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Mar 16 15:03:56 2023 +0800

    [SPARK-42800][CONNECT][PYTHON][ML] Implement ml function `{array_to_vector, 
vector_to_array}`
    
    ### What changes were proposed in this pull request?
    Implement ml function `{array_to_vector, vector_to_array}`
    
    ### Why are the changes needed?
    function parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new functions
    
    ### How was this patch tested?
    added ut and manually check
    
    ```
    (spark_dev) ➜  spark git:(connect_ml_functions) ✗ bin/pyspark --remote 
"local[*]"
    Python 3.9.16 (main, Mar  8 2023, 04:29:24)
    Type 'copyright', 'credits' or 'license' for more information
    IPython 8.11.0 -- An enhanced Interactive Python. Type '?' for help.
    Setting default log level to "WARN".
    To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use 
setLogLevel(newLevel).
    23/03/15 11:56:27 WARN NativeCodeLoader: Unable to load native-hadoop 
library for your platform... using builtin-java classes where applicable
    Welcome to
          ____              __
         / __/__  ___ _____/ /__
        _\ \/ _ \/ _ `/ __/  '_/
       /__ / .__/\_,_/_/ /_/\_\   version 3.5.0.dev0
          /_/
    
    Using Python version 3.9.16 (main, Mar  8 2023 04:29:24)
    Client connected to the Spark Connect server at localhost
    SparkSession available as 'spark'.
    
    In [1]:
    
    In [1]:         query = """
       ...:             SELECT * FROM VALUES
       ...:             (1, 4, ARRAY(1.0, 2.0, 3.0)),
       ...:             (1, 2, ARRAY(-1.0, -2.0, -3.0))
       ...:             AS tab(a, b, c)
       ...:             """
    
    In [2]: cdf = spark.sql(query)
    
    In [3]:     from pyspark.sql.connect.ml import functions as CF
    
    In [4]: cdf1 = cdf.select("a", CF.array_to_vector(cdf.c).alias("d"))
    
    In [5]: cdf1.show()
    +---+----------------+                                              (0 + 1) 
/ 1]
    |  a|               d|
    +---+----------------+
    |  1|   [1.0,2.0,3.0]|
    |  1|[-1.0,-2.0,-3.0]|
    +---+----------------+
    
    In [6]: cdf1.schema
    Out[6]: StructType([StructField('a', IntegerType(), False), 
StructField('d', VectorUDT(), True)])
    
    In [7]: cdf1.select(CF.vector_to_array(cdf1.d))
    Out[7]: DataFrame[UDF(d): array<double>]
    
    In [8]: cdf1.select(CF.vector_to_array(cdf1.d)).show()
    +------------------+
    |            UDF(d)|
    +------------------+
    |   [1.0, 2.0, 3.0]|
    |[-1.0, -2.0, -3.0]|
    +------------------+
    
    In [9]: cdf1.select(CF.vector_to_array(cdf1.d)).schema
    Out[9]: StructType([StructField('UDF(d)', ArrayType(DoubleType(), False), 
False)])
    
    ```
    
    Closes #40432 from zhengruifeng/connect_ml_functions.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  42 ++++++++
 dev/sparktestsupport/modules.py                    |  92 +++++++++--------
 .../main/scala/org/apache/spark/ml/functions.scala |   6 +-
 python/pyspark/ml/connect/__init__.py              |  18 ++++
 python/pyspark/ml/connect/functions.py             |  76 ++++++++++++++
 python/pyspark/ml/functions.py                     |   9 ++
 .../ml/tests/connect/test_connect_function.py      | 113 +++++++++++++++++++++
 python/pyspark/ml/util.py                          |  36 ++++++-
 8 files changed, 346 insertions(+), 46 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index a057bd8d6c1..20db252c057 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -30,6 +30,7 @@ import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand}
 import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
 import org.apache.spark.connect.proto.Parse.ParseFormat
+import org.apache.spark.ml.{functions => MLFunctions}
 import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier}
 import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, 
UnresolvedRelation, UnresolvedStar}
@@ -1187,10 +1188,51 @@ class SparkConnectPlanner(val session: SparkSession) {
           None
         }
 
+      // ML-specific functions
+      case "vector_to_array" if fun.getArgumentsCount == 2 =>
+        val expr = transformExpression(fun.getArguments(0))
+        val dtype = transformExpression(fun.getArguments(1)) match {
+          case Literal(s, StringType) if s != null => s.toString
+          case other =>
+            throw InvalidPlanInput(
+              s"dtype in vector_to_array should be a literal string, but got 
$other")
+        }
+        dtype match {
+          case "float64" =>
+            Some(transformUnregisteredUDF(MLFunctions.vectorToArrayUdf, 
Seq(expr)))
+          case "float32" =>
+            Some(transformUnregisteredUDF(MLFunctions.vectorToArrayFloatUdf, 
Seq(expr)))
+          case other =>
+            throw InvalidPlanInput(s"Unsupported dtype: $other. Valid values: 
float64, float32.")
+        }
+
+      case "array_to_vector" if fun.getArgumentsCount == 1 =>
+        val expr = transformExpression(fun.getArguments(0))
+        Some(transformUnregisteredUDF(MLFunctions.arrayToVectorUdf, Seq(expr)))
+
       case _ => None
     }
   }
 
+  /**
+   * There are some built-in yet not registered UDFs, for example, 
'ml.function.array_to_vector'.
+   * This method is to convert them to ScalaUDF expressions.
+   */
+  private def transformUnregisteredUDF(
+      fun: org.apache.spark.sql.expressions.UserDefinedFunction,
+      exprs: Seq[Expression]): ScalaUDF = {
+    val f = 
fun.asInstanceOf[org.apache.spark.sql.expressions.SparkUserDefinedFunction]
+    ScalaUDF(
+      function = f.f,
+      dataType = f.dataType,
+      children = exprs,
+      inputEncoders = f.inputEncoders,
+      outputEncoder = f.outputEncoder,
+      udfName = f.name,
+      nullable = f.nullable,
+      udfDeterministic = f.deterministic)
+  }
+
   private def transformAlias(alias: proto.Expression.Alias): NamedExpression = 
{
     if (alias.getNameCount == 1) {
       val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) {
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index bbc5cedf9f8..5379f883815 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -513,47 +513,6 @@ pyspark_sql = Module(
     ],
 )
 
-pyspark_connect = Module(
-    name="pyspark-connect",
-    dependencies=[pyspark_sql, connect],
-    source_file_regexes=["python/pyspark/sql/connect"],
-    python_test_goals=[
-        # doctests
-        "pyspark.sql.connect.catalog",
-        "pyspark.sql.connect.conf",
-        "pyspark.sql.connect.group",
-        "pyspark.sql.connect.session",
-        "pyspark.sql.connect.window",
-        "pyspark.sql.connect.column",
-        "pyspark.sql.connect.readwriter",
-        "pyspark.sql.connect.dataframe",
-        "pyspark.sql.connect.functions",
-        # unittests
-        "pyspark.sql.tests.connect.test_client",
-        "pyspark.sql.tests.connect.test_connect_plan",
-        "pyspark.sql.tests.connect.test_connect_basic",
-        "pyspark.sql.tests.connect.test_connect_function",
-        "pyspark.sql.tests.connect.test_connect_column",
-        "pyspark.sql.tests.connect.test_parity_datasources",
-        "pyspark.sql.tests.connect.test_parity_catalog",
-        "pyspark.sql.tests.connect.test_parity_conf",
-        "pyspark.sql.tests.connect.test_parity_serde",
-        "pyspark.sql.tests.connect.test_parity_functions",
-        "pyspark.sql.tests.connect.test_parity_group",
-        "pyspark.sql.tests.connect.test_parity_dataframe",
-        "pyspark.sql.tests.connect.test_parity_types",
-        "pyspark.sql.tests.connect.test_parity_column",
-        "pyspark.sql.tests.connect.test_parity_readwriter",
-        "pyspark.sql.tests.connect.test_parity_udf",
-        "pyspark.sql.tests.connect.test_parity_pandas_udf",
-        "pyspark.sql.tests.connect.test_parity_pandas_map",
-        "pyspark.sql.tests.connect.test_parity_arrow_map",
-    ],
-    excluded_python_implementations=[
-        "PyPy"  # Skip these tests under PyPy since they require numpy, 
pandas, and pyarrow and
-        # they aren't available there
-    ],
-)
 
 pyspark_resource = Module(
     name="pyspark-resource",
@@ -781,6 +740,57 @@ pyspark_pandas_slow = Module(
     ],
 )
 
+
+pyspark_connect = Module(
+    name="pyspark-connect",
+    dependencies=[pyspark_sql, pyspark_ml, connect],
+    source_file_regexes=[
+        "python/pyspark/sql/connect",
+        "python/pyspark/ml/connect",
+    ],
+    python_test_goals=[
+        # sql doctests
+        "pyspark.sql.connect.catalog",
+        "pyspark.sql.connect.conf",
+        "pyspark.sql.connect.group",
+        "pyspark.sql.connect.session",
+        "pyspark.sql.connect.window",
+        "pyspark.sql.connect.column",
+        "pyspark.sql.connect.readwriter",
+        "pyspark.sql.connect.dataframe",
+        "pyspark.sql.connect.functions",
+        # sql unittests
+        "pyspark.sql.tests.connect.test_client",
+        "pyspark.sql.tests.connect.test_connect_plan",
+        "pyspark.sql.tests.connect.test_connect_basic",
+        "pyspark.sql.tests.connect.test_connect_function",
+        "pyspark.sql.tests.connect.test_connect_column",
+        "pyspark.sql.tests.connect.test_parity_datasources",
+        "pyspark.sql.tests.connect.test_parity_catalog",
+        "pyspark.sql.tests.connect.test_parity_conf",
+        "pyspark.sql.tests.connect.test_parity_serde",
+        "pyspark.sql.tests.connect.test_parity_functions",
+        "pyspark.sql.tests.connect.test_parity_group",
+        "pyspark.sql.tests.connect.test_parity_dataframe",
+        "pyspark.sql.tests.connect.test_parity_types",
+        "pyspark.sql.tests.connect.test_parity_column",
+        "pyspark.sql.tests.connect.test_parity_readwriter",
+        "pyspark.sql.tests.connect.test_parity_udf",
+        "pyspark.sql.tests.connect.test_parity_pandas_udf",
+        "pyspark.sql.tests.connect.test_parity_pandas_map",
+        "pyspark.sql.tests.connect.test_parity_arrow_map",
+        # ml doctests
+        "pyspark.ml.connect.functions",
+        # ml unittests
+        "pyspark.ml.tests.connect.test_connect_function",
+    ],
+    excluded_python_implementations=[
+        "PyPy"  # Skip these tests under PyPy since they require numpy, 
pandas, and pyarrow and
+        # they aren't available there
+    ],
+)
+
+
 pyspark_errors = Module(
     name="pyspark-errors",
     dependencies=[],
diff --git a/mllib/core/src/main/scala/org/apache/spark/ml/functions.scala 
b/mllib/core/src/main/scala/org/apache/spark/ml/functions.scala
index 2bd7233f3ac..040eb0bbe95 100644
--- a/mllib/core/src/main/scala/org/apache/spark/ml/functions.scala
+++ b/mllib/core/src/main/scala/org/apache/spark/ml/functions.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.functions.udf
 @Since("3.0.0")
 object functions {
 // scalastyle:on
-  private val vectorToArrayUdf = udf { vec: Any =>
+  private[spark] val vectorToArrayUdf = udf { vec: Any =>
     vec match {
       case v: Vector => v.toArray
       case v: OldVector => v.toArray
@@ -38,7 +38,7 @@ object functions {
     }
   }.asNonNullable()
 
-  private val vectorToArrayFloatUdf = udf { vec: Any =>
+  private[spark] val vectorToArrayFloatUdf = udf { vec: Any =>
     vec match {
       case v: SparseVector =>
         val data = new Array[Float](v.size)
@@ -72,7 +72,7 @@ object functions {
     }
   }
 
-  private val arrayToVectorUdf = udf { array: Seq[Double] =>
+  private[spark] val arrayToVectorUdf = udf { array: Seq[Double] =>
     Vectors.dense(array.toArray)
   }
 
diff --git a/python/pyspark/ml/connect/__init__.py 
b/python/pyspark/ml/connect/__init__.py
new file mode 100644
index 00000000000..7612e0caa28
--- /dev/null
+++ b/python/pyspark/ml/connect/__init__.py
@@ -0,0 +1,18 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Spark Connect Python Client - ML module"""
diff --git a/python/pyspark/ml/connect/functions.py 
b/python/pyspark/ml/connect/functions.py
new file mode 100644
index 00000000000..ab7e3ab3c9a
--- /dev/null
+++ b/python/pyspark/ml/connect/functions.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.sql.connect.utils import check_dependencies
+
+check_dependencies(__name__)
+
+from pyspark.ml import functions as PyMLFunctions
+
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.functions import _invoke_function, _to_col, lit
+
+
+def vector_to_array(col: Column, dtype: str = "float64") -> Column:
+    return _invoke_function("vector_to_array", _to_col(col), lit(dtype))
+
+
+vector_to_array.__doc__ = PyMLFunctions.vector_to_array.__doc__
+
+
+def array_to_vector(col: Column) -> Column:
+    return _invoke_function("array_to_vector", _to_col(col))
+
+
+array_to_vector.__doc__ = PyMLFunctions.array_to_vector.__doc__
+
+
+def _test() -> None:
+    import sys
+    import doctest
+    from pyspark.sql import SparkSession as PySparkSession
+    import pyspark.ml.connect.functions
+
+    globs = pyspark.ml.connect.functions.__dict__.copy()
+
+    # TODO: split vector_to_array doctest since it includes .mllib vectors
+    del pyspark.ml.connect.functions.vector_to_array.__doc__
+
+    # TODO: spark.createDataFrame should support UDT
+    del pyspark.ml.connect.functions.array_to_vector.__doc__
+
+    globs["spark"] = (
+        PySparkSession.builder.appName("ml.connect.functions tests")
+        .remote("local[4]")
+        .getOrCreate()
+    )
+
+    (failure_count, test_count) = doctest.testmod(
+        pyspark.ml.connect.functions,
+        globs=globs,
+        optionflags=doctest.ELLIPSIS
+        | doctest.NORMALIZE_WHITESPACE
+        | doctest.IGNORE_EXCEPTION_DETAIL,
+    )
+
+    globs["spark"].stop()
+
+    if failure_count:
+        sys.exit(-1)
+
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py
index 9a4b381b3bc..c335934be8d 100644
--- a/python/pyspark/ml/functions.py
+++ b/python/pyspark/ml/functions.py
@@ -35,6 +35,7 @@ from pyspark.sql.types import (
     StringType,
     StructType,
 )
+from pyspark.ml.util import try_remote_functions
 from typing import Any, Callable, Iterator, List, Mapping, TYPE_CHECKING, 
Tuple, Union, Optional
 
 if TYPE_CHECKING:
@@ -60,12 +61,16 @@ PredictBatchFunction = Callable[
 ]
 
 
+@try_remote_functions
 def vector_to_array(col: Column, dtype: str = "float64") -> Column:
     """
     Converts a column of MLlib sparse/dense vectors into a column of dense 
arrays.
 
     .. versionadded:: 3.0.0
 
+    .. versionchanged:: 3.5.0
+        Supports Spark Connect.
+
     Parameters
     ----------
     col : :py:class:`pyspark.sql.Column` or str
@@ -112,6 +117,7 @@ def vector_to_array(col: Column, dtype: str = "float64") -> 
Column:
     )
 
 
+@try_remote_functions
 def array_to_vector(col: Column) -> Column:
     """
     Converts a column of array of numeric type into a column of 
pyspark.ml.linalg.DenseVector
@@ -119,6 +125,9 @@ def array_to_vector(col: Column) -> Column:
 
     .. versionadded:: 3.1.0
 
+    .. versionchanged:: 3.5.0
+        Supports Spark Connect.
+
     Parameters
     ----------
     col : :py:class:`pyspark.sql.Column` or str
diff --git a/python/pyspark/ml/tests/connect/test_connect_function.py 
b/python/pyspark/ml/tests/connect/test_connect_function.py
new file mode 100644
index 00000000000..7da3d3f1add
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_connect_function.py
@@ -0,0 +1,113 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import unittest
+
+from pyspark.sql import SparkSession as PySparkSession
+from pyspark.sql.dataframe import DataFrame as SDF
+from pyspark.ml import functions as SF
+
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.testing.connectutils import (
+    should_test_connect,
+    ReusedConnectTestCase,
+)
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+if should_test_connect:
+    from pyspark.sql.connect.dataframe import DataFrame as CDF
+    from pyspark.ml.connect import functions as CF
+
+
+class SparkConnectMLFunctionTests(ReusedConnectTestCase, 
PandasOnSparkTestUtils, SQLTestUtils):
+    """These test cases exercise the interface to the proto plan
+    generation but do not call Spark."""
+
+    @classmethod
+    def setUpClass(cls):
+        super(SparkConnectMLFunctionTests, cls).setUpClass()
+        # Disable the shared namespace so pyspark.sql.functions, etc point the 
regular
+        # PySpark libraries.
+        os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
+        cls.connect = cls.spark  # Switch Spark Connect session and regular 
PySpark sesion.
+        cls.spark = PySparkSession._instantiatedSession
+        assert cls.spark is not None
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.spark = cls.connect  # Stopping Spark Connect closes the session 
in JVM at the server.
+        super(SparkConnectMLFunctionTests, cls).setUpClass()
+        del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
+
+    def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
+        assert isinstance(df1, (SDF, CDF))
+        if isinstance(df1, SDF):
+            str1 = df1._jdf.showString(n, truncate, False)
+        else:
+            str1 = df1._show_string(n, truncate, False)
+
+        assert isinstance(df2, (SDF, CDF))
+        if isinstance(df2, SDF):
+            str2 = df2._jdf.showString(n, truncate, False)
+        else:
+            str2 = df2._show_string(n, truncate, False)
+
+        self.assertEqual(str1, str2)
+
+    def test_array_vector_conversion(self):
+        query = """
+            SELECT * FROM VALUES
+            (1, 4, ARRAY(1.0, 2.0, 3.0)),
+            (1, 2, ARRAY(-1.0, -2.0, -3.0))
+            AS tab(a, b, c)
+            """
+
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
+
+        self.compare_by_show(
+            cdf.select(cdf.b, CF.array_to_vector(cdf.c)),
+            sdf.select(sdf.b, SF.array_to_vector(sdf.c)),
+        )
+
+        cdf1 = cdf.select("a", CF.array_to_vector(cdf.c).alias("d"))
+        sdf1 = sdf.select("a", SF.array_to_vector(sdf.c).alias("d"))
+
+        self.compare_by_show(
+            cdf1.select(CF.vector_to_array(cdf1.d)),
+            sdf1.select(SF.vector_to_array(sdf1.d)),
+        )
+        self.compare_by_show(
+            cdf1.select(CF.vector_to_array(cdf1.d, "float32")),
+            sdf1.select(SF.vector_to_array(sdf1.d, "float32")),
+        )
+        self.compare_by_show(
+            cdf1.select(CF.vector_to_array(cdf1.d, "float64")),
+            sdf1.select(SF.vector_to_array(sdf1.d, "float64")),
+        )
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.connect.test_connect_function import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 67aa2124b22..5d1f89cbc13 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -19,13 +19,27 @@ import json
 import os
 import time
 import uuid
-
-from typing import Any, Dict, Generic, List, Optional, Sequence, Type, 
TypeVar, cast, TYPE_CHECKING
+import functools
+
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generic,
+    List,
+    Optional,
+    Sequence,
+    Type,
+    TypeVar,
+    cast,
+    TYPE_CHECKING,
+)
 
 
 from pyspark import SparkContext, since
 from pyspark.ml.common import inherit_doc
 from pyspark.sql import SparkSession
+from pyspark.sql.utils import is_remote
 from pyspark.util import VersionUtils
 
 if TYPE_CHECKING:
@@ -41,6 +55,8 @@ JW = TypeVar("JW", bound="JavaMLWriter")
 RL = TypeVar("RL", bound="MLReadable")
 JR = TypeVar("JR", bound="JavaMLReader")
 
+FuncT = TypeVar("FuncT", bound=Callable[..., Any])
+
 
 def _jvm() -> "JavaGateway":
     """
@@ -715,3 +731,19 @@ class MetaAlgorithmReadWrite:
                 f"UIDs. List of UIDs: {list(uidMap.keys())}."
             )
         return uidMap
+
+
+def try_remote_functions(f: FuncT) -> FuncT:
+    """Mark API supported from Spark Connect."""
+
+    @functools.wraps(f)
+    def wrapped(*args: Any, **kwargs: Any) -> Any:
+
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+            from pyspark.ml.connect import functions
+
+            return getattr(functions, f.__name__)(*args, **kwargs)
+        else:
+            return f(*args, **kwargs)
+
+    return cast(FuncT, wrapped)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to