This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new f0d54412e8 [CORE] Consolidate RewriteSparkPlanRulesManager,
AddFallbackTagRule, TransformPreOverrides into a single rule (#7918)
f0d54412e8 is described below
commit f0d54412e831fc2075fdd89d8281599b61ec1844
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Nov 12 18:09:02 2024 +0800
[CORE] Consolidate RewriteSparkPlanRulesManager, AddFallbackTagRule,
TransformPreOverrides into a single rule (#7918)
---
.../gluten/backendsapi/clickhouse/CHRuleApi.scala | 21 +-
.../clickhouse/CHSparkPlanExecApi.scala | 15 +-
.../extension/FallbackBroadcastHashJoinRules.scala | 1 -
.../gluten/backendsapi/velox/VeloxRuleApi.scala | 18 +-
.../scala/org/apache/gluten/backend/Backend.scala | 4 +-
.../gluten/extension/GlutenSessionExtensions.scala | 4 +-
.../{RuleInjector.scala => Injector.scala} | 6 +-
.../org/apache/gluten/backendsapi/RuleApi.scala | 4 +-
.../gluten/backendsapi/SubstraitBackend.scala | 4 +-
.../org/apache/gluten/extension/GlutenPlan.scala | 11 +
.../columnar/EnsureLocalSortRequirements.scala | 20 +-
.../gluten/extension/columnar/FallbackRules.scala | 402 ---------------------
.../gluten/extension/columnar/FallbackTag.scala | 140 +++++++
.../extension/columnar/MiscColumnarRules.scala | 29 --
.../columnar/enumerated/EnumeratedTransform.scala | 14 +-
.../extension/columnar/enumerated/RasOffload.scala | 2 +-
.../AddFallbackTags.scala} | 227 +-----------
.../columnar/heuristic/HeuristicTransform.scala | 55 +++
.../columnar/heuristic/LegacyOffload.scala | 48 +++
.../{ => heuristic}/OffloadSingleNode.scala | 3 +-
.../extension/columnar/rewrite/RewriteJoin.scala | 2 +-
.../rewrite/RewriteSparkPlanRulesManager.scala | 13 +-
.../GlutenFormatWriterInjectsBase.scala | 14 +-
23 files changed, 335 insertions(+), 722 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index ccb124b613..a1b8fbdd5b 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -19,10 +19,10 @@ package org.apache.gluten.backendsapi.clickhouse
import org.apache.gluten.backendsapi.RuleApi
import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar._
-import
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides}
-import
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
+import
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
+import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
import org.apache.gluten.extension.columnar.transition.{InsertTransitions,
RemoveTransitions}
-import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector}
+import org.apache.gluten.extension.injector.{Injector, SparkInjector}
import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector,
RasInjector}
import org.apache.gluten.parser.{GlutenCacheFilesSqlParser,
GlutenClickhouseSqlParser}
import org.apache.gluten.sql.shims.SparkShimLoader
@@ -31,14 +31,14 @@ import
org.apache.spark.sql.catalyst.{CHAggregateFunctionRewriteRule, EqualToRew
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.delta.DeltaLogFileIndex
import org.apache.spark.sql.delta.rules.CHOptimizeMetadataOnlyDeltaQuery
-import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages,
CommandResultExec, FileSourceScanExec, GlutenFallbackReporter, RDDScanExec,
SparkPlan}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
import org.apache.spark.util.SparkPlanRules
class CHRuleApi extends RuleApi {
import CHRuleApi._
- override def injectRules(injector: RuleInjector): Unit = {
+ override def injectRules(injector: Injector): Unit = {
injectSpark(injector.spark)
injectLegacy(injector.gluten.legacy)
injectRas(injector.gluten.ras)
@@ -65,8 +65,7 @@ private object CHRuleApi {
}
def injectLegacy(injector: LegacyInjector): Unit = {
-
- // Gluten columnar: Transform rules.
+ // Legacy: Pre-transform rules.
injector.injectTransform(_ => RemoveTransitions)
injector.injectTransform(_ => PushDownInputFileExpression.PreOffload)
injector.injectTransform(c => FallbackOnANSIMode.apply(c.session))
@@ -74,9 +73,11 @@ private object CHRuleApi {
injector.injectTransform(_ => RewriteSubqueryBroadcast())
injector.injectTransform(c => FallbackBroadcastHashJoin.apply(c.session))
injector.injectTransform(c =>
MergeTwoPhasesHashBaseAggregate.apply(c.session))
- injector.injectTransform(_ => intercept(RewriteSparkPlanRulesManager()))
- injector.injectTransform(_ => intercept(AddFallbackTagRule()))
- injector.injectTransform(_ => intercept(TransformPreOverrides()))
+
+ // Legacy: The Legacy transform rule.
+ injector.injectTransform(_ => intercept(HeuristicTransform()))
+
+ // Legacy: Post-transform rules.
injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectTransform(c =>
intercept(RewriteTransformer.apply(c.session)))
injector.injectTransform(_ => PushDownFilterToScan)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 190fcb13ea..f08f21056d 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -23,8 +23,7 @@ import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.expression.ExpressionNames.MONOTONICALLY_INCREASING_ID
import org.apache.gluten.extension.ExpressionExtensionTrait
-import org.apache.gluten.extension.columnar.AddFallbackTagRule
-import
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
+import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, WindowFunctionNode}
import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy}
@@ -224,9 +223,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
}
// FIXME: The operation happens inside ReplaceSingleNode().
// Caller may not know it adds project on top of the shuffle.
- val project = TransformPreOverrides().apply(
- AddFallbackTagRule().apply(
- ProjectExec(plan.child.output ++ projectExpressions, plan.child)))
+ // FIXME: HeuristicTransform is costly. Re-applying it may cause
performance issues.
+ val project =
+ HeuristicTransform()(ProjectExec(plan.child.output ++
projectExpressions, plan.child))
var newExprs = Seq[Expression]()
for (i <- exprs.indices) {
val pos = newExpressionsPosition(i)
@@ -249,9 +248,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
}
// FIXME: The operation happens inside ReplaceSingleNode().
// Caller may not know it adds project on top of the shuffle.
- val project = TransformPreOverrides().apply(
- AddFallbackTagRule().apply(
- ProjectExec(plan.child.output ++ projectExpressions, plan.child)))
+ // FIXME: HeuristicTransform is costly. Re-applying it may cause
performance issues.
+ val project =
+ HeuristicTransform()(ProjectExec(plan.child.output ++
projectExpressions, plan.child))
var newOrderings = Seq[SortOrder]()
for (i <- orderings.indices) {
val oldOrdering = orderings(i)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala
index 207bb0e3a4..d1116c857a 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcastHashJoinRules.scala
@@ -19,7 +19,6 @@ package org.apache.gluten.extension
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.columnar._
-import
org.apache.gluten.extension.columnar.FallbackTags.EncodeFallbackTagImplicits
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index 3554bc5c9c..30e92a18b2 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -20,11 +20,11 @@ import org.apache.gluten.backendsapi.RuleApi
import org.apache.gluten.datasource.ArrowConvertorRule
import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar._
-import
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides}
+import
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform
-import
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
+import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
import org.apache.gluten.extension.columnar.transition.{InsertTransitions,
RemoveTransitions}
-import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector}
+import org.apache.gluten.extension.injector.{Injector, SparkInjector}
import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector,
RasInjector}
import org.apache.gluten.sql.shims.SparkShimLoader
@@ -33,7 +33,7 @@ import
org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, GlutenFa
class VeloxRuleApi extends RuleApi {
import VeloxRuleApi._
- override def injectRules(injector: RuleInjector): Unit = {
+ override def injectRules(injector: Injector): Unit = {
injectSpark(injector.spark)
injectLegacy(injector.gluten.legacy)
injectRas(injector.gluten.ras)
@@ -49,7 +49,7 @@ private object VeloxRuleApi {
}
def injectLegacy(injector: LegacyInjector): Unit = {
- // Gluten columnar: Transform rules.
+ // Legacy: Pre-transform rules.
injector.injectTransform(_ => RemoveTransitions)
injector.injectTransform(_ => PushDownInputFileExpression.PreOffload)
injector.injectTransform(c => FallbackOnANSIMode.apply(c.session))
@@ -57,9 +57,11 @@ private object VeloxRuleApi {
injector.injectTransform(_ => RewriteSubqueryBroadcast())
injector.injectTransform(c =>
BloomFilterMightContainJointRewriteRule.apply(c.session))
injector.injectTransform(c => ArrowScanReplaceRule.apply(c.session))
- injector.injectTransform(_ => RewriteSparkPlanRulesManager())
- injector.injectTransform(_ => AddFallbackTagRule())
- injector.injectTransform(_ => TransformPreOverrides())
+
+ // Legacy: The Legacy transform rule.
+ injector.injectTransform(_ => HeuristicTransform())
+
+ // Legacy: Post-transform rules.
injector.injectTransform(c => PartialProjectRule.apply(c.session))
injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectTransform(c => RewriteTransformer.apply(c.session))
diff --git a/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala
b/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala
index ef22d97e77..346181e140 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.backend
import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionFunc}
-import org.apache.gluten.extension.injector.RuleInjector
+import org.apache.gluten.extension.injector.Injector
import org.apache.spark.SparkContext
import org.apache.spark.api.plugin.PluginContext
@@ -51,7 +51,7 @@ trait Backend {
def convFuncOverride(): ConventionFunc.Override =
ConventionFunc.Override.Empty
/** Query planner rules. */
- def injectRules(injector: RuleInjector): Unit
+ def injectRules(injector: Injector): Unit
}
object Backend {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala
index d5afc6b7e7..addcad8dd0 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala
@@ -18,7 +18,7 @@ package org.apache.gluten.extension
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backend.Backend
-import org.apache.gluten.extension.injector.RuleInjector
+import org.apache.gluten.extension.injector.Injector
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSessionExtensions
@@ -28,7 +28,7 @@ private[gluten] class GlutenSessionExtensions
with Logging {
import GlutenSessionExtensions._
override def apply(exts: SparkSessionExtensions): Unit = {
- val injector = new RuleInjector(exts)
+ val injector = new Injector(exts)
injector.control.disableOn {
session =>
val glutenEnabledGlobally = session.conf
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/Injector.scala
similarity index 89%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/injector/Injector.scala
index c497a24a07..81b5239a8a 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/Injector.scala
@@ -18,8 +18,8 @@ package org.apache.gluten.extension.injector
import org.apache.spark.sql.SparkSessionExtensions
-/** Injector used to inject query planner rules into Spark and Gluten. */
-class RuleInjector(extensions: SparkSessionExtensions) {
+/** Injector used to inject extensible components into Spark and Gluten. */
+class Injector(extensions: SparkSessionExtensions) {
val control = new InjectorControl()
val spark: SparkInjector = new SparkInjector(control, extensions)
val gluten: GlutenInjector = new GlutenInjector(control)
@@ -31,4 +31,4 @@ class RuleInjector(extensions: SparkSessionExtensions) {
}
}
-object RuleInjector {}
+object Injector {}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala
index 7c4c8577f4..745bbcb72d 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala
@@ -16,9 +16,9 @@
*/
package org.apache.gluten.backendsapi
-import org.apache.gluten.extension.injector.RuleInjector
+import org.apache.gluten.extension.injector.Injector
trait RuleApi {
// Injects all Spark query planner rules used by the Gluten backend.
- def injectRules(injector: RuleInjector): Unit
+ def injectRules(injector: Injector): Unit
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SubstraitBackend.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SubstraitBackend.scala
index d7785663d5..37be117105 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SubstraitBackend.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SubstraitBackend.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.backendsapi
import org.apache.gluten.backend.Backend
-import org.apache.gluten.extension.injector.RuleInjector
+import org.apache.gluten.extension.injector.Injector
import org.apache.spark.SparkContext
import org.apache.spark.api.plugin.PluginContext
@@ -35,7 +35,7 @@ trait SubstraitBackend extends Backend {
final override def onExecutorShutdown(): Unit = {
listenerApi().onExecutorShutdown()
}
- final override def injectRules(injector: RuleInjector): Unit = {
+ final override def injectRules(injector: Injector): Unit = {
ruleApi().injectRules(injector)
}
def iteratorApi(): IteratorApi
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
index 06d798e50f..c252bacc9a 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
@@ -21,6 +21,7 @@ import org.apache.gluten.backend.Backend
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.TransformerState
+import org.apache.gluten.extension.columnar.FallbackTags.add
import org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.logging.LogLevelUtil
import org.apache.gluten.substrait.SubstraitContext
@@ -28,6 +29,7 @@ import org.apache.gluten.substrait.plan.PlanBuilder
import org.apache.gluten.substrait.rel.RelNode
import org.apache.gluten.test.TestStats
+import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.execution.SparkPlan
import com.google.common.collect.Lists
@@ -50,6 +52,15 @@ object ValidationResult {
def succeeded: ValidationResult = Succeeded
def failed(reason: String): ValidationResult = Failed(reason)
+
+ implicit class EncodeFallbackTagImplicits(result: ValidationResult) {
+ def tagOnFallback(plan: TreeNode[_]): Unit = {
+ if (result.ok()) {
+ return
+ }
+ add(plan, result)
+ }
+ }
}
/** Every Gluten Operator should extend this trait. */
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala
index ff989d796f..73dc617e08 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala
@@ -16,9 +16,7 @@
*/
package org.apache.gluten.extension.columnar
-import org.apache.gluten.GlutenConfig
-import
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
-import
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
+import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.rules.Rule
@@ -33,24 +31,14 @@ import org.apache.spark.sql.execution.{SortExec, SparkPlan}
* SortAggregate with the same key. So, this rule adds local sort back if
necessary.
*/
object EnsureLocalSortRequirements extends Rule[SparkPlan] {
- private lazy val offload = TransformPreOverrides.apply()
+ private lazy val transform = HeuristicTransform()
private def addLocalSort(
originalChild: SparkPlan,
requiredOrdering: Seq[SortOrder]): SparkPlan = {
+ // FIXME: HeuristicTransform is costly. Re-applying it may cause
performance issues.
val newChild = SortExec(requiredOrdering, global = false, child =
originalChild)
- if (!GlutenConfig.getConf.enableColumnarSort) {
- FallbackTags.add(newChild, "columnar Sort is not enabled in SortExec")
- newChild
- } else {
- val rewrittenPlan = RewriteSparkPlanRulesManager.apply().apply(newChild)
- if (rewrittenPlan.eq(newChild) && FallbackTags.nonEmpty(rewrittenPlan)) {
- // The sort can not be offloaded
- rewrittenPlan
- } else {
- offload.apply(rewrittenPlan)
- }
- }
+ transform.apply(newChild)
}
override def apply(plan: SparkPlan): SparkPlan = {
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
index 1268d36d32..432ecd1584 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
@@ -17,147 +17,13 @@
package org.apache.gluten.extension.columnar
import org.apache.gluten.GlutenConfig
-import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.execution._
-import org.apache.gluten.extension.{GlutenPlan, ValidationResult}
-import
org.apache.gluten.extension.columnar.FallbackTags.EncodeFallbackTagImplicits
-import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
-import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.spark.api.python.EvalPythonExecTransformer
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
QueryStageExec}
-import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec, SortAggregateExec}
-import org.apache.spark.sql.execution.datasources.WriteFilesExec
-import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec,
BatchEvalPythonExec}
-import org.apache.spark.sql.execution.window.{WindowExec,
WindowGroupLimitExecShim}
-import org.apache.spark.sql.hive.HiveTableScanExecTransformer
-
-import org.apache.commons.lang3.exception.ExceptionUtils
-
-sealed trait FallbackTag {
- val stacktrace: Option[String] =
- if (FallbackTags.DEBUG) {
- Some(ExceptionUtils.getStackTrace(new Throwable()))
- } else None
-
- def reason(): String
-}
-
-object FallbackTag {
-
- /** A tag that stores one reason text of fall back. */
- case class Appendable(override val reason: String) extends FallbackTag
-
- /**
- * A tag that stores reason text of fall back. Other reasons will be
discarded when this tag is
- * added to plan.
- */
- case class Exclusive(override val reason: String) extends FallbackTag
-
- trait Converter[T] {
- def from(obj: T): Option[FallbackTag]
- }
-
- object Converter {
- implicit def asIs[T <: FallbackTag]: Converter[T] = (tag: T) => Some(tag)
-
- implicit object FromString extends Converter[String] {
- override def from(reason: String): Option[FallbackTag] =
Some(Appendable(reason))
- }
-
- implicit object FromValidationResult extends Converter[ValidationResult] {
- override def from(result: ValidationResult): Option[FallbackTag] = {
- if (result.ok()) {
- return None
- }
- Some(Appendable(result.reason()))
- }
- }
- }
-}
-
-object FallbackTags {
- val TAG: TreeNodeTag[FallbackTag] =
- TreeNodeTag[FallbackTag]("org.apache.gluten.FallbackTag")
-
- val DEBUG = false
-
- /**
- * If true, the plan node will be guaranteed fallback to Vanilla plan node
while being
- * implemented.
- *
- * If false, the plan still has chance to be turned into "non-transformable"
in any another
- * validation rule. So user should not consider the plan "transformable"
unless all validation
- * rules are passed.
- */
- def nonEmpty(plan: SparkPlan): Boolean = {
- getOption(plan).nonEmpty
- }
-
- /**
- * If true, it implies the plan maybe transformable during validation phase
but not guaranteed,
- * since another validation rule could turn it to "non-transformable" before
implementing the plan
- * within Gluten transformers. If false, the plan node will be guaranteed
fallback to Vanilla plan
- * node while being implemented.
- */
- def maybeOffloadable(plan: SparkPlan): Boolean = !nonEmpty(plan)
-
- def add[T](plan: TreeNode[_], t: T)(implicit converter:
FallbackTag.Converter[T]): Unit = {
- val tagOption = getOption(plan)
- val newTagOption = converter.from(t)
-
- val mergedTagOption: Option[FallbackTag] =
- (tagOption ++ newTagOption).reduceOption[FallbackTag] {
- // New tag comes while the plan was already tagged, merge.
- case (_, exclusive: FallbackTag.Exclusive) =>
- exclusive
- case (exclusive: FallbackTag.Exclusive, _) =>
- exclusive
- case (l: FallbackTag.Appendable, r: FallbackTag.Appendable) =>
- FallbackTag.Appendable(s"${l.reason}; ${r.reason}")
- }
- mergedTagOption
- .foreach(mergedTag => plan.setTagValue(TAG, mergedTag))
- }
-
- def addRecursively[T](plan: TreeNode[_], t: T)(implicit
- converter: FallbackTag.Converter[T]): Unit = {
- plan.foreach {
- case _: GlutenPlan => // ignore
- case other: TreeNode[_] => add(other, t)
- }
- }
-
- def untag(plan: TreeNode[_]): Unit = {
- plan.unsetTagValue(TAG)
- }
-
- def get(plan: TreeNode[_]): FallbackTag = {
- getOption(plan).getOrElse(
- throw new IllegalStateException("Transform hint tag not set in plan: " +
plan.toString()))
- }
-
- def getOption(plan: TreeNode[_]): Option[FallbackTag] = {
- plan.getTagValue(TAG)
- }
-
- implicit class EncodeFallbackTagImplicits(result: ValidationResult) {
- def tagOnFallback(plan: TreeNode[_]): Unit = {
- if (result.ok()) {
- return
- }
- add(plan, result)
- }
- }
-}
case class FallbackOnANSIMode(session: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
@@ -237,271 +103,3 @@ case class FallbackMultiCodegens(session: SparkSession)
extends Rule[SparkPlan]
} else plan
}
}
-
-// This rule will try to convert a plan into plan transformer.
-// The doValidate function will be called to check if the conversion is
supported.
-// If false is returned or any unsupported exception is thrown, a row guard
will
-// be added on the top of that plan to prevent actual conversion.
-case class AddFallbackTagRule() extends Rule[SparkPlan] {
- import AddFallbackTagRule._
- private val glutenConf: GlutenConfig = GlutenConfig.getConf
- private val validator = Validators
- .builder()
- .fallbackByHint()
- .fallbackIfScanOnlyWithFilterPushed(glutenConf.enableScanOnly)
- .fallbackComplexExpressions()
- .fallbackByBackendSettings()
- .fallbackByUserOptions()
- .fallbackByTestInjects()
- .build()
-
- def apply(plan: SparkPlan): SparkPlan = {
- plan.foreachUp { case p => addFallbackTag(p) }
- plan
- }
-
- private def addFallbackTag(plan: SparkPlan): Unit = {
- val outcome = validator.validate(plan)
- outcome match {
- case Validator.Failed(reason) =>
- FallbackTags.add(plan, reason)
- return
- case Validator.Passed =>
- }
-
- try {
- plan match {
- case plan: BatchScanExec =>
- val transformer =
- ScanTransformerFactory
- .createBatchScanTransformer(plan, validation = true)
- .asInstanceOf[BasicScanExecTransformer]
- transformer.doValidate().tagOnFallback(plan)
- case plan: FileSourceScanExec =>
- val transformer =
- ScanTransformerFactory.createFileSourceScanTransformer(plan)
- transformer.doValidate().tagOnFallback(plan)
- case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
- HiveTableScanExecTransformer.validate(plan).tagOnFallback(plan)
- case plan: ProjectExec =>
- val transformer = ProjectExecTransformer(plan.projectList,
plan.child)
- transformer.doValidate().tagOnFallback(plan)
- case plan: FilterExec =>
- val transformer = BackendsApiManager.getSparkPlanExecApiInstance
- .genFilterExecTransformer(plan.condition, plan.child)
- transformer.doValidate().tagOnFallback(plan)
- case plan: HashAggregateExec =>
- val transformer = HashAggregateExecBaseTransformer.from(plan)
- transformer.doValidate().tagOnFallback(plan)
- case plan: SortAggregateExec =>
- val transformer = HashAggregateExecBaseTransformer.from(plan)
- transformer.doValidate().tagOnFallback(plan)
- case plan: ObjectHashAggregateExec =>
- val transformer = HashAggregateExecBaseTransformer.from(plan)
- transformer.doValidate().tagOnFallback(plan)
- case plan: UnionExec =>
- val transformer = ColumnarUnionExec(plan.children)
- transformer.doValidate().tagOnFallback(plan)
- case plan: ExpandExec =>
- val transformer = ExpandExecTransformer(plan.projections,
plan.output, plan.child)
- transformer.doValidate().tagOnFallback(plan)
- case plan: WriteFilesExec =>
- val transformer = WriteFilesExecTransformer(
- plan.child,
- plan.fileFormat,
- plan.partitionColumns,
- plan.bucketSpec,
- plan.options,
- plan.staticPartitions)
- transformer.doValidate().tagOnFallback(plan)
- case plan: SortExec =>
- val transformer =
- SortExecTransformer(plan.sortOrder, plan.global, plan.child,
plan.testSpillFrequency)
- transformer.doValidate().tagOnFallback(plan)
- case plan: ShuffleExchangeExec =>
- val transformer = ColumnarShuffleExchangeExec(plan, plan.child,
plan.child.output)
- transformer.doValidate().tagOnFallback(plan)
- case plan: ShuffledHashJoinExec =>
- val transformer = BackendsApiManager.getSparkPlanExecApiInstance
- .genShuffledHashJoinExecTransformer(
- plan.leftKeys,
- plan.rightKeys,
- plan.joinType,
- OffloadJoin.getShjBuildSide(plan),
- plan.condition,
- plan.left,
- plan.right,
- plan.isSkewJoin)
- transformer.doValidate().tagOnFallback(plan)
- case plan: BroadcastExchangeExec =>
- val transformer = ColumnarBroadcastExchangeExec(plan.mode,
plan.child)
- transformer.doValidate().tagOnFallback(plan)
- case bhj: BroadcastHashJoinExec =>
- val transformer = BackendsApiManager.getSparkPlanExecApiInstance
- .genBroadcastHashJoinExecTransformer(
- bhj.leftKeys,
- bhj.rightKeys,
- bhj.joinType,
- bhj.buildSide,
- bhj.condition,
- bhj.left,
- bhj.right,
- isNullAwareAntiJoin = bhj.isNullAwareAntiJoin)
- transformer.doValidate().tagOnFallback(plan)
- case plan: SortMergeJoinExec =>
- val transformer = BackendsApiManager.getSparkPlanExecApiInstance
- .genSortMergeJoinExecTransformer(
- plan.leftKeys,
- plan.rightKeys,
- plan.joinType,
- plan.condition,
- plan.left,
- plan.right,
- plan.isSkewJoin)
- transformer.doValidate().tagOnFallback(plan)
- case plan: CartesianProductExec =>
- val transformer = BackendsApiManager.getSparkPlanExecApiInstance
- .genCartesianProductExecTransformer(plan.left, plan.right,
plan.condition)
- transformer.doValidate().tagOnFallback(plan)
- case plan: BroadcastNestedLoopJoinExec =>
- val transformer = BackendsApiManager.getSparkPlanExecApiInstance
- .genBroadcastNestedLoopJoinExecTransformer(
- plan.left,
- plan.right,
- plan.buildSide,
- plan.joinType,
- plan.condition)
- transformer.doValidate().tagOnFallback(plan)
- case plan: WindowExec =>
- val transformer = WindowExecTransformer(
- plan.windowExpression,
- plan.partitionSpec,
- plan.orderSpec,
- plan.child)
- transformer.doValidate().tagOnFallback(plan)
- case plan if
SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) =>
- val windowGroupLimitPlan = SparkShimLoader.getSparkShims
- .getWindowGroupLimitExecShim(plan)
- .asInstanceOf[WindowGroupLimitExecShim]
- val transformer = WindowGroupLimitExecTransformer(
- windowGroupLimitPlan.partitionSpec,
- windowGroupLimitPlan.orderSpec,
- windowGroupLimitPlan.rankLikeFunction,
- windowGroupLimitPlan.limit,
- windowGroupLimitPlan.mode,
- windowGroupLimitPlan.child
- )
- transformer.doValidate().tagOnFallback(plan)
- case plan: CoalesceExec =>
- ColumnarCoalesceExec(plan.numPartitions, plan.child)
- .doValidate()
- .tagOnFallback(plan)
- case plan: GlobalLimitExec =>
- val (limit, offset) =
-
SparkShimLoader.getSparkShims.getLimitAndOffsetFromGlobalLimit(plan)
- val transformer = LimitExecTransformer(plan.child, offset, limit)
- transformer.doValidate().tagOnFallback(plan)
- case plan: LocalLimitExec =>
- val transformer = LimitExecTransformer(plan.child, 0L, plan.limit)
- transformer.doValidate().tagOnFallback(plan)
- case plan: GenerateExec =>
- val transformer =
BackendsApiManager.getSparkPlanExecApiInstance.genGenerateTransformer(
- plan.generator,
- plan.requiredChildOutput,
- plan.outer,
- plan.generatorOutput,
- plan.child)
- transformer.doValidate().tagOnFallback(plan)
- case plan: BatchEvalPythonExec =>
- val transformer = EvalPythonExecTransformer(plan.udfs,
plan.resultAttrs, plan.child)
- transformer.doValidate().tagOnFallback(plan)
- case plan: ArrowEvalPythonExec =>
- // When backend doesn't support ColumnarArrow or colunmnar arrow
configuration not
- // enabled, we will try offloading through EvalPythonExecTransformer
- if (
- !BackendsApiManager.getSettings.supportColumnarArrowUdf() ||
- !GlutenConfig.getConf.enableColumnarArrowUDF
- ) {
- // Both CH and Velox will try using backend's built-in functions
for calculate
- val transformer = EvalPythonExecTransformer(plan.udfs,
plan.resultAttrs, plan.child)
- transformer.doValidate().tagOnFallback(plan)
- }
- case plan: TakeOrderedAndProjectExec =>
- val (limit, offset) =
- SparkShimLoader.getSparkShims.getLimitAndOffsetFromTopK(plan)
- val transformer = TakeOrderedAndProjectExecTransformer(
- limit,
- plan.sortOrder,
- plan.projectList,
- plan.child,
- offset)
- transformer.doValidate().tagOnFallback(plan)
- case plan: SampleExec =>
- val transformer =
-
BackendsApiManager.getSparkPlanExecApiInstance.genSampleExecTransformer(
- plan.lowerBound,
- plan.upperBound,
- plan.withReplacement,
- plan.seed,
- plan.child)
- transformer.doValidate().tagOnFallback(plan)
- case _ =>
- // Currently we assume a plan to be offload-able by default.
- }
- } catch {
- case e @ (_: GlutenNotSupportException | _:
UnsupportedOperationException) =>
- FallbackTags.add(
- plan,
- s"${e.getMessage}, original Spark plan is " +
- s"${plan.getClass}(${plan.children.toList.map(_.getClass)})")
- if (!e.isInstanceOf[GlutenNotSupportException]) {
- logDebug("Just a warning. This exception perhaps needs to be
fixed.", e)
- }
- }
- }
-}
-
-object AddFallbackTagRule {
- implicit private class ValidatorBuilderImplicits(builder:
Validators.Builder) {
-
- /**
- * Fails validation on non-scan plan nodes if Gluten is running as
scan-only mode. Also, passes
- * validation on filter for the exception that filter + scan is detected.
Because filters can be
- * pushed into scan then the filter conditions will be processed only in
scan.
- */
- def fallbackIfScanOnlyWithFilterPushed(scanOnly: Boolean):
Validators.Builder = {
- builder.add(new FallbackIfScanOnlyWithFilterPushed(scanOnly))
- builder
- }
- }
-
- private class FallbackIfScanOnlyWithFilterPushed(scanOnly: Boolean) extends
Validator {
- override def validate(plan: SparkPlan): Validator.OutCome = {
- if (!scanOnly) {
- return pass()
- }
- // Scan-only mode
- plan match {
- case _: BatchScanExec => pass()
- case _: FileSourceScanExec => pass()
- case p if HiveTableScanExecTransformer.isHiveTableScan(p) => pass()
- case filter: FilterExec =>
- val childIsScan = filter.child.isInstanceOf[FileSourceScanExec] ||
- filter.child.isInstanceOf[BatchScanExec]
- if (childIsScan) {
- pass()
- } else {
- fail(filter)
- }
- case other => fail(other)
- }
- }
- }
-}
-
-case class RemoveFallbackTagRule() extends Rule[SparkPlan] {
- override def apply(plan: SparkPlan): SparkPlan = {
- plan.foreach(FallbackTags.untag)
- plan
- }
-}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackTag.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackTag.scala
new file mode 100644
index 0000000000..c2d844c19b
--- /dev/null
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackTag.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar
+
+import org.apache.gluten.extension.{GlutenPlan, ValidationResult}
+
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag}
+import org.apache.spark.sql.execution.SparkPlan
+
+import org.apache.commons.lang3.exception.ExceptionUtils
+
+sealed trait FallbackTag {
+ val stacktrace: Option[String] =
+ if (FallbackTags.DEBUG) {
+ Some(ExceptionUtils.getStackTrace(new Throwable()))
+ } else None
+
+ def reason(): String
+}
+
+object FallbackTag {
+
+ /** A tag that stores one reason text of fall back. */
+ case class Appendable(override val reason: String) extends FallbackTag
+
+ /**
+ * A tag that stores reason text of fall back. Other reasons will be
discarded when this tag is
+ * added to plan.
+ */
+ case class Exclusive(override val reason: String) extends FallbackTag
+
+ trait Converter[T] {
+ def from(obj: T): Option[FallbackTag]
+ }
+
+ object Converter {
+ implicit def asIs[T <: FallbackTag]: Converter[T] = (tag: T) => Some(tag)
+
+ implicit object FromString extends Converter[String] {
+ override def from(reason: String): Option[FallbackTag] =
Some(Appendable(reason))
+ }
+
+ implicit object FromValidationResult extends Converter[ValidationResult] {
+ override def from(result: ValidationResult): Option[FallbackTag] = {
+ if (result.ok()) {
+ return None
+ }
+ Some(Appendable(result.reason()))
+ }
+ }
+ }
+}
+
+object FallbackTags {
+ val TAG: TreeNodeTag[FallbackTag] =
+ TreeNodeTag[FallbackTag]("org.apache.gluten.FallbackTag")
+
+ val DEBUG = false
+
+ /**
+ * If true, the plan node will be guaranteed fallback to Vanilla plan node
while being
+ * implemented.
+ *
+ * If false, the plan still has chance to be turned into "non-transformable"
in any another
+ * validation rule. So user should not consider the plan "transformable"
unless all validation
+ * rules are passed.
+ */
+ def nonEmpty(plan: SparkPlan): Boolean = {
+ getOption(plan).nonEmpty
+ }
+
+ /**
+ * If true, it implies the plan maybe transformable during validation phase
but not guaranteed,
+ * since another validation rule could turn it to "non-transformable" before
implementing the plan
+ * within Gluten transformers. If false, the plan node will be guaranteed
fallback to Vanilla plan
+ * node while being implemented.
+ */
+ def maybeOffloadable(plan: SparkPlan): Boolean = !nonEmpty(plan)
+
+ def add[T](plan: TreeNode[_], t: T)(implicit converter:
FallbackTag.Converter[T]): Unit = {
+ val tagOption = getOption(plan)
+ val newTagOption = converter.from(t)
+
+ val mergedTagOption: Option[FallbackTag] =
+ (tagOption ++ newTagOption).reduceOption[FallbackTag] {
+ // New tag comes while the plan was already tagged, merge.
+ case (_, exclusive: FallbackTag.Exclusive) =>
+ exclusive
+ case (exclusive: FallbackTag.Exclusive, _) =>
+ exclusive
+ case (l: FallbackTag.Appendable, r: FallbackTag.Appendable) =>
+ FallbackTag.Appendable(s"${l.reason}; ${r.reason}")
+ }
+ mergedTagOption
+ .foreach(mergedTag => plan.setTagValue(TAG, mergedTag))
+ }
+
+ def addRecursively[T](plan: TreeNode[_], t: T)(implicit
+ converter: FallbackTag.Converter[T]): Unit = {
+ plan.foreach {
+ case _: GlutenPlan => // ignore
+ case other: TreeNode[_] => add(other, t)
+ }
+ }
+
+ def untag(plan: TreeNode[_]): Unit = {
+ plan.unsetTagValue(TAG)
+ }
+
+ def get(plan: TreeNode[_]): FallbackTag = {
+ getOption(plan).getOrElse(
+ throw new IllegalStateException("Transform hint tag not set in plan: " +
plan.toString()))
+ }
+
+ def getOption(plan: TreeNode[_]): Option[FallbackTag] = {
+ plan.getTagValue(TAG)
+ }
+}
+
+case class RemoveFallbackTagRule() extends Rule[SparkPlan] {
+ override def apply(plan: SparkPlan): SparkPlan = {
+ plan.foreach(FallbackTags.untag)
+ plan
+ }
+}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
index 11b4b86508..a199b5920c 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
@@ -17,7 +17,6 @@
package org.apache.gluten.extension.columnar
import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike,
Transitions}
-import org.apache.gluten.logging.LogLevelUtil
import org.apache.gluten.utils.PlanUtil
import org.apache.spark.sql.SparkSession
@@ -32,34 +31,6 @@ import
org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
import org.apache.spark.sql.internal.SQLConf
object MiscColumnarRules {
- object TransformPreOverrides {
- def apply(): TransformPreOverrides = {
- TransformPreOverrides(
- List(),
- List(
- OffloadOthers(),
- OffloadExchange(),
- OffloadJoin()
- )
- )
- }
- }
-
- // This rule will conduct the conversion from Spark plan to the plan
transformer.
- case class TransformPreOverrides(
- topDownRules: Seq[OffloadSingleNode],
- bottomUpRules: Seq[OffloadSingleNode])
- extends Rule[SparkPlan]
- with LogLevelUtil {
-
- def apply(plan: SparkPlan): SparkPlan = {
- val plan0 =
- topDownRules.foldLeft(plan)((p, rule) => p.transformDown { case p =>
rule.offload(p) })
- val plan1 =
- bottomUpRules.foldLeft(plan0)((p, rule) => p.transformUp { case p =>
rule.offload(p) })
- plan1
- }
- }
// Replaces all SubqueryBroadcastExec used by sub-queries with
ColumnarSubqueryBroadcastExec.
// This prevents query execution from being failed by fallen-back
SubqueryBroadcastExec with
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index e6d1c4859e..5abc64ec37 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.extension.columnar.enumerated
-import org.apache.gluten.extension.columnar.{OffloadExchange, OffloadJoin,
OffloadOthers}
+import org.apache.gluten.extension.columnar.heuristic.{OffloadExchange,
OffloadJoin, OffloadOthers}
import org.apache.gluten.extension.columnar.transition.ConventionReq
import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
import org.apache.gluten.logging.LogLevelUtil
@@ -38,6 +38,18 @@ import org.apache.spark.sql.execution.python.EvalPythonExec
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
+/**
+ * Rule to offload Spark query plan to Gluten query plan using a search
algorithm and a defined cost
+ * model.
+ *
+ * The effect of this rule is similar to
+ * [[org.apache.gluten.extension.columnar.heuristic.HeuristicTransform]],
except that the 3 stages
+ * in the heuristic version, known as rewrite, validate, offload, will take
place together
+ * individually for each Spark query plan node in RAS rule
+ * [[org.apache.gluten.extension.columnar.enumerated.RasOffload]].
+ *
+ * The feature requires enabling RAS to function.
+ */
case class EnumeratedTransform(session: SparkSession, outputsColumnar: Boolean)
extends Rule[SparkPlan]
with LogLevelUtil {
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
index 52798f712c..1e97ebfbab 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.extension.columnar.enumerated
import org.apache.gluten.extension.GlutenPlan
-import org.apache.gluten.extension.columnar.OffloadSingleNode
+import org.apache.gluten.extension.columnar.heuristic.OffloadSingleNode
import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
import org.apache.gluten.extension.columnar.validator.Validator
import org.apache.gluten.ras.path.Pattern
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/AddFallbackTags.scala
similarity index 63%
copy from
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
copy to
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/AddFallbackTags.scala
index 1268d36d32..538132c81a 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/AddFallbackTags.scala
@@ -14,236 +14,34 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.heuristic
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.execution._
-import org.apache.gluten.extension.{GlutenPlan, ValidationResult}
-import
org.apache.gluten.extension.columnar.FallbackTags.EncodeFallbackTagImplicits
+import org.apache.gluten.execution.{BasicScanExecTransformer,
ColumnarCoalesceExec, ColumnarUnionExec, ExpandExecTransformer,
HashAggregateExecBaseTransformer, LimitExecTransformer, ProjectExecTransformer,
ScanTransformerFactory, SortExecTransformer,
TakeOrderedAndProjectExecTransformer, WindowExecTransformer,
WindowGroupLimitExecTransformer, WriteFilesExecTransformer}
+import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.api.python.EvalPythonExecTransformer
-import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag}
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
QueryStageExec}
+import org.apache.spark.sql.execution.{CoalesceExec,
ColumnarBroadcastExchangeExec, ColumnarShuffleExchangeExec, ExpandExec,
FileSourceScanExec, FilterExec, GenerateExec, GlobalLimitExec, LocalLimitExec,
ProjectExec, SampleExec, SortExec, SparkPlan, TakeOrderedAndProjectExec,
UnionExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.datasources.WriteFilesExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
-import org.apache.spark.sql.execution.exchange._
-import org.apache.spark.sql.execution.joins._
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
BroadcastNestedLoopJoinExec, CartesianProductExec, ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec,
BatchEvalPythonExec}
import org.apache.spark.sql.execution.window.{WindowExec,
WindowGroupLimitExecShim}
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
-import org.apache.commons.lang3.exception.ExceptionUtils
-
-sealed trait FallbackTag {
- val stacktrace: Option[String] =
- if (FallbackTags.DEBUG) {
- Some(ExceptionUtils.getStackTrace(new Throwable()))
- } else None
-
- def reason(): String
-}
-
-object FallbackTag {
-
- /** A tag that stores one reason text of fall back. */
- case class Appendable(override val reason: String) extends FallbackTag
-
- /**
- * A tag that stores reason text of fall back. Other reasons will be
discarded when this tag is
- * added to plan.
- */
- case class Exclusive(override val reason: String) extends FallbackTag
-
- trait Converter[T] {
- def from(obj: T): Option[FallbackTag]
- }
-
- object Converter {
- implicit def asIs[T <: FallbackTag]: Converter[T] = (tag: T) => Some(tag)
-
- implicit object FromString extends Converter[String] {
- override def from(reason: String): Option[FallbackTag] =
Some(Appendable(reason))
- }
-
- implicit object FromValidationResult extends Converter[ValidationResult] {
- override def from(result: ValidationResult): Option[FallbackTag] = {
- if (result.ok()) {
- return None
- }
- Some(Appendable(result.reason()))
- }
- }
- }
-}
-
-object FallbackTags {
- val TAG: TreeNodeTag[FallbackTag] =
- TreeNodeTag[FallbackTag]("org.apache.gluten.FallbackTag")
-
- val DEBUG = false
-
- /**
- * If true, the plan node will be guaranteed fallback to Vanilla plan node
while being
- * implemented.
- *
- * If false, the plan still has chance to be turned into "non-transformable"
in any another
- * validation rule. So user should not consider the plan "transformable"
unless all validation
- * rules are passed.
- */
- def nonEmpty(plan: SparkPlan): Boolean = {
- getOption(plan).nonEmpty
- }
-
- /**
- * If true, it implies the plan maybe transformable during validation phase
but not guaranteed,
- * since another validation rule could turn it to "non-transformable" before
implementing the plan
- * within Gluten transformers. If false, the plan node will be guaranteed
fallback to Vanilla plan
- * node while being implemented.
- */
- def maybeOffloadable(plan: SparkPlan): Boolean = !nonEmpty(plan)
-
- def add[T](plan: TreeNode[_], t: T)(implicit converter:
FallbackTag.Converter[T]): Unit = {
- val tagOption = getOption(plan)
- val newTagOption = converter.from(t)
-
- val mergedTagOption: Option[FallbackTag] =
- (tagOption ++ newTagOption).reduceOption[FallbackTag] {
- // New tag comes while the plan was already tagged, merge.
- case (_, exclusive: FallbackTag.Exclusive) =>
- exclusive
- case (exclusive: FallbackTag.Exclusive, _) =>
- exclusive
- case (l: FallbackTag.Appendable, r: FallbackTag.Appendable) =>
- FallbackTag.Appendable(s"${l.reason}; ${r.reason}")
- }
- mergedTagOption
- .foreach(mergedTag => plan.setTagValue(TAG, mergedTag))
- }
-
- def addRecursively[T](plan: TreeNode[_], t: T)(implicit
- converter: FallbackTag.Converter[T]): Unit = {
- plan.foreach {
- case _: GlutenPlan => // ignore
- case other: TreeNode[_] => add(other, t)
- }
- }
-
- def untag(plan: TreeNode[_]): Unit = {
- plan.unsetTagValue(TAG)
- }
-
- def get(plan: TreeNode[_]): FallbackTag = {
- getOption(plan).getOrElse(
- throw new IllegalStateException("Transform hint tag not set in plan: " +
plan.toString()))
- }
-
- def getOption(plan: TreeNode[_]): Option[FallbackTag] = {
- plan.getTagValue(TAG)
- }
-
- implicit class EncodeFallbackTagImplicits(result: ValidationResult) {
- def tagOnFallback(plan: TreeNode[_]): Unit = {
- if (result.ok()) {
- return
- }
- add(plan, result)
- }
- }
-}
-
-case class FallbackOnANSIMode(session: SparkSession) extends Rule[SparkPlan] {
- override def apply(plan: SparkPlan): SparkPlan = {
- if (GlutenConfig.getConf.enableAnsiMode) {
- plan.foreach(FallbackTags.add(_, "does not support ansi mode"))
- }
- plan
- }
-}
-
-case class FallbackMultiCodegens(session: SparkSession) extends
Rule[SparkPlan] {
- lazy val glutenConf: GlutenConfig = GlutenConfig.getConf
- lazy val physicalJoinOptimize = glutenConf.enablePhysicalJoinOptimize
- lazy val optimizeLevel: Integer = glutenConf.physicalJoinOptimizationThrottle
-
- def existsMultiCodegens(plan: SparkPlan, count: Int = 0): Boolean =
- plan match {
- case plan: CodegenSupport if plan.supportCodegen =>
- if ((count + 1) >= optimizeLevel) return true
- plan.children.exists(existsMultiCodegens(_, count + 1))
- case plan: ShuffledHashJoinExec =>
- if ((count + 1) >= optimizeLevel) return true
- plan.children.exists(existsMultiCodegens(_, count + 1))
- case plan: SortMergeJoinExec if
GlutenConfig.getConf.forceShuffledHashJoin =>
- if ((count + 1) >= optimizeLevel) return true
- plan.children.exists(existsMultiCodegens(_, count + 1))
- case _ => false
- }
-
- def addFallbackTag(plan: SparkPlan): SparkPlan = {
- FallbackTags.add(plan, "fallback multi codegens")
- plan
- }
-
- def supportCodegen(plan: SparkPlan): Boolean = plan match {
- case plan: CodegenSupport =>
- plan.supportCodegen
- case _ => false
- }
-
- def isAQEShuffleReadExec(plan: SparkPlan): Boolean = {
- plan match {
- case _: AQEShuffleReadExec => true
- case _ => false
- }
- }
-
- def addFallbackTagRecursive(plan: SparkPlan): SparkPlan = {
- plan match {
- case p: ShuffleExchangeExec =>
-
addFallbackTag(p.withNewChildren(p.children.map(tagOnFallbackForMultiCodegens)))
- case p: BroadcastExchangeExec =>
-
addFallbackTag(p.withNewChildren(p.children.map(tagOnFallbackForMultiCodegens)))
- case p: ShuffledHashJoinExec =>
-
addFallbackTag(p.withNewChildren(p.children.map(addFallbackTagRecursive)))
- case p if !supportCodegen(p) =>
- p.withNewChildren(p.children.map(tagOnFallbackForMultiCodegens))
- case p if isAQEShuffleReadExec(p) =>
- p.withNewChildren(p.children.map(tagOnFallbackForMultiCodegens))
- case p: QueryStageExec => p
- case p =>
addFallbackTag(p.withNewChildren(p.children.map(addFallbackTagRecursive)))
- }
- }
-
- def tagOnFallbackForMultiCodegens(plan: SparkPlan): SparkPlan = {
- plan match {
- case plan if existsMultiCodegens(plan) =>
- addFallbackTagRecursive(plan)
- case other =>
-
other.withNewChildren(other.children.map(tagOnFallbackForMultiCodegens))
- }
- }
-
- override def apply(plan: SparkPlan): SparkPlan = {
- if (physicalJoinOptimize) {
- tagOnFallbackForMultiCodegens(plan)
- } else plan
- }
-}
-
// This rule will try to convert a plan into plan transformer.
// The doValidate function will be called to check if the conversion is
supported.
// If false is returned or any unsupported exception is thrown, a row guard
will
// be added on the top of that plan to prevent actual conversion.
-case class AddFallbackTagRule() extends Rule[SparkPlan] {
- import AddFallbackTagRule._
+case class AddFallbackTags() extends Rule[SparkPlan] {
+ import AddFallbackTags._
private val glutenConf: GlutenConfig = GlutenConfig.getConf
private val validator = Validators
.builder()
@@ -461,7 +259,7 @@ case class AddFallbackTagRule() extends Rule[SparkPlan] {
}
}
-object AddFallbackTagRule {
+object AddFallbackTags {
implicit private class ValidatorBuilderImplicits(builder:
Validators.Builder) {
/**
@@ -498,10 +296,3 @@ object AddFallbackTagRule {
}
}
}
-
-case class RemoveFallbackTagRule() extends Rule[SparkPlan] {
- override def apply(plan: SparkPlan): SparkPlan = {
- plan.foreach(FallbackTags.untag)
- plan
- }
-}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicTransform.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicTransform.scala
new file mode 100644
index 0000000000..8a026326e5
--- /dev/null
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicTransform.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.heuristic
+
+import
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
+import org.apache.gluten.logging.LogLevelUtil
+
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * Rule to offload Spark query plan to Gluten query plan using programed
heuristics.
+ *
+ * The procedure consists of 3 stages:
+ *
+ * 1. Rewrite 2. Validate 3. Offload
+ *
+ * In the rewrite stage, planner will try converting the Spark query plan to
various forms of
+ * possible alternative Spark query plans, then choose the optimal one to send
to next stage. During
+ * which, the same validation code that is about to be used in stage 2 might
be invoked early to
+ * predict on the estimate "cost" of an alternative Spark query plan.
+ *
+ * Once the plan is rewritten, query planner will call native validation code
in stage 2 to
+ * determine which part of the plan is offload-able or not, then add fallback
tags to the
+ * non-offload-able parts.
+ *
+ * In stage 3, query planner will convert the offload-able Spark plan nodes
into Gluten plan nodes.
+ */
+case class HeuristicTransform() extends Rule[SparkPlan] with LogLevelUtil {
+
+ private val validate = AddFallbackTags()
+ private val rewrite = RewriteSparkPlanRulesManager(validate)
+ private val offload = LegacyOffload()
+
+ override def apply(plan: SparkPlan): SparkPlan = {
+ Seq(rewrite, validate, offload).foldLeft(plan) {
+ case (plan, stage) =>
+ stage(plan)
+ }
+ }
+}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/LegacyOffload.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/LegacyOffload.scala
new file mode 100644
index 0000000000..9cf08ff303
--- /dev/null
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/LegacyOffload.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.heuristic
+
+import org.apache.gluten.logging.LogLevelUtil
+
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+class LegacyOffload(topDownRules: Seq[OffloadSingleNode], bottomUpRules:
Seq[OffloadSingleNode])
+ extends Rule[SparkPlan]
+ with LogLevelUtil {
+
+ def apply(plan: SparkPlan): SparkPlan = {
+ val plan0 =
+ topDownRules.foldLeft(plan)((p, rule) => p.transformDown { case p =>
rule.offload(p) })
+ val plan1 =
+ bottomUpRules.foldLeft(plan0)((p, rule) => p.transformUp { case p =>
rule.offload(p) })
+ plan1
+ }
+}
+
+object LegacyOffload {
+ def apply(): LegacyOffload = {
+ new LegacyOffload(
+ List(),
+ List(
+ OffloadOthers(),
+ OffloadExchange(),
+ OffloadJoin()
+ )
+ )
+ }
+}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/OffloadSingleNode.scala
similarity index 99%
rename from
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
rename to
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/OffloadSingleNode.scala
index 220a45f7b3..a8c200e9be 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/heuristic/OffloadSingleNode.scala
@@ -14,13 +14,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.heuristic
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.extension.GlutenPlan
+import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.logging.LogLevelUtil
import org.apache.gluten.sql.shims.SparkShimLoader
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala
index d0cac0b29d..5fcf374397 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.extension.columnar.rewrite
import org.apache.gluten.GlutenConfig
-import org.apache.gluten.extension.columnar.OffloadJoin
+import org.apache.gluten.extension.columnar.heuristic.OffloadJoin
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans.logical.Join
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
index c6d5e1cca1..3b8ca0df30 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.extension.columnar.rewrite
-import org.apache.gluten.extension.columnar.{AddFallbackTagRule, FallbackTag,
FallbackTags}
+import org.apache.gluten.extension.columnar.{FallbackTag, FallbackTags}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.rdd.RDD
@@ -45,7 +45,9 @@ case class RewrittenNodeWall(originalChild: SparkPlan)
extends LeafExecNode {
*
* Note that, this rule does not touch and tag these operators who does not
need to rewrite.
*/
-class RewriteSparkPlanRulesManager private (rewriteRules:
Seq[RewriteSingleNode])
+class RewriteSparkPlanRulesManager private (
+ validateRule: Rule[SparkPlan],
+ rewriteRules: Seq[RewriteSingleNode])
extends Rule[SparkPlan] {
private def mayNeedRewrite(plan: SparkPlan): Boolean = {
@@ -96,7 +98,6 @@ class RewriteSparkPlanRulesManager private (rewriteRules:
Seq[RewriteSingleNode]
}
override def apply(plan: SparkPlan): SparkPlan = {
- val addHint = AddFallbackTagRule()
plan.transformUp {
case origin if mayNeedRewrite(origin) =>
// Add a wall to avoid transforming unnecessary nodes.
@@ -114,7 +115,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules:
Seq[RewriteSingleNode]
// We do not add tag and leave it to the outside
`AddFallbackTagRule`.
origin
} else {
- addHint.apply(rewrittenPlan)
+ validateRule.apply(rewrittenPlan)
val tag = getFallbackTagBack(rewrittenPlan)
if (tag.isDefined) {
// If the rewritten plan is still not transformable, return the
original plan.
@@ -136,7 +137,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules:
Seq[RewriteSingleNode]
}
object RewriteSparkPlanRulesManager {
- def apply(): Rule[SparkPlan] = {
- new RewriteSparkPlanRulesManager(RewriteSingleNode.allRules())
+ def apply(validateRule: Rule[SparkPlan]): Rule[SparkPlan] = {
+ new RewriteSparkPlanRulesManager(validateRule,
RewriteSingleNode.allRules())
}
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
index 9ec75aa209..11ac62ef58 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
@@ -18,9 +18,7 @@ package org.apache.spark.sql.execution.datasources
import org.apache.gluten.execution.{ProjectExecTransformer,
SortExecTransformer, TransformSupport, WholeStageTransformer}
import org.apache.gluten.execution.datasource.GlutenFormatWriterInjects
-import org.apache.gluten.extension.columnar.AddFallbackTagRule
-import
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
-import
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
+import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -42,12 +40,10 @@ trait GlutenFormatWriterInjectsBase extends
GlutenFormatWriterInjects {
return plan.execute()
}
- val rules = List(
- RewriteSparkPlanRulesManager(),
- AddFallbackTagRule(),
- TransformPreOverrides()
- )
- val transformed = rules.foldLeft(plan) { case (latestPlan, rule) =>
rule.apply(latestPlan) }
+ // FIXME: HeuristicTransform is costly. Re-applying it may cause
performance issues.
+ val transform = HeuristicTransform()
+ val transformed = transform(plan)
+
if (!transformed.isInstanceOf[TransformSupport]) {
throw new IllegalStateException(
"Cannot transform the SparkPlans wrapped by FileFormatWriter, " +
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]