This is an automated email from the ASF dual-hosted git repository.

HyukjinKwon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9e85d0684ae7 [SPARK-56973][PYTHON] Consolidate verify_pandas_result 
with verify_arrow_result via shared helper
9e85d0684ae7 is described below

commit 9e85d0684ae7f6f18cdf7f4a4b30ce7173ae78e0
Author: Yicong Huang <[email protected]>
AuthorDate: Fri May 22 13:29:16 2026 -0700

    [SPARK-56973][PYTHON] Consolidate verify_pandas_result with 
verify_arrow_result via shared helper
    
    ### What changes were proposed in this pull request?
    
    Consolidate the two UDF-result-verification paths in `worker.py`:
    
    - Extract `_verify_column_schema(actual_names, expected_names, *, 
assign_cols_by_name)` that raises `RESULT_COLUMN_NAMES_MISMATCH` (by-name) or 
`RESULT_COLUMN_SCHEMA_MISMATCH` (by-position).
    - `verify_pandas_result` now uses `verify_return_type` for the container 
check and the new helper for the schema check.
    - `verify_arrow_result` uses the same helper, keeping only its own 
`RESULT_COLUMN_TYPES_MISMATCH` check inline.
    - Fix `verify_return_type` to derive the top-level package (`pandas` 
instead of `pandas.core` for `pd.DataFrame`).
    
    ### Why are the changes needed?
    
    After SPARK-56937 added the column-count check to the arrow path, the 
pandas and arrow verifiers raise the same set of error classes but duplicate 
the name/count logic. A shared helper prevents drift as more pandas eval types 
are refactored under SPARK-55388.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. Same error classes raised under the same conditions with the same 
`messageParameters`.
    
    ### How was this patch tested?
    
    Existing tests. Verified locally: `test_pandas_map`, 
`test_pandas_grouped_map`, `test_pandas_cogrouped_map`, 
`test_arrow_grouped_map`, `test_arrow_cogrouped_map`, and 
`test_udtf::LegacyUDTFArrowTests` all pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #56021 from Yicong-Huang/refactor/consolidate-verify-pandas-result.
    
    Authored-by: Yicong Huang <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/worker.py | 219 ++++++++++++++++++++++-------------------------
 1 file changed, 100 insertions(+), 119 deletions(-)

diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 306d4f80dbe5..c894343b752a 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -44,6 +44,9 @@ from typing import (
 T = TypeVar("T")
 
 if TYPE_CHECKING:
+    import pandas as pd
+    import pyarrow as pa
+
     from pyspark.sql.pandas._typing import GroupedBatch
 
 from pyspark.accumulators import (
@@ -256,8 +259,7 @@ def verify_return_type(result: T, expected_type: Type[T]) 
-> T:
     """
     if get_origin(expected_type) is Iterator:
         (element_type,) = get_args(expected_type)
-        package = getattr(inspect.getmodule(element_type), "__package__", "")
-        label = f"iterator of {package}.{element_type.__name__}"
+        label = f"iterator of 
{_top_level_package(element_type)}.{element_type.__name__}"
 
         if not isinstance(result, Iterator):
             raise PySparkTypeError(
@@ -279,17 +281,21 @@ def verify_return_type(result: T, expected_type: Type[T]) 
-> T:
         return map(check_element, result)  # type: ignore[return-value]
 
     if not isinstance(result, expected_type):
-        package = getattr(inspect.getmodule(expected_type), "__package__", "")
         raise PySparkTypeError(
             errorClass="UDF_RETURN_TYPE",
             messageParameters={
-                "expected": f"{package}.{expected_type.__name__}",
+                "expected": 
f"{_top_level_package(expected_type)}.{expected_type.__name__}",
                 "actual": type(result).__name__,
             },
         )
     return result
 
 
+def _top_level_package(t: type) -> str:
+    """Return the top-level package of ``t`` (``pandas`` for 
``pd.DataFrame``)."""
+    return (t.__module__ or "").split(".", 1)[0]
+
+
 def verify_result_row_count(result_length: int, expected: int) -> None:
     """Raise if the result row count doesn't match the expected input row 
count."""
     if result_length != expected:
@@ -429,64 +435,59 @@ def wrap_pandas_batch_iter_udf(f, return_type, 
runner_conf):
     )
 
 
-def verify_pandas_result(result, return_type, assign_cols_by_name, 
truncate_return_schema):
-    import pandas as pd
-
-    if isinstance(return_type, StructType):
-        if not isinstance(result, pd.DataFrame):
-            raise PySparkTypeError(
-                errorClass="UDF_RETURN_TYPE",
+def _verify_column_schema(
+    actual_names: list, expected_names: list, *, assign_cols_by_name: bool
+) -> None:
+    """Check column names (by-name) or count (by-position) match the expected 
schema."""
+    if assign_cols_by_name:
+        actual_set = set(actual_names)
+        expected_set = set(expected_names)
+        missing = sorted(expected_set.difference(actual_set))
+        extra = sorted(actual_set.difference(expected_set))
+        if missing or extra:
+            raise PySparkRuntimeError(
+                errorClass="RESULT_COLUMN_NAMES_MISMATCH",
                 messageParameters={
-                    "expected": "pandas.DataFrame",
-                    "actual": type(result).__name__,
+                    "missing": f" Missing: {', '.join(missing)}." if missing 
else "",
+                    "extra": f" Unexpected: {', '.join(extra)}." if extra else 
"",
                 },
             )
+    elif len(actual_names) != len(expected_names):
+        raise PySparkRuntimeError(
+            errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
+            messageParameters={
+                "expected": str(len(expected_names)),
+                "actual": str(len(actual_names)),
+            },
+        )
 
-        # check the schema of the result only if it is not empty or has columns
-        if not result.empty or len(result.columns) != 0:
-            # if any column name of the result is a string
-            # the column names of the result have to match the return type
-            #   see create_array in 
pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer
-            field_names = set([field.name for field in return_type.fields])
-            # only the first len(field_names) result columns are considered
-            # when truncating the return schema
-            result_columns = (
-                result.columns[: len(field_names)] if truncate_return_schema 
else result.columns
-            )
-            column_names = set(result_columns)
-            if (
-                assign_cols_by_name
-                and any(isinstance(name, str) for name in result.columns)
-                and column_names != field_names
-            ):
-                missing = sorted(list(field_names.difference(column_names)))
-                missing = f" Missing: {', '.join(missing)}." if missing else ""
 
-                extra = sorted(list(column_names.difference(field_names)))
-                extra = f" Unexpected: {', '.join(extra)}." if extra else ""
+def verify_pandas_result(
+    result: Union["pd.DataFrame", "pd.Series"],
+    return_type: DataType,
+    assign_cols_by_name: bool,
+    truncate_return_schema: bool,
+) -> None:
+    import pandas as pd
 
-                raise PySparkRuntimeError(
-                    errorClass="RESULT_COLUMN_NAMES_MISMATCH",
-                    messageParameters={
-                        "missing": missing,
-                        "extra": extra,
-                    },
-                )
-            # otherwise the number of columns of result have to match the 
return type
-            elif len(result_columns) != len(return_type):
-                raise PySparkRuntimeError(
-                    errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
-                    messageParameters={
-                        "expected": str(len(return_type)),
-                        "actual": str(len(result.columns)),
-                    },
-                )
-    else:
-        if not isinstance(result, pd.Series):
-            raise PySparkTypeError(
-                errorClass="UDF_RETURN_TYPE",
-                messageParameters={"expected": "pandas.Series", "actual": 
type(result).__name__},
-            )
+    if not isinstance(return_type, StructType):
+        verify_return_type(result, pd.Series)
+        return
+
+    verify_return_type(result, pd.DataFrame)
+
+    # Skip schema check on a fully empty result (no rows and no columns).
+    if result.empty and len(result.columns) == 0:
+        return
+
+    field_names = [field.name for field in return_type.fields]
+    actual_names = (
+        list(result.columns[: len(field_names)]) if truncate_return_schema 
else list(result.columns)
+    )
+    # By-name mode only applies when the result has string column names;
+    # a numeric RangeIndex falls back to a by-position count check.
+    by_name = assign_cols_by_name and any(isinstance(n, str) for n in 
result.columns)
+    _verify_column_schema(actual_names, field_names, 
assign_cols_by_name=by_name)
 
 
 def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
@@ -511,73 +512,53 @@ def wrap_cogrouped_map_pandas_udf(f, return_type, 
argspec, runner_conf):
     return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), return_type)]
 
 
-def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types):
-    # the types of the fields have to be identical to return type
-    # an empty table can have no columns; if there are columns, they have to 
match
-    if result.num_columns != 0 or result.num_rows != 0:
-        # columns are either mapped by name or position
-        if assign_cols_by_name:
-            actual_cols_and_types = {
-                name: dataType for name, dataType in zip(result.schema.names, 
result.schema.types)
-            }
-            missing = sorted(
-                
list(set(expected_cols_and_types.keys()).difference(actual_cols_and_types.keys()))
-            )
-            extra = sorted(
-                
list(set(actual_cols_and_types.keys()).difference(expected_cols_and_types.keys()))
-            )
-
-            if missing or extra:
-                missing = f" Missing: {', '.join(missing)}." if missing else ""
-                extra = f" Unexpected: {', '.join(extra)}." if extra else ""
-
-                raise PySparkRuntimeError(
-                    errorClass="RESULT_COLUMN_NAMES_MISMATCH",
-                    messageParameters={
-                        "missing": missing,
-                        "extra": extra,
-                    },
-                )
+def verify_arrow_result(
+    result: Union["pa.Table", "pa.RecordBatch"],
+    assign_cols_by_name: bool,
+    expected_cols_and_types: Union[dict[str, "pa.DataType"], list[tuple[str, 
"pa.DataType"]]],
+) -> None:
+    # Skip schema check on a fully empty result (no rows and no columns).
+    if result.num_columns == 0 and result.num_rows == 0:
+        return
+
+    actual_names = list(result.schema.names)
+    actual_types = list(result.schema.types)
+    # expected_cols_and_types is a dict in by-name mode, list of (name, type) 
by position.
+    if isinstance(expected_cols_and_types, dict):
+        expected_names = list(expected_cols_and_types.keys())
+    else:
+        expected_names = [name for name, _ in expected_cols_and_types]
 
-            column_types = [
-                (name, expected_cols_and_types[name], 
actual_cols_and_types[name])
-                for name in sorted(expected_cols_and_types.keys())
-            ]
-        else:
-            actual_cols_and_types = [
-                (name, dataType) for name, dataType in 
zip(result.schema.names, result.schema.types)
-            ]
-            if len(actual_cols_and_types) != len(expected_cols_and_types):
-                raise PySparkRuntimeError(
-                    errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
-                    messageParameters={
-                        "expected": str(len(expected_cols_and_types)),
-                        "actual": str(len(actual_cols_and_types)),
-                    },
-                )
-            column_types = [
-                (expected_name, expected_type, actual_type)
-                for (expected_name, expected_type), (actual_name, actual_type) 
in zip(
-                    expected_cols_and_types, actual_cols_and_types
-                )
-            ]
+    _verify_column_schema(actual_names, expected_names, 
assign_cols_by_name=assign_cols_by_name)
 
-        type_mismatch = [
-            (name, expected, actual)
-            for name, expected, actual in column_types
-            if actual != expected
+    if isinstance(expected_cols_and_types, dict):
+        actual_by_name = dict(zip(actual_names, actual_types))
+        column_types = [
+            (name, expected_cols_and_types[name], actual_by_name[name])
+            for name in sorted(expected_cols_and_types.keys())
         ]
-
-        if type_mismatch:
-            raise PySparkRuntimeError(
-                errorClass="RESULT_COLUMN_TYPES_MISMATCH",
-                messageParameters={
-                    "mismatch": ", ".join(
-                        "column '{}' (expected {}, actual {})".format(name, 
expected, actual)
-                        for name, expected, actual in type_mismatch
-                    )
-                },
+    else:
+        column_types = [
+            (expected_name, expected_type, actual_type)
+            for (expected_name, expected_type), actual_type in zip(
+                expected_cols_and_types, actual_types
             )
+        ]
+
+    type_mismatch = [
+        (name, expected, actual) for name, expected, actual in column_types if 
actual != expected
+    ]
+
+    if type_mismatch:
+        raise PySparkRuntimeError(
+            errorClass="RESULT_COLUMN_TYPES_MISMATCH",
+            messageParameters={
+                "mismatch": ", ".join(
+                    "column '{}' (expected {}, actual {})".format(name, 
expected, actual)
+                    for name, expected, actual in type_mismatch
+                )
+            },
+        )
 
 
 def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to