This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 705a3a2f9f81 [SPARK-54153][PYTHON] Support profiling iterator based
Python UDFs
705a3a2f9f81 is described below
commit 705a3a2f9f81824fb983111d90d59f3c48e427a0
Author: Adam Binford <[email protected]>
AuthorDate: Sat Nov 8 07:45:14 2025 -0800
[SPARK-54153][PYTHON] Support profiling iterator based Python UDFs
### What changes were proposed in this pull request?
Updates the v2 Spark-session based Python UDF profiler to support profiling
iterator based UDFs.
```python
from collections.abc import Iterator
from pstats import SortKey
import pyarrow as pa
df = spark.range(100000)
def map_func(iter: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
for batch in iter:
yield
pa.RecordBatch.from_arrays([pa.compute.add(batch.column("id"), 10)], ["id"])
spark.conf.set('spark.sql.pyspark.udf.profiler', 'perf')
df.mapInArrow(map_func, df.schema).collect()
spark.conf.set('spark.sql.pyspark.udf.profiler', 'memory')
df.mapInArrow(map_func, df.schema).collect()
for stats in
spark.profile.profiler_collector._perf_profile_results.values():
stats.sort_stats(SortKey.CUMULATIVE).print_stats(20)
spark.profile.show(type="memory")
```
```
1395288 function calls (1359888 primitive calls) in 2.850 seconds
Ordered by: cumulative time
List reduced from 1546 to 20 due to restriction <20>
ncalls tottime percall cumtime percall filename:lineno(function)
416 0.008 0.000 5.901 0.014 __init__.py:1(<module>)
424/24 0.000 0.000 2.850 0.119 {built-in method
builtins.next}
24 0.001 0.000 2.850 0.119 test.py:11(map_func)
16 0.002 0.000 2.646 0.165 compute.py:244(wrapper)
2752/24 0.016 0.000 2.642 0.110 <frozen
importlib._bootstrap>:1349(_find_and_load)
2752/24 0.013 0.000 2.641 0.110 <frozen
importlib._bootstrap>:1304(_find_and_load_unlocked)
64 0.002 0.000 2.618 0.041 api.py:1(<module>)
2704/24 0.009 0.000 2.612 0.109 <frozen
importlib._bootstrap>:911(_load_unlocked)
2264/24 0.005 0.000 2.611 0.109 <frozen
importlib._bootstrap_external>:993(exec_module)
6336/48 0.004 0.000 2.591 0.054 <frozen
importlib._bootstrap>:480(_call_with_frames_removed)
2400/24 0.023 0.000 2.591 0.108 {built-in method
builtins.exec}
24 0.002 0.000 1.927 0.080 generic.py:1(<module>)
520/320 0.002 0.000 1.429 0.004 {built-in method
builtins.__import__}
4312/2896 0.006 0.000 1.190 0.000 <frozen
importlib._bootstrap>:1390(_handle_fromlist)
8 0.001 0.000 1.014 0.127 frame.py:1(<module>)
4392/4312 0.069 0.000 0.953 0.000 {built-in method
builtins.__build_class__}
2264 0.030 0.000 0.562 0.000 <frozen
importlib._bootstrap_external>:1066(get_code)
16 0.001 0.000 0.551 0.034 indexing.py:1(<module>)
24 0.001 0.000 0.451 0.019 datetimes.py:1(<module>)
16 0.001 0.000 0.401 0.025 datetimelike.py:1(<module>)
============================================================
Profile of UDF<id=3>
============================================================
Filename: /data/projects/spark/python/test.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
11 1212.1 MiB 1212.1 MiB 8 def map_func(iter:
Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
12 1212.1 MiB -0.2 MiB 24 for batch in iter:
13 1212.1 MiB -0.2 MiB 32 yield
pa.RecordBatch.from_arrays([pa.compute.add(batch.column("id"), 10)], ["id"])
```
### Why are the changes needed?
To add valuable profiling support to all types of UDFs.
### Does this PR introduce _any_ user-facing change?
Yes, iterator based Python UDFs can now be profiled with the SQL config
based profiler.
### How was this patch tested?
Updated UTs that were specifically testing that this wasn't supported to
show they are now supported.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52853 from Kimahriman/udf-iter-profiler.
Authored-by: Adam Binford <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
(cherry picked from commit 6a369f953b0d134ed5acabe731005b234484b7ca)
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/tests/test_udf_profiler.py | 88 ++++++++++++-----
python/pyspark/tests/test_memory_profiler.py | 16 ++--
python/pyspark/worker.py | 106 +++++++++++++++------
.../v2/python/UserDefinedPythonDataSource.scala | 3 +-
.../python/MapInBatchEvaluatorFactory.scala | 5 +-
.../sql/execution/python/MapInBatchExec.scala | 3 +-
6 files changed, 157 insertions(+), 64 deletions(-)
diff --git a/python/pyspark/sql/tests/test_udf_profiler.py
b/python/pyspark/sql/tests/test_udf_profiler.py
index 4e8f722c22cb..37f4a70fabd2 100644
--- a/python/pyspark/sql/tests/test_udf_profiler.py
+++ b/python/pyspark/sql/tests/test_udf_profiler.py
@@ -28,6 +28,7 @@ from typing import Iterator, cast
from pyspark import SparkConf
from pyspark.errors import PySparkValueError
from pyspark.sql import SparkSession
+from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.functions import col, arrow_udf, pandas_udf, udf
from pyspark.sql.window import Window
from pyspark.profiler import UDFBasicProfiler
@@ -325,59 +326,47 @@ class UDFProfiler2TestsMixin:
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
- def test_perf_profiler_pandas_udf_iterator_not_supported(self):
+ def test_perf_profiler_pandas_udf_iterator(self):
import pandas as pd
@pandas_udf("long")
- def add1(x):
- return x + 1
-
- @pandas_udf("long")
- def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
+ def add(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
for s in iter:
- yield s + 2
+ yield s + 1
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
- df = self.spark.range(10, numPartitions=2).select(
- add1("id"), add2("id"), add1("id"), add2(col("id") + 1)
- )
+ df = self.spark.range(10, numPartitions=2).select(add("id"))
df.collect()
self.assertEqual(1, len(self.profile_results),
str(self.profile_results.keys()))
for id in self.profile_results:
- self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=2)
+ self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=4)
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
- def test_perf_profiler_arrow_udf_iterator_not_supported(self):
+ def test_perf_profiler_arrow_udf_iterator(self):
import pyarrow as pa
@arrow_udf("long")
- def add1(x):
- return pa.compute.add(x, 1)
-
- @arrow_udf("long")
- def add2(iter: Iterator[pa.Array]) -> Iterator[pa.Array]:
+ def add(iter: Iterator[pa.Array]) -> Iterator[pa.Array]:
for s in iter:
- yield pa.compute.add(s, 2)
+ yield pa.compute.add(s, 1)
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
- df = self.spark.range(10, numPartitions=2).select(
- add1("id"), add2("id"), add1("id"), add2(col("id") + 1)
- )
+ df = self.spark.range(10, numPartitions=2).select(add("id"))
df.collect()
self.assertEqual(1, len(self.profile_results),
str(self.profile_results.keys()))
for id in self.profile_results:
- self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=2)
+ self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=4)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
- def test_perf_profiler_map_in_pandas_not_supported(self):
- df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
+ def test_perf_profiler_map_in_pandas(self):
+ df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id",
"age")).repartition(1)
def filter_func(iterator):
for pdf in iterator:
@@ -386,7 +375,28 @@ class UDFProfiler2TestsMixin:
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
df.mapInPandas(filter_func, df.schema).show()
- self.assertEqual(0, len(self.profile_results),
str(self.profile_results.keys()))
+ self.assertEqual(1, len(self.profile_results),
str(self.profile_results.keys()))
+
+ for id in self.profile_results:
+ self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=2)
+
+ @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
+ def test_perf_profiler_map_in_arrow(self):
+ import pyarrow as pa
+
+ df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id",
"age")).repartition(1)
+
+ def map_func(iterator: Iterator[pa.RecordBatch]) ->
Iterator[pa.RecordBatch]:
+ for batch in iterator:
+ yield pa.RecordBatch.from_arrays(
+ [batch.column("id"), pa.compute.add(batch.column("age"),
1)], ["id", "age"]
+ )
+
+ with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+ df.mapInArrow(map_func, df.schema).show()
+
+ for id in self.profile_results:
+ self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=2)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
@@ -575,6 +585,34 @@ class UDFProfiler2TestsMixin:
for id in self.profile_results:
self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=2)
+ def test_perf_profiler_data_source(self):
+ class TestDataSourceReader(DataSourceReader):
+ def __init__(self, schema):
+ self.schema = schema
+
+ def partitions(self):
+ raise NotImplementedError
+
+ def read(self, partition):
+ yield from ((1,), (2,), (3,))
+
+ class TestDataSource(DataSource):
+ def schema(self):
+ return "id long"
+
+ def reader(self, schema) -> "DataSourceReader":
+ return TestDataSourceReader(schema)
+
+ self.spark.dataSource.register(TestDataSource)
+
+ with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+ self.spark.read.format("TestDataSource").load().collect()
+
+ self.assertEqual(1, len(self.profile_results),
str(self.profile_results.keys()))
+
+ for id in self.profile_results:
+ self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=4)
+
def test_perf_profiler_render(self):
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
_do_computation(self.spark)
diff --git a/python/pyspark/tests/test_memory_profiler.py
b/python/pyspark/tests/test_memory_profiler.py
index df9d63c5260f..1909358aa2bc 100644
--- a/python/pyspark/tests/test_memory_profiler.py
+++ b/python/pyspark/tests/test_memory_profiler.py
@@ -341,12 +341,13 @@ class MemoryProfiler2TestsMixin:
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
- def test_memory_profiler_pandas_udf_iterator_not_supported(self):
+ def test_memory_profiler_pandas_udf_iterator(self):
import pandas as pd
@pandas_udf("long")
- def add1(x):
- return x + 1
+ def add1(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
+ for s in iter:
+ yield s + 1
@pandas_udf("long")
def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
@@ -359,7 +360,7 @@ class MemoryProfiler2TestsMixin:
)
df.collect()
- self.assertEqual(1, len(self.profile_results),
str(self.profile_results.keys()))
+ self.assertEqual(3, len(self.profile_results),
str(self.profile_results.keys()))
for id in self.profile_results:
self.assert_udf_memory_profile_present(udf_id=id)
@@ -368,7 +369,7 @@ class MemoryProfiler2TestsMixin:
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
- def test_memory_profiler_map_in_pandas_not_supported(self):
+ def test_memory_profiler_map_in_pandas(self):
df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
def filter_func(iterator):
@@ -378,7 +379,10 @@ class MemoryProfiler2TestsMixin:
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "memory"}):
df.mapInPandas(filter_func, df.schema).show()
- self.assertEqual(0, len(self.profile_results),
str(self.profile_results.keys()))
+ self.assertEqual(1, len(self.profile_results),
str(self.profile_results.keys()))
+
+ for id in self.profile_results:
+ self.assert_udf_memory_profile_present(udf_id=id)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 09c6a40a33db..6e34b041665a 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1158,17 +1158,19 @@ def wrap_kwargs_support(f, args_offsets,
kwargs_offsets):
return f, args_offsets
-def _supports_profiler(eval_type: int) -> bool:
- return eval_type not in (
+def _is_iter_based(eval_type: int) -> bool:
+ return eval_type in (
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+ PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
)
-def wrap_perf_profiler(f, result_id):
+def wrap_perf_profiler(f, eval_type, result_id):
import cProfile
import pstats
@@ -1178,38 +1180,89 @@ def wrap_perf_profiler(f, result_id):
SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam
)
- def profiling_func(*args, **kwargs):
- with cProfile.Profile() as pr:
- ret = f(*args, **kwargs)
- st = pstats.Stats(pr)
- st.stream = None # make it picklable
- st.strip_dirs()
+ if _is_iter_based(eval_type):
+
+ def profiling_func(*args, **kwargs):
+ iterator = iter(f(*args, **kwargs))
+ pr = cProfile.Profile()
+ while True:
+ try:
+ with pr:
+ item = next(iterator)
+ yield item
+ except StopIteration:
+ break
+
+ st = pstats.Stats(pr)
+ st.stream = None # make it picklable
+ st.strip_dirs()
+
+ accumulator.add({result_id: (st, None)})
- accumulator.add({result_id: (st, None)})
+ else:
+
+ def profiling_func(*args, **kwargs):
+ with cProfile.Profile() as pr:
+ ret = f(*args, **kwargs)
+ st = pstats.Stats(pr)
+ st.stream = None # make it picklable
+ st.strip_dirs()
- return ret
+ accumulator.add({result_id: (st, None)})
+
+ return ret
return profiling_func
-def wrap_memory_profiler(f, result_id):
+def wrap_memory_profiler(f, eval_type, result_id):
from pyspark.sql.profiler import ProfileResultsParam
from pyspark.profiler import UDFLineProfilerV2
+ if not has_memory_profiler:
+ return f
+
accumulator = _deserialize_accumulator(
SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam
)
- def profiling_func(*args, **kwargs):
- profiler = UDFLineProfilerV2()
+ if _is_iter_based(eval_type):
- wrapped = profiler(f)
- ret = wrapped(*args, **kwargs)
- codemap_dict = {
- filename: list(line_iterator) for filename, line_iterator in
profiler.code_map.items()
- }
- accumulator.add({result_id: (None, codemap_dict)})
- return ret
+ def profiling_func(*args, **kwargs):
+ profiler = UDFLineProfilerV2()
+ profiler.add_function(f)
+
+ iterator = iter(f(*args, **kwargs))
+
+ while True:
+ try:
+ with profiler:
+ item = next(iterator)
+ yield item
+ except StopIteration:
+ break
+
+ codemap_dict = {
+ filename: list(line_iterator)
+ for filename, line_iterator in profiler.code_map.items()
+ }
+ accumulator.add({result_id: (None, codemap_dict)})
+
+ else:
+
+ def profiling_func(*args, **kwargs):
+ profiler = UDFLineProfilerV2()
+ profiler.add_function(f)
+
+ with profiler:
+ ret = f(*args, **kwargs)
+
+ codemap_dict = {
+ filename: list(line_iterator)
+ for filename, line_iterator in profiler.code_map.items()
+ }
+ accumulator.add({result_id: (None, codemap_dict)})
+ return ret
return profiling_func
@@ -1254,17 +1307,12 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index, profil
if profiler == "perf":
result_id = read_long(infile)
- if _supports_profiler(eval_type):
- profiling_func = wrap_perf_profiler(chained_func, result_id)
- else:
- profiling_func = chained_func
+ profiling_func = wrap_perf_profiler(chained_func, eval_type, result_id)
elif profiler == "memory":
result_id = read_long(infile)
- if _supports_profiler(eval_type) and has_memory_profiler:
- profiling_func = wrap_memory_profiler(chained_func, result_id)
- else:
- profiling_func = chained_func
+
+ profiling_func = wrap_memory_profiler(chained_func, eval_type,
result_id)
else:
profiling_func = chained_func
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index c147030037cd..47e64a5b4041 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -172,7 +172,8 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
pythonRunnerConf,
metrics,
jobArtifactUUID,
- sessionUUID)
+ sessionUUID,
+ conf.pythonUDFProfiler)
}
def createPythonMetrics(): Array[CustomMetric] = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 4e78b3035a7e..51909df26a56 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -41,7 +41,8 @@ class MapInBatchEvaluatorFactory(
pythonRunnerConf: Map[String, String],
val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
- sessionUUID: Option[String])
+ sessionUUID: Option[String],
+ profiler: Option[String])
extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow]
=
@@ -74,7 +75,7 @@ class MapInBatchEvaluatorFactory(
pythonMetrics,
jobArtifactUUID,
sessionUUID,
- None) with BatchedPythonArrowInput
+ profiler) with BatchedPythonArrowInput
val columnarBatchIter = pyRunner.compute(batchIter,
context.partitionId(), context)
val unsafeProj = UnsafeProjection.create(output, output)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 1d03c0cf7603..c4f090674e7c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -70,7 +70,8 @@ trait MapInBatchExec extends UnaryExecNode with
PythonSQLMetrics {
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
- sessionUUID)
+ sessionUUID,
+ conf.pythonUDFProfiler)
val rdd = if (isBarrier) {
val rddBarrier = child.execute().barrier()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]