This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new ea0f50b18d6 [SPARK-42263][CONNECT][PYTHON] Implement 
`spark.catalog.registerFunction`
ea0f50b18d6 is described below

commit ea0f50b18d6230fd5c5362b84f3dabc045635883
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Tue Feb 14 13:56:12 2023 +0800

    [SPARK-42263][CONNECT][PYTHON] Implement `spark.catalog.registerFunction`
    
    ### What changes were proposed in this pull request?
    Implement `spark.catalog.registerFunction`.
    
    ### Why are the changes needed?
    To reach parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. `spark.catalog.registerFunction` is supported, as shown below.
    
    ```py
    >>> udf
    ... def f():
    ...   return 'hi'
    ...
    >>> spark.catalog.registerFunction('HI', f)
    <function f at 0x7fcdd8341dc0>
    >>> spark.sql("SELECT HI()").collect()
    [Row(HI()='hi')]
    
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #39984 from xinrong-meng/catalog_register.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
    (cherry picked from commit c619a402451df9ae5b305e5a48eb244c9ffd2eb6)
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
---
 python/pyspark/sql/catalog.py                      |  3 ++
 python/pyspark/sql/connect/catalog.py              | 13 ++++--
 python/pyspark/sql/connect/udf.py                  |  2 +-
 .../sql/tests/connect/test_connect_basic.py        |  7 ----
 .../pyspark/sql/tests/connect/test_parity_udf.py   | 49 +++-------------------
 python/pyspark/sql/tests/test_udf.py               |  4 +-
 6 files changed, 21 insertions(+), 57 deletions(-)

diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index a7f3e761f3f..c83d02d4cb3 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -924,6 +924,9 @@ class Catalog:
 
         .. deprecated:: 2.3.0
             Use :func:`spark.udf.register` instead.
+
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
         """
         warnings.warn("Deprecated in 2.3.0. Use spark.udf.register instead.", 
FutureWarning)
         return self._sparkSession.udf.register(name, f, returnType)
diff --git a/python/pyspark/sql/connect/catalog.py 
b/python/pyspark/sql/connect/catalog.py
index b7ea44e831e..233fb904529 100644
--- a/python/pyspark/sql/connect/catalog.py
+++ b/python/pyspark/sql/connect/catalog.py
@@ -18,8 +18,9 @@ from pyspark.sql.connect import check_dependencies
 
 check_dependencies(__name__, __file__)
 
-from typing import Any, List, Optional, TYPE_CHECKING
+from typing import Any, Callable, List, Optional, TYPE_CHECKING
 
+import warnings
 import pandas as pd
 
 from pyspark.sql.types import StructType
@@ -36,6 +37,7 @@ from pyspark.sql.connect import plan
 
 if TYPE_CHECKING:
     from pyspark.sql.connect.session import SparkSession
+    from pyspark.sql.connect._typing import DataTypeOrString, 
UserDefinedFunctionLike
 
 
 class Catalog:
@@ -306,8 +308,13 @@ class Catalog:
 
     refreshByPath.__doc__ = PySparkCatalog.refreshByPath.__doc__
 
-    def registerFunction(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("registerFunction() is not implemented.")
+    def registerFunction(
+        self, name: str, f: Callable[..., Any], returnType: 
Optional["DataTypeOrString"] = None
+    ) -> "UserDefinedFunctionLike":
+        warnings.warn("Deprecated in 2.3.0. Use spark.udf.register instead.", 
FutureWarning)
+        return self._sparkSession.udf.register(name, f, returnType)
+
+    registerFunction.__doc__ = PySparkCatalog.registerFunction.__doc__
 
 
 Catalog.__doc__ = PySparkCatalog.__doc__
diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
index 39c31e85992..bef5a99a65b 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -212,7 +212,7 @@ class UDFRegistration:
                 )
             return_udf = f
             self.sparkSession._client.register_udf(
-                f, f.returnType, name, f.evalType, f.deterministic
+                f.func, f.returnType, name, f.evalType, f.deterministic
             )
         else:
             if returnType is None:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 9e9341c9a2a..8bfffee1ac1 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2805,13 +2805,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             with self.assertRaises(NotImplementedError):
                 getattr(self.connect, f)()
 
-    def test_unsupported_catalog_functions(self):
-        # SPARK-41939: Disable unsupported functions.
-
-        for f in ("registerFunction",):
-            with self.assertRaises(NotImplementedError):
-                getattr(self.connect.catalog, f)()
-
     def test_unsupported_io_functions(self):
         # SPARK-41964: Disable unsupported functions.
         df = self.connect.createDataFrame([(x, f"{x}") for x in range(100)], 
["id", "name"])
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py 
b/python/pyspark/sql/tests/connect/test_parity_udf.py
index 160f06d37f7..293f4b0f41a 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf.py
@@ -85,61 +85,24 @@ class UDFParityTests(BaseUDFTestsMixin, 
ReusedConnectTestCase):
     def test_udf_registration_return_type_not_none(self):
         super().test_udf_registration_return_type_not_none()
 
-    # TODO(SPARK-42263): implement `spark.catalog.registerFunction`
-    @unittest.skip("Fails in Spark Connect, should enable.")
+    @unittest.skip("Spark Connect doesn't support RDD but the test depends on 
it.")
     def test_worker_original_stdin_closed(self):
         super().test_worker_original_stdin_closed()
 
-    # TODO(SPARK-42263): implement `spark.catalog.registerFunction`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_chained_udf(self):
-        super().test_chained_udf()
-
-    # TODO(SPARK-42263): implement `spark.catalog.registerFunction`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_udf_without_arguments(self):
-        super().test_udf_without_arguments()
-
-    # TODO(SPARK-42263): implement `spark.catalog.registerFunction`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_multiple_udfs(self):
-        super().test_multiple_udfs()
-
-    # TODO(SPARK-42263): implement `spark.catalog.registerFunction`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_nondeterministic_udf2(self):
-        super().test_nondeterministic_udf2()
-
-    # TODO(SPARK-42263): implement `spark.catalog.registerFunction`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_single_udf_with_repeated_argument(self):
-        super().test_single_udf_with_repeated_argument()
-
-    # TODO(SPARK-42263): implement `spark.catalog.registerFunction`
-    @unittest.skip("Fails in Spark Connect, should enable.")
+    @unittest.skip("Spark Connect does not support SQLContext but the test 
depends on it.")
     def test_udf(self):
-        super().test_df()
+        super().test_udf()
 
-    # TODO(SPARK-42263): implement `spark.catalog.registerFunction`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_udf2(self):
-        super().test_df2()
-
-    # TODO(SPARK-42263): implement `spark.catalog.registerFunction`
+    # TODO(SPARK-42247): implement `UserDefinedFunction.returnType`
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_udf3(self):
-        super().test_df3()
+        super().test_udf3()
 
-    # TODO(SPARK-42263): implement `spark.catalog.registerFunction`
+    # TODO(SPARK-42247): implement `UserDefinedFunction.returnType`
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_udf_registration_return_type_none(self):
         super().test_udf_registration_return_type_none()
 
-    # TODO(SPARK-42263): implement `spark.catalog.registerFunction`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_udf_with_array_type(self):
-        super().test_udf_with_array_type()
-
     # TODO(SPARK-42210): implement `spark.udf`
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_non_existed_udaf(self):
diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index e7bb00f3034..0f93babbd6c 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -83,9 +83,7 @@ class BaseUDFTestsMixin(object):
     def test_udf2(self):
         with self.tempView("test"):
             self.spark.catalog.registerFunction("strlen", lambda string: 
len(string), IntegerType())
-            self.spark.createDataFrame(
-                self.sc.parallelize([Row(a="test")])
-            ).createOrReplaceTempView("test")
+            self.spark.createDataFrame([("test",)], 
["a"]).createOrReplaceTempView("test")
             [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) 
> 1").collect()
             self.assertEqual(4, res[0])
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to