[ 
https://issues.apache.org/jira/browse/SPARK-53615?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Ruifeng Zheng resolved SPARK-53615.
-----------------------------------
    Fix Version/s: 4.2.0
       Resolution: Fixed

Issue resolved by pull request 53035
[https://github.com/apache/spark/pull/53035]

> Introduce iterator API for arrow grouped agg UDF
> ------------------------------------------------
>
>                 Key: SPARK-53615
>                 URL: https://issues.apache.org/jira/browse/SPARK-53615
>             Project: Spark
>          Issue Type: Sub-task
>          Components: Connect, PySpark
>    Affects Versions: 4.1.0
>            Reporter: Ruifeng Zheng
>            Assignee: Yicong Huang
>            Priority: Major
>              Labels: pull-request-available
>             Fix For: 4.2.0
>
>
> for single column
> {code:java}
> import pyarrow as pa
> @arrow_udf("double")
> def arrow_mean(it: Iterator[pa.Array]) -> float:
>     sum = 0.0
>     cnt = 0
>     for v in it:
>         assert isinstance(v, pa.Array)
>         sum += pa.compute.sum(v).as_py()
>         cnt += len(v)
>         
>     return weighted_sum / cnt
>  {code}
>  
> for multiple columns
> {code:java}
> import pyarrow as pa
> import numpy as np
> @arrow_udf("double")
> def arrow_weighted_mean(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
>     weighted_sum = 0.0
>     weight = 0.0
>     for v, w in it:
>         assert isinstance(v, pa.Array)
>         assert isinstance(w, pa.Array)
>         weighted_sum += np.dot(v, w)
>         weight += pa.compute.sum(w)
>         
>     return weighted_sum / weight{code}



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to