Yicong-Huang opened a new pull request, #55530:
URL: https://github.com/apache/spark/pull/55530
### What changes were proposed in this pull request?
Refactor `ArrowBatchTransformer.enforce_schema` in
`python/pyspark/sql/conversion.py` to be the single entry point for Arrow
output schema enforcement, and migrate the three grouped/cogrouped map Arrow
UDF paths to use it.
Concretely:
1. `enforce_schema` is generalized:
- Accepts `pa.RecordBatch` **or** `pa.Table` (returns the same container
type).
- Adds `reorder_by_name: bool = True` parameter.
- `True` (default): match columns by name, reorder to target order,
rename output to target names. Any missing/extra name raises
`RESULT_COLUMN_NAMES_MISMATCH`.
- `False`: match columns by position (names ignored), preserve input
column names. Count mismatch raises `RESULT_COLUMN_SCHEMA_MISMATCH`.
- Collects **all** missing/extra/type mismatches before raising, instead
of raising on the first one (matches existing `verify_arrow_result` semantics).
- Uses `PySparkRuntimeError` with `errorClass`
(`RESULT_COLUMN_NAMES_MISMATCH` / `RESULT_COLUMN_TYPES_MISMATCH` /
`RESULT_COLUMN_SCHEMA_MISMATCH`) instead of bare f-string `PySparkTypeError`.
This gives the same friendly error format that `verify_arrow_result` already
produced.
2. Migrate three Arrow UDF paths in `python/pyspark/worker.py` to use
`enforce_schema(arrow_cast=False,
reorder_by_name=runner_conf.assign_cols_by_name)`:
- `wrap_cogrouped_map_arrow_udf` (`SQL_COGROUPED_MAP_ARROW_UDF`).
- `SQL_GROUPED_MAP_ARROW_UDF` mapper.
- `SQL_GROUPED_MAP_ARROW_ITER_UDF` mapper.
This removes the separate "verify, then manually reorder by name" steps
in grouped-map paths; `enforce_schema` handles both.
3. Delete the now-unused helpers `verify_arrow_table` and
`verify_arrow_batch`. The instance check (`pa.Table` / `pa.RecordBatch`) that
they did is inlined at the call site (still raises `UDF_RETURN_TYPE`).
4. The `SQL_ARROW_TABLE_UDF` path is out of scope for this PR (no benchmark
yet) and still uses `verify_arrow_result` as before, unchanged.
### Why are the changes needed?
Part of [SPARK-55388](https://issues.apache.org/jira/browse/SPARK-55388)
(Refactor PythonEvalType processing logic). Today, Arrow UDF output validation
is split across two places with inconsistent error formats:
- `verify_arrow_result` in `worker.py` raises friendly `errorClass` errors.
- `ArrowBatchTransformer.enforce_schema` in `conversion.py` raises bare
f-string `PySparkTypeError` errors.
Consolidating behind `enforce_schema` lets every Arrow UDF path share one
code path and one error-format convention, and drops the duplicate "verify +
reorder" work in grouped-map paths.
### Does this PR introduce _any_ user-facing change?
Yes, minor: the error messages for the `SQL_ARROW_UDTF` path (the only
pre-existing consumer of `enforce_schema` via `ArrowStreamArrowUDTFSerializer`)
change from the bare f-string format to the same friendly
`errorClass`-templated format already used by other Arrow UDFs:
- Before: `Result column 'x' does not exist in the output. Expected schema:
x: int32\ny: string, got: wrong_col: int32\nanother_wrong_col: double.`
- After: `[RESULT_COLUMN_NAMES_MISMATCH] Column names of the returned data
do not match specified schema. Missing: x, y. Unexpected: another_wrong_col,
wrong_col.`
For grouped/cogrouped map Arrow UDF paths, the error-class names and message
formats are unchanged (`verify_arrow_result` already used these).
### How was this patch tested?
**Unit tests** (`python/pyspark/sql/tests/test_conversion.py`):
- Existing enforce_schema tests updated to assert the new `errorClass`
(`PySparkTypeError` → `PySparkRuntimeError` with `getCondition` check).
- New tests added for: `reorder_by_name=True` reordering,
`reorder_by_name=False` positional matching, extra-column detection, positional
count mismatch, and `pa.Table` input/output.
**Existing integration tests**
(`python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py` and
`test_arrow_cogrouped_map.py`) already assert the `errorClass`-templated error
format; they continue to pass unchanged.
**Test regex** in `test_arrow_udtf.py` updated to match the new friendly
format for the two `SQL_ARROW_UDTF` error tests
(`test_arrow_udtf_error_mismatched_schema` and
`test_arrow_udtf_type_coercion_string_to_int`).
**ASV benchmarks** — ran `CogroupedMapArrowUDFTimeBench`,
`GroupedMapArrowUDFTimeBench`, and `GroupedMapArrowIterUDFTimeBench` with `-a
repeat=3` on both the baseline (upstream/master) and this PR. No regression
detected (`asv compare -f 1.05`):
```text
All benchmarks:
| Change | Before [d0edf1a7] | After [39dd966a] | Ratio | Benchmark
(Parameter)
|
|--------|-------------------|------------------|-------|-----------------------------------------------------------------------------------------------|
| | 139 +-3ms | 140 +-3ms | 1.00 |
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'concat_udf')
|
| | 91.8 +-3ms | 92.2 +-3ms | 1.00 |
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'identity_udf')
|
| | 307 +-6ms | 314 +-3ms | 1.02 |
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'left_semi_udf')
|
| | 27.2 +-2ms | 28.0 +-2ms | 1.03 |
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'concat_udf')
|
| | 17.3 +-1ms | 18.7 +-1ms | ~1.08 |
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'identity_udf')
|
| | 84.5 +-0.3ms | 85.3 +-1ms | 1.01 |
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'left_semi_udf')
|
| | 388 +-6ms | 386 +-4ms | 1.00 |
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'concat_udf')
|
| | 260 +-2ms | 269 +-0.7ms | 1.03 |
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'identity_udf')
|
| | 1.14 +-0.01s | 1.13 +-0.01s | 0.99 |
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'left_semi_udf')
|
| | 710 +-8ms | 724 +-4ms | 1.02 |
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'concat_udf')
|
| | 527 +-4ms | 540 +-2ms | 1.03 |
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'identity_udf')
|
| | 1.96 +-0.05s | 1.97 +-0.01s | 1.00 |
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'left_semi_udf')
|
| | 147 +-6ms | 146 +-7ms | 1.00 |
CogroupedMapArrowUDFTimeBench.time_worker('multi_key', 'concat_udf')
|
| | 96.9 +-0.5ms | 95.4 +-0.3ms | 0.98 |
CogroupedMapArrowUDFTimeBench.time_worker('multi_key', 'identity_udf')
|
| | 268 +-1ms | 266 +-2ms | 0.99 |
CogroupedMapArrowUDFTimeBench.time_worker('multi_key', 'left_semi_udf')
|
| | 539 +-8ms | 534 +-1ms | 0.99 |
CogroupedMapArrowUDFTimeBench.time_worker('wide_values', 'concat_udf')
|
| | 396 +-0.6ms | 403 +-2ms | 1.02 |
CogroupedMapArrowUDFTimeBench.time_worker('wide_values', 'identity_udf')
|
| | 785 +-10ms | 780 +-10ms | 0.99 |
CogroupedMapArrowUDFTimeBench.time_worker('wide_values', 'left_semi_udf')
|
| | 113 +-0.8ms | 115 +-3ms | 1.01 |
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_lg', 'filter_udf')
|
| | 72.3 +-2ms | 72.4 +-0.5ms | 1.00 |
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_lg', 'identity_udf')
|
| | 264 +-5ms | 265 +-5ms | 1.01 |
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_lg', 'sort_udf')
|
| | 19.4 +-1ms | 20.5 +-0.7ms | ~1.06 |
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_sm', 'filter_udf')
|
| | 13.5 +-2ms | 11.8 +-1ms | ~0.87 |
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_sm', 'identity_udf')
|
| | 33.8 +-1ms | 33.8 +-1ms | 1.00 |
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_sm', 'sort_udf')
|
| | 316 +-2ms | 312 +-3ms | 0.99 |
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_lg', 'filter_udf')
|
| | 205 +-2ms | 206 +-2ms | 1.00 |
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_lg', 'identity_udf')
|
| | 610 +-2ms | 609 +-2ms | 1.00 |
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_lg', 'sort_udf')
|
| | 496 +-3ms | 506 +-5ms | 1.02 |
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_sm', 'filter_udf')
|
| | 380 +-9ms | 384 +-3ms | 1.01 |
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_sm', 'identity_udf')
|
| | 595 +-9ms | 605 +-9ms | 1.02 |
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_sm', 'sort_udf')
|
| | 91.5 +-0.8ms | 92.8 +-0.6ms | 1.01 |
GroupedMapArrowIterUDFTimeBench.time_worker('multi_key', 'filter_udf')
|
| | 62.3 +-0.5ms | 63.6 +-0.8ms | 1.02 |
GroupedMapArrowIterUDFTimeBench.time_worker('multi_key', 'identity_udf')
|
| | 105 +-0.8ms | 110 +-5ms | ~1.05 |
GroupedMapArrowIterUDFTimeBench.time_worker('multi_key', 'sort_udf')
|
| | 375 +-9ms | 372 +-2ms | 0.99 |
GroupedMapArrowIterUDFTimeBench.time_worker('wide_values', 'filter_udf')
|
| | 269 +-2ms | 273 +-3ms | 1.02 |
GroupedMapArrowIterUDFTimeBench.time_worker('wide_values', 'identity_udf')
|
| | 439 +-2ms | 441 +-1ms | 1.01 |
GroupedMapArrowIterUDFTimeBench.time_worker('wide_values', 'sort_udf')
|
| | 113 +-0.7ms | 113 +-0.3ms | 1.00 |
GroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'filter_udf')
|
| | 77.1 +-3ms | 72.4 +-2ms | ~0.94 |
GroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'identity_udf')
|
| | 269 +-4ms | 263 +-0.9ms | 0.98 |
GroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'sort_udf')
|
| | 20.8 +-1ms | 20.2 +-1ms | 0.97 |
GroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'filter_udf')
|
| | 14.1 +-2ms | 15.6 +-0.7ms | ~1.10 |
GroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'identity_udf')
|
| | 35.3 +-1ms | 36.7 +-1ms | 1.04 |
GroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'sort_udf')
|
| | 342 +-4ms | 324 +-2ms | ~0.95 |
GroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'filter_udf')
|
| | 212 +-1ms | 210 +-2ms | 0.99 |
GroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'identity_udf')
|
| | 629 +-10ms | 613 +-4ms | 0.97 |
GroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'sort_udf')
|
| | 548 +-8ms | 541 +-2ms | 0.99 |
GroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'filter_udf')
|
| | 405 +-2ms | 402 +-1ms | 0.99 |
GroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'identity_udf')
|
| | 619 +-7ms | 662 +-30ms | ~1.07 |
GroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'sort_udf')
|
| | 96.6 +-0.3ms | 96.2 +-0.5ms | 1.00 |
GroupedMapArrowUDFTimeBench.time_worker('multi_key', 'filter_udf')
|
| | 65.5 +-0.2ms | 65.1 +-0.6ms | 0.99 |
GroupedMapArrowUDFTimeBench.time_worker('multi_key', 'identity_udf')
|
| | 108 +-0.7ms | 109 +-0.6ms | 1.01 |
GroupedMapArrowUDFTimeBench.time_worker('multi_key', 'sort_udf')
|
| | 377 +-0.3ms | 380 +-0.8ms | 1.01 |
GroupedMapArrowUDFTimeBench.time_worker('wide_values', 'filter_udf')
|
| | 272 +-2ms | 281 +-2ms | 1.03 |
GroupedMapArrowUDFTimeBench.time_worker('wide_values', 'identity_udf')
|
| | 444 +-2ms | 449 +-0.5ms | 1.01 |
GroupedMapArrowUDFTimeBench.time_worker('wide_values', 'sort_udf')
|
```
Tilde (\`~\`) markers are from micro-benchmarks (<20ms), where +/-1-2ms
stddev inflates the ratio. No benchmark is flagged as a regression at the
default 5% threshold.
### Was this patch authored or co-authored using generative AI tooling?
No.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]