This is an automated email from the ASF dual-hosted git repository.

HyukjinKwon pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new 9b3f96be9e8e [SPARK-56781][PYTHON] Refactor SQL_GROUPED_AGG_PANDAS_UDF
9b3f96be9e8e is described below

commit 9b3f96be9e8ea26fac4f97817cf95bc3ed42b3f8
Author: Yicong Huang <[email protected]>
AuthorDate: Wed May 20 14:45:43 2026 -0700

    [SPARK-56781][PYTHON] Refactor SQL_GROUPED_AGG_PANDAS_UDF
    
    ### What changes were proposed in this pull request?
    
    Refactor `SQL_GROUPED_AGG_PANDAS_UDF` to use `ArrowStreamGroupSerializer` 
as a pure I/O layer, moving per-group pandas conversion and UDF invocation into 
`read_udfs()` in `worker.py`. The custom `ArrowStreamAggPandasUDFSerializer` is 
no longer used for this eval type (still used by 
`SQL_GROUPED_AGG_PANDAS_ITER_UDF` and `SQL_WINDOW_AGG_PANDAS_UDF`).
    
    **Side effect: 9-39% faster.** The refactor eliminates redundant per-batch 
work in the old path:
    
    | Per-group work    | Old path                     | New path               
           |
    
|-------------------|------------------------------|-----------------------------------|
    | `to_pandas()`     | **N** times (once per batch) | **1** time (on 
combined table)    |
    | `pd.concat`       | **(N-1) x num_cols** times   | **0**                  
           |
    | Arrow merge       | implicit via `pd.concat`     | `combine_chunks()` 
(zero-copy)    |
    
    ### Why are the changes needed?
    
    Part of SPARK-55388.
    
    ASV benchmark (`GroupedAggPandasUDFTimeBench`, min of 3 samples):
    
    ```text
    Scenario         UDF                Base       HEAD     Delta
    few_groups_sm    sum_udf            48.2ms     43.3ms   -10.2%
    few_groups_sm    mean_multi_udf     55.6ms     46.4ms   -16.5%
    few_groups_lg    sum_udf            88.5ms     71.2ms   -19.5%
    few_groups_lg    mean_multi_udf     92.0ms     83.8ms    -8.9%
    many_groups_sm   sum_udf          1816.8ms   1610.1ms   -11.4%
    many_groups_sm   mean_multi_udf   2208.4ms   1742.4ms   -21.1%
    many_groups_lg   sum_udf           640.3ms    405.8ms   -36.6%
    many_groups_lg   mean_multi_udf    797.4ms    485.2ms   -39.2%
    wide_cols        sum_udf           552.6ms    397.4ms   -28.1%
    wide_cols        mean_multi_udf    561.6ms    448.9ms   -20.1%
    ```
    
    Gains scale with group count: more groups means more per-group `to_pandas` 
/ `concat` overhead amortized away.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests. No behavior change.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #55808 from Yicong-Huang/SPARK-56781.
    
    Authored-by: Yicong Huang <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 64cff79cc2809c471274c77f3b033fd35d219d1e)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/worker.py | 73 +++++++++++++++++++++++++++++++-----------------
 1 file changed, 48 insertions(+), 25 deletions(-)

diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index ad4742bfb6c7..ccd81a169095 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -770,21 +770,6 @@ def wrap_grouped_map_pandas_udf_with_state(f, return_type, 
runner_conf):
     return lambda k, v, s: [(wrapped(k, v, s), return_type)]
 
 
-def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, 
runner_conf):
-    func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, 
kwargs_offsets)
-
-    def wrapped(*series):
-        import pandas as pd
-
-        result = func(*series)
-        return pd.Series([result])
-
-    return (
-        args_kwargs_offsets,
-        lambda *a: (wrapped(*a), return_type),
-    )
-
-
 def wrap_grouped_agg_pandas_iter_udf(f, args_offsets, kwargs_offsets, 
return_type, runner_conf):
     func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, 
kwargs_offsets)
 
@@ -1075,11 +1060,8 @@ def read_single_udf(pickleSer, udf_info, eval_type, 
runner_conf, udf_index):
     elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
         argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
         return func, args_offsets, return_type, len(argspec.args)
-    elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
-        return wrap_grouped_agg_pandas_udf(
-            func, args_offsets, kwargs_offsets, return_type, runner_conf
-        )
     elif eval_type in (
+        PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
         PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
         PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
     ):
@@ -2283,6 +2265,7 @@ def read_udfs(pickleSer, udf_info_list, eval_type, 
runner_conf, eval_conf):
             PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
             PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
             PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
+            PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
             PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
             PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
         ):
@@ -2290,7 +2273,6 @@ def read_udfs(pickleSer, udf_info_list, eval_type, 
runner_conf, eval_conf):
         elif eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF:
             ser = ArrowStreamGroupSerializer(write_start_stream=True)
         elif eval_type in (
-            PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
             PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF,
             PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
         ):
@@ -2616,6 +2598,50 @@ def read_udfs(pickleSer, udf_info_list, eval_type, 
runner_conf, eval_conf):
         # profiling is not supported for UDF
         return grouped_func, None, ser, ser
 
+    if eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
+        import pyarrow as pa
+        import pandas as pd
+
+        col_names = ["_%d" % i for i in range(len(udfs))]
+        output_schema = StructType(
+            [StructField(name, rt) for name, (_, _, _, rt) in zip(col_names, 
udfs)]
+        )
+
+        def grouped_func(
+            split_index: int, data: Iterator["GroupedBatch"]
+        ) -> Iterator[pa.RecordBatch]:
+            for group in data:
+                batch_list = list(group)
+                if not batch_list:
+                    continue
+                table = pa.Table.from_batches(batch_list).combine_chunks()
+                all_series = ArrowBatchTransformer.to_pandas(
+                    table,
+                    timezone=runner_conf.timezone,
+                    prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
+                )
+                results = [
+                    udf_func(
+                        *[all_series[o] for o in args_offsets],
+                        **{k: all_series[v] for k, v in 
kwargs_offsets.items()},
+                    )
+                    for udf_func, args_offsets, kwargs_offsets, _ in udfs
+                ]
+                result_series = [pd.Series([r]) for r in results]
+                yield PandasToArrowConversion.convert(
+                    result_series,
+                    output_schema,
+                    timezone=runner_conf.timezone,
+                    safecheck=runner_conf.safecheck,
+                    arrow_cast=True,
+                    prefers_large_types=runner_conf.use_large_var_types,
+                    assign_cols_by_name=runner_conf.assign_cols_by_name,
+                    
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
+                )
+
+        # profiling is not supported for UDF
+        return grouped_func, None, ser, ser
+
     if eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF:
         import pyarrow as pa
 
@@ -3523,13 +3549,10 @@ def read_udfs(pickleSer, udf_info_list, eval_type, 
runner_conf, eval_conf):
                 )
             return f(series_iter)
 
-    elif eval_type in (
-        PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
-        PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
-    ):
+    elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
         import pandas as pd
 
-        # For SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF,
+        # For SQL_WINDOW_AGG_PANDAS_UDF,
         # convert iterator of batch tuples to concatenated pandas Series
         def mapper(batch_iter):
             # batch_iter is Iterator[Tuple[pd.Series, ...]] where each tuple 
represents one batch


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to