This is an automated email from the ASF dual-hosted git repository.

ulyssesyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ee108f97 [CORE] Pullout pre-project for ExpandExec (#5066)
3ee108f97 is described below

commit 3ee108f97d103364facfde371bdd2af2d5013d7e
Author: Joey <joey....@alibaba-inc.com>
AuthorDate: Thu Mar 21 19:13:26 2024 +0800

    [CORE] Pullout pre-project for ExpandExec (#5066)
---
 .../execution/ExpandExecTransformer.scala          | 126 ++++-----------------
 .../extension/columnar/PullOutPreProject.scala     |  12 +-
 .../columnar/RewriteSparkPlanRulesManager.scala    |   1 +
 3 files changed, 35 insertions(+), 104 deletions(-)

diff --git 
a/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala
 
b/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala
index 4d547f771..daa195b68 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala
@@ -17,12 +17,12 @@
 package io.glutenproject.execution
 
 import io.glutenproject.backendsapi.BackendsApiManager
-import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, 
LiteralTransformer}
+import io.glutenproject.expression.{ConverterUtils, ExpressionConverter}
 import io.glutenproject.extension.ValidationResult
 import io.glutenproject.metrics.MetricsUpdater
 import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
 import io.glutenproject.substrait.SubstraitContext
-import io.glutenproject.substrait.expression.{ExpressionBuilder, 
ExpressionNode}
+import io.glutenproject.substrait.expression.ExpressionNode
 import io.glutenproject.substrait.extensions.ExtensionBuilder
 import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
 
@@ -32,9 +32,6 @@ import org.apache.spark.sql.execution._
 
 import java.util.{ArrayList => JArrayList, List => JList}
 
-import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
-
 case class ExpandExecTransformer(
     projections: Seq[Seq[Expression]],
     output: Seq[Attribute],
@@ -66,110 +63,33 @@ case class ExpandExecTransformer(
       input: RelNode,
       validation: Boolean): RelNode = {
     val args = context.registeredFunction
-    def needsPreProjection(projections: Seq[Seq[Expression]]): Boolean = {
-      projections
-        .exists(set => set.exists(p => !p.isInstanceOf[Attribute] && 
!p.isInstanceOf[Literal]))
-    }
-    if (needsPreProjection(projections)) {
-      // if there is not literal and attribute expression in project sets, add 
a project op
-      // to calculate them before expand op.
-      val preExprs = ArrayBuffer.empty[Expression]
-      val selectionMaps = ArrayBuffer.empty[Seq[Int]]
-      var preExprIndex = 0
-      for (i <- projections.indices) {
-        val selections = ArrayBuffer.empty[Int]
-        for (j <- projections(i).indices) {
-          val proj = projections(i)(j)
-          if (!proj.isInstanceOf[Literal]) {
-            val exprIdx = preExprs.indexWhere(expr => 
expr.semanticEquals(proj))
-            if (exprIdx != -1) {
-              selections += exprIdx
-            } else {
-              preExprs += proj
-              selections += preExprIndex
-              preExprIndex = preExprIndex + 1
-            }
-          } else {
-            selections += -1
-          }
-        }
-        selectionMaps += selections
-      }
-      // make project
-      val preExprNodes = preExprs
-        .map(
-          ExpressionConverter
-            .replaceWithExpressionTransformer(_, originalInputAttributes)
-            .doTransform(args))
-        .asJava
-
-      val emitStartIndex = originalInputAttributes.size
-      val inputRel = if (!validation) {
-        RelBuilder.makeProjectRel(input, preExprNodes, context, operatorId, 
emitStartIndex)
-      } else {
-        // Use a extension node to send the input types through Substrait plan 
for a validation.
-        val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
-        for (attr <- originalInputAttributes) {
-          inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, 
attr.nullable))
-        }
-        val extensionNode = ExtensionBuilder.makeAdvancedExtension(
-          BackendsApiManager.getTransformerApiInstance.packPBMessage(
-            TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
-        RelBuilder.makeProjectRel(
-          input,
-          preExprNodes,
-          extensionNode,
-          context,
-          operatorId,
-          emitStartIndex)
-      }
-
-      // make expand
-      val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
-      for (i <- projections.indices) {
+    val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
+    projections.foreach {
+      projectSet =>
         val projectExprNodes = new JArrayList[ExpressionNode]()
-        for (j <- projections(i).indices) {
-          val projectExprNode = projections(i)(j) match {
-            case l: Literal =>
-              LiteralTransformer(l).doTransform(args)
-            case _ =>
-              ExpressionBuilder.makeSelection(selectionMaps(i)(j))
-          }
-
-          projectExprNodes.add(projectExprNode)
+        projectSet.foreach {
+          project =>
+            val projectExprNode = ExpressionConverter
+              .replaceWithExpressionTransformer(project, 
originalInputAttributes)
+              .doTransform(args)
+            projectExprNodes.add(projectExprNode)
         }
         projectSetExprNodes.add(projectExprNodes)
-      }
-      RelBuilder.makeExpandRel(inputRel, projectSetExprNodes, context, 
operatorId)
+    }
+
+    if (!validation) {
+      RelBuilder.makeExpandRel(input, projectSetExprNodes, context, operatorId)
     } else {
-      val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
-      projections.foreach {
-        projectSet =>
-          val projectExprNodes = new JArrayList[ExpressionNode]()
-          projectSet.foreach {
-            project =>
-              val projectExprNode = ExpressionConverter
-                .replaceWithExpressionTransformer(project, 
originalInputAttributes)
-                .doTransform(args)
-              projectExprNodes.add(projectExprNode)
-          }
-          projectSetExprNodes.add(projectExprNodes)
+      // Use a extension node to send the input types through Substrait plan 
for a validation.
+      val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
+      for (attr <- originalInputAttributes) {
+        inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, 
attr.nullable))
       }
 
-      if (!validation) {
-        RelBuilder.makeExpandRel(input, projectSetExprNodes, context, 
operatorId)
-      } else {
-        // Use a extension node to send the input types through Substrait plan 
for a validation.
-        val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
-        for (attr <- originalInputAttributes) {
-          inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, 
attr.nullable))
-        }
-
-        val extensionNode = ExtensionBuilder.makeAdvancedExtension(
-          BackendsApiManager.getTransformerApiInstance.packPBMessage(
-            TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
-        RelBuilder.makeExpandRel(input, projectSetExprNodes, extensionNode, 
context, operatorId)
-      }
+      val extensionNode = ExtensionBuilder.makeAdvancedExtension(
+        BackendsApiManager.getTransformerApiInstance.packPBMessage(
+          TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
+      RelBuilder.makeExpandRel(input, projectSetExprNodes, extensionNode, 
context, operatorId)
     }
   }
 
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
index 5bf70597c..440f609de 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala
@@ -21,7 +21,7 @@ import io.glutenproject.utils.PullOutProjectHelper
 import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete, Partial}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, 
TakeOrderedAndProjectExec}
+import org.apache.spark.sql.execution.{ExpandExec, ProjectExec, SortExec, 
SparkPlan, TakeOrderedAndProjectExec}
 import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, 
TypedAggregateExpression}
 import org.apache.spark.sql.execution.window.WindowExec
 
@@ -74,6 +74,7 @@ object PullOutPreProject extends Rule[SparkPlan] with 
PullOutProjectHelper {
             }
           case _ => false
         }.isDefined)
+      case expand: ExpandExec => 
expand.projections.flatten.exists(isNotAttributeAndLiteral)
       case _ => false
     }
   }
@@ -179,6 +180,15 @@ object PullOutPreProject extends Rule[SparkPlan] with 
PullOutProjectHelper {
 
       ProjectExec(window.output, newWindow)
 
+    case expand: ExpandExec if needsPreProject(expand) =>
+      val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
+      val newProjections =
+        expand.projections.map(_.map(replaceExpressionWithAttribute(_, 
expressionMap)))
+      expand.copy(
+        projections = newProjections,
+        child = ProjectExec(
+          eliminateProjectList(expand.child.outputSet, 
expressionMap.values.toSeq),
+          expand.child))
     case _ => plan
   }
 }
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
index 892e5eeef..8f3f01f95 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala
@@ -53,6 +53,7 @@ class RewriteSparkPlanRulesManager(rewriteRules: 
Seq[Rule[SparkPlan]]) extends R
         case _: WindowExec => true
         case _: FilterExec => true
         case _: FileSourceScanExec => true
+        case _: ExpandExec => true
         case _ => false
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org
For additional commands, e-mail: commits-h...@gluten.apache.org

Reply via email to