[
https://issues.apache.org/jira/browse/SPARK-53615?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
Ruifeng Zheng resolved SPARK-53615.
-----------------------------------
Fix Version/s: 4.2.0
Resolution: Fixed
Issue resolved by pull request 53035
[https://github.com/apache/spark/pull/53035]
> Introduce iterator API for arrow grouped agg UDF
> ------------------------------------------------
>
> Key: SPARK-53615
> URL: https://issues.apache.org/jira/browse/SPARK-53615
> Project: Spark
> Issue Type: Sub-task
> Components: Connect, PySpark
> Affects Versions: 4.1.0
> Reporter: Ruifeng Zheng
> Assignee: Yicong Huang
> Priority: Major
> Labels: pull-request-available
> Fix For: 4.2.0
>
>
> for single column
> {code:java}
> import pyarrow as pa
> @arrow_udf("double")
> def arrow_mean(it: Iterator[pa.Array]) -> float:
> sum = 0.0
> cnt = 0
> for v in it:
> assert isinstance(v, pa.Array)
> sum += pa.compute.sum(v).as_py()
> cnt += len(v)
>
> return weighted_sum / cnt
> {code}
>
> for multiple columns
> {code:java}
> import pyarrow as pa
> import numpy as np
> @arrow_udf("double")
> def arrow_weighted_mean(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
> weighted_sum = 0.0
> weight = 0.0
> for v, w in it:
> assert isinstance(v, pa.Array)
> assert isinstance(w, pa.Array)
> weighted_sum += np.dot(v, w)
> weight += pa.compute.sum(w)
>
> return weighted_sum / weight{code}
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]