Yicong-Huang opened a new pull request, #55530:
URL: https://github.com/apache/spark/pull/55530

   ### What changes were proposed in this pull request?
   
   Refactor `ArrowBatchTransformer.enforce_schema` in 
`python/pyspark/sql/conversion.py` to be the single entry point for Arrow 
output schema enforcement, and migrate the three grouped/cogrouped map Arrow 
UDF paths to use it.
   
   Concretely:
   
   1. `enforce_schema` is generalized:
      - Accepts `pa.RecordBatch` **or** `pa.Table` (returns the same container 
type).
      - Adds `reorder_by_name: bool = True` parameter.
        - `True` (default): match columns by name, reorder to target order, 
rename output to target names. Any missing/extra name raises 
`RESULT_COLUMN_NAMES_MISMATCH`.
        - `False`: match columns by position (names ignored), preserve input 
column names. Count mismatch raises `RESULT_COLUMN_SCHEMA_MISMATCH`.
      - Collects **all** missing/extra/type mismatches before raising, instead 
of raising on the first one (matches existing `verify_arrow_result` semantics).
      - Uses `PySparkRuntimeError` with `errorClass` 
(`RESULT_COLUMN_NAMES_MISMATCH` / `RESULT_COLUMN_TYPES_MISMATCH` / 
`RESULT_COLUMN_SCHEMA_MISMATCH`) instead of bare f-string `PySparkTypeError`. 
This gives the same friendly error format that `verify_arrow_result` already 
produced.
   
   2. Migrate three Arrow UDF paths in `python/pyspark/worker.py` to use 
`enforce_schema(arrow_cast=False, 
reorder_by_name=runner_conf.assign_cols_by_name)`:
      - `wrap_cogrouped_map_arrow_udf` (`SQL_COGROUPED_MAP_ARROW_UDF`).
      - `SQL_GROUPED_MAP_ARROW_UDF` mapper.
      - `SQL_GROUPED_MAP_ARROW_ITER_UDF` mapper.
   
      This removes the separate "verify, then manually reorder by name" steps 
in grouped-map paths; `enforce_schema` handles both.
   
   3. Delete the now-unused helpers `verify_arrow_table` and 
`verify_arrow_batch`. The instance check (`pa.Table` / `pa.RecordBatch`) that 
they did is inlined at the call site (still raises `UDF_RETURN_TYPE`).
   
   4. The `SQL_ARROW_TABLE_UDF` path is out of scope for this PR (no benchmark 
yet) and still uses `verify_arrow_result` as before, unchanged.
   
   ### Why are the changes needed?
   
   Part of [SPARK-55388](https://issues.apache.org/jira/browse/SPARK-55388) 
(Refactor PythonEvalType processing logic). Today, Arrow UDF output validation 
is split across two places with inconsistent error formats:
   
   - `verify_arrow_result` in `worker.py` raises friendly `errorClass` errors.
   - `ArrowBatchTransformer.enforce_schema` in `conversion.py` raises bare 
f-string `PySparkTypeError` errors.
   
   Consolidating behind `enforce_schema` lets every Arrow UDF path share one 
code path and one error-format convention, and drops the duplicate "verify + 
reorder" work in grouped-map paths.
   
   ### Does this PR introduce _any_ user-facing change?
   
   Yes, minor: the error messages for the `SQL_ARROW_UDTF` path (the only 
pre-existing consumer of `enforce_schema` via `ArrowStreamArrowUDTFSerializer`) 
change from the bare f-string format to the same friendly 
`errorClass`-templated format already used by other Arrow UDFs:
   
   - Before: `Result column 'x' does not exist in the output. Expected schema: 
x: int32\ny: string, got: wrong_col: int32\nanother_wrong_col: double.`
   - After: `[RESULT_COLUMN_NAMES_MISMATCH] Column names of the returned data 
do not match specified schema. Missing: x, y. Unexpected: another_wrong_col, 
wrong_col.`
   
   For grouped/cogrouped map Arrow UDF paths, the error-class names and message 
formats are unchanged (`verify_arrow_result` already used these).
   
   ### How was this patch tested?
   
   **Unit tests** (`python/pyspark/sql/tests/test_conversion.py`):
   - Existing enforce_schema tests updated to assert the new `errorClass` 
(`PySparkTypeError` → `PySparkRuntimeError` with `getCondition` check).
   - New tests added for: `reorder_by_name=True` reordering, 
`reorder_by_name=False` positional matching, extra-column detection, positional 
count mismatch, and `pa.Table` input/output.
   
   **Existing integration tests** 
(`python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py` and 
`test_arrow_cogrouped_map.py`) already assert the `errorClass`-templated error 
format; they continue to pass unchanged.
   
   **Test regex** in `test_arrow_udtf.py` updated to match the new friendly 
format for the two `SQL_ARROW_UDTF` error tests 
(`test_arrow_udtf_error_mismatched_schema` and 
`test_arrow_udtf_type_coercion_string_to_int`).
   
   **ASV benchmarks** — ran `CogroupedMapArrowUDFTimeBench`, 
`GroupedMapArrowUDFTimeBench`, and `GroupedMapArrowIterUDFTimeBench` with `-a 
repeat=3` on both the baseline (upstream/master) and this PR. No regression 
detected (`asv compare -f 1.05`):
   
   ```text
   All benchmarks:
   
   | Change | Before [d0edf1a7] | After [39dd966a] | Ratio | Benchmark 
(Parameter)                                                                     
    |
   
|--------|-------------------|------------------|-------|-----------------------------------------------------------------------------------------------|
   |        | 139 +-3ms         | 140 +-3ms        | 1.00  | 
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'concat_udf')        
              |
   |        | 91.8 +-3ms        | 92.2 +-3ms       | 1.00  | 
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'identity_udf')      
              |
   |        | 307 +-6ms         | 314 +-3ms        | 1.02  | 
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'left_semi_udf')     
              |
   |        | 27.2 +-2ms        | 28.0 +-2ms       | 1.03  | 
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'concat_udf')        
              |
   |        | 17.3 +-1ms        | 18.7 +-1ms       | ~1.08 | 
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'identity_udf')      
              |
   |        | 84.5 +-0.3ms      | 85.3 +-1ms       | 1.01  | 
CogroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'left_semi_udf')     
              |
   |        | 388 +-6ms         | 386 +-4ms        | 1.00  | 
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'concat_udf')       
              |
   |        | 260 +-2ms         | 269 +-0.7ms      | 1.03  | 
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'identity_udf')     
              |
   |        | 1.14 +-0.01s      | 1.13 +-0.01s     | 0.99  | 
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'left_semi_udf')    
              |
   |        | 710 +-8ms         | 724 +-4ms        | 1.02  | 
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'concat_udf')       
              |
   |        | 527 +-4ms         | 540 +-2ms        | 1.03  | 
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'identity_udf')     
              |
   |        | 1.96 +-0.05s      | 1.97 +-0.01s     | 1.00  | 
CogroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'left_semi_udf')    
              |
   |        | 147 +-6ms         | 146 +-7ms        | 1.00  | 
CogroupedMapArrowUDFTimeBench.time_worker('multi_key', 'concat_udf')            
              |
   |        | 96.9 +-0.5ms      | 95.4 +-0.3ms     | 0.98  | 
CogroupedMapArrowUDFTimeBench.time_worker('multi_key', 'identity_udf')          
              |
   |        | 268 +-1ms         | 266 +-2ms        | 0.99  | 
CogroupedMapArrowUDFTimeBench.time_worker('multi_key', 'left_semi_udf')         
              |
   |        | 539 +-8ms         | 534 +-1ms        | 0.99  | 
CogroupedMapArrowUDFTimeBench.time_worker('wide_values', 'concat_udf')          
              |
   |        | 396 +-0.6ms       | 403 +-2ms        | 1.02  | 
CogroupedMapArrowUDFTimeBench.time_worker('wide_values', 'identity_udf')        
              |
   |        | 785 +-10ms        | 780 +-10ms       | 0.99  | 
CogroupedMapArrowUDFTimeBench.time_worker('wide_values', 'left_semi_udf')       
              |
   |        | 113 +-0.8ms       | 115 +-3ms        | 1.01  | 
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_lg', 'filter_udf')      
              |
   |        | 72.3 +-2ms        | 72.4 +-0.5ms     | 1.00  | 
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_lg', 'identity_udf')    
              |
   |        | 264 +-5ms         | 265 +-5ms        | 1.01  | 
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_lg', 'sort_udf')        
              |
   |        | 19.4 +-1ms        | 20.5 +-0.7ms     | ~1.06 | 
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_sm', 'filter_udf')      
              |
   |        | 13.5 +-2ms        | 11.8 +-1ms       | ~0.87 | 
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_sm', 'identity_udf')    
              |
   |        | 33.8 +-1ms        | 33.8 +-1ms       | 1.00  | 
GroupedMapArrowIterUDFTimeBench.time_worker('few_groups_sm', 'sort_udf')        
              |
   |        | 316 +-2ms         | 312 +-3ms        | 0.99  | 
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_lg', 'filter_udf')     
              |
   |        | 205 +-2ms         | 206 +-2ms        | 1.00  | 
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_lg', 'identity_udf')   
              |
   |        | 610 +-2ms         | 609 +-2ms        | 1.00  | 
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_lg', 'sort_udf')       
              |
   |        | 496 +-3ms         | 506 +-5ms        | 1.02  | 
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_sm', 'filter_udf')     
              |
   |        | 380 +-9ms         | 384 +-3ms        | 1.01  | 
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_sm', 'identity_udf')   
              |
   |        | 595 +-9ms         | 605 +-9ms        | 1.02  | 
GroupedMapArrowIterUDFTimeBench.time_worker('many_groups_sm', 'sort_udf')       
              |
   |        | 91.5 +-0.8ms      | 92.8 +-0.6ms     | 1.01  | 
GroupedMapArrowIterUDFTimeBench.time_worker('multi_key', 'filter_udf')          
              |
   |        | 62.3 +-0.5ms      | 63.6 +-0.8ms     | 1.02  | 
GroupedMapArrowIterUDFTimeBench.time_worker('multi_key', 'identity_udf')        
              |
   |        | 105 +-0.8ms       | 110 +-5ms        | ~1.05 | 
GroupedMapArrowIterUDFTimeBench.time_worker('multi_key', 'sort_udf')            
              |
   |        | 375 +-9ms         | 372 +-2ms        | 0.99  | 
GroupedMapArrowIterUDFTimeBench.time_worker('wide_values', 'filter_udf')        
              |
   |        | 269 +-2ms         | 273 +-3ms        | 1.02  | 
GroupedMapArrowIterUDFTimeBench.time_worker('wide_values', 'identity_udf')      
              |
   |        | 439 +-2ms         | 441 +-1ms        | 1.01  | 
GroupedMapArrowIterUDFTimeBench.time_worker('wide_values', 'sort_udf')          
              |
   |        | 113 +-0.7ms       | 113 +-0.3ms      | 1.00  | 
GroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'filter_udf')          
              |
   |        | 77.1 +-3ms        | 72.4 +-2ms       | ~0.94 | 
GroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'identity_udf')        
              |
   |        | 269 +-4ms         | 263 +-0.9ms      | 0.98  | 
GroupedMapArrowUDFTimeBench.time_worker('few_groups_lg', 'sort_udf')            
              |
   |        | 20.8 +-1ms        | 20.2 +-1ms       | 0.97  | 
GroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'filter_udf')          
              |
   |        | 14.1 +-2ms        | 15.6 +-0.7ms     | ~1.10 | 
GroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'identity_udf')        
              |
   |        | 35.3 +-1ms        | 36.7 +-1ms       | 1.04  | 
GroupedMapArrowUDFTimeBench.time_worker('few_groups_sm', 'sort_udf')            
              |
   |        | 342 +-4ms         | 324 +-2ms        | ~0.95 | 
GroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'filter_udf')         
              |
   |        | 212 +-1ms         | 210 +-2ms        | 0.99  | 
GroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'identity_udf')       
              |
   |        | 629 +-10ms        | 613 +-4ms        | 0.97  | 
GroupedMapArrowUDFTimeBench.time_worker('many_groups_lg', 'sort_udf')           
              |
   |        | 548 +-8ms         | 541 +-2ms        | 0.99  | 
GroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'filter_udf')         
              |
   |        | 405 +-2ms         | 402 +-1ms        | 0.99  | 
GroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'identity_udf')       
              |
   |        | 619 +-7ms         | 662 +-30ms       | ~1.07 | 
GroupedMapArrowUDFTimeBench.time_worker('many_groups_sm', 'sort_udf')           
              |
   |        | 96.6 +-0.3ms      | 96.2 +-0.5ms     | 1.00  | 
GroupedMapArrowUDFTimeBench.time_worker('multi_key', 'filter_udf')              
              |
   |        | 65.5 +-0.2ms      | 65.1 +-0.6ms     | 0.99  | 
GroupedMapArrowUDFTimeBench.time_worker('multi_key', 'identity_udf')            
              |
   |        | 108 +-0.7ms       | 109 +-0.6ms      | 1.01  | 
GroupedMapArrowUDFTimeBench.time_worker('multi_key', 'sort_udf')                
              |
   |        | 377 +-0.3ms       | 380 +-0.8ms      | 1.01  | 
GroupedMapArrowUDFTimeBench.time_worker('wide_values', 'filter_udf')            
              |
   |        | 272 +-2ms         | 281 +-2ms        | 1.03  | 
GroupedMapArrowUDFTimeBench.time_worker('wide_values', 'identity_udf')          
              |
   |        | 444 +-2ms         | 449 +-0.5ms      | 1.01  | 
GroupedMapArrowUDFTimeBench.time_worker('wide_values', 'sort_udf')              
              |
   ```
   
   Tilde (\`~\`) markers are from micro-benchmarks (<20ms), where +/-1-2ms 
stddev inflates the ratio. No benchmark is flagged as a regression at the 
default 5% threshold.
   
   ### Was this patch authored or co-authored using generative AI tooling?
   
   No.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to