Yicong-Huang commented on code in PR #55532:
URL: https://github.com/apache/spark/pull/55532#discussion_r3157518317


##########
python/pyspark/worker.py:
##########
@@ -234,46 +251,56 @@ def chain(f, g):
     return lambda *a: g(f(*a))
 
 
-def verify_result(expected_type: type) -> Callable[[Any], Iterator]:
-    """
-    Create a result verifier that checks both iterability and element types.
+@overload
+def verify_return_type(result: Any, expected_type: Type[T]) -> T: ...
 
-    Returns a function that takes a UDF result, verifies it is iterable,
-    and lazily type-checks each element via map.
 
-    Parameters
-    ----------
-    expected_type : type
-        The expected Python/PyArrow type for each element
-        (e.g. pa.RecordBatch, pa.Array).
+@overload
+def verify_return_type(result: Any, expected_type: Any) -> Any: ...
+
+
+def verify_return_type(result: Any, expected_type: Any) -> Any:
     """
+    Verify a UDF return value against an expected type.
 
-    package = getattr(inspect.getmodule(expected_type), "__package__", "")
-    label: str = f"{package}.{expected_type.__name__}"
+    Returns ``result`` unchanged if ``isinstance(result, expected_type)``.
+    For ``Iterator[T]``, returns a lazy iterator that checks each element
+    against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch.
+    """
+    if get_origin(expected_type) is collections.abc.Iterator:
+        (element_type,) = get_args(expected_type)
+        package = getattr(inspect.getmodule(element_type), "__package__", "")

Review Comment:
   Yes, backward compat. `repr(pa.Table)` is `<class 'pyarrow.lib.Table'>` — 
would leak the internal `lib` segment into user-facing messages. 
`package.__name__` keeps it as `pyarrow.Table`, matching the pre-PR 
`verify_result` factory and existing tests.



##########
python/pyspark/worker.py:
##########
@@ -234,46 +251,56 @@ def chain(f, g):
     return lambda *a: g(f(*a))
 
 
-def verify_result(expected_type: type) -> Callable[[Any], Iterator]:
-    """
-    Create a result verifier that checks both iterability and element types.
+@overload
+def verify_return_type(result: Any, expected_type: Type[T]) -> T: ...
 
-    Returns a function that takes a UDF result, verifies it is iterable,
-    and lazily type-checks each element via map.
 
-    Parameters
-    ----------
-    expected_type : type
-        The expected Python/PyArrow type for each element
-        (e.g. pa.RecordBatch, pa.Array).
+@overload
+def verify_return_type(result: Any, expected_type: Any) -> Any: ...
+
+
+def verify_return_type(result: Any, expected_type: Any) -> Any:
     """
+    Verify a UDF return value against an expected type.
 
-    package = getattr(inspect.getmodule(expected_type), "__package__", "")
-    label: str = f"{package}.{expected_type.__name__}"
+    Returns ``result`` unchanged if ``isinstance(result, expected_type)``.
+    For ``Iterator[T]``, returns a lazy iterator that checks each element
+    against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch.
+    """
+    if get_origin(expected_type) is collections.abc.Iterator:

Review Comment:
   Done in b3b4658 — switched to `from collections.abc import Iterator` and 
dropped the now-redundant `import collections.abc`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to