This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 0cbb7f229 [GLUTEN-5696] Add preprojection support for 
ArrowEvalPythonExec (#5697)
0cbb7f229 is described below

commit 0cbb7f2297a940047dfb788caac17cb2ad540356
Author: Yan Ma <yan...@intel.com>
AuthorDate: Thu May 16 15:09:44 2024 +0800

    [GLUTEN-5696] Add preprojection support for ArrowEvalPythonExec (#5697)
---
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  |  8 +-
 .../python/ColumnarArrowEvalPythonExec.scala       | 97 ++++++++++++++++++++--
 .../python/ArrowEvalPythonExecSuite.scala          | 43 +++++++++-
 .../gluten/backendsapi/SparkPlanExecApi.scala      |  4 +
 .../columnar/rewrite/PullOutPreProject.scala       |  5 ++
 .../rewrite/RewriteSparkPlanRulesManager.scala     |  2 +
 6 files changed, 149 insertions(+), 10 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 4d41ed0c0..33ce1ee72 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -31,7 +31,7 @@ import 
org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode
 import org.apache.gluten.vectorized.{ColumnarBatchSerializer, 
ColumnarBatchSerializeResult}
 
 import org.apache.spark.{ShuffleDependency, SparkException}
-import org.apache.spark.api.python.ColumnarArrowEvalPythonExec
+import org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, 
PullOutArrowEvalPythonPreProjectHelper}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{GenShuffleWriterParameters, 
GlutenShuffleWriterWrapper}
@@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.datasources.FileFormat
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.{BuildSideRelation, 
HashedRelationBroadcastMode}
 import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
 import org.apache.spark.sql.execution.utils.ExecUtil
 import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, 
UserDefinedAggregateFunction}
 import org.apache.spark.sql.internal.SQLConf
@@ -846,6 +847,11 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
     PullOutGenerateProjectHelper.pullOutPostProject(generate)
   }
 
+  override def genPreProjectForArrowEvalPythonExec(
+      arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan = {
+    
PullOutArrowEvalPythonPreProjectHelper.pullOutPreProject(arrowEvalPythonExec)
+  }
+
   override def maybeCollapseTakeOrderedAndProject(plan: SparkPlan): SparkPlan 
= {
     // This to-top-n optimization assumes exchange operators were already 
placed in input plan.
     plan.transformUp {
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/python/ColumnarArrowEvalPythonExec.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/python/ColumnarArrowEvalPythonExec.scala
index 77ef1c642..d3112c974 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/python/ColumnarArrowEvalPythonExec.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/python/ColumnarArrowEvalPythonExec.scala
@@ -17,17 +17,18 @@
 package org.apache.spark.api.python
 
 import org.apache.gluten.columnarbatch.ColumnarBatches
+import org.apache.gluten.exception.GlutenException
 import org.apache.gluten.extension.GlutenPlan
 import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
-import org.apache.gluten.utils.Iterators
+import org.apache.gluten.utils.{Iterators, PullOutProjectHelper}
 import org.apache.gluten.vectorized.ArrowWritableColumnVector
 
 import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.python.{BasePythonRunnerShim, 
EvalPythonExec, PythonUDFRunner}
+import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
+import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, 
BasePythonRunnerShim, EvalPythonExec, PythonUDFRunner}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, StructField, StructType}
 import org.apache.spark.sql.utils.{SparkArrowUtil, SparkSchemaUtil, 
SparkVectorUtil}
@@ -41,6 +42,7 @@ import java.io.{DataInputStream, DataOutputStream}
 import java.net.Socket
 import java.util.concurrent.atomic.AtomicBoolean
 
+import scala.collection.{mutable, Seq}
 import scala.collection.mutable.ArrayBuffer
 
 class ColumnarArrowPythonRunner(
@@ -207,7 +209,6 @@ case class ColumnarArrowEvalPythonExec(
   extends EvalPythonExec
   with GlutenPlan {
   override def supportsColumnar: Boolean = true
-  // TODO: add additional projection support by pre-project
   // FIXME: incorrect metrics updater
 
   override protected def evaluate(
@@ -221,6 +222,7 @@ case class ColumnarArrowEvalPythonExec(
   }
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+
   private def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
     val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> 
conf.sessionLocalTimeZone)
     val pandasColsByName = Seq(
@@ -231,6 +233,7 @@ case class ColumnarArrowEvalPythonExec(
         conf.arrowSafeTypeConversion.toString)
     Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
   }
+
   private val pythonRunnerConf = getPythonRunnerConfMap(conf)
 
   protected def evaluateColumnar(
@@ -279,16 +282,29 @@ case class ColumnarArrowEvalPythonExec(
       iter =>
         val context = TaskContext.get()
         val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
-        // flatten all the arguments
+        // We only write the referred cols by UDFs to python worker. So we need
+        // get corresponding offsets
         val allInputs = new ArrayBuffer[Expression]
         val dataTypes = new ArrayBuffer[DataType]
+        val originalOffsets = new ArrayBuffer[Int]
         val argOffsets = inputs.map {
           input =>
             input.map {
               e =>
-                if (allInputs.exists(_.semanticEquals(e))) {
+                if (!e.isInstanceOf[AttributeReference]) {
+                  throw new GlutenException(
+                    "ColumnarArrowEvalPythonExec should only has 
[AttributeReference] inputs.")
+                } else if (allInputs.exists(_.semanticEquals(e))) {
                   allInputs.indexWhere(_.semanticEquals(e))
                 } else {
+                  var offset: Int = -1
+                  offset = child.output.indexWhere(
+                    _.exprId.equals(e.asInstanceOf[AttributeReference].exprId))
+                  if (offset == -1) {
+                    throw new GlutenException(
+                      "ColumnarArrowEvalPythonExec can't find referred input 
col.")
+                  }
+                  originalOffsets += offset
                   allInputs += e
                   dataTypes += e.dataType
                   allInputs.length - 1
@@ -299,15 +315,21 @@ case class ColumnarArrowEvalPythonExec(
           case (dt, i) =>
             StructField(s"_$i", dt)
         }.toSeq)
+
         val contextAwareIterator = new ContextAwareIterator(context, iter)
         val inputCbCache = new ArrayBuffer[ColumnarBatch]()
         val inputBatchIter = contextAwareIterator.map {
           inputCb =>
             
ColumnarBatches.ensureLoaded(ArrowBufferAllocators.contextInstance, inputCb)
-            // 0. cache input for later merge
             ColumnarBatches.retain(inputCb)
+            // 0. cache input for later merge
             inputCbCache += inputCb
-            inputCb
+            // We only need to pass the referred cols data to python worker 
for evaluation.
+            var colsForEval = new ArrayBuffer[ColumnVector]()
+            for (i <- originalOffsets) {
+              colsForEval += inputCb.column(i)
+            }
+            new ColumnarBatch(colsForEval.toArray, inputCb.numRows())
         }
 
         val outputColumnarBatchIterator =
@@ -335,6 +357,65 @@ case class ColumnarArrowEvalPythonExec(
           .create()
     }
   }
+
   override protected def withNewChildInternal(newChild: SparkPlan): 
ColumnarArrowEvalPythonExec =
     copy(udfs, resultAttrs, newChild)
 }
+
+object PullOutArrowEvalPythonPreProjectHelper extends PullOutProjectHelper {
+  private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, 
Seq[Expression]) = {
+    udf.children match {
+      case Seq(u: PythonUDF) =>
+        val (chained, children) = collectFunctions(u)
+        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+      case children =>
+        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+    }
+  }
+
+  private def rewriteUDF(
+      udf: PythonUDF,
+      expressionMap: mutable.HashMap[Expression, NamedExpression]): PythonUDF 
= {
+    udf.children match {
+      case Seq(u: PythonUDF) =>
+        udf
+          .withNewChildren(udf.children.toIndexedSeq.map {
+            func => rewriteUDF(func.asInstanceOf[PythonUDF], expressionMap)
+          })
+          .asInstanceOf[PythonUDF]
+      case children =>
+        val newUDFChildren = udf.children.map {
+          case literal: Literal => literal
+          case other => replaceExpressionWithAttribute(other, expressionMap)
+        }
+        udf.withNewChildren(newUDFChildren).asInstanceOf[PythonUDF]
+    }
+  }
+
+  def pullOutPreProject(arrowEvalPythonExec: ArrowEvalPythonExec): SparkPlan = 
{
+    // pull out preproject
+    val (_, inputs) = arrowEvalPythonExec.udfs.map(collectFunctions).unzip
+    val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
+    // flatten all the arguments
+    val allInputs = new ArrayBuffer[Expression]
+    for (input <- inputs) {
+      input.map {
+        e =>
+          if (!allInputs.exists(_.semanticEquals(e))) {
+            allInputs += e
+            replaceExpressionWithAttribute(e, expressionMap)
+          }
+      }
+    }
+    if (!expressionMap.isEmpty) {
+      // Need preproject.
+      val preProject = ProjectExec(
+        eliminateProjectList(arrowEvalPythonExec.child.outputSet, 
expressionMap.values.toSeq),
+        arrowEvalPythonExec.child)
+      val newUDFs = arrowEvalPythonExec.udfs.map(f => rewriteUDF(f, 
expressionMap))
+      arrowEvalPythonExec.copy(udfs = newUDFs, child = preProject)
+    } else {
+      arrowEvalPythonExec
+    }
+  }
+}
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala
index 2193448b4..1c3e33262 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/python/ArrowEvalPythonExecSuite.scala
@@ -39,7 +39,7 @@ class ArrowEvalPythonExecSuite extends 
WholeStageTransformerSuite {
       .set("spark.executor.cores", "1")
   }
 
-  test("arrow_udf test") {
+  test("arrow_udf test: without projection") {
     lazy val base =
       Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 
1), ("3", 0))
         .toDF("a", "b")
@@ -58,4 +58,45 @@ class ArrowEvalPythonExecSuite extends 
WholeStageTransformerSuite {
     checkSparkOperatorMatch[ColumnarArrowEvalPythonExec](df2)
     checkAnswer(df2, expected)
   }
+
+  test("arrow_udf test: with unrelated projection") {
+    lazy val base =
+      Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 
1), ("3", 0))
+        .toDF("a", "b")
+    lazy val expected = Seq(
+      ("1", 1, "1", 2),
+      ("1", 2, "1", 4),
+      ("2", 1, "2", 2),
+      ("2", 2, "2", 4),
+      ("3", 1, "3", 2),
+      ("3", 2, "3", 4),
+      ("0", 1, "0", 2),
+      ("3", 0, "3", 0)
+    ).toDF("a", "b", "p_a", "d_b")
+
+    val df = base.withColumn("p_a", 
pyarrowTestUDF(base("a"))).withColumn("d_b", base("b") * 2)
+    checkSparkOperatorMatch[ColumnarArrowEvalPythonExec](df)
+    checkAnswer(df, expected)
+  }
+
+  test("arrow_udf test: with preprojection") {
+    lazy val base =
+      Seq(("1", 1), ("1", 2), ("2", 1), ("2", 2), ("3", 1), ("3", 2), ("0", 
1), ("3", 0))
+        .toDF("a", "b")
+    lazy val expected = Seq(
+      ("1", 1, 2, "1", 2),
+      ("1", 2, 4, "1", 4),
+      ("2", 1, 2, "2", 2),
+      ("2", 2, 4, "2", 4),
+      ("3", 1, 2, "3", 2),
+      ("3", 2, 4, "3", 4),
+      ("0", 1, 2, "0", 2),
+      ("3", 0, 0, "3", 0)
+    ).toDF("a", "b", "d_b", "p_a", "p_b")
+    val df = base
+      .withColumn("d_b", base("b") * 2)
+      .withColumn("p_a", pyarrowTestUDF(base("a")))
+      .withColumn("p_b", pyarrowTestUDF(base("b") * 2))
+    checkAnswer(df, expected)
+  }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index c2c733070..8df74bb88 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -43,6 +43,7 @@ import 
org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.joins.BuildSideRelation
 import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
 import org.apache.spark.sql.hive.HiveTableScanExecTransformer
 import org.apache.spark.sql.types.{LongType, NullType, StructType}
 import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -745,6 +746,9 @@ trait SparkPlanExecApi {
 
   def genPostProjectForGenerate(generate: GenerateExec): SparkPlan
 
+  def genPreProjectForArrowEvalPythonExec(arrowEvalPythonExec: 
ArrowEvalPythonExec): SparkPlan =
+    arrowEvalPythonExec
+
   def maybeCollapseTakeOrderedAndProject(plan: SparkPlan): SparkPlan = plan
 
   def outputNativeColumnarSparkCompatibleData(plan: SparkPlan): Boolean = false
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
index 64d4f2736..50dc55423 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete, Partial}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, 
TypedAggregateExpression}
+import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
 import org.apache.spark.sql.execution.window.{WindowExec, 
WindowGroupLimitExecShim}
 
 import scala.collection.mutable
@@ -226,6 +227,10 @@ object PullOutPreProject extends RewriteSingleNode with 
PullOutProjectHelper {
     case generate: GenerateExec =>
       
BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForGenerate(generate)
 
+    case arrowEvalPythonExec: ArrowEvalPythonExec =>
+      
BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForArrowEvalPythonExec(
+        arrowEvalPythonExec)
+
     case _ => plan
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
index 5fd728eca..ac663314b 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
 import org.apache.spark.sql.execution.joins.BaseJoinExec
+import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
 import org.apache.spark.sql.execution.window.WindowExec
 
 case class RewrittenNodeWall(originalChild: SparkPlan) extends LeafExecNode {
@@ -60,6 +61,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules: 
Seq[RewriteSingleNode]
         case _: ExpandExec => true
         case _: GenerateExec => true
         case plan if 
SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) => true
+        case _: ArrowEvalPythonExec => true
         case _ => false
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org
For additional commands, e-mail: commits-h...@gluten.apache.org

Reply via email to