This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 2614db9f3e fix (#9293)
2614db9f3e is described below
commit 2614db9f3e190053e0e46fc3a823ee6ade01a37a
Author: lgbo <[email protected]>
AuthorDate: Fri Apr 11 09:12:55 2025 +0800
fix (#9293)
---
.../gluten/backendsapi/clickhouse/CHBackend.scala | 2 +
.../clickhouse/CHSparkPlanExecApi.scala | 3 +-
.../execution/CHHashJoinExecTransformer.scala | 5 +-
.../EliminateDeduplicateAggregateWithAnyJoin.scala | 43 ++++++++++---
.../extension/JoinAggregateToAggregateUnion.scala | 4 +-
.../RewriteSortMergeJoinToHashJoinRule.scala | 3 +-
.../execution/GlutenEliminateJoinSuite.scala | 72 ++++++++++++++++++++++
7 files changed, 116 insertions(+), 16 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index dcdb5dcc5d..1addfe23ea 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -163,6 +163,8 @@ object CHBackendSettings extends BackendSettingsApi with
Logging {
CHConfig.prefixOf("enable.coalesce.project.union")
val GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION: String =
CHConfig.prefixOf("join.aggregate.to.aggregate.union")
+ val GLUTEN_ELIMINATE_DEDUPLICATE_AGGREGATE_WITH_ANY_JOIN: String =
+ CHConfig.prefixOf("eliminate_deduplicate_aggregate_with_any_join")
def affinityMode: String = {
SparkEnv.get.conf
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index dcf19204d2..c77e7bc31b 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -315,7 +315,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
condition,
left,
right,
- isSkewJoin)
+ isSkewJoin,
+ false)
}
/** Generate BroadcastHashJoinExecTransformer. */
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index 21eab86da5..cbf3a3b6ea 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -90,7 +90,8 @@ case class CHShuffledHashJoinExecTransformer(
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan,
- isSkewJoin: Boolean)
+ isSkewJoin: Boolean,
+ isAnyJoin: Boolean)
extends ShuffledHashJoinExecTransformerBase(
leftKeys,
rightKeys,
@@ -100,8 +101,6 @@ case class CHShuffledHashJoinExecTransformer(
left,
right,
isSkewJoin) {
- // `any join` is used to accelerate the case when the right table is the
aggregate result.
- var isAnyJoin = false
override protected def withNewChildrenInternal(
newLeft: SparkPlan,
newRight: SparkPlan): CHShuffledHashJoinExecTransformer =
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
index 06a4199d53..6a7f8a165c 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/EliminateDeduplicateAggregateWithAnyJoin.scala
@@ -22,6 +22,7 @@ import org.apache.gluten.execution._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
@@ -30,21 +31,48 @@ case class EliminateDeduplicateAggregateWithAnyJoin(spark:
SparkSession)
extends Rule[SparkPlan]
with Logging {
override def apply(plan: SparkPlan): SparkPlan = {
- if (!CHBackendSettings.eliminateDeduplicateAggregateWithAnyJoin()) {
+
+ if (
+ !spark.conf
+
.get(CHBackendSettings.GLUTEN_ELIMINATE_DEDUPLICATE_AGGREGATE_WITH_ANY_JOIN,
"true")
+ .toBoolean
+ ) {
return plan
}
plan.transformUp {
- case hashJoin: CHShuffledHashJoinExecTransformer =>
+ case hashJoin: CHShuffledHashJoinExecTransformer
+ if (hashJoin.buildSide == BuildRight && hashJoin.joinType ==
LeftOuter) =>
hashJoin.right match {
case aggregate: CHHashAggregateExecTransformer =>
+ if (
+ isDeduplicateAggregate(aggregate) &&
allGroupingKeysAreJoinKeys(hashJoin, aggregate)
+ ) {
+ hashJoin.copy(right = aggregate.child, isAnyJoin = true)
+ } else {
+ hashJoin
+ }
+ case project @ ProjectExecTransformer(_, aggregate:
CHHashAggregateExecTransformer) =>
if (
hashJoin.joinType == LeftOuter &&
+ isDeduplicateAggregate(aggregate) &&
+ allGroupingKeysAreJoinKeys(hashJoin, aggregate) &&
project.projectList.forall(
+ _.isInstanceOf[AttributeReference])
+ ) {
+ hashJoin.copy(right = project.copy(child = aggregate.child),
isAnyJoin = true)
+ } else {
+ hashJoin
+ }
+ case _ => hashJoin
+ }
+ case hashJoin: CHShuffledHashJoinExecTransformer
+ if (hashJoin.buildSide == BuildLeft && hashJoin.joinType ==
LeftOuter) =>
+ hashJoin.left match {
+ case aggregate: CHHashAggregateExecTransformer =>
+ if (
isDeduplicateAggregate(aggregate) &&
allGroupingKeysAreJoinKeys(hashJoin, aggregate)
) {
- val newHashJoin = hashJoin.copy(right = aggregate.child)
- newHashJoin.isAnyJoin = true
- newHashJoin
+ hashJoin.copy(left = aggregate.child, isAnyJoin = true)
} else {
hashJoin
}
@@ -55,10 +83,7 @@ case class EliminateDeduplicateAggregateWithAnyJoin(spark:
SparkSession)
allGroupingKeysAreJoinKeys(hashJoin, aggregate) &&
project.projectList.forall(
_.isInstanceOf[AttributeReference])
) {
- val newHashJoin =
- hashJoin.copy(right = project.copy(child = aggregate.child))
- newHashJoin.isAnyJoin = true
- newHashJoin
+ hashJoin.copy(left = project.copy(child = aggregate.child),
isAnyJoin = true)
} else {
hashJoin
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
index 3fcc2d5369..da80d58f26 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala
@@ -334,7 +334,7 @@ case class JoinedAggregateAnalyzer(join: Join, subquery:
LogicalPlan) extends Lo
joinKeys.length != aggregate.groupingExpressions.length ||
!joinKeys.forall(k => outputGroupingKeys.exists(_.semanticEquals(k)))
) {
- logError(
+ logDebug(
s"xxx Join keys and grouping keys are not matched. joinKeys:
$joinKeys" +
s" outputGroupingKeys: $outputGroupingKeys")
return false
@@ -955,7 +955,7 @@ case class JoinAggregateToAggregateUnion(spark:
SparkSession)
analyzedAggregates.insert(0, rightAggregateAnalyzer.get)
collectSameKeysJoinedAggregates(join.left, analyzedAggregates)
} else {
- logError(
+ logDebug(
s"xxx Not have same keys. join keys:" +
s"${analyzedAggregates.head.getPrimeJoinKeys()} vs. " +
s"${rightAggregateAnalyzer.get.getPrimeJoinKeys()}")
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
index 8c5ada043f..441a181fca 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
@@ -80,7 +80,8 @@ case class RewriteSortMergeJoinToHashJoinRule(session:
SparkSession)
smj.condition,
newLeft,
newRight,
- smj.isSkewJoin)
+ smj.isSkewJoin,
+ false)
val validateResult = hashJoin.doValidate()
if (!validateResult.ok()) {
logError(s"Validation failed for ShuffledHashJoinExec:
${validateResult.reason()}")
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
index 080fc9b5c9..169892c4da 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala
@@ -16,6 +16,8 @@
*/
package org.apache.gluten.execution
+import org.apache.gluten.backendsapi.clickhouse._
+
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
@@ -59,6 +61,8 @@ class GlutenEliminateJoinSuite extends
GlutenClickHouseWholeStageTransformerSuit
.set("spark.sql.shuffle.partitions", "5")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.gluten.supported.scala.udfs",
"compare_substrings:compare_substrings")
+
.set(CHConfig.runtimeSettings("max_memory_usage_ratio_for_streaming_aggregating"),
"0.01")
+
.set(CHConfig.runtimeSettings("high_cardinality_threshold_for_streaming_aggregating"),
"0.2")
.set(
SQLConf.OPTIMIZER_EXCLUDED_RULES.key,
ConstantFolding.ruleName + "," + NullPropagation.ruleName)
@@ -469,4 +473,72 @@ class GlutenEliminateJoinSuite extends
GlutenClickHouseWholeStageTransformerSuit
assert(joins.length == 1)
})
}
+
+ // Ensure the isAnyJoin will never lost after apply other rules
+ test("lost any join setting") {
+ spark.sql("drop table if exists t_9267_1")
+ spark.sql("drop table if exists t_9267_2")
+ spark.sql("create table t_9267_1 (a bigint, b bigint) using parquet")
+ spark.sql("create table t_9267_2 (a bigint, b bigint) using parquet")
+ spark.sql("insert into t_9267_1 select id as a, id as b from
range(20000000)")
+ spark.sql("insert into t_9267_2 select id as a, id as b from
range(5000000)")
+ spark.sql("insert into t_9267_2 select id as a, id as b from
range(5000000)")
+
+ val sql =
+ """
+ |select count(1) as n1, count(a1, b1, a2) as n2 from(
+ | select t1.a as a1, t1.b as b1, t2.a as a2 from (
+ | select * from t_9267_1 where a >= 0 and b < 100000000 and b >= 0
+ | ) t1 left join (
+ | select a, b from t_9267_2 group by a, b
+ | ) t2 on t1.a = t2.a and t1.b = t2.b
+ |)""".stripMargin
+ compareResultsAgainstVanillaSpark(
+ sql,
+ true,
+ {
+ df =>
+ val joins = df.queryExecution.executedPlan.collect {
+ case join: ShuffledHashJoinExecTransformerBase => join
+ }
+ assert(joins.length == 1)
+ })
+
+ spark.sql("drop table t_9267_1")
+ spark.sql("drop table t_9267_2")
+ }
+
+ test("build left side") {
+ spark.sql("drop table if exists t_9267_1")
+ spark.sql("drop table if exists t_9267_2")
+ spark.sql("create table t_9267_1 (a bigint, b bigint) using parquet")
+ spark.sql("create table t_9267_2 (a bigint, b bigint) using parquet")
+ spark.sql("insert into t_9267_1 select id as a, id as b from
range(2000000)")
+ spark.sql("insert into t_9267_2 select id as a, id as b from
range(500000)")
+ spark.sql("insert into t_9267_2 select id as a, id as b from
range(500000)")
+
+ // left table is smaller, it will be used as the build side.
+ val sql =
+ """
+ |select count(1) as n1, count(a1, b1, a2) as n2 from(
+ | select t1.a as a1, t1.b as b1, t2.a as a2 from (
+ | select a, b from t_9267_2 group by a, b
+ | ) t1 left join (
+ | select * from t_9267_1 where a >= 0 and b != 100000000 and b >= 0
+ | ) t2 on t1.a = t2.a and t1.b = t2.b
+ |)""".stripMargin
+ compareResultsAgainstVanillaSpark(
+ sql,
+ true,
+ {
+ df =>
+ val joins = df.queryExecution.executedPlan.collect {
+ case join: ShuffledHashJoinExecTransformerBase => join
+ }
+ assert(joins.length == 1)
+ })
+
+ spark.sql("drop table t_9267_1")
+ spark.sql("drop table t_9267_2")
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]