This is an automated email from the ASF dual-hosted git repository.

rui pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new e9034ff5e [GLUTEN-4946][CH] Fix avg(bigint) overflow (#5048)
e9034ff5e is described below

commit e9034ff5e2ec4cce8cd5defaf2ade9b44b8c8aa3
Author: loudongfeng <nemon...@qq.com>
AuthorDate: Mon Mar 25 12:55:00 2024 +0800

    [GLUTEN-4946][CH] Fix avg(bigint) overflow (#5048)
---
 .../clickhouse/CHSparkPlanExecApi.scala            |  2 +
 .../catalyst/CHAggregateFunctionRewriteRule.scala  | 60 ++++++++++++++++++++++
 .../execution/GlutenFunctionValidateSuite.scala    | 21 ++++++++
 .../main/scala/io/glutenproject/GlutenConfig.scala |  8 +++
 4 files changed, 91 insertions(+)

diff --git 
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 29af5a0e5..4b6ee1909 100644
--- 
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -35,6 +35,7 @@ import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{GenShuffleWriterParameters, 
GlutenShuffleWriterWrapper, HashPartitioningWrapper}
 import org.apache.spark.shuffle.utils.CHShuffleUtil
 import org.apache.spark.sql.{SparkSession, Strategy}
+import org.apache.spark.sql.catalyst.CHAggregateFunctionRewriteRule
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
@@ -518,6 +519,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
   override def genExtendedOptimizers(): List[SparkSession => 
Rule[LogicalPlan]] = {
     List(
       spark => new CommonSubexpressionEliminateRule(spark, 
spark.sessionState.conf),
+      spark => CHAggregateFunctionRewriteRule(spark),
       _ => CountDistinctWithoutExpand
     )
   }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala
 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala
new file mode 100644
index 000000000..623db7993
--- /dev/null
+++ 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst
+
+import io.glutenproject.GlutenConfig
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.Cast
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types._
+
+/**
+ * Avg(Int) function: CH use input type for intermediate sum type, while spark 
use double so need
+ * convert .
+ * @param spark
+ */
+case class CHAggregateFunctionRewriteRule(spark: SparkSession) extends 
Rule[LogicalPlan] {
+  override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp 
{
+    case a: Aggregate =>
+      a.transformExpressions {
+        case avgExpr @ AggregateExpression(avg: Average, _, _, _, _)
+            if GlutenConfig.getConf.enableCastAvgAggregateFunction &&
+              GlutenConfig.getConf.enableColumnarHashAgg &&
+              !avgExpr.isDistinct && isDataTypeNeedConvert(avg.child.dataType) 
=>
+          AggregateExpression(
+            avg.copy(child = Cast(avg.child, DoubleType)),
+            avgExpr.mode,
+            avgExpr.isDistinct,
+            avgExpr.filter,
+            avgExpr.resultId
+          )
+      }
+  }
+
+  private def isDataTypeNeedConvert(dataType: DataType): Boolean = {
+    dataType match {
+      case FloatType => true
+      case IntegerType => true
+      case LongType => true
+      case ShortType => true
+      case _ => false
+    }
+  }
+}
diff --git 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
index 818fe4e5f..1a27e68fe 100644
--- 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
@@ -639,4 +639,25 @@ class GlutenFunctionValidateSuite extends 
GlutenClickHouseWholeStageTransformerS
     val sql = "select  cast(concat(' ', cast(id as string)) as bigint) from 
range(10)"
     runQueryAndCompare(sql)(checkOperatorMatch[ProjectExecTransformer])
   }
+
+  test("avg(bigint) overflow") {
+    withSQLConf(
+      "spark.gluten.sql.columnar.forceShuffledHashJoin" -> "false",
+      "spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+      withTable("myitem") {
+        sql("create table big_int(id bigint) using parquet")
+        sql("""
+              |insert into big_int values (9223372036854775807),
+              |(9223372036854775807),
+              |(9223372036854775807),
+              |(9223372036854775807)
+              |""".stripMargin)
+        val q = "select avg(id) from big_int"
+        
runQueryAndCompare(q)(checkOperatorMatch[CHHashAggregateExecTransformer])
+        val disinctSQL = "select count(distinct id), avg(distinct id), avg(id) 
from big_int"
+        
runQueryAndCompare(disinctSQL)(checkOperatorMatch[CHHashAggregateExecTransformer])
+      }
+    }
+  }
+
 }
diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala 
b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
index 48e37cdb3..4119a09fc 100644
--- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
+++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
@@ -353,6 +353,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {
   def enableColumnarProjectCollapse: Boolean = 
conf.getConf(ENABLE_COLUMNAR_PROJECT_COLLAPSE)
 
   def awsSdkLogLevel: String = conf.getConf(AWS_SDK_LOG_LEVEL)
+
+  def enableCastAvgAggregateFunction: Boolean = 
conf.getConf(COLUMNAR_NATIVE_CAST_AGGREGATE_ENABLED)
 }
 
 object GlutenConfig {
@@ -1691,4 +1693,10 @@ object GlutenConfig {
       .doc("Force fallback for orc char type scan.")
       .booleanConf
       .createWithDefault(true)
+
+  val COLUMNAR_NATIVE_CAST_AGGREGATE_ENABLED =
+    buildConf("spark.gluten.sql.columnar.cast.avg")
+      .internal()
+      .booleanConf
+      .createWithDefault(true)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org
For additional commands, e-mail: commits-h...@gluten.apache.org

Reply via email to