This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new 705a3a2f9f81 [SPARK-54153][PYTHON] Support profiling iterator based 
Python UDFs
705a3a2f9f81 is described below

commit 705a3a2f9f81824fb983111d90d59f3c48e427a0
Author: Adam Binford <[email protected]>
AuthorDate: Sat Nov 8 07:45:14 2025 -0800

    [SPARK-54153][PYTHON] Support profiling iterator based Python UDFs
    
    ### What changes were proposed in this pull request?
    
    Updates the v2 Spark-session based Python UDF profiler to support profiling 
iterator based UDFs.
    
    ```python
    from collections.abc import Iterator
    from pstats import SortKey
    import pyarrow as pa
    
    df = spark.range(100000)
    
    def map_func(iter: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
        for batch in iter:
            yield 
pa.RecordBatch.from_arrays([pa.compute.add(batch.column("id"), 10)], ["id"])
    
    spark.conf.set('spark.sql.pyspark.udf.profiler', 'perf')
    df.mapInArrow(map_func, df.schema).collect()
    
    spark.conf.set('spark.sql.pyspark.udf.profiler', 'memory')
    df.mapInArrow(map_func, df.schema).collect()
    
    for stats in 
spark.profile.profiler_collector._perf_profile_results.values():
        stats.sort_stats(SortKey.CUMULATIVE).print_stats(20)
    
    spark.profile.show(type="memory")
    ```
    
    ```
    1395288 function calls (1359888 primitive calls) in 2.850 seconds
    
       Ordered by: cumulative time
       List reduced from 1546 to 20 due to restriction <20>
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
          416    0.008    0.000    5.901    0.014 __init__.py:1(<module>)
       424/24    0.000    0.000    2.850    0.119 {built-in method 
builtins.next}
           24    0.001    0.000    2.850    0.119 test.py:11(map_func)
           16    0.002    0.000    2.646    0.165 compute.py:244(wrapper)
      2752/24    0.016    0.000    2.642    0.110 <frozen 
importlib._bootstrap>:1349(_find_and_load)
      2752/24    0.013    0.000    2.641    0.110 <frozen 
importlib._bootstrap>:1304(_find_and_load_unlocked)
           64    0.002    0.000    2.618    0.041 api.py:1(<module>)
      2704/24    0.009    0.000    2.612    0.109 <frozen 
importlib._bootstrap>:911(_load_unlocked)
      2264/24    0.005    0.000    2.611    0.109 <frozen 
importlib._bootstrap_external>:993(exec_module)
      6336/48    0.004    0.000    2.591    0.054 <frozen 
importlib._bootstrap>:480(_call_with_frames_removed)
      2400/24    0.023    0.000    2.591    0.108 {built-in method 
builtins.exec}
           24    0.002    0.000    1.927    0.080 generic.py:1(<module>)
      520/320    0.002    0.000    1.429    0.004 {built-in method 
builtins.__import__}
    4312/2896    0.006    0.000    1.190    0.000 <frozen 
importlib._bootstrap>:1390(_handle_fromlist)
            8    0.001    0.000    1.014    0.127 frame.py:1(<module>)
    4392/4312    0.069    0.000    0.953    0.000 {built-in method 
builtins.__build_class__}
         2264    0.030    0.000    0.562    0.000 <frozen 
importlib._bootstrap_external>:1066(get_code)
           16    0.001    0.000    0.551    0.034 indexing.py:1(<module>)
           24    0.001    0.000    0.451    0.019 datetimes.py:1(<module>)
           16    0.001    0.000    0.401    0.025 datetimelike.py:1(<module>)
    
    ============================================================
    Profile of UDF<id=3>
    ============================================================
    Filename: /data/projects/spark/python/test.py
    
    Line #    Mem usage    Increment  Occurrences   Line Contents
    =============================================================
        11   1212.1 MiB   1212.1 MiB           8   def map_func(iter: 
Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
        12   1212.1 MiB     -0.2 MiB          24       for batch in iter:
        13   1212.1 MiB     -0.2 MiB          32           yield 
pa.RecordBatch.from_arrays([pa.compute.add(batch.column("id"), 10)], ["id"])
    ```
    
    ### Why are the changes needed?
    
    To add valuable profiling support to all types of UDFs.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, iterator based Python UDFs can now be profiled with the SQL config 
based profiler.
    
    ### How was this patch tested?
    
    Updated UTs that were specifically testing that this wasn't supported to 
show they are now supported.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #52853 from Kimahriman/udf-iter-profiler.
    
    Authored-by: Adam Binford <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
    (cherry picked from commit 6a369f953b0d134ed5acabe731005b234484b7ca)
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/sql/tests/test_udf_profiler.py      |  88 ++++++++++++-----
 python/pyspark/tests/test_memory_profiler.py       |  16 ++--
 python/pyspark/worker.py                           | 106 +++++++++++++++------
 .../v2/python/UserDefinedPythonDataSource.scala    |   3 +-
 .../python/MapInBatchEvaluatorFactory.scala        |   5 +-
 .../sql/execution/python/MapInBatchExec.scala      |   3 +-
 6 files changed, 157 insertions(+), 64 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udf_profiler.py 
b/python/pyspark/sql/tests/test_udf_profiler.py
index 4e8f722c22cb..37f4a70fabd2 100644
--- a/python/pyspark/sql/tests/test_udf_profiler.py
+++ b/python/pyspark/sql/tests/test_udf_profiler.py
@@ -28,6 +28,7 @@ from typing import Iterator, cast
 from pyspark import SparkConf
 from pyspark.errors import PySparkValueError
 from pyspark.sql import SparkSession
+from pyspark.sql.datasource import DataSource, DataSourceReader
 from pyspark.sql.functions import col, arrow_udf, pandas_udf, udf
 from pyspark.sql.window import Window
 from pyspark.profiler import UDFBasicProfiler
@@ -325,59 +326,47 @@ class UDFProfiler2TestsMixin:
         not have_pandas or not have_pyarrow,
         cast(str, pandas_requirement_message or pyarrow_requirement_message),
     )
-    def test_perf_profiler_pandas_udf_iterator_not_supported(self):
+    def test_perf_profiler_pandas_udf_iterator(self):
         import pandas as pd
 
         @pandas_udf("long")
-        def add1(x):
-            return x + 1
-
-        @pandas_udf("long")
-        def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
+        def add(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
             for s in iter:
-                yield s + 2
+                yield s + 1
 
         with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
-            df = self.spark.range(10, numPartitions=2).select(
-                add1("id"), add2("id"), add1("id"), add2(col("id") + 1)
-            )
+            df = self.spark.range(10, numPartitions=2).select(add("id"))
             df.collect()
 
         self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
-            self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=2)
+            self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=4)
 
     @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
-    def test_perf_profiler_arrow_udf_iterator_not_supported(self):
+    def test_perf_profiler_arrow_udf_iterator(self):
         import pyarrow as pa
 
         @arrow_udf("long")
-        def add1(x):
-            return pa.compute.add(x, 1)
-
-        @arrow_udf("long")
-        def add2(iter: Iterator[pa.Array]) -> Iterator[pa.Array]:
+        def add(iter: Iterator[pa.Array]) -> Iterator[pa.Array]:
             for s in iter:
-                yield pa.compute.add(s, 2)
+                yield pa.compute.add(s, 1)
 
         with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
-            df = self.spark.range(10, numPartitions=2).select(
-                add1("id"), add2("id"), add1("id"), add2(col("id") + 1)
-            )
+            df = self.spark.range(10, numPartitions=2).select(add("id"))
             df.collect()
 
         self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
-            self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=2)
+            self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=4)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
         cast(str, pandas_requirement_message or pyarrow_requirement_message),
     )
-    def test_perf_profiler_map_in_pandas_not_supported(self):
-        df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
+    def test_perf_profiler_map_in_pandas(self):
+        df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", 
"age")).repartition(1)
 
         def filter_func(iterator):
             for pdf in iterator:
@@ -386,7 +375,28 @@ class UDFProfiler2TestsMixin:
         with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
             df.mapInPandas(filter_func, df.schema).show()
 
-        self.assertEqual(0, len(self.profile_results), 
str(self.profile_results.keys()))
+        self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
+
+        for id in self.profile_results:
+            self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=2)
+
+    @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
+    def test_perf_profiler_map_in_arrow(self):
+        import pyarrow as pa
+
+        df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", 
"age")).repartition(1)
+
+        def map_func(iterator: Iterator[pa.RecordBatch]) -> 
Iterator[pa.RecordBatch]:
+            for batch in iterator:
+                yield pa.RecordBatch.from_arrays(
+                    [batch.column("id"), pa.compute.add(batch.column("age"), 
1)], ["id", "age"]
+                )
+
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+            df.mapInArrow(map_func, df.schema).show()
+
+        for id in self.profile_results:
+            self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=2)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
@@ -575,6 +585,34 @@ class UDFProfiler2TestsMixin:
         for id in self.profile_results:
             self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=2)
 
+    def test_perf_profiler_data_source(self):
+        class TestDataSourceReader(DataSourceReader):
+            def __init__(self, schema):
+                self.schema = schema
+
+            def partitions(self):
+                raise NotImplementedError
+
+            def read(self, partition):
+                yield from ((1,), (2,), (3,))
+
+        class TestDataSource(DataSource):
+            def schema(self):
+                return "id long"
+
+            def reader(self, schema) -> "DataSourceReader":
+                return TestDataSourceReader(schema)
+
+        self.spark.dataSource.register(TestDataSource)
+
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+            self.spark.read.format("TestDataSource").load().collect()
+
+        self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
+
+        for id in self.profile_results:
+            self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=4)
+
     def test_perf_profiler_render(self):
         with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
             _do_computation(self.spark)
diff --git a/python/pyspark/tests/test_memory_profiler.py 
b/python/pyspark/tests/test_memory_profiler.py
index df9d63c5260f..1909358aa2bc 100644
--- a/python/pyspark/tests/test_memory_profiler.py
+++ b/python/pyspark/tests/test_memory_profiler.py
@@ -341,12 +341,13 @@ class MemoryProfiler2TestsMixin:
         not have_pandas or not have_pyarrow,
         cast(str, pandas_requirement_message or pyarrow_requirement_message),
     )
-    def test_memory_profiler_pandas_udf_iterator_not_supported(self):
+    def test_memory_profiler_pandas_udf_iterator(self):
         import pandas as pd
 
         @pandas_udf("long")
-        def add1(x):
-            return x + 1
+        def add1(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
+            for s in iter:
+                yield s + 1
 
         @pandas_udf("long")
         def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
@@ -359,7 +360,7 @@ class MemoryProfiler2TestsMixin:
             )
             df.collect()
 
-        self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
+        self.assertEqual(3, len(self.profile_results), 
str(self.profile_results.keys()))
 
         for id in self.profile_results:
             self.assert_udf_memory_profile_present(udf_id=id)
@@ -368,7 +369,7 @@ class MemoryProfiler2TestsMixin:
         not have_pandas or not have_pyarrow,
         cast(str, pandas_requirement_message or pyarrow_requirement_message),
     )
-    def test_memory_profiler_map_in_pandas_not_supported(self):
+    def test_memory_profiler_map_in_pandas(self):
         df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
 
         def filter_func(iterator):
@@ -378,7 +379,10 @@ class MemoryProfiler2TestsMixin:
         with self.sql_conf({"spark.sql.pyspark.udf.profiler": "memory"}):
             df.mapInPandas(filter_func, df.schema).show()
 
-        self.assertEqual(0, len(self.profile_results), 
str(self.profile_results.keys()))
+        self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
+
+        for id in self.profile_results:
+            self.assert_udf_memory_profile_present(udf_id=id)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 09c6a40a33db..6e34b041665a 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1158,17 +1158,19 @@ def wrap_kwargs_support(f, args_offsets, 
kwargs_offsets):
         return f, args_offsets
 
 
-def _supports_profiler(eval_type: int) -> bool:
-    return eval_type not in (
+def _is_iter_based(eval_type: int) -> bool:
+    return eval_type in (
         PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
         PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
         PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
         PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+        PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
     )
 
 
-def wrap_perf_profiler(f, result_id):
+def wrap_perf_profiler(f, eval_type, result_id):
     import cProfile
     import pstats
 
@@ -1178,38 +1180,89 @@ def wrap_perf_profiler(f, result_id):
         SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam
     )
 
-    def profiling_func(*args, **kwargs):
-        with cProfile.Profile() as pr:
-            ret = f(*args, **kwargs)
-        st = pstats.Stats(pr)
-        st.stream = None  # make it picklable
-        st.strip_dirs()
+    if _is_iter_based(eval_type):
+
+        def profiling_func(*args, **kwargs):
+            iterator = iter(f(*args, **kwargs))
+            pr = cProfile.Profile()
+            while True:
+                try:
+                    with pr:
+                        item = next(iterator)
+                    yield item
+                except StopIteration:
+                    break
+
+            st = pstats.Stats(pr)
+            st.stream = None  # make it picklable
+            st.strip_dirs()
+
+            accumulator.add({result_id: (st, None)})
 
-        accumulator.add({result_id: (st, None)})
+    else:
+
+        def profiling_func(*args, **kwargs):
+            with cProfile.Profile() as pr:
+                ret = f(*args, **kwargs)
+            st = pstats.Stats(pr)
+            st.stream = None  # make it picklable
+            st.strip_dirs()
 
-        return ret
+            accumulator.add({result_id: (st, None)})
+
+            return ret
 
     return profiling_func
 
 
-def wrap_memory_profiler(f, result_id):
+def wrap_memory_profiler(f, eval_type, result_id):
     from pyspark.sql.profiler import ProfileResultsParam
     from pyspark.profiler import UDFLineProfilerV2
 
+    if not has_memory_profiler:
+        return f
+
     accumulator = _deserialize_accumulator(
         SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam
     )
 
-    def profiling_func(*args, **kwargs):
-        profiler = UDFLineProfilerV2()
+    if _is_iter_based(eval_type):
 
-        wrapped = profiler(f)
-        ret = wrapped(*args, **kwargs)
-        codemap_dict = {
-            filename: list(line_iterator) for filename, line_iterator in 
profiler.code_map.items()
-        }
-        accumulator.add({result_id: (None, codemap_dict)})
-        return ret
+        def profiling_func(*args, **kwargs):
+            profiler = UDFLineProfilerV2()
+            profiler.add_function(f)
+
+            iterator = iter(f(*args, **kwargs))
+
+            while True:
+                try:
+                    with profiler:
+                        item = next(iterator)
+                    yield item
+                except StopIteration:
+                    break
+
+            codemap_dict = {
+                filename: list(line_iterator)
+                for filename, line_iterator in profiler.code_map.items()
+            }
+            accumulator.add({result_id: (None, codemap_dict)})
+
+    else:
+
+        def profiling_func(*args, **kwargs):
+            profiler = UDFLineProfilerV2()
+            profiler.add_function(f)
+
+            with profiler:
+                ret = f(*args, **kwargs)
+
+            codemap_dict = {
+                filename: list(line_iterator)
+                for filename, line_iterator in profiler.code_map.items()
+            }
+            accumulator.add({result_id: (None, codemap_dict)})
+            return ret
 
     return profiling_func
 
@@ -1254,17 +1307,12 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index, profil
     if profiler == "perf":
         result_id = read_long(infile)
 
-        if _supports_profiler(eval_type):
-            profiling_func = wrap_perf_profiler(chained_func, result_id)
-        else:
-            profiling_func = chained_func
+        profiling_func = wrap_perf_profiler(chained_func, eval_type, result_id)
 
     elif profiler == "memory":
         result_id = read_long(infile)
-        if _supports_profiler(eval_type) and has_memory_profiler:
-            profiling_func = wrap_memory_profiler(chained_func, result_id)
-        else:
-            profiling_func = chained_func
+
+        profiling_func = wrap_memory_profiler(chained_func, eval_type, 
result_id)
     else:
         profiling_func = chained_func
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index c147030037cd..47e64a5b4041 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -172,7 +172,8 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       pythonRunnerConf,
       metrics,
       jobArtifactUUID,
-      sessionUUID)
+      sessionUUID,
+      conf.pythonUDFProfiler)
   }
 
   def createPythonMetrics(): Array[CustomMetric] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 4e78b3035a7e..51909df26a56 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -41,7 +41,8 @@ class MapInBatchEvaluatorFactory(
     pythonRunnerConf: Map[String, String],
     val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
-    sessionUUID: Option[String])
+    sessionUUID: Option[String],
+    profiler: Option[String])
   extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
 
   override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] 
=
@@ -74,7 +75,7 @@ class MapInBatchEvaluatorFactory(
         pythonMetrics,
         jobArtifactUUID,
         sessionUUID,
-        None) with BatchedPythonArrowInput
+        profiler) with BatchedPythonArrowInput
       val columnarBatchIter = pyRunner.compute(batchIter, 
context.partitionId(), context)
 
       val unsafeProj = UnsafeProjection.create(output, output)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 1d03c0cf7603..c4f090674e7c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -70,7 +70,8 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
       pythonRunnerConf,
       pythonMetrics,
       jobArtifactUUID,
-      sessionUUID)
+      sessionUUID,
+      conf.pythonUDFProfiler)
 
     val rdd = if (isBarrier) {
       val rddBarrier = child.execute().barrier()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to