This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new c2a578c337 enable lazy expand for avg and sum(decimal) (#7840)
c2a578c337 is described below
commit c2a578c3379adafb79da9c06df3df4e41eb9139a
Author: lgbo <[email protected]>
AuthorDate: Fri Nov 8 17:48:33 2024 +0800
enable lazy expand for avg and sum(decimal) (#7840)
---
.../gluten/extension/LazyAggregateExpandRule.scala | 15 +++++++---
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 35 ++++++++++++++++++++++
2 files changed, 46 insertions(+), 4 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala
index e06503a5e1..86b28ab1f7 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala
@@ -190,10 +190,12 @@ case class LazyAggregateExpandRule(session: SparkSession)
extends Rule[SparkPlan
// 2. if any aggregate function uses attributes which is not from expand's
child, we don't
// enable this
if (
- !aggregate.aggregateExpressions.forall(
+ !aggregate.aggregateExpressions.forall {
e =>
isValidAggregateFunction(e) &&
-
e.aggregateFunction.references.forall(expandOutputAttributes.contains(_)))
+ e.aggregateFunction.references.forall(
+ attr =>
expandOutputAttributes.find(_.semanticEquals(attr)).isDefined)
+ }
) {
logDebug(s"xxx Some aggregate functions are not supported")
return false
@@ -267,7 +269,8 @@ case class LazyAggregateExpandRule(session: SparkSession)
extends Rule[SparkPlan
case _: Count => true
case _: Max => true
case _: Min => true
- case sum: Sum => !sum.dataType.isInstanceOf[DecimalType]
+ case _: Average => true
+ case _: Sum => true
case _ => false
}
}
@@ -275,7 +278,11 @@ case class LazyAggregateExpandRule(session: SparkSession)
extends Rule[SparkPlan
def getReplaceAttribute(
toReplace: Attribute,
attributesToReplace: Map[Attribute, Attribute]): Attribute = {
- attributesToReplace.getOrElse(toReplace, toReplace)
+ val kv = attributesToReplace.find(kv => kv._1.semanticEquals(toReplace))
+ kv match {
+ case Some((_, v)) => v
+ case None => toReplace
+ }
}
def buildReplaceAttributeMap(expand: ExpandExecTransformer): Map[Attribute,
Attribute] = {
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 12047b300c..9affdeb7f7 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -3068,6 +3068,41 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)
}
+ test("GLUTEN-7647 lazy expand for avg and sum") {
+ val create_table_sql =
+ """
+ |create table test_7647(x bigint, y bigint, z bigint, v decimal(10,
2)) using parquet
+ |""".stripMargin
+ spark.sql(create_table_sql)
+ val insert_data_sql =
+ """
+ |insert into test_7647 values
+ |(1, 1, 1, 1.0),
+ |(2, 2, 2, 2.0),
+ |(3, 3, 3, 3.0),
+ |(2,2,1, 4.0)
+ |""".stripMargin
+ spark.sql(insert_data_sql)
+
+ def checkLazyExpand(df: DataFrame): Unit = {
+ val expands = collectWithSubqueries(df.queryExecution.executedPlan) {
+ case e: ExpandExecTransformer if
(e.child.isInstanceOf[HashAggregateExecBaseTransformer]) =>
+ e
+ }
+ assert(expands.size == 1)
+ }
+
+ var sql = "select x, y, avg(z), sum(v) from test_7647 group by x, y with
cube order by x, y"
+ compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)
+ sql =
+ "select x, y, count(distinct z), avg(v) from test_7647 group by x, y
with cube order by x, y"
+ compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)
+ sql =
+ "select x, y, count(distinct z), sum(v) from test_7647 group by x, y
with cube order by x, y"
+ compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)
+ spark.sql("drop table if exists test_7647")
+ }
+
test("GLUTEN-7759: Fix bug of agg pre-project push down") {
val table_create_sql =
"create table test_tbl_7759(id bigint, name string, day string) using
parquet"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]