This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d8703dd1c7bb [SPARK-46686][PYTHON][CONNECT] Basic support of 
SparkSession based Python UDF profiler
d8703dd1c7bb is described below

commit d8703dd1c7bbdac24dbcbf3e3eeb9c54595d53c3
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Thu Jan 18 08:59:27 2024 +0900

    [SPARK-46686][PYTHON][CONNECT] Basic support of SparkSession based Python 
UDF profiler
    
    ### What changes were proposed in this pull request?
    
    Basic support of SparkSession based Python UDF profiler.
    
    To enable the profiler, use a SQL conf `spark.sql.pyspark.udf.profiler`:
    
    - `"perf"`: enable cProfiler
    - `"memory"`: enable memory-profiler (TODO: 
[SPARK-46687](https://issues.apache.org/jira/browse/SPARK-46687))
    
    ```py
    from pyspark.sql.functions import *
    
    spark.conf.set("spark.sql.pyspark.udf.profiler", "perf")  # enable cProfiler
    
    udf("string")
    def f(x):
          return str(x)
    
    df = spark.range(10).select(f(col("id")))
    df.collect()
    
    pandas_udf("string")
    def g(x):
         return x.astype("string")
    
    df = spark.range(10).select(g(col("id")))
    
    spark.conf.unset("spark.sql.pyspark.udf.profiler")  # disable
    
    df.collect()  # won't profile
    
    spark.showPerfProfiles()  # show the result for only the first collect.
    ```
    
    ### Why are the changes needed?
    
    The existing UDF profilers are SparkContext based, which can't support 
Spark Connect.
    
    We should introduce SparkSession based profilers and support Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, SparkSession-based UDF profilers will be available.
    
    ### How was this patch tested?
    
    Added the related tests, manually, and existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44697 from ueshin/issues/SPARK-46686/profiler.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../connect/execution/ExecuteThreadRunner.scala    |  22 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |   8 +-
 .../spark/sql/connect/service/SessionHolder.scala  |  10 +
 .../org/apache/spark/api/python/PythonRDD.scala    |  11 +-
 .../org/apache/spark/api/python/PythonRunner.scala |   5 +-
 .../spark/api/python/PythonWorkerUtils.scala       |   4 +-
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/accumulators.py                     |   4 +
 python/pyspark/profiler.py                         |   3 +-
 python/pyspark/sql/_typing.pyi                     |   5 +
 python/pyspark/sql/connect/client/core.py          |  13 +-
 python/pyspark/sql/connect/profiler.py             |  41 ++++
 python/pyspark/sql/connect/session.py              |  10 +
 python/pyspark/sql/profiler.py                     | 176 ++++++++++++++++
 python/pyspark/sql/session.py                      |   8 +
 .../sql/tests/connect/test_parity_udf_profiler.py  |  59 ++++++
 python/pyspark/sql/tests/test_udf_profiler.py      | 232 +++++++++++++++++++--
 python/pyspark/worker.py                           |  97 ++++++++-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  13 ++
 .../v2/python/UserDefinedPythonDataSource.scala    |   2 +-
 .../execution/python/AggregateInPandasExec.scala   |  12 +-
 .../ApplyInPandasWithStatePythonRunner.scala       |   8 +-
 .../sql/execution/python/ArrowEvalPythonExec.scala |  11 +-
 .../sql/execution/python/ArrowPythonRunner.scala   |  20 +-
 .../sql/execution/python/BatchEvalPythonExec.scala |  12 +-
 .../execution/python/BatchEvalPythonUDTFExec.scala |   2 +-
 .../python/CoGroupedArrowPythonRunner.scala        |  11 +-
 .../python/EvalPythonEvaluatorFactory.scala        |  11 +-
 .../python/FlatMapCoGroupsInBatchExec.scala        |   9 +-
 .../python/FlatMapGroupsInBatchExec.scala          |   9 +-
 .../FlatMapGroupsInPandasWithStateExec.scala       |   6 +-
 .../python/MapInBatchEvaluatorFactory.scala        |   7 +-
 .../sql/execution/python/MapInBatchExec.scala      |   5 +-
 .../sql/execution/python/PythonUDFRunner.scala     |  52 +++--
 .../python/WindowInPandasEvaluatorFactory.scala    |  14 +-
 .../sql/execution/python/WindowInPandasExec.scala  |   3 +-
 36 files changed, 797 insertions(+), 119 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index 0ecdc4bdef96..41146e4ef688 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.connect.execution
 
+import scala.jdk.CollectionConverters._
 import scala.util.control.NonFatal
 
 import com.google.protobuf.Message
@@ -185,19 +186,34 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
             s"${executeHolder.request.getPlan.getOpTypeCase} not supported.")
       }
 
-      if (executeHolder.observations.nonEmpty) {
-        val observedMetrics = executeHolder.observations.map { case (name, 
observation) =>
+      val observedMetrics: Map[String, Seq[(Option[String], Any)]] = {
+        executeHolder.observations.map { case (name, observation) =>
           val values = observation.getOrEmpty.map { case (key, value) =>
             (Some(key), value)
           }.toSeq
           name -> values
         }.toMap
+      }
+      val accumulatedInPython: Map[String, Seq[(Option[String], Any)]] = {
+        executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator =>
+          accumulator.synchronized {
+            val value = accumulator.value.asScala.toSeq
+            if (value.nonEmpty) {
+              accumulator.reset()
+              Some("__python_accumulator__" -> value.map(value => (None, 
value)))
+            } else {
+              None
+            }
+          }
+        }.toMap
+      }
+      if (observedMetrics.nonEmpty || accumulatedInPython.nonEmpty) {
         executeHolder.responseObserver.onNext(
           SparkConnectPlanExecution
             .createObservedMetricsResponse(
               executeHolder.sessionHolder.sessionId,
               executeHolder.sessionHolder.serverSessionId,
-              observedMetrics))
+              observedMetrics ++ accumulatedInPython))
       }
 
       lock.synchronized {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 25c78413170e..5be79a090a38 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -972,8 +972,8 @@ class SparkConnectPlanner(
       pythonVer = fun.getPythonVer,
       // Empty broadcast variables
       broadcastVars = Lists.newArrayList(),
-      // Null accumulator
-      accumulator = null)
+      // Accumulator if available
+      accumulator = sessionHolder.pythonAccumulator.orNull)
   }
 
   private def transformCachedRemoteRelation(rel: proto.CachedRemoteRelation): 
LogicalPlan = {
@@ -1680,8 +1680,8 @@ class SparkConnectPlanner(
       pythonVer = fun.getPythonVer,
       // Empty broadcast variables
       broadcastVars = Lists.newArrayList(),
-      // Null accumulator
-      accumulator = null)
+      // Accumulator if available
+      accumulator = sessionHolder.pythonAccumulator.orNull)
   }
 
   /**
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 427b1a50588c..ef79cdcce8ff 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -24,11 +24,13 @@ import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
+import scala.util.Try
 
 import com.google.common.base.Ticker
 import com.google.common.cache.CacheBuilder
 
 import org.apache.spark.{SparkException, SparkSQLException}
+import org.apache.spark.api.python.PythonFunction.PythonAccumulator
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.SparkSession
@@ -371,6 +373,14 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
   private[connect] def listListenerIds(): Seq[String] = {
     listenerCache.keySet().asScala.toSeq
   }
+
+  /**
+   * An accumulator for Python executors.
+   *
+   * The accumulated results will be sent to the Python client via 
observed_metrics message.
+   */
+  private[connect] val pythonAccumulator: Option[PythonAccumulator] =
+    Try(session.sparkContext.collectionAccumulator[Array[Byte]]).toOption
 }
 
 object SessionHolder {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 1d0c905164ad..5aa080b5fb29 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => 
NewInputFormat, OutputFormat
 
 import org.apache.spark._
 import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
+import org.apache.spark.api.python.PythonFunction.PythonAccumulator
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.input.PortableDataStream
 import org.apache.spark.internal.Logging
@@ -83,7 +84,11 @@ private[spark] trait PythonFunction {
   def pythonExec: String
   def pythonVer: String
   def broadcastVars: JList[Broadcast[PythonBroadcast]]
-  def accumulator: PythonAccumulatorV2
+  def accumulator: PythonAccumulator
+}
+
+private[spark] object PythonFunction {
+  type PythonAccumulator = CollectionAccumulator[Array[Byte]]
 }
 
 /**
@@ -96,7 +101,7 @@ private[spark] case class SimplePythonFunction(
     pythonExec: String,
     pythonVer: String,
     broadcastVars: JList[Broadcast[PythonBroadcast]],
-    accumulator: PythonAccumulatorV2) extends PythonFunction {
+    accumulator: PythonAccumulator) extends PythonFunction {
 
   def this(
       command: Array[Byte],
@@ -105,7 +110,7 @@ private[spark] case class SimplePythonFunction(
       pythonExec: String,
       pythonVer: String,
       broadcastVars: JList[Broadcast[PythonBroadcast]],
-      accumulator: PythonAccumulatorV2) = {
+      accumulator: PythonAccumulator) = {
     this(command.toImmutableArraySeq,
       envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, 
accumulator)
   }
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 148f80540d96..17cb0c5a55dd 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -30,6 +30,7 @@ import scala.jdk.CollectionConverters._
 import scala.util.control.NonFatal
 
 import org.apache.spark._
+import org.apache.spark.api.python.PythonFunction.PythonAccumulator
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES, Python}
 import org.apache.spark.internal.config.Python._
@@ -146,10 +147,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     }.getOrElse("pyspark.worker")
 
   // TODO: support accumulator in multiple UDF
-  protected val accumulator: PythonAccumulatorV2 = 
funcs.head.funcs.head.accumulator
+  protected val accumulator: PythonAccumulator = 
funcs.head.funcs.head.accumulator
 
   // Python accumulator is always set in production except in tests. See 
SPARK-27893
-  private val maybeAccumulator: Option[PythonAccumulatorV2] = 
Option(accumulator)
+  private val maybeAccumulator: Option[PythonAccumulator] = Option(accumulator)
 
   // Expose a ServerSocket to support method calls via socket from Python 
side. Only relevant for
   // for tasks that are a part of barrier stage, refer [[BarrierTaskContext]] 
for details.
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
index 3d0687553bb3..ae3614445be6 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
@@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream, File}
 import java.nio.charset.StandardCharsets
 
 import org.apache.spark.{SparkEnv, SparkFiles}
+import org.apache.spark.api.python.PythonFunction.PythonAccumulator
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 
@@ -186,7 +187,8 @@ private[spark] object PythonWorkerUtils extends Logging {
    * The updates are sent by `worker_util.send_accumulator_updates`.
    */
   def receiveAccumulatorUpdates(
-      maybeAccumulator: Option[PythonAccumulatorV2], dataIn: DataInputStream): 
Unit = {
+      maybeAccumulator: Option[PythonAccumulator],
+      dataIn: DataInputStream): Unit = {
     val numAccumulatorUpdates = dataIn.readInt()
     (1 to numAccumulatorUpdates).foreach { _ =>
       val update = readBytes(dataIn)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index ad164b1a8636..0500bf38ea8e 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -998,6 +998,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_parity_column",
         "pyspark.sql.tests.connect.test_parity_readwriter",
         "pyspark.sql.tests.connect.test_parity_udf",
+        "pyspark.sql.tests.connect.test_parity_udf_profiler",
         "pyspark.sql.tests.connect.test_parity_udtf",
         "pyspark.sql.tests.connect.test_parity_pandas_udf",
         "pyspark.sql.tests.connect.test_parity_pandas_map",
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index a95bd9debfc2..4f61a9fbd9f7 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -57,6 +57,10 @@ def _deserialize_accumulator(
         return accum
 
 
+class SpecialAccumulatorIds:
+    SQL_UDF_PROFIER = -1
+
+
 class Accumulator(Generic[T]):
 
     """
diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
index b7ea6a19063b..b5f1bc4d714d 100644
--- a/python/pyspark/profiler.py
+++ b/python/pyspark/profiler.py
@@ -422,8 +422,9 @@ class MemoryProfiler(Profiler):
         """Return the collected memory profiles"""
         return cast(CodeMapDict, self._accumulator.value)
 
+    @staticmethod
     def _show_results(
-        self, code_map: CodeMapDict, stream: Optional[Any] = None, precision: 
int = 1
+        code_map: CodeMapDict, stream: Optional[Any] = None, precision: int = 1
     ) -> None:
         if stream is None:
             stream = sys.stdout
diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi
index cee44c4aa069..b696eea7293f 100644
--- a/python/pyspark/sql/_typing.pyi
+++ b/python/pyspark/sql/_typing.pyi
@@ -19,6 +19,7 @@
 from typing import (
     Any,
     Callable,
+    Dict,
     List,
     Optional,
     Tuple,
@@ -29,8 +30,10 @@ from typing_extensions import Literal, Protocol
 
 import datetime
 import decimal
+import pstats
 
 from pyspark._typing import PrimitiveType
+from pyspark.profiler import CodeMapDict
 import pyspark.sql.types
 from pyspark.sql.column import Column
 
@@ -79,3 +82,5 @@ class UserDefinedFunctionLike(Protocol):
     def returnType(self) -> pyspark.sql.types.DataType: ...
     def __call__(self, *args: ColumnOrName) -> Column: ...
     def asNondeterministic(self) -> UserDefinedFunctionLike: ...
+
+ProfileResults = Dict[int, Tuple[Optional[pstats.Stats], 
Optional[CodeMapDict]]]
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 10235fd7d6c4..c1c046e93708 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -54,11 +54,13 @@ import grpc
 from google.protobuf import text_format, any_pb2
 from google.rpc import error_details_pb2
 
+from pyspark.accumulators import SpecialAccumulatorIds
 from pyspark.loose_version import LooseVersion
 from pyspark.version import __version__
 from pyspark.resource.information import ResourceInformation
 from pyspark.sql.connect.client.artifact import ArtifactManager
 from pyspark.sql.connect.client.logging import logger
+from pyspark.sql.connect.profiler import ConnectProfilerCollector
 from pyspark.sql.connect.client.reattach import 
ExecutePlanResponseReattachableIterator
 from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, 
DefaultPolicy
 from pyspark.sql.connect.conversion import storage_level_to_proto, 
proto_to_storage_level
@@ -636,6 +638,8 @@ class SparkConnectClient(object):
         # be updated on the first response received.
         self._server_session_id: Optional[str] = None
 
+        self._profiler_collector = ConnectProfilerCollector()
+
     def _retrying(self) -> "Retrying":
         return Retrying(self._retry_policies)
 
@@ -1169,7 +1173,14 @@ class SparkConnectClient(object):
             if b.observed_metrics:
                 logger.debug("Received observed metric batch.")
                 for observed_metrics in 
self._build_observed_metrics(b.observed_metrics):
-                    if observed_metrics.name in observations:
+                    if observed_metrics.name == "__python_accumulator__":
+                        from pyspark.worker_util import pickleSer
+
+                        for metric in observed_metrics.metrics:
+                            (aid, update) = 
pickleSer.loads(LiteralExpression._to_value(metric))
+                            if aid == SpecialAccumulatorIds.SQL_UDF_PROFIER:
+                                self._profiler_collector._update(update)
+                    elif observed_metrics.name in observations:
                         observation_result = 
observations[observed_metrics.name]._result
                         assert observation_result is not None
                         observation_result.update(
diff --git a/python/pyspark/sql/connect/profiler.py 
b/python/pyspark/sql/connect/profiler.py
new file mode 100644
index 000000000000..b8825cf5678e
--- /dev/null
+++ b/python/pyspark/sql/connect/profiler.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import TYPE_CHECKING
+
+from pyspark.sql.profiler import ProfilerCollector, ProfileResultsParam
+
+if TYPE_CHECKING:
+    from pyspark.sql._typing import ProfileResults
+
+
+class ConnectProfilerCollector(ProfilerCollector):
+    """
+    ProfilerCollector for Spark Connect.
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self._value = ProfileResultsParam.zero(None)
+
+    @property
+    def _profile_results(self) -> "ProfileResults":
+        with self._lock:
+            return self._value if self._value is not None else {}
+
+    def _update(self, update: "ProfileResults") -> None:
+        with self._lock:
+            self._value = 
ProfileResultsParam.addInPlace(self._profile_results, update)
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index a27e6fa4b729..5cbcb4ab5c35 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -62,6 +62,7 @@ from pyspark.sql.connect.plan import (
     CachedRelation,
     CachedRemoteRelation,
 )
+from pyspark.sql.connect.profiler import ProfilerCollector
 from pyspark.sql.connect.readwriter import DataFrameReader
 from pyspark.sql.connect.streaming.readwriter import DataStreamReader
 from pyspark.sql.connect.streaming.query import StreamingQueryManager
@@ -919,6 +920,15 @@ class SparkSession:
     def session_id(self) -> str:
         return self._session_id
 
+    @property
+    def _profiler_collector(self) -> ProfilerCollector:
+        return self._client._profiler_collector
+
+    def showPerfProfiles(self, id: Optional[int] = None) -> None:
+        self._profiler_collector.show_perf_profiles(id)
+
+    showPerfProfiles.__doc__ = PySparkSession.showPerfProfiles.__doc__
+
 
 SparkSession.__doc__ = PySparkSession.__doc__
 
diff --git a/python/pyspark/sql/profiler.py b/python/pyspark/sql/profiler.py
new file mode 100644
index 000000000000..565752197238
--- /dev/null
+++ b/python/pyspark/sql/profiler.py
@@ -0,0 +1,176 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from abc import ABC, abstractmethod
+import pstats
+from threading import RLock
+from typing import Dict, Optional, TYPE_CHECKING
+
+from pyspark.accumulators import (
+    Accumulator,
+    AccumulatorParam,
+    SpecialAccumulatorIds,
+    _accumulatorRegistry,
+)
+from pyspark.profiler import CodeMapDict, MemoryProfiler, MemUsageParam, 
PStatsParam
+
+if TYPE_CHECKING:
+    from pyspark.sql._typing import ProfileResults
+
+
+class _ProfileResultsParam(AccumulatorParam[Optional["ProfileResults"]]):
+    """
+    AccumulatorParam for profilers.
+    """
+
+    @staticmethod
+    def zero(value: Optional["ProfileResults"]) -> Optional["ProfileResults"]:
+        return value
+
+    @staticmethod
+    def addInPlace(
+        value1: Optional["ProfileResults"], value2: Optional["ProfileResults"]
+    ) -> Optional["ProfileResults"]:
+        if value1 is None or len(value1) == 0:
+            value1 = {}
+        if value2 is None or len(value2) == 0:
+            value2 = {}
+
+        value = value1.copy()
+        for key, (perf, mem, *_) in value2.items():
+            if key in value1:
+                orig_perf, orig_mem, *_ = value1[key]
+            else:
+                orig_perf, orig_mem = (PStatsParam.zero(None), 
MemUsageParam.zero(None))
+            value[key] = (
+                PStatsParam.addInPlace(orig_perf, perf),
+                MemUsageParam.addInPlace(orig_mem, mem),
+            )
+        return value
+
+
+ProfileResultsParam = _ProfileResultsParam()
+
+
+class ProfilerCollector(ABC):
+    """
+    A base class of profiler collectors for session based profilers.
+
+    This supports cProfiler and memory-profiler enabled by setting a SQL config
+    `spark.sql.pyspark.udf.profiler` to "perf" or "memory".
+    """
+
+    def __init__(self) -> None:
+        self._lock = RLock()
+
+    def show_perf_profiles(self, id: Optional[int] = None) -> None:
+        """
+        Show the perf profile results.
+
+        .. versionadded:: 4.0.0
+
+        Parameters
+        ----------
+        id : int, optional
+            A UDF ID to be shown. If not specified, all the results will be 
shown.
+        """
+        with self._lock:
+            stats = self._perf_profile_results
+
+        def show(id: int) -> None:
+            s = stats.get(id)
+            if s is not None:
+                print("=" * 60)
+                print(f"Profile of UDF<id={id}>")
+                print("=" * 60)
+                s.sort_stats("time", "cumulative").print_stats()
+
+        if id is not None:
+            show(id)
+        else:
+            for id in sorted(stats.keys()):
+                show(id)
+
+    @property
+    def _perf_profile_results(self) -> Dict[int, pstats.Stats]:
+        with self._lock:
+            return {
+                result_id: perf
+                for result_id, (perf, _, *_) in self._profile_results.items()
+                if perf is not None
+            }
+
+    def show_memory_profiles(self, id: Optional[int] = None) -> None:
+        """
+        Show the memory profile results.
+
+        .. versionadded:: 4.0.0
+
+        Parameters
+        ----------
+        id : int, optional
+            A UDF ID to be shown. If not specified, all the results will be 
shown.
+        """
+        with self._lock:
+            code_map = self._memory_profile_results
+
+        def show(id: int) -> None:
+            cm = code_map.get(id)
+            if cm is not None:
+                print("=" * 60)
+                print(f"Profile of UDF<id={id}>")
+                print("=" * 60)
+                MemoryProfiler._show_results(cm)
+
+        if id is not None:
+            show(id)
+        else:
+            for id in sorted(code_map.keys()):
+                show(id)
+
+    @property
+    def _memory_profile_results(self) -> Dict[int, CodeMapDict]:
+        with self._lock:
+            return {
+                result_id: mem
+                for result_id, (_, mem, *_) in self._profile_results.items()
+                if mem is not None
+            }
+
+    @property
+    @abstractmethod
+    def _profile_results(self) -> "ProfileResults":
+        """
+        Get the profile results.
+        """
+        ...
+
+
+class AccumulatorProfilerCollector(ProfilerCollector):
+    def __init__(self) -> None:
+        super().__init__()
+        if SpecialAccumulatorIds.SQL_UDF_PROFIER in _accumulatorRegistry:
+            self._accumulator = 
_accumulatorRegistry[SpecialAccumulatorIds.SQL_UDF_PROFIER]
+        else:
+            self._accumulator = Accumulator(
+                SpecialAccumulatorIds.SQL_UDF_PROFIER, None, 
ProfileResultsParam
+            )
+
+    @property
+    def _profile_results(self) -> "ProfileResults":
+        with self._lock:
+            value = self._accumulator.value
+            return value if value is not None else {}
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 10b56d006dcd..fef834b9f0a0 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -47,6 +47,7 @@ from pyspark.sql.conf import RuntimeConfig
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.functions import lit
 from pyspark.sql.pandas.conversion import SparkConversionMixin
+from pyspark.sql.profiler import AccumulatorProfilerCollector, 
ProfilerCollector
 from pyspark.sql.readwriter import DataFrameReader
 from pyspark.sql.sql_formatter import SQLStringFormatter
 from pyspark.sql.streaming import DataStreamReader
@@ -623,6 +624,8 @@ class SparkSession(SparkConversionMixin):
             self._jvm.SparkSession.setDefaultSession(self._jsparkSession)
             self._jvm.SparkSession.setActiveSession(self._jsparkSession)
 
+        self._profiler_collector = AccumulatorProfilerCollector()
+
     def _repr_html_(self) -> str:
         return """
             <div>
@@ -2110,6 +2113,11 @@ class SparkSession(SparkConversionMixin):
             message_parameters={"feature": "SparkSession.clearTags"},
         )
 
+    def showPerfProfiles(self, id: Optional[int] = None) -> None:
+        self._profiler_collector.show_perf_profiles(id)
+
+    showPerfProfiles.__doc__ = ProfilerCollector.show_perf_profiles.__doc__
+
 
 def _test() -> None:
     import os
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py 
b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py
new file mode 100644
index 000000000000..463d92441094
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py
@@ -0,0 +1,59 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import inspect
+import os
+import unittest
+
+from pyspark.sql.tests.test_udf_profiler import UDFProfiler2TestsMixin, 
_do_computation
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class UDFProfilerParityTests(UDFProfiler2TestsMixin, ReusedConnectTestCase):
+    def setUp(self) -> None:
+        super().setUp()
+        self.spark._profiler_collector._value = None
+
+    def test_perf_profiler_udf_multiple_actions(self):
+        def action(df):
+            df.collect()
+            df.show()
+
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+            _do_computation(self.spark, action=action)
+
+        self.assertEqual(6, len(self.profile_results), 
str(list(self.profile_results)))
+
+        for id in self.profile_results:
+            with self.trap_stdout() as io:
+                self.spark.showPerfProfiles(id)
+
+            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
+            self.assertRegex(
+                io.getvalue(), 
f"10.*{os.path.basename(inspect.getfile(_do_computation))}"
+            )
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.test_parity_udf_profiler import *  # noqa: 
F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_udf_profiler.py 
b/python/pyspark/sql/tests/test_udf_profiler.py
index 776d5da88bb2..d365523e456c 100644
--- a/python/pyspark/sql/tests/test_udf_profiler.py
+++ b/python/pyspark/sql/tests/test_udf_profiler.py
@@ -15,18 +15,40 @@
 # limitations under the License.
 #
 
+from contextlib import contextmanager
+import inspect
 import tempfile
 import unittest
 import os
 import sys
 import warnings
 from io import StringIO
-from typing import Iterator
+from typing import Iterator, cast
 
 from pyspark import SparkConf
 from pyspark.sql import SparkSession
-from pyspark.sql.functions import udf, pandas_udf
+from pyspark.sql.functions import col, pandas_udf, udf
 from pyspark.profiler import UDFBasicProfiler
+from pyspark.testing.sqlutils import (
+    ReusedSQLTestCase,
+    have_pandas,
+    have_pyarrow,
+    pandas_requirement_message,
+    pyarrow_requirement_message,
+)
+
+
+def _do_computation(spark, *, action=lambda df: df.collect(), use_arrow=False):
+    @udf("long", useArrow=use_arrow)
+    def add1(x):
+        return x + 1
+
+    @udf("long", useArrow=use_arrow)
+    def add2(x):
+        return x + 2
+
+    df = spark.range(10).select(add1("id"), add2("id"), add1("id"), 
add2(col("id") + 1))
+    action(df)
 
 
 class UDFProfilerTests(unittest.TestCase):
@@ -47,10 +69,10 @@ class UDFProfilerTests(unittest.TestCase):
         sys.path = self._old_sys_path
 
     def test_udf_profiler(self):
-        self.do_computation()
+        _do_computation(self.spark)
 
         profilers = self.sc.profiler_collector.profilers
-        self.assertEqual(3, len(profilers))
+        self.assertEqual(4, len(profilers))
 
         old_stdout = sys.stdout
         try:
@@ -62,7 +84,7 @@ class UDFProfilerTests(unittest.TestCase):
         d = tempfile.gettempdir()
         self.sc.dump_profiles(d)
 
-        for i, udf_name in enumerate(["add1", "add2", "add1"]):
+        for i, udf_name in enumerate(["add1", "add2", "add1", "add2"]):
             id, profiler, _ = profilers[i]
             with self.subTest(id=id, udf_name=udf_name):
                 stats = profiler.stats()
@@ -81,28 +103,16 @@ class UDFProfilerTests(unittest.TestCase):
 
         self.sc.profiler_collector.udf_profiler_cls = TestCustomProfiler
 
-        self.do_computation()
+        _do_computation(self.spark)
 
         profilers = self.sc.profiler_collector.profilers
-        self.assertEqual(3, len(profilers))
+        self.assertEqual(4, len(profilers))
         _, profiler, _ = profilers[0]
         self.assertTrue(isinstance(profiler, TestCustomProfiler))
 
         self.sc.show_profiles()
         self.assertEqual("Custom formatting", profiler.result)
 
-    def do_computation(self):
-        @udf
-        def add1(x):
-            return x + 1
-
-        @udf
-        def add2(x):
-            return x + 2
-
-        df = self.spark.range(10)
-        df.select(add1("id"), add2("id"), add1("id")).collect()
-
     # Unsupported
     def exec_pandas_udf_iter_to_iter(self):
         import pandas as pd
@@ -145,6 +155,190 @@ class UDFProfilerTests(unittest.TestCase):
             )
 
 
+class UDFProfiler2TestsMixin:
+    @contextmanager
+    def trap_stdout(self):
+        old_stdout = sys.stdout
+        sys.stdout = io = StringIO()
+        try:
+            yield io
+        finally:
+            sys.stdout = old_stdout
+
+    @property
+    def profile_results(self):
+        return self.spark._profiler_collector._perf_profile_results
+
+    def test_perf_profiler_udf(self):
+        _do_computation(self.spark)
+
+        # Without the conf enabled, no profile results are collected.
+        self.assertEqual(0, len(self.profile_results), 
str(list(self.profile_results)))
+
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+            _do_computation(self.spark)
+
+        self.assertEqual(3, len(self.profile_results), 
str(list(self.profile_results)))
+
+        with self.trap_stdout() as io_all:
+            self.spark.showPerfProfiles()
+
+        for id in self.profile_results:
+            self.assertIn(f"Profile of UDF<id={id}>", io_all.getvalue())
+
+            with self.trap_stdout() as io:
+                self.spark.showPerfProfiles(id)
+
+            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
+            self.assertRegex(
+                io.getvalue(), 
f"10.*{os.path.basename(inspect.getfile(_do_computation))}"
+            )
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        cast(str, pandas_requirement_message or pyarrow_requirement_message),
+    )
+    def test_perf_profiler_udf_with_arrow(self):
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+            _do_computation(self.spark, use_arrow=True)
+
+        self.assertEqual(3, len(self.profile_results), 
str(list(self.profile_results)))
+
+        for id in self.profile_results:
+            with self.trap_stdout() as io:
+                self.spark.showPerfProfiles(id)
+
+            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
+            self.assertRegex(
+                io.getvalue(), 
f"10.*{os.path.basename(inspect.getfile(_do_computation))}"
+            )
+
+    def test_perf_profiler_udf_multiple_actions(self):
+        def action(df):
+            df.collect()
+            df.show()
+
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+            _do_computation(self.spark, action=action)
+
+        self.assertEqual(3, len(self.profile_results), 
str(list(self.profile_results)))
+
+        for id in self.profile_results:
+            with self.trap_stdout() as io:
+                self.spark.showPerfProfiles(id)
+
+            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
+            self.assertRegex(
+                io.getvalue(), 
f"20.*{os.path.basename(inspect.getfile(_do_computation))}"
+            )
+
+    def test_perf_profiler_udf_registered(self):
+        @udf("long")
+        def add1(x):
+            return x + 1
+
+        self.spark.udf.register("add1", add1)
+
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+            self.spark.sql("SELECT id, add1(id) add1 FROM range(10)").collect()
+
+        self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
+
+        for id in self.profile_results:
+            with self.trap_stdout() as io:
+                self.spark.showPerfProfiles(id)
+
+            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
+            self.assertRegex(
+                io.getvalue(), 
f"10.*{os.path.basename(inspect.getfile(_do_computation))}"
+            )
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        cast(str, pandas_requirement_message or pyarrow_requirement_message),
+    )
+    def test_perf_profiler_pandas_udf(self):
+        @pandas_udf("long")
+        def add1(x):
+            return x + 1
+
+        @pandas_udf("long")
+        def add2(x):
+            return x + 2
+
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+            df = self.spark.range(10, numPartitions=2).select(
+                add1("id"), add2("id"), add1("id"), add2(col("id") + 1)
+            )
+            df.collect()
+
+        self.assertEqual(3, len(self.profile_results), 
str(self.profile_results.keys()))
+
+        for id in self.profile_results:
+            with self.trap_stdout() as io:
+                self.spark.showPerfProfiles(id)
+
+            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
+            self.assertRegex(
+                io.getvalue(), 
f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
+            )
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        cast(str, pandas_requirement_message or pyarrow_requirement_message),
+    )
+    def test_perf_profiler_pandas_udf_iterator_not_supported(self):
+        import pandas as pd
+
+        @pandas_udf("long")
+        def add1(x):
+            return x + 1
+
+        @pandas_udf("long")
+        def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
+            for s in iter:
+                yield s + 2
+
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+            df = self.spark.range(10, numPartitions=2).select(
+                add1("id"), add2("id"), add1("id"), add2(col("id") + 1)
+            )
+            df.collect()
+
+        self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
+
+        for id in self.profile_results:
+            with self.trap_stdout() as io:
+                self.spark.showPerfProfiles(id)
+
+            self.assertIn(f"Profile of UDF<id={id}>", io.getvalue())
+            self.assertRegex(
+                io.getvalue(), 
f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
+            )
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        cast(str, pandas_requirement_message or pyarrow_requirement_message),
+    )
+    def test_perf_profiler_map_in_pandas_not_supported(self):
+        df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
+
+        def filter_func(iterator):
+            for pdf in iterator:
+                yield pdf[pdf.id == 1]
+
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+            df.mapInPandas(filter_func, df.schema).show()
+
+        self.assertEqual(0, len(self.profile_results), 
str(self.profile_results.keys()))
+
+
+class UDFProfiler2Tests(UDFProfiler2TestsMixin, ReusedSQLTestCase):
+    def setUp(self) -> None:
+        super().setUp()
+        self.spark._profiler_collector._accumulator._value = None
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests.test_udf_profiler import *  # noqa: F401
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 158e3ae62bb7..3e3592a8ffa2 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -27,7 +27,11 @@ import json
 from typing import Any, Callable, Iterable, Iterator, Optional
 import faulthandler
 
-from pyspark.accumulators import _accumulatorRegistry
+from pyspark.accumulators import (
+    SpecialAccumulatorIds,
+    _accumulatorRegistry,
+    _deserialize_accumulator,
+)
 from pyspark.java_gateway import local_connect_and_auth
 from pyspark.taskcontext import BarrierTaskContext, TaskContext
 from pyspark.resource import ResourceInformation
@@ -688,7 +692,40 @@ def wrap_kwargs_support(f, args_offsets, kwargs_offsets):
         return f, args_offsets
 
 
-def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
+def _supports_profiler(eval_type: int) -> bool:
+    return eval_type not in (
+        PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
+        PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
+        PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+    )
+
+
+def wrap_perf_profiler(f, result_id):
+    import cProfile
+    import pstats
+
+    from pyspark.sql.profiler import ProfileResultsParam
+
+    accumulator = _deserialize_accumulator(
+        SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam
+    )
+
+    def profiling_func(*args, **kwargs):
+        pr = cProfile.Profile()
+        ret = pr.runcall(f, *args, **kwargs)
+        st = pstats.Stats(pr)
+        st.stream = None  # make it picklable
+        st.strip_dirs()
+
+        accumulator.add({result_id: (st, None)})
+
+        return ret
+
+    return profiling_func
+
+
+def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, 
profiler):
     num_arg = read_int(infile)
 
     if eval_type in (
@@ -721,15 +758,31 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
         else:
             chained_func = chain(chained_func, f)
 
+    if profiler == "perf":
+        result_id = read_long(infile)
+
+        if _supports_profiler(eval_type):
+            profiling_func = wrap_perf_profiler(chained_func, result_id)
+        else:
+            profiling_func = chained_func
+
+    elif profiler == "memory":
+        # TODO(SPARK-46687): Implement memory profiler
+        result_id = read_long(infile)
+        profiling_func = chained_func
+
+    else:
+        profiling_func = chained_func
+
     if eval_type in (
         PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
         PythonEvalType.SQL_ARROW_BATCHED_UDF,
     ):
-        func = chained_func
+        func = profiling_func
     else:
         # make sure StopIteration's raised in the user code are not ignored
         # when they are processed in a for loop, raise them as RuntimeError's 
instead
-        func = fail_on_stopiteration(chained_func)
+        func = fail_on_stopiteration(profiling_func)
 
     # the last returnType will be the return type of UDF
     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
@@ -1403,6 +1456,12 @@ def read_udfs(pickleSer, infile, eval_type):
     else:
         ser = BatchedSerializer(CPickleSerializer(), 100)
 
+    is_profiling = read_bool(infile)
+    if is_profiling:
+        profiler = utf8_deserializer.loads(infile)
+    else:
+        profiler = None
+
     num_udfs = read_int(infile)
 
     is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
@@ -1417,7 +1476,9 @@ def read_udfs(pickleSer, infile, eval_type):
         if is_map_arrow_iter:
             assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."
 
-        arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, udf = read_single_udf(
+            pickleSer, infile, eval_type, runner_conf, udf_index=0, 
profiler=profiler
+        )
 
         def func(_, iterator):
             num_input_rows = 0
@@ -1507,7 +1568,9 @@ def read_udfs(pickleSer, infile, eval_type):
 
         # See FlatMapGroupsInPandasExec for how arg_offsets are used to
         # distinguish between grouping attributes and data attributes
-        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, f = read_single_udf(
+            pickleSer, infile, eval_type, runner_conf, udf_index=0, 
profiler=profiler
+        )
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
         # Create function like this:
@@ -1526,7 +1589,9 @@ def read_udfs(pickleSer, infile, eval_type):
 
         # See FlatMapGroupsInPandasExec for how arg_offsets are used to
         # distinguish between grouping attributes and data attributes
-        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, f = read_single_udf(
+            pickleSer, infile, eval_type, runner_conf, udf_index=0, 
profiler=profiler
+        )
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
         def batch_from_offset(batch, offsets):
@@ -1550,7 +1615,9 @@ def read_udfs(pickleSer, infile, eval_type):
 
         # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are 
used to
         # distinguish between grouping attributes and data attributes
-        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, f = read_single_udf(
+            pickleSer, infile, eval_type, runner_conf, udf_index=0, 
profiler=profiler
+        )
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
         def mapper(a):
@@ -1584,7 +1651,9 @@ def read_udfs(pickleSer, infile, eval_type):
         # We assume there is only one UDF here because cogrouped map doesn't
         # support combining multiple UDFs.
         assert num_udfs == 1
-        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, f = read_single_udf(
+            pickleSer, infile, eval_type, runner_conf, udf_index=0, 
profiler=profiler
+        )
 
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
@@ -1601,7 +1670,9 @@ def read_udfs(pickleSer, infile, eval_type):
         # We assume there is only one UDF here because cogrouped map doesn't
         # support combining multiple UDFs.
         assert num_udfs == 1
-        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, f = read_single_udf(
+            pickleSer, infile, eval_type, runner_conf, udf_index=0, 
profiler=profiler
+        )
 
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
@@ -1624,7 +1695,11 @@ def read_udfs(pickleSer, infile, eval_type):
     else:
         udfs = []
         for i in range(num_udfs):
-            udfs.append(read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=i))
+            udfs.append(
+                read_single_udf(
+                    pickleSer, infile, eval_type, runner_conf, udf_index=i, 
profiler=profiler
+                )
+            )
 
         def mapper(a):
             result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f 
in udfs)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 743a2e20c885..eb5233bfb123 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2927,6 +2927,17 @@ object SQLConf {
       // show full stacktrace in tests but hide in production by default.
       .createWithDefault(Utils.isTesting)
 
+  val PYTHON_UDF_PROFILER =
+    buildConf("spark.sql.pyspark.udf.profiler")
+      .doc("Configure the Python/Pandas UDF profiler by enabling or disabling 
it " +
+        "with the option to choose between \"perf\" and \"memory\" types, " +
+        "or unsetting the config disables the profiler. This is disabled by 
default.")
+      .version("4.0.0")
+      .stringConf
+      .transform(_.toLowerCase(Locale.ROOT))
+      .checkValues(Set("perf", "memory"))
+      .createOptional
+
   val PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED =
     buildConf("spark.sql.execution.pyspark.udf.faulthandler.enabled")
       .doc(
@@ -5296,6 +5307,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def pysparkJVMStacktraceEnabled: Boolean = 
getConf(PYSPARK_JVM_STACKTRACE_ENABLED)
 
+  def pythonUDFProfiler: Option[String] = getConf(PYTHON_UDF_PROFILER)
+
   def pythonUDFWorkerFaulthandlerEnabled: Boolean = 
getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED)
 
   def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index f11a63429d78..0e5f359ee76f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -128,7 +128,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
     val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
     new MapInBatchEvaluatorFactory(
       toAttributes(outputSchema),
-      Seq(ChainedPythonFunctions(Seq(pythonUDF.func))),
+      Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), 
pythonUDF.resultId.id)),
       inputSchema,
       conf.arrowMaxRecordsPerBatch,
       pythonEvalType,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
index 7e349b665f35..876373177447 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -85,15 +85,15 @@ case class AggregateInPandasExec(
   }
 
   private def collectFunctions(
-      udf: PythonFuncExpression): (ChainedPythonFunctions, Seq[Expression]) = {
+      udf: PythonFuncExpression): ((ChainedPythonFunctions, Long), 
Seq[Expression]) = {
     udf.children match {
       case Seq(u: PythonFuncExpression) =>
-        val (chained, children) = collectFunctions(u)
-        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+        val ((chained, _), children) = collectFunctions(u)
+        ((ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), 
udf.resultId.id), children)
       case children =>
         // There should not be any other UDFs, or the children can't be 
evaluated directly.
         
assert(children.forall(!_.exists(_.isInstanceOf[PythonFuncExpression])))
-        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+        ((ChainedPythonFunctions(Seq(udf.func)), udf.resultId.id), 
udf.children)
     }
   }
 
@@ -180,7 +180,9 @@ case class AggregateInPandasExec(
         largeVarTypes,
         pythonRunnerConf,
         pythonMetrics,
-        jobArtifactUUID).compute(projectedRowIter, context.partitionId(), 
context)
+        jobArtifactUUID,
+        None) // TODO(SPARK-46688): Support profiling on AggregateInPandasExec
+        .compute(projectedRowIter, context.partitionId(), context)
 
       val joinedAttributes =
         groupingExpressions.map(_.toAttribute) ++ 
aggExpressions.map(_.resultAttribute)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index 8795374b2a72..8eeb919d0baf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -51,7 +51,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, 
ColumnarBatch}
  * and output along with data, which requires different struct on Arrow 
RecordBatch.
  */
 class ApplyInPandasWithStatePythonRunner(
-    funcs: Seq[ChainedPythonFunctions],
+    funcs: Seq[(ChainedPythonFunctions, Long)],
     evalType: Int,
     argOffsets: Array[Array[Int]],
     inputSchema: StructType,
@@ -63,13 +63,13 @@ class ApplyInPandasWithStatePythonRunner(
     stateValueSchema: StructType,
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String])
-  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets, 
jobArtifactUUID)
+  extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
   with PythonArrowInput[InType]
   with PythonArrowOutput[OutType] {
 
   override val pythonExec: String =
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
-      funcs.head.funcs.head.pythonExec)
+      funcs.head._1.funcs.head.pythonExec)
 
   override val faultHandlerEnabled: Boolean = 
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
 
@@ -108,7 +108,7 @@ class ApplyInPandasWithStatePythonRunner(
   private val stateRowDeserializer = stateEncoder.createDeserializer()
 
   override protected def writeUDF(dataOut: DataOutputStream): Unit = {
-    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, None)
   }
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index a6937b7bf89c..da4c5bff3445 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -77,7 +77,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], 
resultAttrs: Seq[Attribute]
       conf.arrowUseLargeVarTypes,
       ArrowPythonRunner.getPythonRunnerConfMap(conf),
       pythonMetrics,
-      jobArtifactUUID)
+      jobArtifactUUID,
+      conf.pythonUDFProfiler)
   }
 
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
@@ -94,11 +95,12 @@ class ArrowEvalPythonEvaluatorFactory(
     largeVarTypes: Boolean,
     pythonRunnerConf: Map[String, String],
     pythonMetrics: Map[String, SQLMetric],
-    jobArtifactUUID: Option[String])
+    jobArtifactUUID: Option[String],
+    profiler: Option[String])
   extends EvalPythonEvaluatorFactory(childOutput, udfs, output) {
 
   override def evaluate(
-      funcs: Seq[ChainedPythonFunctions],
+      funcs: Seq[(ChainedPythonFunctions, Long)],
       argMetas: Array[Array[ArgumentMetadata]],
       iter: Iterator[InternalRow],
       schema: StructType,
@@ -118,7 +120,8 @@ class ArrowEvalPythonEvaluatorFactory(
       largeVarTypes,
       pythonRunnerConf,
       pythonMetrics,
-      jobArtifactUUID).compute(batchIter, context.partitionId(), context)
+      jobArtifactUUID,
+      profiler).compute(batchIter, context.partitionId(), context)
 
     columnarBatchIter.flatMap { batch =>
       val actualDataTypes = (0 until batch.numCols()).map(i => 
batch.column(i).dataType())
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 33933b64bbaf..a555d660ea1a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
 abstract class BaseArrowPythonRunner(
-    funcs: Seq[ChainedPythonFunctions],
+    funcs: Seq[(ChainedPythonFunctions, Long)],
     evalType: Int,
     argOffsets: Array[Array[Int]],
     _schema: StructType,
@@ -38,13 +38,13 @@ abstract class BaseArrowPythonRunner(
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String])
   extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
-    funcs, evalType, argOffsets, jobArtifactUUID)
+    funcs.map(_._1), evalType, argOffsets, jobArtifactUUID)
   with BasicPythonArrowInput
   with BasicPythonArrowOutput {
 
   override val pythonExec: String =
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
-      funcs.head.funcs.head.pythonExec)
+      funcs.head._1.funcs.head.pythonExec)
 
   override val faultHandlerEnabled: Boolean = 
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
 
@@ -67,7 +67,7 @@ abstract class BaseArrowPythonRunner(
  * Similar to `PythonUDFRunner`, but exchange data with Python worker via 
Arrow stream.
  */
 class ArrowPythonRunner(
-    funcs: Seq[ChainedPythonFunctions],
+    funcs: Seq[(ChainedPythonFunctions, Long)],
     evalType: Int,
     argOffsets: Array[Array[Int]],
     _schema: StructType,
@@ -75,13 +75,14 @@ class ArrowPythonRunner(
     largeVarTypes: Boolean,
     workerConf: Map[String, String],
     pythonMetrics: Map[String, SQLMetric],
-    jobArtifactUUID: Option[String])
+    jobArtifactUUID: Option[String],
+    profiler: Option[String])
   extends BaseArrowPythonRunner(
     funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, 
workerConf,
     pythonMetrics, jobArtifactUUID) {
 
   override protected def writeUDF(dataOut: DataOutputStream): Unit =
-    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
 }
 
 /**
@@ -89,7 +90,7 @@ class ArrowPythonRunner(
  * via Arrow stream.
  */
 class ArrowPythonWithNamedArgumentRunner(
-    funcs: Seq[ChainedPythonFunctions],
+    funcs: Seq[(ChainedPythonFunctions, Long)],
     evalType: Int,
     argMetas: Array[Array[ArgumentMetadata]],
     _schema: StructType,
@@ -97,13 +98,14 @@ class ArrowPythonWithNamedArgumentRunner(
     largeVarTypes: Boolean,
     workerConf: Map[String, String],
     pythonMetrics: Map[String, SQLMetric],
-    jobArtifactUUID: Option[String])
+    jobArtifactUUID: Option[String],
+    profiler: Option[String])
   extends BaseArrowPythonRunner(
     funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, 
largeVarTypes, workerConf,
     pythonMetrics, jobArtifactUUID) {
 
   override protected def writeUDF(dataOut: DataOutputStream): Unit =
-    PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas)
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas, profiler)
 }
 
 object ArrowPythonRunner {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 04d71c6c0153..e6958392cad4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -44,7 +44,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], 
resultAttrs: Seq[Attribute]
       udfs,
       output,
       pythonMetrics,
-      jobArtifactUUID)
+      jobArtifactUUID,
+      conf.pythonUDFProfiler)
   }
 
   override protected def withNewChildInternal(newChild: SparkPlan): 
BatchEvalPythonExec =
@@ -56,11 +57,12 @@ class BatchEvalPythonEvaluatorFactory(
     udfs: Seq[PythonUDF],
     output: Seq[Attribute],
     pythonMetrics: Map[String, SQLMetric],
-    jobArtifactUUID: Option[String])
-    extends EvalPythonEvaluatorFactory(childOutput, udfs, output) {
+    jobArtifactUUID: Option[String],
+    profiler: Option[String])
+  extends EvalPythonEvaluatorFactory(childOutput, udfs, output) {
 
   override def evaluate(
-      funcs: Seq[ChainedPythonFunctions],
+      funcs: Seq[(ChainedPythonFunctions, Long)],
       argMetas: Array[Array[ArgumentMetadata]],
       iter: Iterator[InternalRow],
       schema: StructType,
@@ -73,7 +75,7 @@ class BatchEvalPythonEvaluatorFactory(
     // Output iterator for results from Python.
     val outputIterator =
       new PythonUDFWithNamedArgumentsRunner(
-        funcs, PythonEvalType.SQL_BATCHED_UDF, argMetas, pythonMetrics, 
jobArtifactUUID)
+        funcs, PythonEvalType.SQL_BATCHED_UDF, argMetas, pythonMetrics, 
jobArtifactUUID, profiler)
       .compute(inputIterator, context.partitionId(), context)
 
     val unpickle = new Unpickler
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index e6b19910296e..9eebd4ea7e79 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -98,7 +98,7 @@ class PythonUDTFRunner(
     pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String])
   extends BasePythonUDFRunner(
-    Seq(ChainedPythonFunctions(Seq(udtf.func))),
+    Seq((ChainedPythonFunctions(Seq(udtf.func)), udtf.resultId.id)),
     PythonEvalType.SQL_TABLE_UDF, Array(argMetas.map(_.offset)), 
pythonMetrics, jobArtifactUUID) {
 
   override protected def writeUDF(dataOut: DataOutputStream): Unit = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 7e1c8c2ffc13..5670cad67e7b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -39,7 +39,7 @@ import org.apache.spark.util.Utils
  * groups them in Python, and receive it back in JVM as batches of single 
DataFrame.
  */
 class CoGroupedArrowPythonRunner(
-    funcs: Seq[ChainedPythonFunctions],
+    funcs: Seq[(ChainedPythonFunctions, Long)],
     evalType: Int,
     argOffsets: Array[Array[Int]],
     leftSchema: StructType,
@@ -47,15 +47,16 @@ class CoGroupedArrowPythonRunner(
     timeZoneId: String,
     conf: Map[String, String],
     override val pythonMetrics: Map[String, SQLMetric],
-    jobArtifactUUID: Option[String])
+    jobArtifactUUID: Option[String],
+    profiler: Option[String])
   extends BasePythonRunner[
     (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](
-    funcs, evalType, argOffsets, jobArtifactUUID)
+    funcs.map(_._1), evalType, argOffsets, jobArtifactUUID)
   with BasicPythonArrowOutput {
 
   override val pythonExec: String =
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
-      funcs.head.funcs.head.pythonExec)
+      funcs.head._1.funcs.head.pythonExec)
 
   override val faultHandlerEnabled: Boolean = 
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
 
@@ -79,7 +80,7 @@ class CoGroupedArrowPythonRunner(
           PythonRDD.writeUTF(v, dataOut)
         }
 
-        PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+        PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
       }
 
       override def writeNextInputToStream(dataOut: DataOutputStream): Boolean 
= {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
index d5142f58eab4..34f9be0aa633 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
@@ -36,7 +36,7 @@ abstract class EvalPythonEvaluatorFactory(
   extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
 
   protected def evaluate(
-      funcs: Seq[ChainedPythonFunctions],
+      funcs: Seq[(ChainedPythonFunctions, Long)],
       argMetas: Array[Array[ArgumentMetadata]],
       iter: Iterator[InternalRow],
       schema: StructType,
@@ -47,15 +47,16 @@ abstract class EvalPythonEvaluatorFactory(
 
   private class EvalPythonPartitionEvaluator
       extends PartitionEvaluator[InternalRow, InternalRow] {
-    private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, 
Seq[Expression]) = {
+    private def collectFunctions(
+        udf: PythonUDF): ((ChainedPythonFunctions, Long), Seq[Expression]) = {
       udf.children match {
         case Seq(u: PythonUDF) =>
-          val (chained, children) = collectFunctions(u)
-          (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+          val ((chained, _), children) = collectFunctions(u)
+          ((ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), 
udf.resultId.id), children)
         case children =>
           // There should not be any other UDFs, or the children can't be 
evaluated directly.
           assert(children.forall(!_.exists(_.isInstanceOf[PythonUDF])))
-          (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+          ((ChainedPythonFunctions(Seq(udf.func)), udf.resultId.id), 
udf.children)
       }
     }
     override def eval(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
index 97aa1495670f..bc6f9859ec28 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala
@@ -43,8 +43,10 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with 
BinaryExecNode with Pyth
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
   private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
-  private val pandasFunction = func.asInstanceOf[PythonUDF].func
-  private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
+  private val pythonUDF = func.asInstanceOf[PythonUDF]
+  private val pandasFunction = pythonUDF.func
+  private val chainedFunc =
+    Seq((ChainedPythonFunctions(Seq(pandasFunction)), pythonUDF.resultId.id))
 
   override def producedAttributes: AttributeSet = AttributeSet(output)
 
@@ -84,7 +86,8 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with 
BinaryExecNode with Pyth
           sessionLocalTimeZone,
           pythonRunnerConf,
           pythonMetrics,
-          jobArtifactUUID)
+          jobArtifactUUID,
+          None) // TODO(SPARK-46690): Support profiling on 
FlatMapCoGroupsInBatchExec
 
         executePython(data, output, runner)
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
index facf7bc49c5a..580ef46e842d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
@@ -43,8 +43,10 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with 
UnaryExecNode with PythonS
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
   private val largeVarTypes = conf.arrowUseLargeVarTypes
   private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
-  private val pythonFunction = func.asInstanceOf[PythonUDF].func
-  private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
+  private val pythonUDF = func.asInstanceOf[PythonUDF]
+  private val pythonFunction = pythonUDF.func
+  private val chainedFunc =
+    Seq((ChainedPythonFunctions(Seq(pythonFunction)), pythonUDF.resultId.id))
   private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
 
   override def producedAttributes: AttributeSet = AttributeSet(output)
@@ -89,7 +91,8 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with 
UnaryExecNode with PythonS
         largeVarTypes,
         pythonRunnerConf,
         pythonMetrics,
-        jobArtifactUUID)
+        jobArtifactUUID,
+        None) // TODO(SPARK-46689): Support profiling on 
FlatMapGroupsInBatchExec
 
       executePython(data, output, runner)
     }}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
index 105c5ca6493e..850ee016e363 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
@@ -82,8 +82,10 @@ case class FlatMapGroupsInPandasWithStateExec(
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
   private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
 
-  private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func
-  private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
+  private val pythonUDF = functionExpr.asInstanceOf[PythonUDF]
+  private val pythonFunction = pythonUDF.func
+  private val chainedFunc =
+    Seq((ChainedPythonFunctions(Seq(pythonFunction)), pythonUDF.resultId.id))
   private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(
     groupingAttributes ++ child.output, groupingAttributes)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 29dc6e0aa541..8d65fe558937 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, 
ColumnarBatch}
 
 class MapInBatchEvaluatorFactory(
     output: Seq[Attribute],
-    chainedFunc: Seq[ChainedPythonFunctions],
+    chainedFunc: Seq[(ChainedPythonFunctions, Long)],
     outputTypes: StructType,
     batchSize: Int,
     pythonEvalType: Int,
@@ -38,7 +38,7 @@ class MapInBatchEvaluatorFactory(
     pythonRunnerConf: Map[String, String],
     val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String])
-    extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
+  extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
 
   override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] 
=
     new MapInBatchEvaluator
@@ -70,7 +70,8 @@ class MapInBatchEvaluatorFactory(
         largeVarTypes,
         pythonRunnerConf,
         pythonMetrics,
-        jobArtifactUUID).compute(batchIter, context.partitionId(), context)
+        jobArtifactUUID,
+        None).compute(batchIter, context.partitionId(), context)
 
       val unsafeProj = UnsafeProjection.create(output, output)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 8db389f02667..346a3a2ca354 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -46,8 +46,9 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
 
   override protected def doExecute(): RDD[InternalRow] = {
     val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
-    val pythonFunction = func.asInstanceOf[PythonUDF].func
-    val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
+    val pythonUDF = func.asInstanceOf[PythonUDF]
+    val pythonFunction = pythonUDF.func
+    val chainedFunc = Seq((ChainedPythonFunctions(Seq(pythonFunction)), 
pythonUDF.resultId.id))
     val evaluatorFactory = new MapInBatchEvaluatorFactory(
       output,
       chainedFunc,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index 167a96ed41c7..bbe9fbfc748d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -30,17 +30,17 @@ import org.apache.spark.sql.internal.SQLConf
  * A helper class to run Python UDFs in Spark.
  */
 abstract class BasePythonUDFRunner(
-    funcs: Seq[ChainedPythonFunctions],
+    funcs: Seq[(ChainedPythonFunctions, Long)],
     evalType: Int,
     argOffsets: Array[Array[Int]],
     pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String])
   extends BasePythonRunner[Array[Byte], Array[Byte]](
-    funcs, evalType, argOffsets, jobArtifactUUID) {
+    funcs.map(_._1), evalType, argOffsets, jobArtifactUUID) {
 
   override val pythonExec: String =
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
-      funcs.head.funcs.head.pythonExec)
+      funcs.head._1.funcs.head.pythonExec)
 
   override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
 
@@ -112,29 +112,31 @@ abstract class BasePythonUDFRunner(
 }
 
 class PythonUDFRunner(
-    funcs: Seq[ChainedPythonFunctions],
+    funcs: Seq[(ChainedPythonFunctions, Long)],
     evalType: Int,
     argOffsets: Array[Array[Int]],
     pythonMetrics: Map[String, SQLMetric],
-    jobArtifactUUID: Option[String])
+    jobArtifactUUID: Option[String],
+    profiler: Option[String])
   extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, 
jobArtifactUUID) {
 
   override protected def writeUDF(dataOut: DataOutputStream): Unit = {
-    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
   }
 }
 
 class PythonUDFWithNamedArgumentsRunner(
-    funcs: Seq[ChainedPythonFunctions],
+    funcs: Seq[(ChainedPythonFunctions, Long)],
     evalType: Int,
     argMetas: Array[Array[ArgumentMetadata]],
     pythonMetrics: Map[String, SQLMetric],
-    jobArtifactUUID: Option[String])
+    jobArtifactUUID: Option[String],
+    profiler: Option[String])
   extends BasePythonUDFRunner(
     funcs, evalType, argMetas.map(_.map(_.offset)), pythonMetrics, 
jobArtifactUUID) {
 
   override protected def writeUDF(dataOut: DataOutputStream): Unit = {
-    PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas)
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas, profiler)
   }
 }
 
@@ -142,10 +144,17 @@ object PythonUDFRunner {
 
   def writeUDFs(
       dataOut: DataOutputStream,
-      funcs: Seq[ChainedPythonFunctions],
-      argOffsets: Array[Array[Int]]): Unit = {
+      funcs: Seq[(ChainedPythonFunctions, Long)],
+      argOffsets: Array[Array[Int]],
+      profiler: Option[String]): Unit = {
+    profiler match {
+      case Some(p) =>
+        dataOut.writeBoolean(true)
+        PythonWorkerUtils.writeUTF(p, dataOut)
+      case _ => dataOut.writeBoolean(false)
+    }
     dataOut.writeInt(funcs.length)
-    funcs.zip(argOffsets).foreach { case (chained, offsets) =>
+    funcs.zip(argOffsets).foreach { case ((chained, resultId), offsets) =>
       dataOut.writeInt(offsets.length)
       offsets.foreach { offset =>
         dataOut.writeInt(offset)
@@ -154,15 +163,25 @@ object PythonUDFRunner {
       chained.funcs.foreach { f =>
         PythonWorkerUtils.writePythonFunction(f, dataOut)
       }
+      if (profiler.isDefined) {
+        dataOut.writeLong(resultId)
+      }
     }
   }
 
   def writeUDFs(
       dataOut: DataOutputStream,
-      funcs: Seq[ChainedPythonFunctions],
-      argMetas: Array[Array[ArgumentMetadata]]): Unit = {
+      funcs: Seq[(ChainedPythonFunctions, Long)],
+      argMetas: Array[Array[ArgumentMetadata]],
+      profiler: Option[String]): Unit = {
+    profiler match {
+      case Some(p) =>
+        dataOut.writeBoolean(true)
+        PythonWorkerUtils.writeUTF(p, dataOut)
+      case _ => dataOut.writeBoolean(false)
+    }
     dataOut.writeInt(funcs.length)
-    funcs.zip(argMetas).foreach { case (chained, metas) =>
+    funcs.zip(argMetas).foreach { case ((chained, resultId), metas) =>
       dataOut.writeInt(metas.length)
       metas.foreach {
         case ArgumentMetadata(offset, name) =>
@@ -179,6 +198,9 @@ object PythonUDFRunner {
       chained.funcs.foreach { f =>
         PythonWorkerUtils.writePythonFunction(f, dataOut)
       }
+      if (profiler.isDefined) {
+        dataOut.writeLong(resultId)
+      }
     }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
index 12d484b12dac..e7fc9c7391af 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
@@ -43,7 +43,8 @@ class WindowInPandasEvaluatorFactory(
     val orderSpec: Seq[SortOrder],
     val childOutput: Seq[Attribute],
     val spillSize: SQLMetric,
-    pythonMetrics: Map[String, SQLMetric])
+    pythonMetrics: Map[String, SQLMetric],
+    profiler: Option[String])
   extends PartitionEvaluatorFactory[InternalRow, InternalRow] with 
WindowEvaluatorFactoryBase {
 
   /**
@@ -69,15 +70,15 @@ class WindowInPandasEvaluatorFactory(
   private val windowBoundTypeConf = "pandas_window_bound_types"
 
   private def collectFunctions(
-      udf: PythonFuncExpression): (ChainedPythonFunctions, Seq[Expression]) = {
+      udf: PythonFuncExpression): ((ChainedPythonFunctions, Long), 
Seq[Expression]) = {
     udf.children match {
       case Seq(u: PythonFuncExpression) =>
-        val (chained, children) = collectFunctions(u)
-        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+        val ((chained, _), children) = collectFunctions(u)
+        ((ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), 
udf.resultId.id), children)
       case children =>
         // There should not be any other UDFs, or the children can't be 
evaluated directly.
         
assert(children.forall(!_.exists(_.isInstanceOf[PythonFuncExpression])))
-        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+        ((ChainedPythonFunctions(Seq(udf.func)), udf.resultId.id), 
udf.children)
     }
   }
 
@@ -368,7 +369,8 @@ class WindowInPandasEvaluatorFactory(
         largeVarTypes,
         pythonRunnerConf,
         pythonMetrics,
-        jobArtifactUUID).compute(pythonInput, context.partitionId(), context)
+        jobArtifactUUID,
+        profiler).compute(pythonInput, context.partitionId(), context)
 
       val joined = new JoinedRow
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
index ee0044162b9a..c0a38eadbe64 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -87,7 +87,8 @@ case class WindowInPandasExec(
         orderSpec,
         child.output,
         longMetric("spillSize"),
-        pythonMetrics)
+        pythonMetrics,
+        None) // TODO(SPARK-46691): Support profiling on WindowInPandasExec
 
     // Start processing.
     if (conf.usePartitionEvaluator) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to