This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 06bf544973f [SPARK-42896][SQL][PYTHON] Make `mapInPandas` / 
`mapInArrow` support barrier mode execution
06bf544973f is described below

commit 06bf544973f4e221c569487473fbe3268543ebb7
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Mon Mar 27 09:39:48 2023 +0800

    [SPARK-42896][SQL][PYTHON] Make `mapInPandas` / `mapInArrow` support 
barrier mode execution
    
    ### What changes were proposed in this pull request?
    
    Make mapInPandas / mapInArrow support barrier mode execution
    
    ### Why are the changes needed?
    
    This is the preparation PR for supporting mapInPandas / mapInArrow barrier 
execution in spark connect mode. The feature is required by machine learning 
use cases.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Closes #40520 from WeichenXu123/barrier-udf.
    
    Authored-by: Weichen Xu <weichen...@databricks.com>
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  6 +++--
 python/pyspark/sql/pandas/map_ops.py               | 26 ++++++++++++++++++----
 .../catalyst/analysis/DeduplicateRelations.scala   |  4 ++--
 .../plans/logical/pythonLogicalOperators.scala     |  6 +++--
 .../sql/catalyst/analysis/AnalysisSuite.scala      |  3 ++-
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 10 +++++----
 .../spark/sql/execution/SparkStrategies.scala      |  8 +++----
 .../sql/execution/python/MapInBatchExec.scala      | 10 ++++++++-
 .../sql/execution/python/MapInPandasExec.scala     |  3 ++-
 .../execution/python/PythonMapInArrowExec.scala    |  3 ++-
 10 files changed, 57 insertions(+), 22 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 13052ec9b01..e7911ccdf11 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -489,12 +489,14 @@ class SparkConnectPlanner(val session: SparkSession) {
         logical.MapInPandas(
           pythonUdf,
           pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
-          transformRelation(rel.getInput))
+          transformRelation(rel.getInput),
+          false)
       case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
         logical.PythonMapInArrow(
           pythonUdf,
           pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
-          transformRelation(rel.getInput))
+          transformRelation(rel.getInput),
+          false)
       case _ =>
         throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} 
is not supported")
     }
diff --git a/python/pyspark/sql/pandas/map_ops.py 
b/python/pyspark/sql/pandas/map_ops.py
index 47b17578ae1..a4c0c94844b 100644
--- a/python/pyspark/sql/pandas/map_ops.py
+++ b/python/pyspark/sql/pandas/map_ops.py
@@ -32,7 +32,7 @@ class PandasMapOpsMixin:
     """
 
     def mapInPandas(
-        self, func: "PandasMapIterFunction", schema: Union[StructType, str]
+        self, func: "PandasMapIterFunction", schema: Union[StructType, str], 
isBarrier: bool = False
     ) -> "DataFrame":
         """
         Maps an iterator of batches in the current :class:`DataFrame` using a 
Python native
@@ -60,6 +60,7 @@ class PandasMapOpsMixin:
         schema : :class:`pyspark.sql.types.DataType` or str
             the return type of the `func` in PySpark. The value can be either a
             :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+        isBarrier : Use barrier mode execution if True.
 
         Examples
         --------
@@ -74,6 +75,14 @@ class PandasMapOpsMixin:
         +---+---+
         |  1| 21|
         +---+---+
+        >>> # Set isBarrier=True to force the "mapInPandas" stage running in 
barrier mode,
+        >>> # it ensures all python UDF workers in the stage will be launched 
concurrently.
+        >>> df.mapInPandas(filter_func, df.schema, isBarrier=True).show()  # 
doctest: +SKIP
+        +---+---+
+        | id|age|
+        +---+---+
+        |  1| 21|
+        +---+---+
 
         Notes
         -----
@@ -93,11 +102,11 @@ class PandasMapOpsMixin:
             func, returnType=schema, 
functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
         )  # type: ignore[call-overload]
         udf_column = udf(*[self[col] for col in self.columns])
-        jdf = self._jdf.mapInPandas(udf_column._jc.expr())
+        jdf = self._jdf.mapInPandas(udf_column._jc.expr(), isBarrier)
         return DataFrame(jdf, self.sparkSession)
 
     def mapInArrow(
-        self, func: "ArrowMapIterFunction", schema: Union[StructType, str]
+        self, func: "ArrowMapIterFunction", schema: Union[StructType, str], 
isBarrier: bool = False
     ) -> "DataFrame":
         """
         Maps an iterator of batches in the current :class:`DataFrame` using a 
Python native
@@ -122,6 +131,7 @@ class PandasMapOpsMixin:
         schema : :class:`pyspark.sql.types.DataType` or str
             the return type of the `func` in PySpark. The value can be either a
             :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+        isBarrier : Use barrier mode execution if True.
 
         Examples
         --------
@@ -137,6 +147,14 @@ class PandasMapOpsMixin:
         +---+---+
         |  1| 21|
         +---+---+
+        >>> # Set isBarrier=True to force the "mapInArrow" stage running in 
barrier mode,
+        >>> # it ensures all python UDF workers in the stage will be launched 
concurrently.
+        >>> df.mapInArrow(filter_func, df.schema, isBarrier=True).show()  # 
doctest: +SKIP
+        +---+---+
+        | id|age|
+        +---+---+
+        |  1| 21|
+        +---+---+
 
         Notes
         -----
@@ -157,7 +175,7 @@ class PandasMapOpsMixin:
             func, returnType=schema, 
functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF
         )  # type: ignore[call-overload]
         udf_column = udf(*[self[col] for col in self.columns])
-        jdf = self._jdf.pythonMapInArrow(udf_column._jc.expr())
+        jdf = self._jdf.pythonMapInArrow(udf_column._jc.expr(), isBarrier)
         return DataFrame(jdf, self.sparkSession)
 
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index 909ec908020..1d4c7d3f9f1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -233,13 +233,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
         newVersion.copyTagsFrom(oldVersion)
         Seq((oldVersion, newVersion))
 
-      case oldVersion @ MapInPandas(_, output, _)
+      case oldVersion @ MapInPandas(_, output, _, _)
         if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
         val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
         newVersion.copyTagsFrom(oldVersion)
         Seq((oldVersion, newVersion))
 
-      case oldVersion @ PythonMapInArrow(_, output, _)
+      case oldVersion @ PythonMapInArrow(_, output, _, _)
         if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
         val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
         newVersion.copyTagsFrom(oldVersion)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index 1ce6808be60..fe5eee481be 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -53,7 +53,8 @@ case class FlatMapGroupsInPandas(
 case class MapInPandas(
     functionExpr: Expression,
     output: Seq[Attribute],
-    child: LogicalPlan) extends UnaryNode {
+    child: LogicalPlan,
+    isBarrier: Boolean) extends UnaryNode {
 
   override val producedAttributes = AttributeSet(output)
 
@@ -68,7 +69,8 @@ case class MapInPandas(
 case class PythonMapInArrow(
     functionExpr: Expression,
     output: Seq[Attribute],
-    child: LogicalPlan) extends UnaryNode {
+    child: LogicalPlan,
+    isBarrier: Boolean) extends UnaryNode {
 
   override val producedAttributes = AttributeSet(output)
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 54ea4086c9b..68ac6fe378a 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -691,7 +691,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
     val mapInPandas = MapInPandas(
       pythonUdf,
       output,
-      project)
+      project,
+      false)
     val left = SubqueryAlias("temp0", mapInPandas)
     val right = SubqueryAlias("temp1", mapInPandas)
     val join = Join(left, right, Inner, None, JoinHint.NONE)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 57da3b5af60..7981e3badef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3283,13 +3283,14 @@ class Dataset[T] private[sql](
    * This function uses Apache Arrow as serialization format between Java 
executors and Python
    * workers.
    */
-  private[sql] def mapInPandas(func: PythonUDF): DataFrame = {
+  private[sql] def mapInPandas(func: PythonUDF, isBarrier: Boolean = false): 
DataFrame = {
     Dataset.ofRows(
       sparkSession,
       MapInPandas(
         func,
         func.dataType.asInstanceOf[StructType].toAttributes,
-        logicalPlan))
+        logicalPlan,
+        isBarrier))
   }
 
   /**
@@ -3297,13 +3298,14 @@ class Dataset[T] private[sql](
    * defines a transformation: `iter(pyarrow.RecordBatch)` -> 
`iter(pyarrow.RecordBatch)`.
    * Each partition is each iterator consisting of `pyarrow.RecordBatch`s as 
batches.
    */
-  private[sql] def pythonMapInArrow(func: PythonUDF): DataFrame = {
+  private[sql] def pythonMapInArrow(func: PythonUDF, isBarrier: Boolean = 
false): DataFrame = {
     Dataset.ofRows(
       sparkSession,
       PythonMapInArrow(
         func,
         func.dataType.asInstanceOf[StructType].toAttributes,
-        logicalPlan))
+        logicalPlan,
+        isBarrier))
   }
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ddf1213cfed..972376220f8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -806,10 +806,10 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.python.FlatMapCoGroupsInPandasExec(
           f.leftAttributes, f.rightAttributes,
           func, output, planLater(left), planLater(right)) :: Nil
-      case logical.MapInPandas(func, output, child) =>
-        execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
-      case logical.PythonMapInArrow(func, output, child) =>
-        execution.python.PythonMapInArrowExec(func, output, planLater(child)) 
:: Nil
+      case logical.MapInPandas(func, output, child, isBarrier) =>
+        execution.python.MapInPandasExec(func, output, planLater(child), 
isBarrier) :: Nil
+      case logical.PythonMapInArrow(func, output, child, isBarrier) =>
+        execution.python.PythonMapInArrowExec(func, output, planLater(child), 
isBarrier) :: Nil
       case logical.AttachDistributedSequence(attr, child) =>
         execution.python.AttachDistributedSequenceExec(attr, planLater(child)) 
:: Nil
       case logical.MapElements(f, _, _, objAttr, child) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 6021233f685..0fe3acb14e8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -41,6 +41,8 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
   protected val func: Expression
   protected val pythonEvalType: Int
 
+  protected val isBarrier: Boolean
+
   private val pythonFunction = func.asInstanceOf[PythonUDF].func
 
   override def producedAttributes: AttributeSet = AttributeSet(output)
@@ -50,7 +52,7 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
   override protected def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsInternal { inputIter =>
+    def mapper(inputIter: Iterator[InternalRow]): Iterator[InternalRow] = {
       // Single function with one struct.
       val argOffsets = Array(Array(0))
       val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
@@ -90,5 +92,11 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
         flattenedBatch.rowIterator.asScala
       }.map(unsafeProj)
     }
+
+    if (isBarrier) {
+      child.execute().barrier().mapPartitions(mapper)
+    } else {
+      child.execute().mapPartitionsInternal(mapper)
+    }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
index 7a711b5da16..cfd97b6f497 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
@@ -28,7 +28,8 @@ import org.apache.spark.sql.execution.SparkPlan
 case class MapInPandasExec(
     func: Expression,
     output: Seq[Attribute],
-    child: SparkPlan)
+    child: SparkPlan,
+    override val isBarrier: Boolean)
   extends MapInBatchExec {
 
   override protected val pythonEvalType: Int = 
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonMapInArrowExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonMapInArrowExec.scala
index e3c185301a1..e5a457035c6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonMapInArrowExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonMapInArrowExec.scala
@@ -28,7 +28,8 @@ import org.apache.spark.sql.execution.SparkPlan
 case class PythonMapInArrowExec(
     func: Expression,
     output: Seq[Attribute],
-    child: SparkPlan)
+    child: SparkPlan,
+    override val isBarrier: Boolean)
   extends MapInBatchExec {
 
   override protected val pythonEvalType: Int = 
PythonEvalType.SQL_MAP_ARROW_ITER_UDF


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to