This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 18af4bc3c [VL] RAS: Include rewrite rules used by
RewriteSparkPlanRulesManager in EnumeratedTransform (#5575)
18af4bc3c is described below
commit 18af4bc3ce4c3e685ad63c869880f8b63d48dc1c
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue May 7 14:49:03 2024 +0800
[VL] RAS: Include rewrite rules used by RewriteSparkPlanRulesManager in
EnumeratedTransform (#5575)
---
.../clickhouse/CHSparkPlanExecApi.scala | 3 +-
.../execution/CHHashAggregateExecTransformer.scala | 6 +-
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 3 +-
.../execution/HashAggregateExecTransformer.scala | 14 ++--
.../gluten/backendsapi/SparkPlanExecApi.scala | 1 -
.../HashAggregateExecBaseTransformer.scala | 30 +++++--
.../org/apache/gluten/extension/RewriteIn.scala | 7 +-
.../extension/columnar/MiscColumnarRules.scala | 18 ++--
...ormSingleNode.scala => OffloadSingleNode.scala} | 95 ++++++++--------------
.../extension/columnar/TransformHintRule.scala | 33 +-------
.../columnar/enumerated/ConditionedRule.scala | 51 ------------
.../columnar/enumerated/EnumeratedApplier.scala | 5 +-
.../columnar/enumerated/EnumeratedTransform.scala | 56 ++++---------
.../columnar/enumerated/PushFilterToScan.scala | 27 +++---
.../extension/columnar/enumerated/RasOffload.scala | 84 +++++++++++++++++++
...ntAggregate.scala => RasOffloadAggregate.scala} | 35 ++------
...mplementFilter.scala => RasOffloadFilter.scala} | 16 ++--
.../{FilterRemoveRule.scala => RemoveFilter.scala} | 2 +-
.../columnar/heuristic/HeuristicApplier.scala | 1 +
.../{ => rewrite}/PullOutPostProject.scala | 14 ++--
.../columnar/{ => rewrite}/PullOutPreProject.scala | 9 +-
.../{ => columnar/rewrite}/RewriteCollect.scala | 9 +-
.../{ => rewrite}/RewriteMultiChildrenCount.scala | 7 +-
.../columnar/rewrite/RewriteSingleNode.scala | 48 +++++++++++
.../RewriteSparkPlanRulesManager.scala | 17 ++--
.../RewriteTypedImperativeAggregate.scala | 7 +-
.../gluten/planner/cost/GlutenCostModel.scala | 6 +-
.../GlutenFormatWriterInjectsBase.scala | 3 +-
28 files changed, 295 insertions(+), 312 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 64090af28..a9a12a3ea 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -204,10 +204,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
/** Generate HashAggregateExecPullOutHelper */
override def genHashAggregateExecPullOutHelper(
- groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute]): HashAggregateExecPullOutBaseHelper
=
- CHHashAggregateExecPullOutHelper(groupingExpressions,
aggregateExpressions, aggregateAttributes)
+ CHHashAggregateExecPullOutHelper(aggregateExpressions, aggregateAttributes)
/**
* If there are expressions (not field reference) in the partitioning's
children, add a projection
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
index 82c492f4c..d4f2f9eb3 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
@@ -411,13 +411,9 @@ case class CHHashAggregateExecTransformer(
}
case class CHHashAggregateExecPullOutHelper(
- groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute])
- extends HashAggregateExecPullOutBaseHelper(
- groupingExpressions,
- aggregateExpressions,
- aggregateAttributes) {
+ extends HashAggregateExecPullOutBaseHelper {
/** This method calculates the output attributes of Aggregation. */
override protected def getAttrForAggregateExprs: List[Attribute] = {
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index a55aa1817..0a9f3ef65 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -330,10 +330,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
/** Generate HashAggregateExecPullOutHelper */
override def genHashAggregateExecPullOutHelper(
- groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute]): HashAggregateExecPullOutBaseHelper
=
- HashAggregateExecPullOutHelper(groupingExpressions, aggregateExpressions,
aggregateAttributes)
+ HashAggregateExecPullOutHelper(aggregateExpressions, aggregateAttributes)
override def genColumnarShuffleExchange(
shuffle: ShuffleExchangeExec,
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
index 0a9904206..f0a7ea180 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
@@ -20,7 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression._
import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.extension.columnar.RewriteTypedImperativeAggregate
+import
org.apache.gluten.extension.columnar.rewrite.RewriteTypedImperativeAggregate
import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
import org.apache.gluten.substrait.{AggregationParams, SubstraitContext}
import org.apache.gluten.substrait.expression.{AggregateFunctionNode,
ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
@@ -60,6 +60,12 @@ abstract class HashAggregateExecTransformer(
resultExpressions,
child) {
+ override def output: Seq[Attribute] = {
+ // TODO: We should have a check to make sure the returned schema actually
matches the output
+ // data. Since "resultExpressions" is not actually in used by Velox.
+ super.output
+ }
+
override def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
@@ -793,13 +799,9 @@ case class FlushableHashAggregateExecTransformer(
}
case class HashAggregateExecPullOutHelper(
- groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute])
- extends HashAggregateExecPullOutBaseHelper(
- groupingExpressions,
- aggregateExpressions,
- aggregateAttributes) {
+ extends HashAggregateExecPullOutBaseHelper {
/** This method calculates the output attributes of Aggregation. */
override protected def getAttrForAggregateExprs: List[Attribute] = {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index cfa1a4e53..f5e08a05d 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -102,7 +102,6 @@ trait SparkPlanExecApi {
/** Generate HashAggregateExecPullOutHelper */
def genHashAggregateExecPullOutHelper(
- groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute]): HashAggregateExecPullOutBaseHelper
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
index baf88c727..49a9ee1e8 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
@@ -172,12 +172,32 @@ abstract class HashAggregateExecBaseTransformer(
validation: Boolean = false): RelNode
}
-abstract class HashAggregateExecPullOutBaseHelper(
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression],
- aggregateAttributes: Seq[Attribute]) {
+object HashAggregateExecBaseTransformer {
+
+ private def getInitialInputBufferOffset(agg: BaseAggregateExec): Int = agg
match {
+ case a: HashAggregateExec => a.initialInputBufferOffset
+ case a: ObjectHashAggregateExec => a.initialInputBufferOffset
+ case a: SortAggregateExec => a.initialInputBufferOffset
+ }
+
+ def from(agg: BaseAggregateExec)(
+ childConverter: SparkPlan => SparkPlan = p => p):
HashAggregateExecBaseTransformer = {
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genHashAggregateExecTransformer(
+ agg.requiredChildDistributionExpressions,
+ agg.groupingExpressions,
+ agg.aggregateExpressions,
+ agg.aggregateAttributes,
+ getInitialInputBufferOffset(agg),
+ agg.resultExpressions,
+ childConverter(agg.child)
+ )
+ }
+}
+
+trait HashAggregateExecPullOutBaseHelper {
// The direct outputs of Aggregation.
- lazy val allAggregateResultAttributes: List[Attribute] =
+ def allAggregateResultAttributes(groupingExpressions: Seq[NamedExpression]):
List[Attribute] =
groupingExpressions.map(ConverterUtils.getAttrFromExpr(_)).toList :::
getAttrForAggregateExprs
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteIn.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteIn.scala
index b508f3eff..565b9bb19 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteIn.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteIn.scala
@@ -16,8 +16,9 @@
*/
package org.apache.gluten.extension
+import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
+
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, In, Or}
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{FileSourceScanExec, FilterExec,
SparkPlan}
import org.apache.spark.sql.types.StructType
@@ -32,7 +33,7 @@ import org.apache.spark.sql.types.StructType
*
* TODO: Remove this rule once Velox support the list option in `In` is not
literal.
*/
-object RewriteIn extends Rule[SparkPlan] {
+object RewriteIn extends RewriteSingleNode {
private def shouldRewrite(e: Expression): Boolean = {
e match {
@@ -58,7 +59,7 @@ object RewriteIn extends Rule[SparkPlan] {
}
}
- override def apply(plan: SparkPlan): SparkPlan = {
+ override def rewrite(plan: SparkPlan): SparkPlan = {
plan match {
// TODO: Support datasource v2
case scan: FileSourceScanExec if
scan.dataFilters.exists(_.find(shouldRewrite).isDefined) =>
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
index 02a466b6a..068f62e49 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
@@ -30,12 +30,12 @@ object MiscColumnarRules {
object TransformPreOverrides {
def apply(): TransformPreOverrides = {
TransformPreOverrides(
- List(TransformFilter()),
+ List(OffloadFilter()),
List(
- TransformOthers(),
- TransformAggregate(),
- TransformExchange(),
- TransformJoin()
+ OffloadOthers(),
+ OffloadAggregate(),
+ OffloadExchange(),
+ OffloadJoin()
)
)
}
@@ -43,17 +43,17 @@ object MiscColumnarRules {
// This rule will conduct the conversion from Spark plan to the plan
transformer.
case class TransformPreOverrides(
- topDownRules: Seq[TransformSingleNode],
- bottomUpRules: Seq[TransformSingleNode])
+ topDownRules: Seq[OffloadSingleNode],
+ bottomUpRules: Seq[OffloadSingleNode])
extends Rule[SparkPlan]
with LogLevelUtil {
@transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
def apply(plan: SparkPlan): SparkPlan = {
val plan0 =
- topDownRules.foldLeft(plan)((p, rule) => p.transformDown { case p =>
rule.impl(p) })
+ topDownRules.foldLeft(plan)((p, rule) => p.transformDown { case p =>
rule.offload(p) })
val plan1 =
- bottomUpRules.foldLeft(plan0)((p, rule) => p.transformUp { case p =>
rule.impl(p) })
+ bottomUpRules.foldLeft(plan0)((p, rule) => p.transformUp { case p =>
rule.offload(p) })
planChangeLogger.logRule(ruleName, plan, plan1)
plan1
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
similarity index 88%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
index 760929bbd..84a2ec5c6 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
@@ -40,13 +40,20 @@ import
org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BatchEvalPyth
import org.apache.spark.sql.execution.window.{WindowExec,
WindowGroupLimitExecShim}
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
-sealed trait TransformSingleNode extends Logging {
- def impl(plan: SparkPlan): SparkPlan
+/**
+ * Converts a vanilla Spark plan node into Gluten plan node. Gluten plan is
supposed to be executed
+ * in native, and the internals of execution is subject by backend's
implementation.
+ *
+ * Note: Only the current plan node is supposed to be open to modification. Do
not access or modify
+ * the children node. Tree-walking is done by caller of this trait.
+ */
+sealed trait OffloadSingleNode extends Logging {
+ def offload(plan: SparkPlan): SparkPlan
}
// Aggregation transformation.
-case class TransformAggregate() extends TransformSingleNode with LogLevelUtil {
- override def impl(plan: SparkPlan): SparkPlan = plan match {
+case class OffloadAggregate() extends OffloadSingleNode with LogLevelUtil {
+ override def offload(plan: SparkPlan): SparkPlan = plan match {
case plan if TransformHints.isNotTransformable(plan) =>
plan
case agg: HashAggregateExec =>
@@ -69,19 +76,6 @@ case class TransformAggregate() extends TransformSingleNode
with LogLevelUtil {
val aggChild = plan.child
- def transformHashAggregate(): GlutenPlan = {
- BackendsApiManager.getSparkPlanExecApiInstance
- .genHashAggregateExecTransformer(
- plan.requiredChildDistributionExpressions,
- plan.groupingExpressions,
- plan.aggregateExpressions,
- plan.aggregateAttributes,
- plan.initialInputBufferOffset,
- plan.resultExpressions,
- aggChild
- )
- }
-
// If child's output is empty, fallback or offload both the child and
aggregation.
if (
aggChild.output.isEmpty && BackendsApiManager.getSettings
@@ -91,9 +85,9 @@ case class TransformAggregate() extends TransformSingleNode
with LogLevelUtil {
case _: TransformSupport =>
// If the child is transformable, transform aggregation as well.
logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
- transformHashAggregate()
+ HashAggregateExecBaseTransformer.from(plan)()
case p: SparkPlan if PlanUtil.isGlutenTableCache(p) =>
- transformHashAggregate()
+ HashAggregateExecBaseTransformer.from(plan)()
case _ =>
// If the child is not transformable, do not transform the agg.
TransformHints.tagNotTransformable(plan, "child output schema is
empty")
@@ -101,14 +95,14 @@ case class TransformAggregate() extends
TransformSingleNode with LogLevelUtil {
}
} else {
logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
- transformHashAggregate()
+ HashAggregateExecBaseTransformer.from(plan)()
}
}
}
// Exchange transformation.
-case class TransformExchange() extends TransformSingleNode with LogLevelUtil {
- override def impl(plan: SparkPlan): SparkPlan = plan match {
+case class OffloadExchange() extends OffloadSingleNode with LogLevelUtil {
+ override def offload(plan: SparkPlan): SparkPlan = plan match {
case plan if TransformHints.isNotTransformable(plan) =>
plan
case plan: ShuffleExchangeExec =>
@@ -131,10 +125,10 @@ case class TransformExchange() extends
TransformSingleNode with LogLevelUtil {
}
// Join transformation.
-case class TransformJoin() extends TransformSingleNode with LogLevelUtil {
- import TransformJoin._
+case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil {
+ import OffloadJoin._
- override def impl(plan: SparkPlan): SparkPlan = {
+ override def offload(plan: SparkPlan): SparkPlan = {
if (TransformHints.isNotTransformable(plan)) {
logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
plan match {
@@ -223,7 +217,7 @@ case class TransformJoin() extends TransformSingleNode with
LogLevelUtil {
}
-object TransformJoin {
+object OffloadJoin {
private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec):
BuildSide = {
plan.joinType match {
case LeftOuter | LeftSemi => BuildRight
@@ -238,11 +232,11 @@ object TransformJoin {
}
// Filter transformation.
-case class TransformFilter() extends TransformSingleNode with LogLevelUtil {
- import TransformOthers._
+case class OffloadFilter() extends OffloadSingleNode with LogLevelUtil {
+ import OffloadOthers._
private val replace = new ReplaceSingleNode()
- override def impl(plan: SparkPlan): SparkPlan = plan match {
+ override def offload(plan: SparkPlan): SparkPlan = plan match {
case filter: FilterExec =>
genFilterExec(filter)
case other => other
@@ -286,14 +280,14 @@ case class TransformFilter() extends TransformSingleNode
with LogLevelUtil {
}
// Other transformations.
-case class TransformOthers() extends TransformSingleNode with LogLevelUtil {
- import TransformOthers._
+case class OffloadOthers() extends OffloadSingleNode with LogLevelUtil {
+ import OffloadOthers._
private val replace = new ReplaceSingleNode()
- override def impl(plan: SparkPlan): SparkPlan = replace.doReplace(plan)
+ override def offload(plan: SparkPlan): SparkPlan = replace.doReplace(plan)
}
-object TransformOthers {
+object OffloadOthers {
// Utility to replace single node within transformed Gluten node.
// Children will be preserved as they are as children of the output node.
//
@@ -333,35 +327,16 @@ object TransformOthers {
ProjectExecTransformer(plan.projectList, columnarChild)
case plan: SortAggregateExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
- BackendsApiManager.getSparkPlanExecApiInstance
- .genHashAggregateExecTransformer(
- plan.requiredChildDistributionExpressions,
- plan.groupingExpressions,
- plan.aggregateExpressions,
- plan.aggregateAttributes,
- plan.initialInputBufferOffset,
- plan.resultExpressions,
- plan.child match {
- case sort: SortExecTransformer if !sort.global =>
- sort.child
- case sort: SortExec if !sort.global =>
- sort.child
- case _ => plan.child
- }
- )
+ HashAggregateExecBaseTransformer.from(plan) {
+ case sort: SortExecTransformer if !sort.global =>
+ sort.child
+ case sort: SortExec if !sort.global =>
+ sort.child
+ case other => other
+ }
case plan: ObjectHashAggregateExec =>
- val child = plan.child
logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
- BackendsApiManager.getSparkPlanExecApiInstance
- .genHashAggregateExecTransformer(
- plan.requiredChildDistributionExpressions,
- plan.groupingExpressions,
- plan.aggregateExpressions,
- plan.aggregateAttributes,
- plan.initialInputBufferOffset,
- plan.resultExpressions,
- child
- )
+ HashAggregateExecBaseTransformer.from(plan)()
case plan: UnionExec =>
val children = plan.children
logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
index ea934425f..3c3d23ccc 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
@@ -355,40 +355,13 @@ case class AddTransformHintRule() extends Rule[SparkPlan]
{
.genFilterExecTransformer(plan.condition, plan.child)
transformer.doValidate().tagOnFallback(plan)
case plan: HashAggregateExec =>
- val transformer = BackendsApiManager.getSparkPlanExecApiInstance
- .genHashAggregateExecTransformer(
- plan.requiredChildDistributionExpressions,
- plan.groupingExpressions,
- plan.aggregateExpressions,
- plan.aggregateAttributes,
- plan.initialInputBufferOffset,
- plan.resultExpressions,
- plan.child
- )
+ val transformer = HashAggregateExecBaseTransformer.from(plan)()
transformer.doValidate().tagOnFallback(plan)
case plan: SortAggregateExec =>
- val transformer = BackendsApiManager.getSparkPlanExecApiInstance
- .genHashAggregateExecTransformer(
- plan.requiredChildDistributionExpressions,
- plan.groupingExpressions,
- plan.aggregateExpressions,
- plan.aggregateAttributes,
- plan.initialInputBufferOffset,
- plan.resultExpressions,
- plan.child
- )
+ val transformer = HashAggregateExecBaseTransformer.from(plan)()
transformer.doValidate().tagOnFallback(plan)
case plan: ObjectHashAggregateExec =>
- val transformer = BackendsApiManager.getSparkPlanExecApiInstance
- .genHashAggregateExecTransformer(
- plan.requiredChildDistributionExpressions,
- plan.groupingExpressions,
- plan.aggregateExpressions,
- plan.aggregateAttributes,
- plan.initialInputBufferOffset,
- plan.resultExpressions,
- plan.child
- )
+ val transformer = HashAggregateExecBaseTransformer.from(plan)()
transformer.doValidate().tagOnFallback(plan)
case plan: UnionExec =>
val transformer = ColumnarUnionExec(plan.children)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
deleted file mode 100644
index 33d99f5f7..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.extension.columnar.enumerated
-
-import org.apache.gluten.extension.columnar.validator.Validator
-import org.apache.gluten.ras.rule.{RasRule, Shape}
-
-import org.apache.spark.sql.execution.SparkPlan
-
-object ConditionedRule {
- trait PreCondition {
- def apply(node: SparkPlan): Boolean
- }
-
- object PreCondition {
- implicit class FromValidator(validator: Validator) extends PreCondition {
- override def apply(node: SparkPlan): Boolean = {
- validator.validate(node) match {
- case Validator.Passed => true
- case Validator.Failed(reason) => false
- }
- }
- }
- }
-
- def wrap(rule: RasRule[SparkPlan], cond: ConditionedRule.PreCondition):
RasRule[SparkPlan] = {
- new RasRule[SparkPlan] {
- override def shift(node: SparkPlan): Iterable[SparkPlan] = {
- val out = List(node)
- .filter(cond.apply)
- .flatMap(rule.shift)
- out
- }
- override def shape(): Shape[SparkPlan] = rule.shape()
- }
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
index dfc2d474f..92d64abf3 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
@@ -118,10 +118,7 @@ class EnumeratedApplier(session: SparkSession)
(_: SparkSession) => FallbackEmptySchemaRelation()
) :::
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules()
:::
- List(
- (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark),
- (_: SparkSession) => RewriteSparkPlanRulesManager()
- ) :::
+ List((spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark)) :::
List(
(session: SparkSession) => EnumeratedTransform(session,
outputsColumnar),
(_: SparkSession) => RemoveTransitions
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index 973020438..dc34bc1af 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -16,13 +16,10 @@
*/
package org.apache.gluten.extension.columnar.enumerated
-import org.apache.gluten.extension.GlutenPlan
-import org.apache.gluten.extension.columnar.{TransformExchange, TransformJoin,
TransformOthers, TransformSingleNode}
-import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
+import org.apache.gluten.extension.columnar.{OffloadExchange, OffloadJoin,
OffloadOthers, OffloadSingleNode}
import org.apache.gluten.planner.GlutenOptimization
import org.apache.gluten.planner.property.Conventions
import org.apache.gluten.ras.property.PropertySet
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
import org.apache.gluten.utils.LogLevelUtil
import org.apache.spark.sql.SparkSession
@@ -34,31 +31,22 @@ case class EnumeratedTransform(session: SparkSession,
outputsColumnar: Boolean)
with LogLevelUtil {
import EnumeratedTransform._
- private val validator = Validators
- .builder()
- .fallbackByHint()
- .fallbackIfScanOnly()
- .fallbackComplexExpressions()
- .fallbackByBackendSettings()
- .fallbackByUserOptions()
- .build()
-
private val rules = List(
- PushFilterToScan,
- FilterRemoveRule
+ new PushFilterToScan(RasOffload.validator),
+ RemoveFilter
)
// TODO: Should obey ReplaceSingleNode#applyScanNotTransformable to select
// (vanilla) scan with cheaper sub-query plan through cost model.
- private val implRules = List(
- AsRasImplement(TransformOthers()),
- AsRasImplement(TransformExchange()),
- AsRasImplement(TransformJoin()),
- ImplementAggregate,
- ImplementFilter
- ).map(_.withValidator(validator))
+ private val offloadRules = List(
+ new AsRasOffload(OffloadOthers()),
+ new AsRasOffload(OffloadExchange()),
+ new AsRasOffload(OffloadJoin()),
+ RasOffloadAggregate,
+ RasOffloadFilter
+ )
- private val optimization = GlutenOptimization(rules ++ implRules)
+ private val optimization = GlutenOptimization(rules ++ offloadRules)
private val reqConvention = Conventions.ANY
private val altConventions =
@@ -75,24 +63,12 @@ case class EnumeratedTransform(session: SparkSession,
outputsColumnar: Boolean)
}
object EnumeratedTransform {
- private case class AsRasImplement(delegate: TransformSingleNode) extends
RasRule[SparkPlan] {
- override def shift(node: SparkPlan): Iterable[SparkPlan] = {
- val out = delegate.impl(node)
- out match {
- case t: GlutenPlan if !t.doValidate().isValid =>
- List.empty
- case other =>
- List(other)
- }
- }
-
- override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
- }
- // TODO: Currently not in use. Prepared for future development.
- implicit private class RasRuleImplicits(rasRule: RasRule[SparkPlan]) {
- def withValidator(v: Validator): RasRule[SparkPlan] = {
- ConditionedRule.wrap(rasRule, v)
+ /** Accepts a [[OffloadSingleNode]] rule to convert it into a RAS offload
rule. */
+ private class AsRasOffload(delegate: OffloadSingleNode) extends RasOffload {
+ override protected def offload(node: SparkPlan): SparkPlan = {
+ val out = delegate.offload(node)
+ out
}
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
index f04f572c1..7306b734a 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
@@ -17,28 +17,31 @@
package org.apache.gluten.extension.columnar.enumerated
import org.apache.gluten.execution.{FilterHandler, TransformSupport}
-import org.apache.gluten.extension.columnar.TransformHints
+import org.apache.gluten.extension.columnar.validator.Validator
import org.apache.gluten.ras.path.Pattern._
import org.apache.gluten.ras.path.Pattern.Matchers._
import org.apache.gluten.ras.rule.{RasRule, Shape}
import org.apache.gluten.ras.rule.Shapes._
-import org.apache.spark.sql.execution.{ColumnarToRowExec,
ColumnarToRowTransition, FileSourceScanExec, FilterExec, SparkPlan}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
-object PushFilterToScan extends RasRule[SparkPlan] {
+// TODO: Match on Vanilla filter + Gluten scan.
+class PushFilterToScan(validator: Validator) extends RasRule[SparkPlan] {
override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
case FilterAndScan(filter, scan) =>
- if (!TransformHints.isTransformable(scan)) {
- return List.empty
- }
- val newScan =
- FilterHandler.pushFilterToScan(filter.condition, scan)
- newScan match {
- case ts: TransformSupport if ts.doValidate().isValid =>
- List(filter.withNewChildren(List(ts)))
- case _ =>
+ validator.validate(scan) match {
+ case Validator.Failed(reason) =>
List.empty
+ case Validator.Passed =>
+ val newScan =
+ FilterHandler.pushFilterToScan(filter.condition, scan)
+ newScan match {
+ case ts: TransformSupport if ts.doValidate().isValid =>
+ List(filter.withNewChildren(List(ts)))
+ case _ =>
+ List.empty
+ }
}
case _ =>
List.empty
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
new file mode 100644
index 000000000..57e093bde
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.enumerated
+
+import org.apache.gluten.extension.GlutenPlan
+import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
+import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
+import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
+
+import org.apache.spark.sql.execution.SparkPlan
+
+trait RasOffload extends RasRule[SparkPlan] {
+ import RasOffload._
+
+ final override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+ // 0. If the node is already offloaded, return fast.
+ if (node.isInstanceOf[GlutenPlan]) {
+ return List.empty
+ }
+
+ // 1. Rewrite the node to form that native library supports.
+ val rewritten = rewrites.foldLeft(node) {
+ case (node, rewrite) =>
+ node.transformUp {
+ case p =>
+ val out = rewrite.rewrite(p)
+ out
+ }
+ }
+
+ // 2. Walk the rewritten tree.
+ val offloaded = rewritten.transformUp {
+ case from =>
+ // 3. Validate current node. If passed, offload it.
+ validator.validate(from) match {
+ case Validator.Passed =>
+ offload(from) match {
+ case t: GlutenPlan if !t.doValidate().isValid =>
+ // 4. If native validation fails on the offloaded node, return
the
+ // original one.
+ from
+ case other =>
+ other
+ }
+ case Validator.Failed(reason) =>
+ from
+ }
+ }
+
+ // 5. Return the final tree.
+ List(offloaded)
+ }
+
+ protected def offload(node: SparkPlan): SparkPlan
+
+ final override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
+}
+
+object RasOffload {
+ val validator = Validators
+ .builder()
+ .fallbackByHint()
+ .fallbackIfScanOnly()
+ .fallbackComplexExpressions()
+ .fallbackByBackendSettings()
+ .fallbackByUserOptions()
+ .build()
+
+ private val rewrites = RewriteSingleNode.allRules()
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
similarity index 50%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
index 8c51ca4fd..e48545ae9 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
@@ -16,39 +16,16 @@
*/
package org.apache.gluten.extension.columnar.enumerated
-import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.execution.HashAggregateExecBaseTransformer
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
-object ImplementAggregate extends RasRule[SparkPlan] {
- override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
- case agg: HashAggregateExec => shiftAgg(agg)
- case _ => List.empty
+object RasOffloadAggregate extends RasOffload {
+ override protected def offload(node: SparkPlan): SparkPlan = node match {
+ case agg: HashAggregateExec =>
+ val out = HashAggregateExecBaseTransformer.from(agg)()
+ out
+ case other => other
}
-
- private def shiftAgg(agg: HashAggregateExec): Iterable[SparkPlan] = {
- val transformer = implement(agg)
- if (!transformer.doValidate().isValid) {
- return List.empty
- }
- List(transformer)
- }
-
- private def implement(agg: HashAggregateExec):
HashAggregateExecBaseTransformer = {
- BackendsApiManager.getSparkPlanExecApiInstance
- .genHashAggregateExecTransformer(
- agg.requiredChildDistributionExpressions,
- agg.groupingExpressions,
- agg.aggregateExpressions,
- agg.aggregateAttributes,
- agg.initialInputBufferOffset,
- agg.resultExpressions,
- agg.child
- )
- }
-
- override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
similarity index 75%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
index 33121e7f1..030d05d47 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
@@ -17,22 +17,16 @@
package org.apache.gluten.extension.columnar.enumerated
import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
-object ImplementFilter extends RasRule[SparkPlan] {
- override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
+object RasOffloadFilter extends RasOffload {
+ override protected def offload(node: SparkPlan): SparkPlan = node match {
case FilterExec(condition, child) =>
val out = BackendsApiManager.getSparkPlanExecApiInstance
.genFilterExecTransformer(condition, child)
- if (!out.doValidate().isValid) {
- List.empty
- } else {
- List(out)
- }
- case _ =>
- List.empty
+ out
+ case other =>
+ other
}
- override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/FilterRemoveRule.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
similarity index 97%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/FilterRemoveRule.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
index 52b5be981..c9f4b27bf 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/FilterRemoveRule.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.SparkPlan
// because the pushed filter is not considered in the model. Removing the
filter will make
// optimizer choose a single scan as the winner sub-plan since a single scan's
cost is lower than
// filter + scan.
-object FilterRemoveRule extends RasRule[SparkPlan] {
+object RemoveFilter extends RasRule[SparkPlan] {
override def shift(node: SparkPlan): Iterable[SparkPlan] = {
val filter = node.asInstanceOf[FilterExecTransformerBase]
if (filter.isNoop()) {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
index d33cda2e6..0e905ced1 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
@@ -20,6 +20,7 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.columnar._
import
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
RemoveTopmostColumnarToRow, TransformPostOverrides, TransformPreOverrides}
+import
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
import org.apache.gluten.extension.columnar.util.AdaptiveContext
import org.apache.gluten.metrics.GlutenTimeMetric
import org.apache.gluten.utils.{LogLevelUtil, PhysicalPlanSelector}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPostProject.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPostProject.scala
similarity index 92%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPostProject.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPostProject.scala
index dc2e6423c..1b5467144 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPostProject.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPostProject.scala
@@ -14,13 +14,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.rewrite
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.utils.PullOutProjectHelper
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
NamedExpression, WindowExpression}
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{GenerateExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.window.WindowExec
@@ -33,17 +32,17 @@ import scala.collection.mutable.ArrayBuffer
* the output of Spark, ensuring that the output data of the native plan can
match the Spark plan
* when a fallback occurs.
*/
-object PullOutPostProject extends Rule[SparkPlan] with PullOutProjectHelper {
+object PullOutPostProject extends RewriteSingleNode with PullOutProjectHelper {
private def needsPostProjection(plan: SparkPlan): Boolean = {
plan match {
case agg: BaseAggregateExec =>
val pullOutHelper =
BackendsApiManager.getSparkPlanExecApiInstance.genHashAggregateExecPullOutHelper(
- agg.groupingExpressions,
agg.aggregateExpressions,
agg.aggregateAttributes)
- val allAggregateResultAttributes =
pullOutHelper.allAggregateResultAttributes
+ val allAggregateResultAttributes =
+ pullOutHelper.allAggregateResultAttributes(agg.groupingExpressions)
// If the result expressions has different size with output attribute,
// post-projection is needed.
agg.resultExpressions.size != allAggregateResultAttributes.size ||
@@ -72,14 +71,13 @@ object PullOutPostProject extends Rule[SparkPlan] with
PullOutProjectHelper {
}
}
- override def apply(plan: SparkPlan): SparkPlan = plan match {
+ override def rewrite(plan: SparkPlan): SparkPlan = plan match {
case agg: BaseAggregateExec if supportedAggregate(agg) &&
needsPostProjection(agg) =>
val pullOutHelper =
BackendsApiManager.getSparkPlanExecApiInstance.genHashAggregateExecPullOutHelper(
- agg.groupingExpressions,
agg.aggregateExpressions,
agg.aggregateAttributes)
- val newResultExpressions = pullOutHelper.allAggregateResultAttributes
+ val newResultExpressions =
pullOutHelper.allAggregateResultAttributes(agg.groupingExpressions)
val newAgg = copyBaseAggregateExec(agg)(newResultExpressions =
newResultExpressions)
ProjectExec(agg.resultExpressions, newAgg)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPreProject.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
similarity index 96%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPreProject.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
index 48a9a7687..64d4f2736 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPreProject.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.rewrite
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.sql.shims.SparkShimLoader
@@ -22,8 +22,7 @@ import org.apache.gluten.utils.PullOutProjectHelper
import org.apache.spark.sql.catalyst.expressions._
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Complete, Partial}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ExpandExec, GenerateExec, ProjectExec,
SortExec, SparkPlan, TakeOrderedAndProjectExec}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
TypedAggregateExpression}
import org.apache.spark.sql.execution.window.{WindowExec,
WindowGroupLimitExecShim}
@@ -36,7 +35,7 @@ import scala.collection.mutable
* to transform the SparkPlan at the physical plan level, constructing a
SparkPlan that supports
* execution by the native engine.
*/
-object PullOutPreProject extends Rule[SparkPlan] with PullOutProjectHelper {
+object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper {
private def needsPreProject(plan: SparkPlan): Boolean = {
plan match {
@@ -118,7 +117,7 @@ object PullOutPreProject extends Rule[SparkPlan] with
PullOutProjectHelper {
}
}
- override def apply(plan: SparkPlan): SparkPlan = plan match {
+ override def rewrite(plan: SparkPlan): SparkPlan = plan match {
case sort: SortExec if needsPreProject(sort) =>
val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
val newSortOrder = getNewSortOrder(sort.sortOrder, expressionMap)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteCollect.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteCollect.scala
similarity index 93%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/RewriteCollect.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteCollect.scala
index 3b6710857..74d493de5 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteCollect.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteCollect.scala
@@ -14,14 +14,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension
+package org.apache.gluten.extension.columnar.rewrite
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.utils.PullOutProjectHelper
import org.apache.spark.sql.catalyst.expressions.{And, Attribute,
AttributeSet, If, IsNotNull, IsNull, Literal, NamedExpression}
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
CollectSet, Complete, Final, Partial}
-import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.types.ArrayType
@@ -36,7 +35,7 @@ import scala.collection.mutable.ArrayBuffer
*
* TODO: remove this rule once Velox compatible with vanilla Spark.
*/
-object RewriteCollect extends Rule[SparkPlan] with PullOutProjectHelper {
+object RewriteCollect extends RewriteSingleNode with PullOutProjectHelper {
private lazy val shouldRewriteCollect =
BackendsApiManager.getSettings.shouldRewriteCollect()
@@ -121,7 +120,7 @@ object RewriteCollect extends Rule[SparkPlan] with
PullOutProjectHelper {
(newAggregateAttributes, newResultExpressions)
}
- override def apply(plan: SparkPlan): SparkPlan = {
+ override def rewrite(plan: SparkPlan): SparkPlan = {
if (!shouldRewriteCollect) {
return plan
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteMultiChildrenCount.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteMultiChildrenCount.scala
similarity index 93%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteMultiChildrenCount.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteMultiChildrenCount.scala
index 9657c127d..b395d961a 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteMultiChildrenCount.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteMultiChildrenCount.scala
@@ -14,14 +14,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.rewrite
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.utils.PullOutProjectHelper
import org.apache.spark.sql.catalyst.expressions.{If, IsNull, Literal, Or}
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Count, Partial}
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.types.IntegerType
@@ -46,7 +45,7 @@ import org.apache.spark.sql.types.IntegerType
*
* TODO: Remove this rule when Velox support multi-children Count
*/
-object RewriteMultiChildrenCount extends Rule[SparkPlan] with
PullOutProjectHelper {
+object RewriteMultiChildrenCount extends RewriteSingleNode with
PullOutProjectHelper {
private lazy val shouldRewriteCount =
BackendsApiManager.getSettings.shouldRewriteCount()
private def extractCountForRewrite(aggExpr: AggregateExpression):
Option[Count] = {
@@ -92,7 +91,7 @@ object RewriteMultiChildrenCount extends Rule[SparkPlan] with
PullOutProjectHelp
}
}
- override def apply(plan: SparkPlan): SparkPlan = {
+ override def rewrite(plan: SparkPlan): SparkPlan = {
if (!shouldRewriteCount) {
return plan
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSingleNode.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSingleNode.scala
new file mode 100644
index 000000000..73bc8b967
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSingleNode.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.rewrite
+
+import org.apache.gluten.extension.RewriteIn
+
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * Rewrites a plan node from vanilla Spark into its alternative representation.
+ *
+ * Gluten's planner will pick one that is considered the best executable plan
between input plan and
+ * the output plan.
+ *
+ * Note: Only the current plan node is supposed to be open to modification. Do
not access or modify
+ * the children node. Tree-walking is done by caller of this trait.
+ *
+ * TODO: Ideally for such API we'd better to allow multiple alternative
outputs.
+ */
+trait RewriteSingleNode {
+ def rewrite(plan: SparkPlan): SparkPlan
+}
+
+object RewriteSingleNode {
+ def allRules(): Seq[RewriteSingleNode] = {
+ Seq(
+ RewriteIn,
+ RewriteMultiChildrenCount,
+ RewriteCollect,
+ RewriteTypedImperativeAggregate,
+ PullOutPreProject,
+ PullOutPostProject)
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteSparkPlanRulesManager.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
similarity index 91%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteSparkPlanRulesManager.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
index 6070613c1..5fd728eca 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteSparkPlanRulesManager.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
@@ -14,9 +14,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.rewrite
-import org.apache.gluten.extension.{RewriteCollect, RewriteIn}
+import org.apache.gluten.extension.columnar.{AddTransformHintRule,
TransformHint, TransformHints}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.rdd.RDD
@@ -44,7 +44,7 @@ case class RewrittenNodeWall(originalChild: SparkPlan)
extends LeafExecNode {
*
* Note that, this rule does not touch and tag these operators who does not
need to rewrite.
*/
-class RewriteSparkPlanRulesManager private (rewriteRules: Seq[Rule[SparkPlan]])
+class RewriteSparkPlanRulesManager private (rewriteRules:
Seq[RewriteSingleNode])
extends Rule[SparkPlan] {
private def mayNeedRewrite(plan: SparkPlan): Boolean = {
@@ -83,7 +83,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules:
Seq[Rule[SparkPlan]])
// Some rewrite rules may generate new parent plan node, we should
use transform to
// rewrite the original plan. For example, PullOutPreProject and
PullOutPostProject
// will generate post-project plan node.
- plan.transformUp { case p => rule.apply(p) }
+ plan.transformUp { case p => rule.rewrite(p) }
}
(rewrittenPlan, None)
} catch {
@@ -133,13 +133,6 @@ class RewriteSparkPlanRulesManager private (rewriteRules:
Seq[Rule[SparkPlan]])
object RewriteSparkPlanRulesManager {
def apply(): Rule[SparkPlan] = {
- val rewriteRules = Seq(
- RewriteIn,
- RewriteMultiChildrenCount,
- RewriteCollect,
- RewriteTypedImperativeAggregate,
- PullOutPreProject,
- PullOutPostProject)
- new RewriteSparkPlanRulesManager(rewriteRules)
+ new RewriteSparkPlanRulesManager(RewriteSingleNode.allRules())
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteTypedImperativeAggregate.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteTypedImperativeAggregate.scala
similarity index 91%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteTypedImperativeAggregate.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteTypedImperativeAggregate.scala
index df5341373..971a87923 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteTypedImperativeAggregate.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteTypedImperativeAggregate.scala
@@ -14,18 +14,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.rewrite
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.utils.PullOutProjectHelper
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
-object RewriteTypedImperativeAggregate extends Rule[SparkPlan] with
PullOutProjectHelper {
+object RewriteTypedImperativeAggregate extends RewriteSingleNode with
PullOutProjectHelper {
private lazy val shouldRewriteTypedImperativeAggregate =
BackendsApiManager.getSettings.shouldRewriteTypedImperativeAggregate()
@@ -40,7 +39,7 @@ object RewriteTypedImperativeAggregate extends
Rule[SparkPlan] with PullOutProje
}
}
- override def apply(plan: SparkPlan): SparkPlan = {
+ override def rewrite(plan: SparkPlan): SparkPlan = {
if (!shouldRewriteTypedImperativeAggregate) {
return plan
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
index a5b66df46..2920c0a39 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.planner.cost
-import org.apache.gluten.extension.columnar.{ColumnarTransitions,
TransformJoin}
+import org.apache.gluten.extension.columnar.{ColumnarTransitions, OffloadJoin}
import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
import org.apache.gluten.ras.{Cost, CostModel}
import org.apache.gluten.utils.PlanUtil
@@ -57,7 +57,9 @@ object GlutenCostModel {
// A very rough estimation as of now.
private def selfLongCostOf(node: SparkPlan): Long = {
node match {
- case p: ShuffledHashJoinExec if !TransformJoin.isLegal(p) =>
+ case p: ShuffledHashJoinExec if !OffloadJoin.isLegal(p) =>
+ // To exclude the rewritten intermediate plan that is not executable
+ // by vanilla Spark and was generated by strategy
"JoinSelectionOverrides"
infLongCost
case ColumnarToRowExec(child) => 3L
case RowToColumnarExec(child) => 3L
diff --git
a/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
b/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
index 7308703e7..fbdbeadba 100644
---
a/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
+++
b/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
@@ -18,8 +18,9 @@ package org.apache.spark.sql.execution.datasources
import org.apache.gluten.execution.{ProjectExecTransformer,
SortExecTransformer, TransformSupport, WholeStageTransformer}
import org.apache.gluten.execution.datasource.GlutenFormatWriterInjects
-import org.apache.gluten.extension.columnar.{AddTransformHintRule,
RewriteSparkPlanRulesManager}
+import org.apache.gluten.extension.columnar.AddTransformHintRule
import
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
+import
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]