This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 71c67b079087 [SPARK-53641][DOCS] Add PARTITION BY support in Arrow
Python UDTF docs
71c67b079087 is described below
commit 71c67b079087d249a526ee45c9944965661a6ca5
Author: Allison Wang <[email protected]>
AuthorDate: Mon Sep 22 10:07:13 2025 +0800
[SPARK-53641][DOCS] Add PARTITION BY support in Arrow Python UDTF docs
### What changes were proposed in this pull request?
This PR adds PARTITION BY and ORDER BY semantics in Arrow Python UDTF docs.
It also add more examples for Arrow Python UDTFs.
### Why are the changes needed?
To improve documentations
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
manual tests
### Was this patch authored or co-authored using generative AI tooling?
Yes
Closes #52392 from allisonwang-db/spark-53641-udtf-partition-by-docs.
Authored-by: Allison Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../docs/source/tutorial/sql/arrow_python_udtf.rst | 418 +++++++++++++++++++++
python/pyspark/sql/tests/arrow/test_arrow_udtf.py | 93 ++++-
2 files changed, 495 insertions(+), 16 deletions(-)
diff --git a/python/docs/source/tutorial/sql/arrow_python_udtf.rst
b/python/docs/source/tutorial/sql/arrow_python_udtf.rst
index 3e933fea722f..a5f87ad9161c 100644
--- a/python/docs/source/tutorial/sql/arrow_python_udtf.rst
+++ b/python/docs/source/tutorial/sql/arrow_python_udtf.rst
@@ -227,6 +227,21 @@ Here's how to use these UDTFs in DataFrame:
df = spark.range(10).selectExpr("id", "cast(id as string) as value")
MyArrowPythonUDTF(df.asTable()).show()
+ # Result:
+ # +---+
+ # | c1|
+ # +---+
+ # | 0|
+ # | 1|
+ # | 2|
+ # | 3|
+ # | 4|
+ # | 5|
+ # | 6|
+ # | 7|
+ # | 8|
+ # | 9|
+ # +---+
# Register the UDTF
spark.udtf.register("my_arrow_udtf", MyArrowPythonUDTF)
@@ -235,3 +250,406 @@ Here's how to use these UDTFs in DataFrame:
df = spark.sql("""
SELECT * FROM my_arrow_udtf(TABLE(SELECT id, cast(id as string) as
value FROM range(10)))
""")
+
+
+TABLE Argument
+--------------
+
+Arrow UDTFs can take a TABLE argument. When your UDTF receives a TABLE
argument,
+its ``eval`` method is called with a ``pyarrow.RecordBatch`` containing the
input
+table’s columns, and any additional scalar/struct expressions are passed as
+``pyarrow.Array`` values.
+
+Key points:
+- The TABLE argument is a single ``pa.RecordBatch``; access columns by name or
index.
+- Scalar arguments (including structs) are ``pa.Array`` values, not
``RecordBatch``.
+- Named and positional arguments are both supported in SQL.
+
+Example (DataFrame API):
+
+.. code-block:: python
+
+ import pyarrow as pa
+ import pyarrow.compute as pc
+ from typing import Iterator, Optional
+ from pyspark.sql.functions import arrow_udtf, SkipRestOfInputTableException
+
+ @arrow_udtf(returnType="value int")
+ class EchoTable:
+ def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]:
+ # Return the input column named "value" as-is
+ yield pa.table({"value": batch.column("value")})
+
+ df = spark.range(5).selectExpr("id as value")
+ EchoTable(df.asTable()).show()
+
+ # Result:
+ # +-----+
+ # |value|
+ # +-----+
+ # | 0|
+ # | 1|
+ # | 2|
+ # | 3|
+ # | 4|
+ # +-----+
+
+Example (SQL): TABLE plus a scalar threshold
+
+.. code-block:: python
+
+ import pyarrow as pa
+ import pyarrow.compute as pc
+ from typing import Iterator
+ from pyspark.sql.functions import arrow_udtf
+
+ # Keep rows with value > threshold; works with SQL using TABLE + scalar
argument
+ @arrow_udtf(returnType="partition_key int, value int")
+ class ThresholdFilter:
+ def eval(self, batch: pa.RecordBatch, threshold: pa.Array) ->
Iterator[pa.Table]:
+ tbl = pa.table(batch)
+ thr = int(threshold.cast(pa.int64())[0].as_py())
+ mask = pc.greater(tbl["value"], thr)
+ yield tbl.filter(mask)
+
+ spark.udtf.register("threshold_filter", ThresholdFilter)
+ spark.createDataFrame([(1, 10), (1, 30), (2, 5)], "partition_key int,
value int").createOrReplaceTempView("v")
+
+ spark.sql(
+ """
+ SELECT *
+ FROM threshold_filter(
+ TABLE(v),
+ 10
+ )
+ ORDER BY partition_key, value
+ """
+ ).show()
+
+ # Result:
+ # +-------------+-----+
+ # |partition_key|value|
+ # +-------------+-----+
+ # | 1| 30|
+ # +-------------+-----+
+
+
+PARTITION BY and ORDER BY
+-------------------------
+
+Arrow UDTFs support ``TABLE(...) PARTITION BY ... ORDER BY ...``. Think of it
as
+“process rows group by group, and in a specific order within each group”.
+
+Semantics:
+
+- PARTITION BY groups rows by the given keys; your UDTF runs for each group
independently.
+- ORDER BY controls the row order within each group as seen by ``eval``.
+- ``eval`` may be called multiple times per group; accumulate state and
typically emit the group's result in ``terminate``.
+
+Example: Aggregation per key with terminate
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+PARTITION BY is especially useful for per-group aggregation. ``eval`` may be
called
+multiple times for the same group as rows arrive in batches, so keep running
totals in
+the UDTF instance and emit the final row in ``terminate``.
+
+.. code-block:: python
+
+ import pyarrow as pa
+ import pyarrow.compute as pc
+ from typing import Iterator
+ from pyspark.sql.functions import arrow_udtf
+
+ @arrow_udtf(returnType="user_id int, total_amount int, rows int")
+ class SumPerUser:
+ def __init__(self):
+ self._user = None
+ self._sum = 0
+ self._count = 0
+
+ def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]:
+ tbl = pa.table(batch)
+ # All rows in this batch belong to the same user within a partition
+ self._user = pc.unique(tbl["user_id"]).to_pylist()[0]
+ self._sum += pc.sum(tbl["amount"]).as_py()
+ self._count += tbl.num_rows
+ return iter(()) # emit once in terminate
+
+ def terminate(self) -> Iterator[pa.Table]:
+ if self._user is not None:
+ yield pa.table({
+ "user_id": pa.array([self._user], pa.int32()),
+ "total_amount": pa.array([self._sum], pa.int32()),
+ "rows": pa.array([self._count], pa.int32()),
+ })
+
+ spark.udtf.register("sum_per_user", SumPerUser)
+ spark.createDataFrame(
+ [(1, 10), (2, 5), (1, 20), (2, 15), (3, 7)],
+ "user_id int, amount int",
+ ).createOrReplaceTempView("purchases")
+
+ spark.sql(
+ """
+ SELECT *
+ FROM sum_per_user(
+ TABLE(purchases)
+ PARTITION BY user_id
+ )
+ ORDER BY user_id
+ """
+ ).show()
+
+ # Result:
+ # +-------+------------+----+
+ # |user_id|total_amount|rows|
+ # +-------+------------+----+
+ # | 1| 30| 2|
+ # | 2| 20| 2|
+ # | 3| 7| 1|
+ # +-------+------------+----+
+
+Why terminate? ``eval`` may run multiple times per group if the input is split
into
+several batches. Emitting the aggregated row in ``terminate`` guarantees
exactly one
+output row per group after all its rows have been processed.
+
+Example: Top reviews per product using ORDER BY
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: python
+
+ import pyarrow as pa
+ import pyarrow.compute as pc
+ from typing import Iterator, Optional
+ from pyspark.sql.functions import arrow_udtf, SkipRestOfInputTableException
+
+ @arrow_udtf(returnType="product_id int, review_id int, rating int, review
string")
+ class TopReviewsPerProduct:
+ TOP_K = 3
+
+ def __init__(self):
+ self._product = None
+ self._seen = 0
+ self._batches: list[pa.Table] = []
+ self._top: Optional[pa.Table] = None
+
+ def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]:
+ tbl = pa.table(batch)
+ if tbl.num_rows == 0:
+ return iter(())
+
+ products = pc.unique(tbl["product_id"]).to_pylist()
+ assert len(products) == 1, f"Expected one product per batch, saw
{products}"
+ product = products[0]
+
+ if self._product is None:
+ self._product = product
+ else:
+ assert self._product == product, f"Mixed products
{self._product} and {product}"
+
+ self._batches.append(tbl)
+ self._seen += tbl.num_rows
+
+ if self._seen >= self.TOP_K and self._top is None:
+ combined = pa.concat_tables(self._batches)
+ self._top = combined.slice(0, self.TOP_K)
+ raise SkipRestOfInputTableException(
+ f"Collected top {self.TOP_K} reviews for product
{self._product}"
+ )
+
+ return iter(())
+
+ def terminate(self) -> Iterator[pa.Table]:
+ if self._product is None:
+ return iter(())
+
+ if self._top is None:
+ combined = pa.concat_tables(self._batches) if self._batches
else pa.table({})
+ limit = min(self.TOP_K, self._seen)
+ self._top = combined.slice(0, limit)
+
+ yield self._top
+
+ spark.udtf.register("top_reviews_per_product", TopReviewsPerProduct)
+ spark.createDataFrame(
+ [
+ (101, 1, 5, "Amazing battery life"),
+ (101, 2, 5, "Still great after a month"),
+ (101, 3, 4, "Solid build"),
+ (101, 4, 3, "Average sound"),
+ (202, 5, 5, "My go-to lens"),
+ (202, 6, 4, "Sharp and bright"),
+ (202, 7, 4, "Great value"),
+ ],
+ "product_id int, review_id int, rating int, review string",
+ ).createOrReplaceTempView("reviews")
+
+ spark.sql(
+ """
+ SELECT *
+ FROM top_reviews_per_product(
+ TABLE(reviews)
+ PARTITION BY (product_id)
+ ORDER BY (rating DESC, review_id)
+ )
+ ORDER BY product_id, rating DESC, review_id
+ """
+ ).show()
+
+ # Result:
+ # +----------+---------+------+--------------------------+
+ # |product_id|review_id|rating|review |
+ # +----------+---------+------+--------------------------+
+ # | 101| 1| 5|Amazing battery life |
+ # | 101| 2| 5|Still great after a month |
+ # | 101| 3| 4|Solid build |
+ # | 202| 5| 5|My go-to lens |
+ # | 202| 6| 4|Sharp and bright |
+ # | 202| 7| 4|Great value |
+ # +----------+---------+------+--------------------------+
+
+
+Best Practices
+--------------
+- Stream work from :py:meth:`eval` when possible. Yielding one ``pa.Table``
per Arrow batch keeps
+ memory bounded and shortens feedback loops; reserve :py:meth:`terminate` for
true per-partition
+ operations.
+- Keep per-partition state tiny and reset it promptly. If you only need the
first *N* rows, raise
+ :py:class:`~pyspark.sql.functions.SkipRestOfInputTableException` after
collecting them so Spark
+ skips the rest of the partition.
+- Guard external calls with short timeouts and operate on the current batch
instead of deferring to
+ ``terminate``; this avoids giant buffers and keeps retries narrow.
+
+
+When to use Arrow UDTFs vs Other UDTFs
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+- Prefer ``arrow_udtf`` when the logic is naturally vectorised, you can stay
within Python, and the
+ input/output schema is Arrow-friendly. You gain batch-friendly
+ performance and native interoperability with PySpark DataFrames.
+- Stick with the classic (row-based) Python UDTF when you only need simple
per-row expansion, or when
+ your logic depends on Python objects that Arrow cannot represent cleanly.
+- Use SQL UDTFs if the functionality is performance critical and the logic can
be represented in SQL.
+
+
+More Examples
+-------------
+
+Example: Simple anomaly detection per device
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Compute simple per-device stats and return them from ``terminate``; this
pattern is
+useful for anomaly detection workflows that first summarize distributions by
key.
+
+.. code-block:: python
+
+ import pyarrow as pa
+ import pyarrow.compute as pc
+ from typing import Iterator
+ from pyspark.sql.functions import arrow_udtf
+
+ @arrow_udtf(returnType="device_id int, count int, mean double, stddev
double, max_value int")
+ class DeviceStats:
+ def __init__(self):
+ self._device = None
+ self._count = 0
+ self._sum = 0.0
+ self._sumsq = 0.0
+ self._max = None
+
+ def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]:
+ tbl = pa.table(batch)
+ self._device = pc.unique(tbl["device_id"]).to_pylist()[0]
+ vals = tbl["reading"].cast(pa.float64())
+ self._count += len(vals)
+ self._sum += pc.sum(vals).as_py() or 0.0
+ self._sumsq += pc.sum(pc.multiply(vals, vals)).as_py() or 0.0
+ cur_max = pc.max(vals).as_py()
+ self._max = cur_max if self._max is None else max(self._max,
cur_max)
+ return iter(())
+
+ def terminate(self) -> Iterator[pa.Table]:
+ if self._device is not None and self._count > 0:
+ mean = self._sum / self._count
+ var = max(self._sumsq / self._count - mean * mean, 0.0)
+ std = var ** 0.5
+ # Round to 2 decimal places for display
+ mean_rounded = round(mean, 2)
+ std_rounded = round(std, 2)
+ yield pa.table({
+ "device_id": pa.array([self._device], pa.int32()),
+ "count": pa.array([self._count], pa.int32()),
+ "mean": pa.array([mean_rounded], pa.float64()),
+ "stddev": pa.array([std_rounded], pa.float64()),
+ "max_value": pa.array([int(self._max)], pa.int32()),
+ })
+
+ spark.udtf.register("device_stats", DeviceStats)
+ spark.createDataFrame(
+ [(1, 10), (1, 12), (1, 100), (2, 5), (2, 7)],
+ "device_id int, reading int",
+ ).createOrReplaceTempView("readings")
+
+ spark.sql(
+ """
+ SELECT *
+ FROM device_stats(
+ TABLE(readings)
+ PARTITION BY device_id
+ )
+ ORDER BY device_id
+ """
+ ).show()
+
+ # Result:
+ # +---------+-----+-----+------+---------+
+ # |device_id|count| mean|stddev|max_value|
+ # +---------+-----+-----+------+---------+
+ # | 1| 3|40.67| 41.96| 100|
+ # | 2| 2| 6.0| 1.0| 7|
+ # +---------+-----+-----+------+---------+
+
+
+Example: Arrow UDTFs as RDD map-style transforms
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Arrow UDTFs can replace many ``RDD.map``/``flatMap``-style transforms with
better
+performance and first-class SQL integration. Instead of mapping row-by-row in
Python,
+you work on Arrow batches and return a table.
+
+Example: tokenize text into words (flatMap-like)
+
+.. code-block:: python
+
+ import pyarrow as pa
+ import pyarrow.compute as pc
+ from typing import Iterator
+ from pyspark.sql.functions import arrow_udtf
+
+ @arrow_udtf(returnType="doc_id int, word string")
+ class Tokenize:
+ def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]:
+ tbl = pa.table(batch)
+ # Split on whitespace; build flat arrays for (doc_id, word)
+ doc_ids: list[int] = []
+ words: list[str] = []
+ for doc_id, text in zip(tbl["doc_id"].to_pylist(),
tbl["text"].to_pylist()):
+ for w in (text or "").split():
+ doc_ids.append(doc_id)
+ words.append(w)
+ if doc_ids:
+ yield pa.table({"doc_id": pa.array(doc_ids, pa.int32()),
"word": pa.array(words)})
+
+ spark.udtf.register("tokenize", Tokenize)
+ spark.createDataFrame([(1, "spark is fast"), (2, "arrow udtf")], "doc_id
int, text string").createOrReplaceTempView("docs")
+ spark.sql("SELECT * FROM tokenize(TABLE(docs)) ORDER BY doc_id,
word").show()
+
+ # Result:
+ # +------+-----+
+ # |doc_id| word|
+ # +------+-----+
+ # | 1| fast|
+ # | 1| is|
+ # | 1|spark|
+ # | 2|arrow|
+ # | 2| udtf|
+ # +------+-----+
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
index 08d09f82dbe5..006084e88cd6 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
import unittest
-from typing import Iterator
+from typing import Iterator, Optional
from pyspark.errors import PySparkAttributeError
from pyspark.errors import PythonException
@@ -996,35 +996,96 @@ class ArrowUDTFTestsMixin:
assertDataFrameEqual(result_df, expected_df)
def test_arrow_udtf_partition_by_all_columns(self):
- @arrow_udtf(returnType="rows int")
- class CountRowsUDTF:
+ from pyspark.sql.functions import SkipRestOfInputTableException
+
+ @arrow_udtf(returnType="product_id int, review_id int, rating int,
review string")
+ class TopReviewsPerProduct:
+ TOP_K = 3
+
def __init__(self):
- self._rows = 0
+ self._product = None
+ self._seen = 0
+ self._batches: list[pa.Table] = []
+ self._top_table: Optional[pa.Table] = None
def eval(self, table_data: "pa.RecordBatch") ->
Iterator["pa.Table"]:
+ import pyarrow.compute as pc
+
table = pa.table(table_data)
- self._rows += table.num_rows
+ if table.num_rows == 0:
+ return iter(())
+
+ products = pc.unique(table["product_id"]).to_pylist()
+ assert len(products) == 1, f"Expected one product, saw
{products}"
+ product = products[0]
+
+ if self._product is None:
+ self._product = product
+ else:
+ assert (
+ self._product == product
+ ), f"Mixed products {self._product} and {product} in
partition"
+
+ self._batches.append(table)
+ self._seen += table.num_rows
+
+ if self._seen >= self.TOP_K and self._top_table is None:
+ combined = pa.concat_tables(self._batches)
+ self._top_table = combined.slice(0, self.TOP_K)
+ raise SkipRestOfInputTableException(
+ f"Top {self.TOP_K} reviews ready for product
{self._product}"
+ )
+
return iter(())
def terminate(self) -> Iterator["pa.Table"]:
- result_table = pa.table({"rows": pa.array([self._rows],
type=pa.int32())})
- yield result_table
-
- df = self.spark.createDataFrame([(1, 10), (1, 20), (2, 30)],
"partition_key int, value int")
- self.spark.udtf.register("count_rows_udtf", CountRowsUDTF)
- df.createOrReplaceTempView("partition_all_columns")
+ if self._product is None:
+ return iter(())
+
+ if self._top_table is None:
+ combined = pa.concat_tables(self._batches) if
self._batches else pa.table({})
+ limit = min(self.TOP_K, self._seen)
+ self._top_table = combined.slice(0, limit)
+
+ yield self._top_table
+
+ review_data = [
+ (101, 1, 5, "Amazing battery life"),
+ (101, 2, 5, "Still great after a month"),
+ (101, 3, 4, "Solid build"),
+ (101, 4, 3, "Average sound"),
+ (202, 5, 5, "My go-to lens"),
+ (202, 6, 4, "Sharp and bright"),
+ (202, 7, 4, "Great value"),
+ ]
+ df = self.spark.createDataFrame(
+ review_data, "product_id int, review_id int, rating int, review
string"
+ )
+ self.spark.udtf.register("top_reviews_udtf", TopReviewsPerProduct)
+ df.createOrReplaceTempView("product_reviews")
result_df = self.spark.sql(
"""
- SELECT * FROM count_rows_udtf(
- TABLE(partition_all_columns)
- PARTITION BY (partition_key, value)
+ SELECT * FROM top_reviews_udtf(
+ TABLE(product_reviews)
+ PARTITION BY (product_id)
+ ORDER BY (rating DESC, review_id)
)
- ORDER BY rows
+ ORDER BY product_id, rating DESC, review_id
"""
)
- expected_df = self.spark.createDataFrame([(1,), (1,), (1,)], "rows
int")
+ expected_df = self.spark.createDataFrame(
+ [
+ (101, 1, 5, "Amazing battery life"),
+ (101, 2, 5, "Still great after a month"),
+ (101, 3, 4, "Solid build"),
+ (202, 5, 5, "My go-to lens"),
+ (202, 6, 4, "Sharp and bright"),
+ (202, 7, 4, "Great value"),
+ ],
+ "product_id int, review_id int, rating int, review string",
+ )
assertDataFrameEqual(result_df, expected_df)
def
test_arrow_udtf_partition_by_single_partition_multiple_input_partitions(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]