This is an automated email from the ASF dual-hosted git repository.
HyukjinKwon pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 78a4c5474b5b [SPARK-56973][PYTHON] Consolidate verify_pandas_result
with verify_arrow_result via shared helper
78a4c5474b5b is described below
commit 78a4c5474b5b3ef8043c973c5153cdb8079d0a79
Author: Yicong Huang <[email protected]>
AuthorDate: Fri May 22 13:29:16 2026 -0700
[SPARK-56973][PYTHON] Consolidate verify_pandas_result with
verify_arrow_result via shared helper
### What changes were proposed in this pull request?
Consolidate the two UDF-result-verification paths in `worker.py`:
- Extract `_verify_column_schema(actual_names, expected_names, *,
assign_cols_by_name)` that raises `RESULT_COLUMN_NAMES_MISMATCH` (by-name) or
`RESULT_COLUMN_SCHEMA_MISMATCH` (by-position).
- `verify_pandas_result` now uses `verify_return_type` for the container
check and the new helper for the schema check.
- `verify_arrow_result` uses the same helper, keeping only its own
`RESULT_COLUMN_TYPES_MISMATCH` check inline.
- Fix `verify_return_type` to derive the top-level package (`pandas`
instead of `pandas.core` for `pd.DataFrame`).
### Why are the changes needed?
After SPARK-56937 added the column-count check to the arrow path, the
pandas and arrow verifiers raise the same set of error classes but duplicate
the name/count logic. A shared helper prevents drift as more pandas eval types
are refactored under SPARK-55388.
### Does this PR introduce _any_ user-facing change?
No. Same error classes raised under the same conditions with the same
`messageParameters`.
### How was this patch tested?
Existing tests. Verified locally: `test_pandas_map`,
`test_pandas_grouped_map`, `test_pandas_cogrouped_map`,
`test_arrow_grouped_map`, `test_arrow_cogrouped_map`, and
`test_udtf::LegacyUDTFArrowTests` all pass.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #56021 from Yicong-Huang/refactor/consolidate-verify-pandas-result.
Authored-by: Yicong Huang <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 9e85d0684ae7f6f18cdf7f4a4b30ce7173ae78e0)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/worker.py | 219 ++++++++++++++++++++++-------------------------
1 file changed, 100 insertions(+), 119 deletions(-)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index f6f1913edc8e..95980a6842ba 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -44,6 +44,9 @@ from typing import (
T = TypeVar("T")
if TYPE_CHECKING:
+ import pandas as pd
+ import pyarrow as pa
+
from pyspark.sql.pandas._typing import GroupedBatch
from pyspark.accumulators import (
@@ -256,8 +259,7 @@ def verify_return_type(result: T, expected_type: Type[T])
-> T:
"""
if get_origin(expected_type) is Iterator:
(element_type,) = get_args(expected_type)
- package = getattr(inspect.getmodule(element_type), "__package__", "")
- label = f"iterator of {package}.{element_type.__name__}"
+ label = f"iterator of
{_top_level_package(element_type)}.{element_type.__name__}"
if not isinstance(result, Iterator):
raise PySparkTypeError(
@@ -279,17 +281,21 @@ def verify_return_type(result: T, expected_type: Type[T])
-> T:
return map(check_element, result) # type: ignore[return-value]
if not isinstance(result, expected_type):
- package = getattr(inspect.getmodule(expected_type), "__package__", "")
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
- "expected": f"{package}.{expected_type.__name__}",
+ "expected":
f"{_top_level_package(expected_type)}.{expected_type.__name__}",
"actual": type(result).__name__,
},
)
return result
+def _top_level_package(t: type) -> str:
+ """Return the top-level package of ``t`` (``pandas`` for
``pd.DataFrame``)."""
+ return (t.__module__ or "").split(".", 1)[0]
+
+
def verify_result_row_count(result_length: int, expected: int) -> None:
"""Raise if the result row count doesn't match the expected input row
count."""
if result_length != expected:
@@ -465,64 +471,59 @@ def wrap_pandas_batch_iter_udf(f, return_type,
runner_conf):
)
-def verify_pandas_result(result, return_type, assign_cols_by_name,
truncate_return_schema):
- import pandas as pd
-
- if isinstance(return_type, StructType):
- if not isinstance(result, pd.DataFrame):
- raise PySparkTypeError(
- errorClass="UDF_RETURN_TYPE",
+def _verify_column_schema(
+ actual_names: list, expected_names: list, *, assign_cols_by_name: bool
+) -> None:
+ """Check column names (by-name) or count (by-position) match the expected
schema."""
+ if assign_cols_by_name:
+ actual_set = set(actual_names)
+ expected_set = set(expected_names)
+ missing = sorted(expected_set.difference(actual_set))
+ extra = sorted(actual_set.difference(expected_set))
+ if missing or extra:
+ raise PySparkRuntimeError(
+ errorClass="RESULT_COLUMN_NAMES_MISMATCH",
messageParameters={
- "expected": "pandas.DataFrame",
- "actual": type(result).__name__,
+ "missing": f" Missing: {', '.join(missing)}." if missing
else "",
+ "extra": f" Unexpected: {', '.join(extra)}." if extra else
"",
},
)
+ elif len(actual_names) != len(expected_names):
+ raise PySparkRuntimeError(
+ errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
+ messageParameters={
+ "expected": str(len(expected_names)),
+ "actual": str(len(actual_names)),
+ },
+ )
- # check the schema of the result only if it is not empty or has columns
- if not result.empty or len(result.columns) != 0:
- # if any column name of the result is a string
- # the column names of the result have to match the return type
- # see create_array in
pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer
- field_names = set([field.name for field in return_type.fields])
- # only the first len(field_names) result columns are considered
- # when truncating the return schema
- result_columns = (
- result.columns[: len(field_names)] if truncate_return_schema
else result.columns
- )
- column_names = set(result_columns)
- if (
- assign_cols_by_name
- and any(isinstance(name, str) for name in result.columns)
- and column_names != field_names
- ):
- missing = sorted(list(field_names.difference(column_names)))
- missing = f" Missing: {', '.join(missing)}." if missing else ""
- extra = sorted(list(column_names.difference(field_names)))
- extra = f" Unexpected: {', '.join(extra)}." if extra else ""
+def verify_pandas_result(
+ result: Union["pd.DataFrame", "pd.Series"],
+ return_type: DataType,
+ assign_cols_by_name: bool,
+ truncate_return_schema: bool,
+) -> None:
+ import pandas as pd
- raise PySparkRuntimeError(
- errorClass="RESULT_COLUMN_NAMES_MISMATCH",
- messageParameters={
- "missing": missing,
- "extra": extra,
- },
- )
- # otherwise the number of columns of result have to match the
return type
- elif len(result_columns) != len(return_type):
- raise PySparkRuntimeError(
- errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
- messageParameters={
- "expected": str(len(return_type)),
- "actual": str(len(result.columns)),
- },
- )
- else:
- if not isinstance(result, pd.Series):
- raise PySparkTypeError(
- errorClass="UDF_RETURN_TYPE",
- messageParameters={"expected": "pandas.Series", "actual":
type(result).__name__},
- )
+ if not isinstance(return_type, StructType):
+ verify_return_type(result, pd.Series)
+ return
+
+ verify_return_type(result, pd.DataFrame)
+
+ # Skip schema check on a fully empty result (no rows and no columns).
+ if result.empty and len(result.columns) == 0:
+ return
+
+ field_names = [field.name for field in return_type.fields]
+ actual_names = (
+ list(result.columns[: len(field_names)]) if truncate_return_schema
else list(result.columns)
+ )
+ # By-name mode only applies when the result has string column names;
+ # a numeric RangeIndex falls back to a by-position count check.
+ by_name = assign_cols_by_name and any(isinstance(n, str) for n in
result.columns)
+ _verify_column_schema(actual_names, field_names,
assign_cols_by_name=by_name)
def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
@@ -547,73 +548,53 @@ def wrap_cogrouped_map_pandas_udf(f, return_type,
argspec, runner_conf):
return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), return_type)]
-def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types):
- # the types of the fields have to be identical to return type
- # an empty table can have no columns; if there are columns, they have to
match
- if result.num_columns != 0 or result.num_rows != 0:
- # columns are either mapped by name or position
- if assign_cols_by_name:
- actual_cols_and_types = {
- name: dataType for name, dataType in zip(result.schema.names,
result.schema.types)
- }
- missing = sorted(
-
list(set(expected_cols_and_types.keys()).difference(actual_cols_and_types.keys()))
- )
- extra = sorted(
-
list(set(actual_cols_and_types.keys()).difference(expected_cols_and_types.keys()))
- )
-
- if missing or extra:
- missing = f" Missing: {', '.join(missing)}." if missing else ""
- extra = f" Unexpected: {', '.join(extra)}." if extra else ""
-
- raise PySparkRuntimeError(
- errorClass="RESULT_COLUMN_NAMES_MISMATCH",
- messageParameters={
- "missing": missing,
- "extra": extra,
- },
- )
+def verify_arrow_result(
+ result: Union["pa.Table", "pa.RecordBatch"],
+ assign_cols_by_name: bool,
+ expected_cols_and_types: Union[dict[str, "pa.DataType"], list[tuple[str,
"pa.DataType"]]],
+) -> None:
+ # Skip schema check on a fully empty result (no rows and no columns).
+ if result.num_columns == 0 and result.num_rows == 0:
+ return
+
+ actual_names = list(result.schema.names)
+ actual_types = list(result.schema.types)
+ # expected_cols_and_types is a dict in by-name mode, list of (name, type)
by position.
+ if isinstance(expected_cols_and_types, dict):
+ expected_names = list(expected_cols_and_types.keys())
+ else:
+ expected_names = [name for name, _ in expected_cols_and_types]
- column_types = [
- (name, expected_cols_and_types[name],
actual_cols_and_types[name])
- for name in sorted(expected_cols_and_types.keys())
- ]
- else:
- actual_cols_and_types = [
- (name, dataType) for name, dataType in
zip(result.schema.names, result.schema.types)
- ]
- if len(actual_cols_and_types) != len(expected_cols_and_types):
- raise PySparkRuntimeError(
- errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
- messageParameters={
- "expected": str(len(expected_cols_and_types)),
- "actual": str(len(actual_cols_and_types)),
- },
- )
- column_types = [
- (expected_name, expected_type, actual_type)
- for (expected_name, expected_type), (actual_name, actual_type)
in zip(
- expected_cols_and_types, actual_cols_and_types
- )
- ]
+ _verify_column_schema(actual_names, expected_names,
assign_cols_by_name=assign_cols_by_name)
- type_mismatch = [
- (name, expected, actual)
- for name, expected, actual in column_types
- if actual != expected
+ if isinstance(expected_cols_and_types, dict):
+ actual_by_name = dict(zip(actual_names, actual_types))
+ column_types = [
+ (name, expected_cols_and_types[name], actual_by_name[name])
+ for name in sorted(expected_cols_and_types.keys())
]
-
- if type_mismatch:
- raise PySparkRuntimeError(
- errorClass="RESULT_COLUMN_TYPES_MISMATCH",
- messageParameters={
- "mismatch": ", ".join(
- "column '{}' (expected {}, actual {})".format(name,
expected, actual)
- for name, expected, actual in type_mismatch
- )
- },
+ else:
+ column_types = [
+ (expected_name, expected_type, actual_type)
+ for (expected_name, expected_type), actual_type in zip(
+ expected_cols_and_types, actual_types
)
+ ]
+
+ type_mismatch = [
+ (name, expected, actual) for name, expected, actual in column_types if
actual != expected
+ ]
+
+ if type_mismatch:
+ raise PySparkRuntimeError(
+ errorClass="RESULT_COLUMN_TYPES_MISMATCH",
+ messageParameters={
+ "mismatch": ", ".join(
+ "column '{}' (expected {}, actual {})".format(name,
expected, actual)
+ for name, expected, actual in type_mismatch
+ )
+ },
+ )
def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]