This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3c7ef7d6135 [SPARK-42800][CONNECT][PYTHON][ML] Implement ml function `{array_to_vector, vector_to_array}` 3c7ef7d6135 is described below commit 3c7ef7d6135a33448e9b08902f4b5582ae2d60c4 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Mar 16 15:03:56 2023 +0800 [SPARK-42800][CONNECT][PYTHON][ML] Implement ml function `{array_to_vector, vector_to_array}` ### What changes were proposed in this pull request? Implement ml function `{array_to_vector, vector_to_array}` ### Why are the changes needed? function parity ### Does this PR introduce _any_ user-facing change? yes, new functions ### How was this patch tested? added ut and manually check ``` (spark_dev) ➜ spark git:(connect_ml_functions) ✗ bin/pyspark --remote "local[*]" Python 3.9.16 (main, Mar 8 2023, 04:29:24) Type 'copyright', 'credits' or 'license' for more information IPython 8.11.0 -- An enhanced Interactive Python. Type '?' for help. Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). 23/03/15 11:56:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 3.5.0.dev0 /_/ Using Python version 3.9.16 (main, Mar 8 2023 04:29:24) Client connected to the Spark Connect server at localhost SparkSession available as 'spark'. In [1]: In [1]: query = """ ...: SELECT * FROM VALUES ...: (1, 4, ARRAY(1.0, 2.0, 3.0)), ...: (1, 2, ARRAY(-1.0, -2.0, -3.0)) ...: AS tab(a, b, c) ...: """ In [2]: cdf = spark.sql(query) In [3]: from pyspark.sql.connect.ml import functions as CF In [4]: cdf1 = cdf.select("a", CF.array_to_vector(cdf.c).alias("d")) In [5]: cdf1.show() +---+----------------+ (0 + 1) / 1] | a| d| +---+----------------+ | 1| [1.0,2.0,3.0]| | 1|[-1.0,-2.0,-3.0]| +---+----------------+ In [6]: cdf1.schema Out[6]: StructType([StructField('a', IntegerType(), False), StructField('d', VectorUDT(), True)]) In [7]: cdf1.select(CF.vector_to_array(cdf1.d)) Out[7]: DataFrame[UDF(d): array<double>] In [8]: cdf1.select(CF.vector_to_array(cdf1.d)).show() +------------------+ | UDF(d)| +------------------+ | [1.0, 2.0, 3.0]| |[-1.0, -2.0, -3.0]| +------------------+ In [9]: cdf1.select(CF.vector_to_array(cdf1.d)).schema Out[9]: StructType([StructField('UDF(d)', ArrayType(DoubleType(), False), False)]) ``` Closes #40432 from zhengruifeng/connect_ml_functions. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../sql/connect/planner/SparkConnectPlanner.scala | 42 ++++++++ dev/sparktestsupport/modules.py | 92 +++++++++-------- .../main/scala/org/apache/spark/ml/functions.scala | 6 +- python/pyspark/ml/connect/__init__.py | 18 ++++ python/pyspark/ml/connect/functions.py | 76 ++++++++++++++ python/pyspark/ml/functions.py | 9 ++ .../ml/tests/connect/test_connect_function.py | 113 +++++++++++++++++++++ python/pyspark/ml/util.py | 36 ++++++- 8 files changed, 346 insertions(+), 46 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index a057bd8d6c1..20db252c057 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -30,6 +30,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand} import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult import org.apache.spark.connect.proto.Parse.ParseFormat +import org.apache.spark.ml.{functions => MLFunctions} import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} @@ -1187,10 +1188,51 @@ class SparkConnectPlanner(val session: SparkSession) { None } + // ML-specific functions + case "vector_to_array" if fun.getArgumentsCount == 2 => + val expr = transformExpression(fun.getArguments(0)) + val dtype = transformExpression(fun.getArguments(1)) match { + case Literal(s, StringType) if s != null => s.toString + case other => + throw InvalidPlanInput( + s"dtype in vector_to_array should be a literal string, but got $other") + } + dtype match { + case "float64" => + Some(transformUnregisteredUDF(MLFunctions.vectorToArrayUdf, Seq(expr))) + case "float32" => + Some(transformUnregisteredUDF(MLFunctions.vectorToArrayFloatUdf, Seq(expr))) + case other => + throw InvalidPlanInput(s"Unsupported dtype: $other. Valid values: float64, float32.") + } + + case "array_to_vector" if fun.getArgumentsCount == 1 => + val expr = transformExpression(fun.getArguments(0)) + Some(transformUnregisteredUDF(MLFunctions.arrayToVectorUdf, Seq(expr))) + case _ => None } } + /** + * There are some built-in yet not registered UDFs, for example, 'ml.function.array_to_vector'. + * This method is to convert them to ScalaUDF expressions. + */ + private def transformUnregisteredUDF( + fun: org.apache.spark.sql.expressions.UserDefinedFunction, + exprs: Seq[Expression]): ScalaUDF = { + val f = fun.asInstanceOf[org.apache.spark.sql.expressions.SparkUserDefinedFunction] + ScalaUDF( + function = f.f, + dataType = f.dataType, + children = exprs, + inputEncoders = f.inputEncoders, + outputEncoder = f.outputEncoder, + udfName = f.name, + nullable = f.nullable, + udfDeterministic = f.deterministic) + } + private def transformAlias(alias: proto.Expression.Alias): NamedExpression = { if (alias.getNameCount == 1) { val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) { diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index bbc5cedf9f8..5379f883815 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -513,47 +513,6 @@ pyspark_sql = Module( ], ) -pyspark_connect = Module( - name="pyspark-connect", - dependencies=[pyspark_sql, connect], - source_file_regexes=["python/pyspark/sql/connect"], - python_test_goals=[ - # doctests - "pyspark.sql.connect.catalog", - "pyspark.sql.connect.conf", - "pyspark.sql.connect.group", - "pyspark.sql.connect.session", - "pyspark.sql.connect.window", - "pyspark.sql.connect.column", - "pyspark.sql.connect.readwriter", - "pyspark.sql.connect.dataframe", - "pyspark.sql.connect.functions", - # unittests - "pyspark.sql.tests.connect.test_client", - "pyspark.sql.tests.connect.test_connect_plan", - "pyspark.sql.tests.connect.test_connect_basic", - "pyspark.sql.tests.connect.test_connect_function", - "pyspark.sql.tests.connect.test_connect_column", - "pyspark.sql.tests.connect.test_parity_datasources", - "pyspark.sql.tests.connect.test_parity_catalog", - "pyspark.sql.tests.connect.test_parity_conf", - "pyspark.sql.tests.connect.test_parity_serde", - "pyspark.sql.tests.connect.test_parity_functions", - "pyspark.sql.tests.connect.test_parity_group", - "pyspark.sql.tests.connect.test_parity_dataframe", - "pyspark.sql.tests.connect.test_parity_types", - "pyspark.sql.tests.connect.test_parity_column", - "pyspark.sql.tests.connect.test_parity_readwriter", - "pyspark.sql.tests.connect.test_parity_udf", - "pyspark.sql.tests.connect.test_parity_pandas_udf", - "pyspark.sql.tests.connect.test_parity_pandas_map", - "pyspark.sql.tests.connect.test_parity_arrow_map", - ], - excluded_python_implementations=[ - "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and - # they aren't available there - ], -) pyspark_resource = Module( name="pyspark-resource", @@ -781,6 +740,57 @@ pyspark_pandas_slow = Module( ], ) + +pyspark_connect = Module( + name="pyspark-connect", + dependencies=[pyspark_sql, pyspark_ml, connect], + source_file_regexes=[ + "python/pyspark/sql/connect", + "python/pyspark/ml/connect", + ], + python_test_goals=[ + # sql doctests + "pyspark.sql.connect.catalog", + "pyspark.sql.connect.conf", + "pyspark.sql.connect.group", + "pyspark.sql.connect.session", + "pyspark.sql.connect.window", + "pyspark.sql.connect.column", + "pyspark.sql.connect.readwriter", + "pyspark.sql.connect.dataframe", + "pyspark.sql.connect.functions", + # sql unittests + "pyspark.sql.tests.connect.test_client", + "pyspark.sql.tests.connect.test_connect_plan", + "pyspark.sql.tests.connect.test_connect_basic", + "pyspark.sql.tests.connect.test_connect_function", + "pyspark.sql.tests.connect.test_connect_column", + "pyspark.sql.tests.connect.test_parity_datasources", + "pyspark.sql.tests.connect.test_parity_catalog", + "pyspark.sql.tests.connect.test_parity_conf", + "pyspark.sql.tests.connect.test_parity_serde", + "pyspark.sql.tests.connect.test_parity_functions", + "pyspark.sql.tests.connect.test_parity_group", + "pyspark.sql.tests.connect.test_parity_dataframe", + "pyspark.sql.tests.connect.test_parity_types", + "pyspark.sql.tests.connect.test_parity_column", + "pyspark.sql.tests.connect.test_parity_readwriter", + "pyspark.sql.tests.connect.test_parity_udf", + "pyspark.sql.tests.connect.test_parity_pandas_udf", + "pyspark.sql.tests.connect.test_parity_pandas_map", + "pyspark.sql.tests.connect.test_parity_arrow_map", + # ml doctests + "pyspark.ml.connect.functions", + # ml unittests + "pyspark.ml.tests.connect.test_connect_function", + ], + excluded_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and + # they aren't available there + ], +) + + pyspark_errors = Module( name="pyspark-errors", dependencies=[], diff --git a/mllib/core/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/core/src/main/scala/org/apache/spark/ml/functions.scala index 2bd7233f3ac..040eb0bbe95 100644 --- a/mllib/core/src/main/scala/org/apache/spark/ml/functions.scala +++ b/mllib/core/src/main/scala/org/apache/spark/ml/functions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.functions.udf @Since("3.0.0") object functions { // scalastyle:on - private val vectorToArrayUdf = udf { vec: Any => + private[spark] val vectorToArrayUdf = udf { vec: Any => vec match { case v: Vector => v.toArray case v: OldVector => v.toArray @@ -38,7 +38,7 @@ object functions { } }.asNonNullable() - private val vectorToArrayFloatUdf = udf { vec: Any => + private[spark] val vectorToArrayFloatUdf = udf { vec: Any => vec match { case v: SparseVector => val data = new Array[Float](v.size) @@ -72,7 +72,7 @@ object functions { } } - private val arrayToVectorUdf = udf { array: Seq[Double] => + private[spark] val arrayToVectorUdf = udf { array: Seq[Double] => Vectors.dense(array.toArray) } diff --git a/python/pyspark/ml/connect/__init__.py b/python/pyspark/ml/connect/__init__.py new file mode 100644 index 00000000000..7612e0caa28 --- /dev/null +++ b/python/pyspark/ml/connect/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Spark Connect Python Client - ML module""" diff --git a/python/pyspark/ml/connect/functions.py b/python/pyspark/ml/connect/functions.py new file mode 100644 index 00000000000..ab7e3ab3c9a --- /dev/null +++ b/python/pyspark/ml/connect/functions.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) + +from pyspark.ml import functions as PyMLFunctions + +from pyspark.sql.connect.column import Column +from pyspark.sql.connect.functions import _invoke_function, _to_col, lit + + +def vector_to_array(col: Column, dtype: str = "float64") -> Column: + return _invoke_function("vector_to_array", _to_col(col), lit(dtype)) + + +vector_to_array.__doc__ = PyMLFunctions.vector_to_array.__doc__ + + +def array_to_vector(col: Column) -> Column: + return _invoke_function("array_to_vector", _to_col(col)) + + +array_to_vector.__doc__ = PyMLFunctions.array_to_vector.__doc__ + + +def _test() -> None: + import sys + import doctest + from pyspark.sql import SparkSession as PySparkSession + import pyspark.ml.connect.functions + + globs = pyspark.ml.connect.functions.__dict__.copy() + + # TODO: split vector_to_array doctest since it includes .mllib vectors + del pyspark.ml.connect.functions.vector_to_array.__doc__ + + # TODO: spark.createDataFrame should support UDT + del pyspark.ml.connect.functions.array_to_vector.__doc__ + + globs["spark"] = ( + PySparkSession.builder.appName("ml.connect.functions tests") + .remote("local[4]") + .getOrCreate() + ) + + (failure_count, test_count) = doctest.testmod( + pyspark.ml.connect.functions, + globs=globs, + optionflags=doctest.ELLIPSIS + | doctest.NORMALIZE_WHITESPACE + | doctest.IGNORE_EXCEPTION_DETAIL, + ) + + globs["spark"].stop() + + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py index 9a4b381b3bc..c335934be8d 100644 --- a/python/pyspark/ml/functions.py +++ b/python/pyspark/ml/functions.py @@ -35,6 +35,7 @@ from pyspark.sql.types import ( StringType, StructType, ) +from pyspark.ml.util import try_remote_functions from typing import Any, Callable, Iterator, List, Mapping, TYPE_CHECKING, Tuple, Union, Optional if TYPE_CHECKING: @@ -60,12 +61,16 @@ PredictBatchFunction = Callable[ ] +@try_remote_functions def vector_to_array(col: Column, dtype: str = "float64") -> Column: """ Converts a column of MLlib sparse/dense vectors into a column of dense arrays. .. versionadded:: 3.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- col : :py:class:`pyspark.sql.Column` or str @@ -112,6 +117,7 @@ def vector_to_array(col: Column, dtype: str = "float64") -> Column: ) +@try_remote_functions def array_to_vector(col: Column) -> Column: """ Converts a column of array of numeric type into a column of pyspark.ml.linalg.DenseVector @@ -119,6 +125,9 @@ def array_to_vector(col: Column) -> Column: .. versionadded:: 3.1.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- col : :py:class:`pyspark.sql.Column` or str diff --git a/python/pyspark/ml/tests/connect/test_connect_function.py b/python/pyspark/ml/tests/connect/test_connect_function.py new file mode 100644 index 00000000000..7da3d3f1add --- /dev/null +++ b/python/pyspark/ml/tests/connect/test_connect_function.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import unittest + +from pyspark.sql import SparkSession as PySparkSession +from pyspark.sql.dataframe import DataFrame as SDF +from pyspark.ml import functions as SF + +from pyspark.testing.sqlutils import SQLTestUtils +from pyspark.testing.connectutils import ( + should_test_connect, + ReusedConnectTestCase, +) +from pyspark.testing.pandasutils import PandasOnSparkTestUtils + +if should_test_connect: + from pyspark.sql.connect.dataframe import DataFrame as CDF + from pyspark.ml.connect import functions as CF + + +class SparkConnectMLFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils): + """These test cases exercise the interface to the proto plan + generation but do not call Spark.""" + + @classmethod + def setUpClass(cls): + super(SparkConnectMLFunctionTests, cls).setUpClass() + # Disable the shared namespace so pyspark.sql.functions, etc point the regular + # PySpark libraries. + os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1" + cls.connect = cls.spark # Switch Spark Connect session and regular PySpark sesion. + cls.spark = PySparkSession._instantiatedSession + assert cls.spark is not None + + @classmethod + def tearDownClass(cls): + cls.spark = cls.connect # Stopping Spark Connect closes the session in JVM at the server. + super(SparkConnectMLFunctionTests, cls).setUpClass() + del os.environ["PYSPARK_NO_NAMESPACE_SHARE"] + + def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20): + assert isinstance(df1, (SDF, CDF)) + if isinstance(df1, SDF): + str1 = df1._jdf.showString(n, truncate, False) + else: + str1 = df1._show_string(n, truncate, False) + + assert isinstance(df2, (SDF, CDF)) + if isinstance(df2, SDF): + str2 = df2._jdf.showString(n, truncate, False) + else: + str2 = df2._show_string(n, truncate, False) + + self.assertEqual(str1, str2) + + def test_array_vector_conversion(self): + query = """ + SELECT * FROM VALUES + (1, 4, ARRAY(1.0, 2.0, 3.0)), + (1, 2, ARRAY(-1.0, -2.0, -3.0)) + AS tab(a, b, c) + """ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + self.compare_by_show( + cdf.select(cdf.b, CF.array_to_vector(cdf.c)), + sdf.select(sdf.b, SF.array_to_vector(sdf.c)), + ) + + cdf1 = cdf.select("a", CF.array_to_vector(cdf.c).alias("d")) + sdf1 = sdf.select("a", SF.array_to_vector(sdf.c).alias("d")) + + self.compare_by_show( + cdf1.select(CF.vector_to_array(cdf1.d)), + sdf1.select(SF.vector_to_array(sdf1.d)), + ) + self.compare_by_show( + cdf1.select(CF.vector_to_array(cdf1.d, "float32")), + sdf1.select(SF.vector_to_array(sdf1.d, "float32")), + ) + self.compare_by_show( + cdf1.select(CF.vector_to_array(cdf1.d, "float64")), + sdf1.select(SF.vector_to_array(sdf1.d, "float64")), + ) + + +if __name__ == "__main__": + from pyspark.ml.tests.connect.test_connect_function import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 67aa2124b22..5d1f89cbc13 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -19,13 +19,27 @@ import json import os import time import uuid - -from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, cast, TYPE_CHECKING +import functools + +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Sequence, + Type, + TypeVar, + cast, + TYPE_CHECKING, +) from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession +from pyspark.sql.utils import is_remote from pyspark.util import VersionUtils if TYPE_CHECKING: @@ -41,6 +55,8 @@ JW = TypeVar("JW", bound="JavaMLWriter") RL = TypeVar("RL", bound="MLReadable") JR = TypeVar("JR", bound="JavaMLReader") +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) + def _jvm() -> "JavaGateway": """ @@ -715,3 +731,19 @@ class MetaAlgorithmReadWrite: f"UIDs. List of UIDs: {list(uidMap.keys())}." ) return uidMap + + +def try_remote_functions(f: FuncT) -> FuncT: + """Mark API supported from Spark Connect.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + from pyspark.ml.connect import functions + + return getattr(functions, f.__name__)(*args, **kwargs) + else: + return f(*args, **kwargs) + + return cast(FuncT, wrapped) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org