gaogaotiantian commented on code in PR #55532:
URL: https://github.com/apache/spark/pull/55532#discussion_r3151257687
##########
python/pyspark/worker.py:
##########
@@ -234,46 +251,56 @@ def chain(f, g):
return lambda *a: g(f(*a))
-def verify_result(expected_type: type) -> Callable[[Any], Iterator]:
- """
- Create a result verifier that checks both iterability and element types.
+@overload
Review Comment:
I was wondering if we can add the iterator type overload - that would have
some constraint on `result` right?
##########
python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py:
##########
@@ -148,8 +148,8 @@ def stats_iter(
with self.assertRaisesRegex(
PythonException,
- "Return type of the user-defined function should be
pyarrow.RecordBatch, but is "
- + "tuple",
+ "Return type of the user-defined function should be iterator
of "
Review Comment:
We should start reducing precise match of an exception. This cause
unnecessary trouble when we make minor changes to the exceptions. We don't have
to do it immediately for all tests, but while we are changing the tests, we can
probably do that gradually. We can check for some important keyword with regex.
##########
python/pyspark/worker.py:
##########
@@ -234,46 +251,56 @@ def chain(f, g):
return lambda *a: g(f(*a))
-def verify_result(expected_type: type) -> Callable[[Any], Iterator]:
- """
- Create a result verifier that checks both iterability and element types.
+@overload
+def verify_return_type(result: Any, expected_type: Type[T]) -> T: ...
- Returns a function that takes a UDF result, verifies it is iterable,
- and lazily type-checks each element via map.
- Parameters
- ----------
- expected_type : type
- The expected Python/PyArrow type for each element
- (e.g. pa.RecordBatch, pa.Array).
+@overload
+def verify_return_type(result: Any, expected_type: Any) -> Any: ...
+
+
+def verify_return_type(result: Any, expected_type: Any) -> Any:
"""
+ Verify a UDF return value against an expected type.
- package = getattr(inspect.getmodule(expected_type), "__package__", "")
- label: str = f"{package}.{expected_type.__name__}"
+ Returns ``result`` unchanged if ``isinstance(result, expected_type)``.
+ For ``Iterator[T]``, returns a lazy iterator that checks each element
+ against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch.
+ """
+ if get_origin(expected_type) is collections.abc.Iterator:
+ (element_type,) = get_args(expected_type)
+ package = getattr(inspect.getmodule(element_type), "__package__", "")
Review Comment:
If I remember correctly, we do this to make sure the tests are backward
compatible right? The whole "label" thing. I mean here we create some format
but the type itself normally already have `repr`.
##########
python/pyspark/worker.py:
##########
@@ -234,46 +251,56 @@ def chain(f, g):
return lambda *a: g(f(*a))
-def verify_result(expected_type: type) -> Callable[[Any], Iterator]:
- """
- Create a result verifier that checks both iterability and element types.
+@overload
+def verify_return_type(result: Any, expected_type: Type[T]) -> T: ...
- Returns a function that takes a UDF result, verifies it is iterable,
- and lazily type-checks each element via map.
- Parameters
- ----------
- expected_type : type
- The expected Python/PyArrow type for each element
- (e.g. pa.RecordBatch, pa.Array).
+@overload
+def verify_return_type(result: Any, expected_type: Any) -> Any: ...
+
+
+def verify_return_type(result: Any, expected_type: Any) -> Any:
"""
+ Verify a UDF return value against an expected type.
- package = getattr(inspect.getmodule(expected_type), "__package__", "")
- label: str = f"{package}.{expected_type.__name__}"
+ Returns ``result`` unchanged if ``isinstance(result, expected_type)``.
+ For ``Iterator[T]``, returns a lazy iterator that checks each element
+ against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch.
+ """
+ if get_origin(expected_type) is collections.abc.Iterator:
Review Comment:
I double checked here. We should use `collections.abc.Iterator` for all the
type hints. So basically instead of `from typing import Iterator`, we should do
`from collections.abc import Iterator`. It works for both runtime and static
type checking. Consider it as `dict` vs `Dict`. This has been done since 3.9.
This is the correct way to use it because `collections.abc.Iterator` is the
actual abc all the iterators use. We don't need to change other code, we just
need to change the import.
##########
python/pyspark/worker.py:
##########
@@ -234,46 +251,56 @@ def chain(f, g):
return lambda *a: g(f(*a))
-def verify_result(expected_type: type) -> Callable[[Any], Iterator]:
- """
- Create a result verifier that checks both iterability and element types.
+@overload
+def verify_return_type(result: Any, expected_type: Type[T]) -> T: ...
- Returns a function that takes a UDF result, verifies it is iterable,
- and lazily type-checks each element via map.
- Parameters
- ----------
- expected_type : type
- The expected Python/PyArrow type for each element
- (e.g. pa.RecordBatch, pa.Array).
+@overload
+def verify_return_type(result: Any, expected_type: Any) -> Any: ...
Review Comment:
Why `Any` at all? Any case that can't be covered by `Type[T]`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]