This is an automated email from the ASF dual-hosted git repository. xinrong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c619a402451 [SPARK-42263][CONNECT][PYTHON] Implement `spark.catalog.registerFunction` c619a402451 is described below commit c619a402451df9ae5b305e5a48eb244c9ffd2eb6 Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Tue Feb 14 13:56:12 2023 +0800 [SPARK-42263][CONNECT][PYTHON] Implement `spark.catalog.registerFunction` ### What changes were proposed in this pull request? Implement `spark.catalog.registerFunction`. ### Why are the changes needed? To reach parity with vanilla PySpark. ### Does this PR introduce _any_ user-facing change? Yes. `spark.catalog.registerFunction` is supported, as shown below. ```py >>> udf ... def f(): ... return 'hi' ... >>> spark.catalog.registerFunction('HI', f) <function f at 0x7fcdd8341dc0> >>> spark.sql("SELECT HI()").collect() [Row(HI()='hi')] ``` ### How was this patch tested? Unit tests. Closes #39984 from xinrong-meng/catalog_register. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Xinrong Meng <xinr...@apache.org> --- python/pyspark/sql/catalog.py | 3 ++ python/pyspark/sql/connect/catalog.py | 13 ++++-- python/pyspark/sql/connect/udf.py | 2 +- .../sql/tests/connect/test_connect_basic.py | 7 ---- .../pyspark/sql/tests/connect/test_parity_udf.py | 49 +++------------------- python/pyspark/sql/tests/test_udf.py | 4 +- 6 files changed, 21 insertions(+), 57 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index a7f3e761f3f..c83d02d4cb3 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -924,6 +924,9 @@ class Catalog: .. deprecated:: 2.3.0 Use :func:`spark.udf.register` instead. + + .. versionchanged:: 3.4.0 + Support Spark Connect. """ warnings.warn("Deprecated in 2.3.0. Use spark.udf.register instead.", FutureWarning) return self._sparkSession.udf.register(name, f, returnType) diff --git a/python/pyspark/sql/connect/catalog.py b/python/pyspark/sql/connect/catalog.py index b7ea44e831e..233fb904529 100644 --- a/python/pyspark/sql/connect/catalog.py +++ b/python/pyspark/sql/connect/catalog.py @@ -18,8 +18,9 @@ from pyspark.sql.connect import check_dependencies check_dependencies(__name__, __file__) -from typing import Any, List, Optional, TYPE_CHECKING +from typing import Any, Callable, List, Optional, TYPE_CHECKING +import warnings import pandas as pd from pyspark.sql.types import StructType @@ -36,6 +37,7 @@ from pyspark.sql.connect import plan if TYPE_CHECKING: from pyspark.sql.connect.session import SparkSession + from pyspark.sql.connect._typing import DataTypeOrString, UserDefinedFunctionLike class Catalog: @@ -306,8 +308,13 @@ class Catalog: refreshByPath.__doc__ = PySparkCatalog.refreshByPath.__doc__ - def registerFunction(self, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError("registerFunction() is not implemented.") + def registerFunction( + self, name: str, f: Callable[..., Any], returnType: Optional["DataTypeOrString"] = None + ) -> "UserDefinedFunctionLike": + warnings.warn("Deprecated in 2.3.0. Use spark.udf.register instead.", FutureWarning) + return self._sparkSession.udf.register(name, f, returnType) + + registerFunction.__doc__ = PySparkCatalog.registerFunction.__doc__ Catalog.__doc__ = PySparkCatalog.__doc__ diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 39c31e85992..bef5a99a65b 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -212,7 +212,7 @@ class UDFRegistration: ) return_udf = f self.sparkSession._client.register_udf( - f, f.returnType, name, f.evalType, f.deterministic + f.func, f.returnType, name, f.evalType, f.deterministic ) else: if returnType is None: diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 9e9341c9a2a..8bfffee1ac1 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -2805,13 +2805,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): with self.assertRaises(NotImplementedError): getattr(self.connect, f)() - def test_unsupported_catalog_functions(self): - # SPARK-41939: Disable unsupported functions. - - for f in ("registerFunction",): - with self.assertRaises(NotImplementedError): - getattr(self.connect.catalog, f)() - def test_unsupported_io_functions(self): # SPARK-41964: Disable unsupported functions. df = self.connect.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"]) diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py index 160f06d37f7..293f4b0f41a 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udf.py @@ -85,61 +85,24 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase): def test_udf_registration_return_type_not_none(self): super().test_udf_registration_return_type_not_none() - # TODO(SPARK-42263): implement `spark.catalog.registerFunction` - @unittest.skip("Fails in Spark Connect, should enable.") + @unittest.skip("Spark Connect doesn't support RDD but the test depends on it.") def test_worker_original_stdin_closed(self): super().test_worker_original_stdin_closed() - # TODO(SPARK-42263): implement `spark.catalog.registerFunction` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_chained_udf(self): - super().test_chained_udf() - - # TODO(SPARK-42263): implement `spark.catalog.registerFunction` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_udf_without_arguments(self): - super().test_udf_without_arguments() - - # TODO(SPARK-42263): implement `spark.catalog.registerFunction` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_multiple_udfs(self): - super().test_multiple_udfs() - - # TODO(SPARK-42263): implement `spark.catalog.registerFunction` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_nondeterministic_udf2(self): - super().test_nondeterministic_udf2() - - # TODO(SPARK-42263): implement `spark.catalog.registerFunction` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_single_udf_with_repeated_argument(self): - super().test_single_udf_with_repeated_argument() - - # TODO(SPARK-42263): implement `spark.catalog.registerFunction` - @unittest.skip("Fails in Spark Connect, should enable.") + @unittest.skip("Spark Connect does not support SQLContext but the test depends on it.") def test_udf(self): - super().test_df() + super().test_udf() - # TODO(SPARK-42263): implement `spark.catalog.registerFunction` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_udf2(self): - super().test_df2() - - # TODO(SPARK-42263): implement `spark.catalog.registerFunction` + # TODO(SPARK-42247): implement `UserDefinedFunction.returnType` @unittest.skip("Fails in Spark Connect, should enable.") def test_udf3(self): - super().test_df3() + super().test_udf3() - # TODO(SPARK-42263): implement `spark.catalog.registerFunction` + # TODO(SPARK-42247): implement `UserDefinedFunction.returnType` @unittest.skip("Fails in Spark Connect, should enable.") def test_udf_registration_return_type_none(self): super().test_udf_registration_return_type_none() - # TODO(SPARK-42263): implement `spark.catalog.registerFunction` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_udf_with_array_type(self): - super().test_udf_with_array_type() - # TODO(SPARK-42210): implement `spark.udf` @unittest.skip("Fails in Spark Connect, should enable.") def test_non_existed_udaf(self): diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index e7bb00f3034..0f93babbd6c 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -83,9 +83,7 @@ class BaseUDFTestsMixin(object): def test_udf2(self): with self.tempView("test"): self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.spark.createDataFrame( - self.sc.parallelize([Row(a="test")]) - ).createOrReplaceTempView("test") + self.spark.createDataFrame([("test",)], ["a"]).createOrReplaceTempView("test") [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org