This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 3eb30a5e16 [GLUTEN-8168] Add pre-projections for join condition (#8185)
3eb30a5e16 is described below
commit 3eb30a5e16d79bc84e18f2cffb691dd5662ce696
Author: lgbo <[email protected]>
AuthorDate: Thu Dec 12 14:13:40 2024 +0800
[GLUTEN-8168] Add pre-projections for join condition (#8185)
* refactor
* add new rule
* fixed
---
.../gluten/backendsapi/clickhouse/CHBackend.scala | 7 +
.../gluten/backendsapi/clickhouse/CHRuleApi.scala | 1 +
.../extension/AddPreProjectionForHashJoin.scala | 202 +++++++++++++++++++++
.../execution/GlutenClickHouseJoinSuite.scala | 36 ++++
.../GlutenClickHouseTPCHBucketSuite.scala | 2 +-
.../metrics/GlutenClickHouseTPCHMetricsSuite.scala | 1 +
.../GlutenClickHouseTPCHParquetBucketSuite.scala | 1 +
.../Parser/RelParsers/JoinRelParser.cpp | 127 +++++--------
8 files changed, 293 insertions(+), 84 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index 0e79e6ca93..1aaf801860 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -350,6 +350,13 @@ object CHBackendSettings extends BackendSettingsApi with
Logging {
)
}
+ def enablePreProjectionForJoinConditions(): Boolean = {
+ SparkEnv.get.conf.getBoolean(
+ CHConf.runtimeConfig("enable_pre_projection_for_join_conditions"),
+ defaultValue = true
+ )
+ }
+
// If the partition keys are high cardinality, the aggregation method is
slower.
def enableConvertWindowGroupLimitToAggregate(): Boolean = {
SparkEnv.get.conf.getBoolean(
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index fc5d1df918..40e5353618 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -111,6 +111,7 @@ object CHRuleApi {
c.session)))
injector.injectPostTransform(c =>
InsertTransitions.create(c.outputsColumnar, CHBatch))
injector.injectPostTransform(c => RemoveDuplicatedColumns.apply(c.session))
+ injector.injectPostTransform(c =>
AddPreProjectionForHashJoin.apply(c.session))
// Gluten columnar: Fallback policies.
injector.injectFallbackPolicy(
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/AddPreProjectionForHashJoin.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/AddPreProjectionForHashJoin.scala
new file mode 100644
index 0000000000..a0475fe39d
--- /dev/null
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/AddPreProjectionForHashJoin.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
+import org.apache.gluten.execution._
+import org.apache.gluten.utils._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.NamedExpression
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution._
+
+import scala.collection.mutable
+
+/*
+ * It will add a projection before hash join for
+ * 1. The join keys which are not attribute references.
+ * 2. If the join could rewrite into multiple join contiditions, also replace
all the join keys with
+ * attribute references.
+ *
+ * JoinUtils.createPreProjectionIfNeeded has add a pre-projection for join
keys, but it doesn't
+ * handle the post filter.
+ */
+case class AddPreProjectionForHashJoin(session: SparkSession)
+ extends Rule[SparkPlan]
+ with PullOutProjectHelper
+ with Logging {
+ override def apply(plan: SparkPlan): SparkPlan = {
+ if (!CHBackendSettings.enablePreProjectionForJoinConditions) {
+ return plan
+ }
+
+ plan.transformUp {
+ case hashJoin: CHShuffledHashJoinExecTransformer =>
+ val leftReplacedExpressions = new mutable.HashMap[Expression,
NamedExpression]
+ val rightReplacedExpressions = new mutable.HashMap[Expression,
NamedExpression]
+
+ val newLeftKeys = hashJoin.leftKeys.map {
+ case e => replaceExpressionWithAttribute(e, leftReplacedExpressions,
false, false)
+ }
+
+ val newRightKeys = hashJoin.rightKeys.map {
+ case e => replaceExpressionWithAttribute(e,
rightReplacedExpressions, false, false)
+ }
+
+ val newCondition = replaceExpressionInCondition(
+ hashJoin.condition,
+ hashJoin.left,
+ leftReplacedExpressions,
+ hashJoin.right,
+ rightReplacedExpressions)
+ val leftProjectExprs =
+ eliminateProjectList(hashJoin.left.outputSet,
leftReplacedExpressions.values.toSeq)
+ val rightProjectExprs =
+ eliminateProjectList(hashJoin.right.outputSet,
rightReplacedExpressions.values.toSeq)
+ val newHashJoin = hashJoin.copy(
+ leftKeys = newLeftKeys,
+ rightKeys = newRightKeys,
+ condition = newCondition,
+ left = if (leftReplacedExpressions.size > 0) {
+ ProjectExecTransformer(leftProjectExprs, hashJoin.left)
+ } else { hashJoin.left },
+ right = if (rightReplacedExpressions.size > 0) {
+ ProjectExecTransformer(rightProjectExprs, hashJoin.right)
+ } else { hashJoin.right }
+ )
+ if (leftReplacedExpressions.size > 0 || rightReplacedExpressions.size
> 0) {
+ ProjectExecTransformer(hashJoin.output, newHashJoin)
+ } else {
+ newHashJoin
+ }
+ }
+ }
+
+ private def replaceExpressionInCondition(
+ condition: Option[Expression],
+ leftPlan: SparkPlan,
+ leftReplacedExpressions: mutable.HashMap[Expression, NamedExpression],
+ rightPlan: SparkPlan,
+ rightReplacedExpressions: mutable.HashMap[Expression, NamedExpression])
+ : Option[Expression] = {
+ if (!condition.isDefined) {
+ return condition
+ }
+
+ def replaceExpression(
+ e: Expression,
+ exprMap: mutable.HashMap[Expression, NamedExpression]): Expression = {
+ e match {
+ case or @ Or(left, right) =>
+ Or(replaceExpression(left, exprMap), replaceExpression(right,
exprMap))
+ case and @ And(left, right) =>
+ And(replaceExpression(left, exprMap), replaceExpression(right,
exprMap))
+ case equalTo @ EqualTo(left, right) =>
+ EqualTo(replaceExpression(left, exprMap), replaceExpression(right,
exprMap))
+ case _ =>
+ exprMap.getOrElseUpdate(e.canonicalized, null) match {
+ case null => e
+ case ne: NamedExpression => ne.toAttribute
+ }
+ }
+ }
+
+ val leftExprs = new mutable.ArrayBuffer[Expression]
+ val rightExps = new mutable.ArrayBuffer[Expression]
+ if (isMultipleOrEqualsCondition(condition.get, leftPlan, leftExprs,
rightPlan, rightExps)) {
+ leftExprs.foreach {
+ e => replaceExpressionWithAttribute(e, leftReplacedExpressions, false,
false)
+ }
+ rightExps.foreach {
+ e => replaceExpressionWithAttribute(e, rightReplacedExpressions,
false, false)
+ }
+ Some(replaceExpression(condition.get, leftReplacedExpressions ++
rightReplacedExpressions))
+ } else {
+ condition
+ }
+ }
+
+ private def isMultipleOrEqualsCondition(
+ e: Expression,
+ leftPlan: SparkPlan,
+ leftExpressions: mutable.ArrayBuffer[Expression],
+ rightPlan: SparkPlan,
+ rightExpressions: mutable.ArrayBuffer[Expression]): Boolean = {
+ def splitIntoOrExpressions(e: Expression, result:
mutable.ArrayBuffer[Expression]): Unit = {
+ e match {
+ case Or(left, right) =>
+ splitIntoOrExpressions(left, result)
+ splitIntoOrExpressions(right, result)
+ case _ =>
+ result += e
+ }
+ }
+ def splitIntoAndExpressions(e: Expression, result:
mutable.ArrayBuffer[Expression]): Unit = {
+ e match {
+ case And(left, right) =>
+ splitIntoAndExpressions(left, result)
+ splitIntoAndExpressions(right, result)
+ case _ =>
+ result += e
+ }
+ }
+
+ val leftOutputSet = leftPlan.outputSet
+ val rightOutputSet = rightPlan.outputSet
+ val orExpressions = new mutable.ArrayBuffer[Expression]()
+ splitIntoOrExpressions(e, orExpressions)
+ orExpressions.foreach {
+ orExpression =>
+ val andExpressions = new mutable.ArrayBuffer[Expression]()
+ splitIntoAndExpressions(orExpression, andExpressions)
+ andExpressions.foreach {
+ e =>
+ if (!e.isInstanceOf[EqualTo]) {
+ return false;
+ }
+ val equalExpr = e.asInstanceOf[EqualTo]
+ val leftPos = if
(equalExpr.left.references.subsetOf(leftOutputSet)) {
+ leftExpressions += equalExpr.left
+ 0
+ } else if (equalExpr.left.references.subsetOf(rightOutputSet)) {
+ rightExpressions += equalExpr.left
+ 1
+ } else {
+ return false
+ }
+ val rightPos = if
(equalExpr.right.references.subsetOf(leftOutputSet)) {
+ leftExpressions += equalExpr.right
+ 0
+ } else if (equalExpr.right.references.subsetOf(rightOutputSet)) {
+ rightExpressions += equalExpr.right
+ 1
+ } else {
+ return false
+ }
+
+ // they should come from different side
+ if (leftPos == rightPos) {
+ return false
+ }
+ }
+ }
+ true
+ }
+}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala
index ce4ae09c66..1a276a26b2 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala
@@ -105,4 +105,40 @@ class GlutenClickHouseJoinSuite extends
GlutenClickHouseWholeStageTransformerSui
}
}
+ test("GLUTEN-8168 eliminate non-attribute expressions in join keys and
condition") {
+ sql("create table tj1 (a int, b int, c int, d int) using parquet")
+ sql("create table tj2 (a int, b int, c int, d int) using parquet")
+ sql("insert into tj1 values (1, 2, 3, 4), (2, 2, 4, 5), (3, 4, 5, 4), (4,
5, 3, 7)")
+ sql("insert into tj2 values (1, 2, 3, 4), (2, 2, 4, 5), (3, 4, 5, 4), (4,
5, 3, 7)")
+ compareResultsAgainstVanillaSpark(
+ """
+ |SELECT t1.*, t2.*
+ |FROM tj1 t1 LEFT JOIN tj2 t2
+ |ON t1.a = t2.a AND (t1.b = t2.b or t1.c + 1 = t2.c) ORDER BY t1.a,
t2.a
+ |""".stripMargin,
+ true,
+ { _ => }
+ )
+ compareResultsAgainstVanillaSpark(
+ """
+ |SELECT t1.*, t2.*
+ |FROM tj1 t1 LEFT JOIN tj2 t2
+ |ON t1.a = t2.a AND (t1.b = t2.b or t1.c = t2.c) ORDER BY t1.a, t2.a
+ |""".stripMargin,
+ true,
+ { _ => }
+ )
+ compareResultsAgainstVanillaSpark(
+ """
+ |SELECT t1.*, t2.*
+ |FROM tj1 t1 LEFT JOIN tj2 t2
+ |ON t1.a = t2.a + 1 AND (t1.b + 1 = t2.b or t1.c = t2.c + 1) ORDER BY
t1.a, t2.a
+ |""".stripMargin,
+ true,
+ { _ => }
+ )
+ sql("drop table if exists tj1")
+ sql("drop table if exists tj2")
+ }
+
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala
index 90e09e75f1..f0bfe64d6d 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala
@@ -34,7 +34,6 @@ import scala.collection.mutable
class GlutenClickHouseTPCHBucketSuite
extends GlutenClickHouseTPCHAbstractSuite
with AdaptiveSparkPlanHelper {
-
override protected val tablesPath: String = basePath + "/tpch-data-ch"
override protected val tpchQueries: String = rootPath +
"queries/tpch-queries-ch"
override protected val queriesResults: String = rootPath +
"bucket-queries-output"
@@ -48,6 +47,7 @@ class GlutenClickHouseTPCHBucketSuite
.set("spark.sql.autoBroadcastJoinThreshold", "-1") // for test bucket
join
.set("spark.sql.adaptive.enabled", "true")
.set("spark.gluten.sql.columnar.backend.ch.shuffle.hash.algorithm",
"sparkMurmurHash3_32")
+ .setCHConfig("enable_pre_projection_for_join_conditions", "false")
.setCHConfig("enable_grace_aggregate_spill_test", "true")
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
index 015da0dfae..dcb3514581 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
@@ -51,6 +51,7 @@ class GlutenClickHouseTPCHMetricsSuite extends
GlutenClickHouseTPCHAbstractSuite
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.setCHConfig("logger.level", "error")
.setCHSettings("input_format_parquet_max_block_size",
parquetMaxBlockSize)
+ .setCHConfig("enable_pre_projection_for_join_conditions", "false")
.setCHConfig("enable_streaming_aggregating", true)
}
// scalastyle:on line.size.limit
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetBucketSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetBucketSuite.scala
index 7a927bf23a..3eace1ab49 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetBucketSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHParquetBucketSuite.scala
@@ -54,6 +54,7 @@ class GlutenClickHouseTPCHParquetBucketSuite
.set("spark.sql.autoBroadcastJoinThreshold", "-1") // for test bucket
join
.set("spark.sql.adaptive.enabled", "true")
.set("spark.gluten.sql.columnar.backend.ch.shuffle.hash.algorithm",
"sparkMurmurHash3_32")
+ .setCHConfig("enable_pre_projection_for_join_conditions", "false")
.setCHConfig("enable_grace_aggregate_spill_test", "true")
}
diff --git a/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp
b/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp
index 6a5f9bc937..3ffbbf4171 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp
@@ -641,13 +641,9 @@ bool JoinRelParser::couldRewriteToMultiJoinOnClauses(
const DB::Block & left_header,
const DB::Block & right_header)
{
- /// There is only one join clause
if (!join_rel.has_post_join_filter())
return false;
-
const auto & filter_expr = join_rel.post_join_filter();
- std::list<const substrait::Expression *> expression_stack;
- expression_stack.push_back(&filter_expr);
auto check_function = [&](const String function_name_, const
substrait::Expression & e)
{
@@ -657,15 +653,42 @@ bool JoinRelParser::couldRewriteToMultiJoinOnClauses(
return function_name.has_value() && *function_name == function_name_;
};
+ std::function<void(std::vector<const substrait::Expression *> &, const
substrait::Expression &)> dfs_visit_or_expr
+ = [&](std::vector<const substrait::Expression *> & or_exprs, const
substrait::Expression & e) -> void
+ {
+ if (!check_function("or", e))
+ {
+ or_exprs.push_back(&e);
+ return;
+ }
+ const auto & args = e.scalar_function().arguments();
+ dfs_visit_or_expr(or_exprs, args[0].value());
+ dfs_visit_or_expr(or_exprs, args[1].value());
+ };
+
+ std::function<void(std::vector<const substrait::Expression *> &, const
substrait::Expression &)> dfs_visit_and_expr
+ = [&](std::vector<const substrait::Expression *> & and_exprs, const
substrait::Expression & e) -> void
+ {
+ if (!check_function("and", e))
+ {
+ and_exprs.push_back(&e);
+ return;
+ }
+ const auto & args = e.scalar_function().arguments();
+ dfs_visit_and_expr(and_exprs, args[0].value());
+ dfs_visit_and_expr(and_exprs, args[1].value());
+ };
+
auto get_field_ref = [](const substrait::Expression & e) ->
std::optional<Int32>
{
if (e.has_selection() && e.selection().has_direct_reference() &&
e.selection().direct_reference().has_struct_field())
return
std::optional<Int32>(e.selection().direct_reference().struct_field().field());
return {};
};
-
- auto parse_join_keys = [&](const substrait::Expression & e) ->
std::optional<std::pair<String, String>>
+ auto visit_equal_expr = [&](const substrait::Expression & e) ->
std::optional<std::pair<String, String>>
{
+ if (!check_function("equals", e))
+ return {};
const auto & args = e.scalar_function().arguments();
auto l_field_ref = get_field_ref(args[0].value());
auto r_field_ref = get_field_ref(args[1].value());
@@ -683,91 +706,29 @@ bool JoinRelParser::couldRewriteToMultiJoinOnClauses(
return {};
};
- auto parse_and_expression = [&](const substrait::Expression & e,
DB::TableJoin::JoinOnClause & join_on_clause)
- {
- std::vector<const substrait::Expression *> and_expression_stack;
- and_expression_stack.push_back(&e);
- while (!and_expression_stack.empty())
- {
- const auto & current_expr = *(and_expression_stack.back());
- and_expression_stack.pop_back();
- if (check_function("and", current_expr))
- {
- for (const auto & arg :
current_expr.scalar_function().arguments())
- and_expression_stack.push_back(&arg.value());
- }
- else if (check_function("equals", current_expr))
- {
- auto optional_keys = parse_join_keys(current_expr);
- if (!optional_keys)
- {
- LOG_DEBUG(getLogger("JoinRelParser"), "Not equal
comparison for keys from both tables");
- return false;
- }
- join_on_clause.addKey(optional_keys->first,
optional_keys->second, false);
- }
- else
- {
- LOG_DEBUG(getLogger("JoinRelParser"), "And or equals function
is expected");
- return false;
- }
- }
- return true;
- };
-
- while (!expression_stack.empty())
- {
- const auto & current_expr = *(expression_stack.back());
- expression_stack.pop_back();
- if (!check_function("or", current_expr))
- {
- LOG_DEBUG(getLogger("JoinRelParser"), "Not an or expression");
- return false;
- }
-
- auto get_current_join_on_clause = [&]()
- {
- DB::TableJoin::JoinOnClause new_clause = prefix_clause;
- clauses.push_back(new_clause);
- return &clauses.back();
- };
- const auto & args = current_expr.scalar_function().arguments();
- for (const auto & arg : args)
+ std::vector<const substrait::Expression *> or_exprs;
+ dfs_visit_or_expr(or_exprs, filter_expr);
+ if (or_exprs.empty())
+ return false;
+ for (const auto * or_expr : or_exprs)
+ {
+ DB::TableJoin::JoinOnClause new_clause = prefix_clause;
+ clauses.push_back(new_clause);
+ auto & current_clause = clauses.back();
+ std::vector<const substrait::Expression *> and_exprs;
+ dfs_visit_and_expr(and_exprs, *or_expr);
+ for (const auto * and_expr : and_exprs)
{
- if (check_function("equals", arg.value()))
- {
- auto optional_keys = parse_join_keys(arg.value());
- if (!optional_keys)
- {
- LOG_DEBUG(getLogger("JoinRelParser"), "Not equal
comparison for keys from both tables");
- return false;
- }
- get_current_join_on_clause()->addKey(optional_keys->first,
optional_keys->second, false);
- }
- else if (check_function("and", arg.value()))
- {
- if (!parse_and_expression(arg.value(),
*get_current_join_on_clause()))
- {
- LOG_DEBUG(getLogger("JoinRelParser"), "Parse and
expression failed");
- return false;
- }
- }
- else if (check_function("or", arg.value()))
- {
- expression_stack.push_back(&arg.value());
- }
- else
- {
- LOG_DEBUG(getLogger("JoinRelParser"), "Unknow function");
+ auto join_keys = visit_equal_expr(*and_expr);
+ if (!join_keys)
return false;
- }
+ current_clause.addKey(join_keys->first, join_keys->second, false);
}
}
return true;
}
-
DB::QueryPlanPtr JoinRelParser::buildMultiOnClauseHashJoin(
std::shared_ptr<DB::TableJoin> table_join,
DB::QueryPlanPtr left_plan,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]