This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8c7a6fc81cc [SPARK-43886][PYTHON] Accept generics tuple as typing 
hints of Pandas UDF
8c7a6fc81cc is described below

commit 8c7a6fc81cceaba3d9c428baec336639b0d91205
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Wed May 31 09:19:21 2023 +0900

    [SPARK-43886][PYTHON] Accept generics tuple as typing hints of Pandas UDF
    
    ### What changes were proposed in this pull request?
    Accept generics tuple as typing hints in Pandas UDF.
    
    ### Why are the changes needed?
    Adapt to [PEP 585](https://peps.python.org/pep-0585/) with Python 3.9.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. `tuple` is accepted as typing hints of Pandas UDF.
    
    FROM
    ```py
    >>> pandas_udf("long")
    ... def multiply(iterator: Iterator[tuple[pd.Series, pd.DataFrame]]) -> 
Iterator[pd.Series]:
    ...   for s1, df in iterator:
    ...     yield s1 * df.v
    ...
    Traceback (most recent call last):
    ...
        raise PySparkNotImplementedError(
    pyspark.errors.exceptions.base.PySparkNotImplementedError: 
[UNSUPPORTED_SIGNATURE] Unsupported signature: (iterator: 
Iterator[tuple[pandas.core.series.Series, pandas.core.frame.DataFrame]]) -> 
Iterator[pandas.core.series.Series].
    ```
    
    TO
    ```py
    >>> pandas_udf("long")
    ... def multiply(iterator: Iterator[tuple[pd.Series, pd.DataFrame]]) -> 
Iterator[pd.Series]:
    ...   for s1, df in iterator:
    ...     yield s1 * df.v
    ...
    >>> multiply._unwrapped.evalType
    204  # SQL_SCALAR_PANDAS_ITER_UDF
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #41388 from xinrong-meng/tuple.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/pandas/typehints.py             |  2 +-
 .../sql/tests/pandas/test_pandas_udf_typehints.py  | 24 ++++++++++++++++++++++
 2 files changed, 25 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/pandas/typehints.py 
b/python/pyspark/sql/pandas/typehints.py
index 29ac81af944..f0c13e66a63 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -145,7 +145,7 @@ def check_tuple_annotation(
     # Tuple has _name but other types have __name__
     # Check if the name is Tuple first. After that, check the generic types.
     name = getattr(annotation, "_name", getattr(annotation, "__name__", None))
-    return name == "Tuple" and (
+    return name in ("Tuple", "tuple") and (
         parameter_check_func is None or all(map(parameter_check_func, 
annotation.__args__))
     )
 
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
index 3cdf83e2d06..bfb874ffe53 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import sys
 import unittest
 from inspect import signature
 from typing import Union, Iterator, Tuple, cast, get_type_hints
@@ -113,6 +114,29 @@ class PandasUDFTypeHintsTests(ReusedSQLTestCase):
             infer_eval_type(signature(func), get_type_hints(func)), 
PandasUDFType.SCALAR_ITER
         )
 
+    @unittest.skipIf(sys.version_info < (3, 9), "Type hinting generics require 
Python 3.9.")
+    def test_type_annotation_tuple_generics(self):
+        def func(iter: Iterator[tuple[pd.DataFrame, pd.Series]]) -> 
Iterator[pd.DataFrame]:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
PandasUDFType.SCALAR_ITER
+        )
+
+        def func(iter: Iterator[tuple[pd.DataFrame, ...]]) -> 
Iterator[pd.Series]:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
PandasUDFType.SCALAR_ITER
+        )
+
+        def func(iter: Iterator[tuple[Union[pd.DataFrame, pd.Series], ...]]) 
-> Iterator[pd.Series]:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
PandasUDFType.SCALAR_ITER
+        )
+
     def test_type_annotation_group_agg(self):
         def func(col: pd.Series) -> str:
             pass


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to