This is an automated email from the ASF dual-hosted git repository. rui pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push: new e9034ff5e [GLUTEN-4946][CH] Fix avg(bigint) overflow (#5048) e9034ff5e is described below commit e9034ff5e2ec4cce8cd5defaf2ade9b44b8c8aa3 Author: loudongfeng <nemon...@qq.com> AuthorDate: Mon Mar 25 12:55:00 2024 +0800 [GLUTEN-4946][CH] Fix avg(bigint) overflow (#5048) --- .../clickhouse/CHSparkPlanExecApi.scala | 2 + .../catalyst/CHAggregateFunctionRewriteRule.scala | 60 ++++++++++++++++++++++ .../execution/GlutenFunctionValidateSuite.scala | 21 ++++++++ .../main/scala/io/glutenproject/GlutenConfig.scala | 8 +++ 4 files changed, 91 insertions(+) diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 29af5a0e5..4b6ee1909 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -35,6 +35,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper, HashPartitioningWrapper} import org.apache.spark.shuffle.utils.CHShuffleUtil import org.apache.spark.sql.{SparkSession, Strategy} +import org.apache.spark.sql.catalyst.CHAggregateFunctionRewriteRule import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ @@ -518,6 +519,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = { List( spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf), + spark => CHAggregateFunctionRewriteRule(spark), _ => CountDistinctWithoutExpand ) } diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala new file mode 100644 index 000000000..623db7993 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst + +import io.glutenproject.GlutenConfig + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types._ + +/** + * Avg(Int) function: CH use input type for intermediate sum type, while spark use double so need + * convert . + * @param spark + */ +case class CHAggregateFunctionRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case a: Aggregate => + a.transformExpressions { + case avgExpr @ AggregateExpression(avg: Average, _, _, _, _) + if GlutenConfig.getConf.enableCastAvgAggregateFunction && + GlutenConfig.getConf.enableColumnarHashAgg && + !avgExpr.isDistinct && isDataTypeNeedConvert(avg.child.dataType) => + AggregateExpression( + avg.copy(child = Cast(avg.child, DoubleType)), + avgExpr.mode, + avgExpr.isDistinct, + avgExpr.filter, + avgExpr.resultId + ) + } + } + + private def isDataTypeNeedConvert(dataType: DataType): Boolean = { + dataType match { + case FloatType => true + case IntegerType => true + case LongType => true + case ShortType => true + case _ => false + } + } +} diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala index 818fe4e5f..1a27e68fe 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala @@ -639,4 +639,25 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS val sql = "select cast(concat(' ', cast(id as string)) as bigint) from range(10)" runQueryAndCompare(sql)(checkOperatorMatch[ProjectExecTransformer]) } + + test("avg(bigint) overflow") { + withSQLConf( + "spark.gluten.sql.columnar.forceShuffledHashJoin" -> "false", + "spark.sql.autoBroadcastJoinThreshold" -> "-1") { + withTable("myitem") { + sql("create table big_int(id bigint) using parquet") + sql(""" + |insert into big_int values (9223372036854775807), + |(9223372036854775807), + |(9223372036854775807), + |(9223372036854775807) + |""".stripMargin) + val q = "select avg(id) from big_int" + runQueryAndCompare(q)(checkOperatorMatch[CHHashAggregateExecTransformer]) + val disinctSQL = "select count(distinct id), avg(distinct id), avg(id) from big_int" + runQueryAndCompare(disinctSQL)(checkOperatorMatch[CHHashAggregateExecTransformer]) + } + } + } + } diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala index 48e37cdb3..4119a09fc 100644 --- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala +++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala @@ -353,6 +353,8 @@ class GlutenConfig(conf: SQLConf) extends Logging { def enableColumnarProjectCollapse: Boolean = conf.getConf(ENABLE_COLUMNAR_PROJECT_COLLAPSE) def awsSdkLogLevel: String = conf.getConf(AWS_SDK_LOG_LEVEL) + + def enableCastAvgAggregateFunction: Boolean = conf.getConf(COLUMNAR_NATIVE_CAST_AGGREGATE_ENABLED) } object GlutenConfig { @@ -1691,4 +1693,10 @@ object GlutenConfig { .doc("Force fallback for orc char type scan.") .booleanConf .createWithDefault(true) + + val COLUMNAR_NATIVE_CAST_AGGREGATE_ENABLED = + buildConf("spark.gluten.sql.columnar.cast.avg") + .internal() + .booleanConf + .createWithDefault(true) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org For additional commands, e-mail: commits-h...@gluten.apache.org