This is an automated email from the ASF dual-hosted git repository.

mahongbin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new b07b36960 [Gluten-4706] Add a mode to execute count distinct directly 
instead of Expand+Count (#4708)
b07b36960 is described below

commit b07b36960238db62647b6aeefc930b7fc3f271c7
Author: Hongbin Ma <mahong...@apache.org>
AuthorDate: Tue Mar 12 16:38:47 2024 +0800

    [Gluten-4706] Add a mode to execute count distinct directly instead of 
Expand+Count (#4708)
---
 .../clickhouse/CHSparkPlanExecApi.scala            |   6 +-
 .../execution/CHHashAggregateExecTransformer.scala |   6 ++
 .../execution/GlutenClickHouseHiveTableSuite.scala |  71 +++++++------
 .../GlutenClickHouseNativeWriteTableSuite.scala    |  38 +------
 ...lutenClickHouseWholeStageTransformerSuite.scala |  17 ++-
 .../GlutenClickhouseCountDistinctSuite.scala       | 118 +++++++++++++++++++++
 .../extension/CustomAggExpressionTransformer.scala |   2 +-
 .../AggregateFunctionPartialMerge.h                |   4 +-
 .../local-engine/Parser/AggregateFunctionParser.h  |  14 ++-
 .../CommonAggregateFunctionParser.cpp              |   2 +-
 .../expression/ExpressionMappings.scala            |   1 +
 .../extension/CountDistinctWithoutExpand.scala     |  49 +++++++++
 .../expressions/aggregate/CountDistinct.scala      |  59 +++++++++++
 .../main/scala/io/glutenproject/GlutenConfig.scala |  13 +++
 .../glutenproject/expression/ExpressionNames.scala |   1 +
 15 files changed, 322 insertions(+), 79 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 43d8ed1bc..f32116728 100644
--- 
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -22,6 +22,7 @@ import io.glutenproject.execution._
 import io.glutenproject.expression._
 import io.glutenproject.expression.ConverterUtils.FunctionConfig
 import io.glutenproject.extension.{FallbackBroadcastHashJoin, 
FallbackBroadcastHashJoinPrepQueryStage}
+import io.glutenproject.extension.CountDistinctWithoutExpand
 import io.glutenproject.extension.columnar.AddTransformHintRule
 import 
io.glutenproject.extension.columnar.MiscColumnarRules.TransformPreOverrides
 import io.glutenproject.substrait.expression.{ExpressionBuilder, 
ExpressionNode, WindowFunctionNode}
@@ -503,7 +504,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
    * @return
    */
   override def genExtendedOptimizers(): List[SparkSession => 
Rule[LogicalPlan]] = {
-    List(spark => new CommonSubexpressionEliminateRule(spark, 
spark.sessionState.conf))
+    List(
+      spark => new CommonSubexpressionEliminateRule(spark, 
spark.sessionState.conf),
+      _ => CountDistinctWithoutExpand
+    )
   }
 
   /**
diff --git 
a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala
index c6b99a6bc..85b967446 100644
--- 
a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala
@@ -358,6 +358,12 @@ case class CHHashAggregateExecTransformer(
               (makeStructType(fields), attr.nullable)
             case expr if "bloom_filter_agg".equals(expr.prettyName) =>
               (makeStructTypeSingleOne(expr.children.head.dataType, 
attr.nullable), attr.nullable)
+            case cd: CountDistinct =>
+              var fields = Seq[(DataType, Boolean)]()
+              for (child <- cd.children) {
+                fields = fields :+ (child.dataType, child.nullable)
+              }
+              (makeStructType(fields), false)
             case _ =>
               (makeStructTypeSingleOne(attr.dataType, attr.nullable), 
attr.nullable)
           }
diff --git 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
index 938f8e6d1..b40a0fe0d 100644
--- 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
@@ -19,8 +19,10 @@ package io.glutenproject.execution
 import io.glutenproject.GlutenConfig
 import io.glutenproject.utils.UTSystemParameters
 
-import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf}
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.SPARK_VERSION_SHORT
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.delta.DeltaLog
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.hive.HiveTableScanExecTransformer
@@ -29,11 +31,11 @@ import org.apache.spark.sql.internal.SQLConf
 import org.apache.hadoop.fs.Path
 
 import java.io.{File, PrintWriter}
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
 
 import scala.reflect.ClassTag
 
-case class AllDataTypesWithComplextType(
+case class AllDataTypesWithComplexType(
     string_field: String = null,
     int_field: java.lang.Integer = null,
     long_field: java.lang.Long = null,
@@ -51,6 +53,36 @@ case class AllDataTypesWithComplextType(
     mapValueContainsNull: Map[Int, Option[Long]] = null
 )
 
+object AllDataTypesWithComplexType {
+  def genTestData(): Seq[AllDataTypesWithComplexType] = {
+    (0 to 199).map {
+      i =>
+        if (i % 100 == 1) {
+          AllDataTypesWithComplexType()
+        } else {
+          AllDataTypesWithComplexType(
+            s"$i",
+            i,
+            i.toLong,
+            i.toFloat,
+            i.toDouble,
+            i.toShort,
+            i.toByte,
+            i % 2 == 0,
+            new java.math.BigDecimal(i + ".56"),
+            Date.valueOf(new 
Date(System.currentTimeMillis()).toLocalDate.plusDays(i % 10)),
+            Timestamp.valueOf(
+              new 
Timestamp(System.currentTimeMillis()).toLocalDateTime.plusDays(i % 10)),
+            Seq.apply(i + 1, i + 2, i + 3),
+            Seq.apply(Option.apply(i + 1), Option.empty, Option.apply(i + 3)),
+            Map.apply((i + 1, i + 2), (i + 3, i + 4)),
+            Map.empty
+          )
+        }
+    }
+  }
+}
+
 class GlutenClickHouseHiveTableSuite
   extends GlutenClickHouseWholeStageTransformerSuite
   with AdaptiveSparkPlanHelper {
@@ -170,38 +202,13 @@ class GlutenClickHouseHiveTableSuite
       "map_field map<int, long>," +
       "map_field_with_null map<int, long>) stored as %s".format(fileFormat)
 
-  def genTestData(): Seq[AllDataTypesWithComplextType] = {
-    (0 to 199).map {
-      i =>
-        if (i % 100 == 1) {
-          AllDataTypesWithComplextType()
-        } else {
-          AllDataTypesWithComplextType(
-            s"$i",
-            i,
-            i.toLong,
-            i.toFloat,
-            i.toDouble,
-            i.toShort,
-            i.toByte,
-            i % 2 == 0,
-            new java.math.BigDecimal(i + ".56"),
-            new java.sql.Date(System.currentTimeMillis()),
-            new Timestamp(System.currentTimeMillis()),
-            Seq.apply(i + 1, i + 2, i + 3),
-            Seq.apply(Option.apply(i + 1), Option.empty, Option.apply(i + 3)),
-            Map.apply((i + 1, i + 2), (i + 3, i + 4)),
-            Map.empty
-          )
-        }
-    }
-  }
-
   protected def initializeTable(
       table_name: String,
       table_create_sql: String,
       partitions: Seq[String]): Unit = {
-    spark.createDataFrame(genTestData()).createOrReplaceTempView("tmp_t")
+    spark
+      .createDataFrame(AllDataTypesWithComplexType.genTestData())
+      .createOrReplaceTempView("tmp_t")
     val truncate_sql = "truncate table %s".format(table_name)
     val drop_sql = "drop table if exists %s".format(table_name)
     spark.sql(drop_sql)
diff --git 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala
 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala
index 5e9c2d1fc..724fb2721 100644
--- 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala
@@ -17,6 +17,7 @@
 package io.glutenproject.execution
 
 import io.glutenproject.GlutenConfig
+import io.glutenproject.execution.AllDataTypesWithComplexType.genTestData
 import io.glutenproject.utils.UTSystemParameters
 
 import org.apache.spark.SparkConf
@@ -27,8 +28,6 @@ import org.apache.spark.sql.test.SharedSparkSession
 
 import org.scalatest.BeforeAndAfterAll
 
-import java.sql.{Date, Timestamp}
-
 class GlutenClickHouseNativeWriteTableSuite
   extends GlutenClickHouseWholeStageTransformerSuite
   with AdaptiveSparkPlanHelper
@@ -90,41 +89,6 @@ class GlutenClickHouseNativeWriteTableSuite
   private val table_name_vanilla_template = "hive_%s_test_written_by_vanilla"
   private val formats = Array("orc", "parquet")
 
-  def genTestData(): Seq[AllDataTypesWithComplextType] = {
-    (0 to 199).map {
-      i =>
-        if (i % 100 == 1) {
-          AllDataTypesWithComplextType()
-        } else {
-          AllDataTypesWithComplextType(
-            s"$i",
-            i,
-            i.toLong,
-            i.toFloat,
-            i.toDouble,
-            i.toShort,
-            i.toByte,
-            i % 2 == 0,
-            new java.math.BigDecimal(i + ".56"),
-            Date.valueOf(new 
Date(System.currentTimeMillis()).toLocalDate.plusDays(i % 10)),
-            Timestamp.valueOf(
-              new 
Timestamp(System.currentTimeMillis()).toLocalDateTime.plusDays(i % 10)),
-            Seq.apply(i + 1, i + 2, i + 3),
-            Seq.apply(Option.apply(i + 1), Option.empty, Option.apply(i + 3)),
-            Map.apply((i + 1, i + 2), (i + 3, i + 4)),
-            Map.empty
-          )
-        }
-    }
-  }
-
-  protected def initializeTable(table_name: String, table_create_sql: String): 
Unit = {
-    spark.createDataFrame(genTestData()).createOrReplaceTempView("tmp_t")
-    spark.sql(s"drop table IF EXISTS $table_name")
-    spark.sql(table_create_sql)
-    spark.sql("insert into %s select * from tmp_t".format(table_name))
-  }
-
   override protected def afterAll(): Unit = {
     DeltaLog.clearCache()
 
diff --git 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala
 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala
index 28da546fc..1890d00e3 100644
--- 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala
@@ -16,6 +16,12 @@
  */
 package io.glutenproject.execution
 
+import io.glutenproject.GlutenConfig
+import io.glutenproject.utils.UTSystemParameters
+
+import org.apache.spark.SparkConf
+import 
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
+
 import org.apache.commons.io.FileUtils
 
 import java.io.File
@@ -37,6 +43,15 @@ class GlutenClickHouseWholeStageTransformerSuite extends 
WholeStageTransformerSu
     }
   }
 
+  override protected def sparkConf: SparkConf =
+    super.sparkConf
+      .set(GlutenConfig.GLUTEN_LIB_PATH, 
UTSystemParameters.getClickHouseLibPath())
+      .set(
+        "spark.gluten.sql.columnar.backend.ch.use.v2",
+        ClickHouseConfig.DEFAULT_USE_DATASOURCE_V2)
+      .set("spark.gluten.sql.enable.native.validation", "false")
+      .set("spark.sql.warehouse.dir", warehouse)
+
   override def beforeAll(): Unit = {
     // prepare working paths
     val basePathDir = new File(basePath)
@@ -56,6 +71,6 @@ class GlutenClickHouseWholeStageTransformerSuite extends 
WholeStageTransformerSu
   protected val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db"
 
   override protected val backend: String = "ch"
-  override protected val resourcePath: String = ""
+  final override protected val resourcePath: String = "" // ch not need this
   override protected val fileFormat: String = "parquet"
 }
diff --git 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseCountDistinctSuite.scala
 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseCountDistinctSuite.scala
new file mode 100644
index 000000000..4ba9882b6
--- /dev/null
+++ 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseCountDistinctSuite.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.execution
+
+import io.glutenproject.execution.AllDataTypesWithComplexType.genTestData
+
+import org.apache.spark.SparkConf
+class GlutenClickhouseCountDistinctSuite extends 
GlutenClickHouseWholeStageTransformerSuite {
+
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf
+      .set("spark.gluten.sql.countDistinctWithoutExpand", "true")
+      .set("spark.sql.adaptive.enabled", "false")
+  }
+
+  test("check count distinct correctness") {
+    // simple case
+    var sql = "select count(distinct(a))  from values (1,1,1), (2,2,2) as 
data(a,b,c)"
+    compareResultsAgainstVanillaSpark(sql, true, { _ => })
+
+    // with null
+    sql = "select count(distinct(a))  from " +
+      "values (1,1,1), (2,2,2), (1,3,3), (null,4,4), (null,5,5) as data(a,b,c)"
+    compareResultsAgainstVanillaSpark(sql, true, { _ => })
+
+    // three CD
+    sql = "select count(distinct(b)), count(distinct(a)),count(distinct c)  
from " +
+      "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2) as 
data(a,b,c)"
+    compareResultsAgainstVanillaSpark(sql, true, { _ => })
+
+    // count distinct with multiple args
+    sql = "select count(distinct(a,b)), count(distinct(a,b,c))  from " +
+      "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2) as 
data(a,b,c)"
+    compareResultsAgainstVanillaSpark(sql, true, { _ => })
+  }
+
+  test("check count distinct execution plan") {
+    val sql =
+      "select count(distinct(b)), count(distinct a, b)  from " +
+        "values (0, null,1), (1, 1,1), (2, 2,1), (1, 2,1) ,(2,2,2) as 
data(a,b,c) group by c"
+
+    val df = spark.sql(sql)
+    WholeStageTransformerSuite.checkFallBack(df)
+
+    val planExecs = df.queryExecution.executedPlan.collect {
+      case aggTransformer: HashAggregateExecBaseTransformer => aggTransformer
+    }
+
+    planExecs.head.aggregateExpressions.foreach {
+      expr => assert(expr.toString().startsWith("countdistinct"))
+    }
+    planExecs(1).aggregateExpressions.foreach {
+      expr => assert(expr.toString().startsWith("partial_countdistinct"))
+    }
+  }
+
+  test("check all data types") {
+    
spark.createDataFrame(genTestData()).createOrReplaceTempView("all_data_types")
+
+    // Vanilla does not support map
+    for (
+      field <- 
AllDataTypesWithComplexType().getClass.getDeclaredFields.filterNot(
+        p => p.getName.startsWith("map"))
+    ) {
+      val sql = s"select count(distinct(${field.getName})) from all_data_types"
+      compareResultsAgainstVanillaSpark(sql, true, { _ => })
+      spark.sql(sql).show
+    }
+
+    // just test success run
+    for (
+      field <- AllDataTypesWithComplexType().getClass.getDeclaredFields.filter(
+        p => p.getName.startsWith("map"))
+    ) {
+      val sql = s"select count(distinct(${field.getName})) from all_data_types"
+      spark.sql(sql).show
+    }
+  }
+
+  test("check count distinct with agg fallback") {
+    // skewness agg is not supported, will cause fallback
+    val sql = "select count(distinct(a,b)) , skewness(b) from " +
+      "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2),(3,3,3) as 
data(a,b,c)"
+    assertThrows[UnsupportedOperationException] {
+      spark.sql(sql).show
+    }
+  }
+
+  test("check count distinct with expr fallback") {
+    // try_add is not supported, will cause fallback after a project operator
+    val sql = s"""
+      select count(distinct(a,b)) , try_add(c,b) from
+      values (0, null,1), (0,null,2), (1, 1,4) as data(a,b,c) group by 
try_add(c,b)
+      """;
+    val df = spark.sql(sql)
+    WholeStageTransformerSuite.checkFallBack(df, noFallback = false)
+  }
+
+  test("check count distinct with filter") {
+    val sql = "select count(distinct(a,b)) FILTER (where c <3) from " +
+      "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2),(3,3,3) as 
data(a,b,c)"
+    compareResultsAgainstVanillaSpark(sql, true, { _ => })
+  }
+}
diff --git 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala
 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala
index 76ec0d7ad..9fbc2aa3b 100644
--- 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala
+++ 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala
@@ -16,7 +16,7 @@
  */
 package io.glutenproject.execution.extension
 
-import io.glutenproject.expression._
+import io.glutenproject.expression.Sig
 import io.glutenproject.extension.ExpressionExtensionTrait
 
 import org.apache.spark.sql.catalyst.expressions._
diff --git 
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h 
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h
index bf2becf32..f77c253f5 100644
--- a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h
+++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h
@@ -50,7 +50,8 @@ private:
 
 public:
     AggregateFunctionPartialMerge(const AggregateFunctionPtr & nested_, const 
DataTypePtr & argument, const Array & params_)
-        : IAggregateFunctionHelper<AggregateFunctionPartialMerge>({argument}, 
params_, createResultType(nested_)), nested_func(nested_)
+        : IAggregateFunctionHelper<AggregateFunctionPartialMerge>({argument}, 
params_, createResultType(nested_))
+        , nested_func(nested_)
     {
         const DataTypeAggregateFunction * data_type = typeid_cast<const 
DataTypeAggregateFunction *>(argument.get());
 
@@ -115,5 +116,4 @@ public:
 
     AggregateFunctionPtr getNestedFunction() const override { return 
nested_func; }
 };
-
 }
diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h 
b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
index a9840eeef..e2444361f 100644
--- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
+++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
@@ -29,7 +29,6 @@
 
 namespace local_engine
 {
-
 class AggregateFunctionParser
 {
 public:
@@ -78,7 +77,11 @@ public:
         }
     };
 
-    AggregateFunctionParser(SerializedPlanParser * plan_parser_) : 
plan_parser(plan_parser_) { }
+    AggregateFunctionParser(SerializedPlanParser * plan_parser_)
+        : plan_parser(plan_parser_)
+    {
+    }
+
     virtual ~AggregateFunctionParser() = default;
 
     virtual String getName() const = 0;
@@ -93,6 +96,7 @@ public:
     /// Do some preprojections for the function arguments, and return the 
necessary arguments for the CH function.
     virtual DB::ActionsDAG::NodeRawConstPtrs
     parseFunctionArguments(const CommonFunctionInfo & func_info, const String 
& ch_func_name, DB::ActionsDAGPtr & actions_dag) const;
+
     DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(const 
CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const
     {
         return parseFunctionArguments(func_info, getCHFunctionName(func_info), 
actions_dag);
@@ -106,7 +110,9 @@ public:
 
     /// Make a postprojection for the function result.
     virtual const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
-        const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * 
func_node, DB::ActionsDAGPtr & actions_dag, bool withNullability) const;
+        const CommonFunctionInfo & func_info,
+        const DB::ActionsDAG::Node * func_node,
+        DB::ActionsDAGPtr & actions_dag, bool withNullability) const;
 
     /// Parameters are only used in aggregate functions at present. e.g. 
percentiles(0.5)(x).
     /// 0.5 is the parameter of percentiles function.
@@ -159,6 +165,7 @@ protected:
     SerializedPlanParser * plan_parser;
     Poco::Logger * logger = 
&Poco::Logger::get("AggregateFunctionParserFactory");
 };
+
 using AggregateFunctionParserPtr = std::shared_ptr<AggregateFunctionParser>;
 using AggregateFunctionParserCreator = 
std::function<AggregateFunctionParserPtr(SerializedPlanParser *)>;
 
@@ -200,5 +207,4 @@ struct AggregateFunctionParserRegister
 {
     AggregateFunctionParserRegister() { 
AggregateFunctionParserFactory::instance().registerAggregateFunctionParser<Parser>(Parser::name);
 }
 };
-
 }
diff --git 
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
 
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
index fc334a214..afe65f793 100644
--- 
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
+++ 
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
@@ -47,5 +47,5 @@ REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(RowNumber, 
row_number, row_number)
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Ntile, ntile, ntile)
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(PercentRank, percent_rank, 
percent_rank)
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CumeDist, cume_dist, cume_dist)
-
+REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CountDistinct, count_distinct, 
uniqExact)
 }
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
 
b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
index b6a874756..a063c244b 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
@@ -251,6 +251,7 @@ object ExpressionMappings {
     Sig[Sum](SUM),
     Sig[Average](AVG),
     Sig[Count](COUNT),
+    Sig[CountDistinct](COUNT_DISTINCT),
     Sig[Min](MIN),
     Sig[Max](MAX),
     Sig[MaxBy](MAX_BY),
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/extension/CountDistinctWithoutExpand.scala
 
b/gluten-core/src/main/scala/io/glutenproject/extension/CountDistinctWithoutExpand.scala
new file mode 100644
index 000000000..c7f57030f
--- /dev/null
+++ 
b/gluten-core/src/main/scala/io/glutenproject/extension/CountDistinctWithoutExpand.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.extension
+
+import io.glutenproject.GlutenConfig
+
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Count, CountDistinct}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE_EXPRESSION
+
+/**
+ * By converting Count(with isDistinct=true) to a UDAF called CountDistinct, 
we can avoid the Expand
+ * operator in the physical plan.
+ *
+ * This rule takes no effect unless spark.gluten.enabled and
+ * spark.gluten.sql.countDistinctWithoutExpand are both true
+ */
+object CountDistinctWithoutExpand extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    if (
+      GlutenConfig.getConf.enableGluten && 
GlutenConfig.getConf.enableCountDistinctWithoutExpand
+    ) {
+      
plan.transformAllExpressionsWithPruning(_.containsPattern(AGGREGATE_EXPRESSION))
 {
+        case ae: AggregateExpression if ae.isDistinct && 
ae.aggregateFunction.isInstanceOf[Count] =>
+          ae.copy(
+            aggregateFunction =
+              
CountDistinct.apply(ae.aggregateFunction.asInstanceOf[Count].children),
+            isDistinct = false)
+      }
+    } else {
+      plan
+    }
+  }
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountDistinct.scala
 
b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountDistinct.scala
new file mode 100644
index 000000000..54918fcc7
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountDistinct.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
+
+/**
+ * By default, spark execute count distinct as Expand + Count. This works 
reliably but may be
+ * slower. We allow user to inject a custom count distinct function to speed 
up the execution. Check
+ * the optimizer rule at CountDistinctWithoutExpand
+ */
+case class CountDistinct(children: Seq[Expression]) extends 
DeclarativeAggregate {
+
+  override def nullable: Boolean = false
+
+  // Return data type.
+  override def dataType: DataType = LongType
+
+  protected lazy val cd = AttributeReference("count_distinct", LongType, 
nullable = false)()
+
+  override lazy val aggBufferAttributes = cd :: Nil
+
+  override lazy val initialValues =
+    throw new UnsupportedOperationException(
+      "count distinct does not have non-columnar implementation")
+
+  override lazy val mergeExpressions =
+    throw new UnsupportedOperationException(
+      "count distinct does not have non-columnar implementation")
+
+  override lazy val evaluateExpression =
+    throw new UnsupportedOperationException(
+      "count distinct does not have non-columnar implementation")
+
+  override def defaultResult: Option[Literal] = Option(Literal(0L))
+
+  override lazy val updateExpressions =
+    throw new UnsupportedOperationException(
+      "count distinct does not have non-columnar implementation")
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): CountDistinct =
+    copy(children = newChildren)
+}
diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala 
b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
index 87f5fc4f0..8a8bd4152 100644
--- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
+++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
@@ -94,6 +94,9 @@ class GlutenConfig(conf: SQLConf) extends Logging {
   def enableCommonSubexpressionEliminate: Boolean =
     conf.getConf(ENABLE_COMMON_SUBEXPRESSION_ELIMINATE)
 
+  def enableCountDistinctWithoutExpand: Boolean =
+    conf.getConf(ENABLE_COUNT_DISTINCT_WITHOUT_EXPAND)
+
   def veloxOrcScanEnabled: Boolean =
     conf.getConf(VELOX_ORC_SCAN_ENABLED)
 
@@ -1537,6 +1540,16 @@ object GlutenConfig {
       .booleanConf
       .createWithDefault(true)
 
+  val ENABLE_COUNT_DISTINCT_WITHOUT_EXPAND =
+    buildConf("spark.gluten.sql.countDistinctWithoutExpand")
+      .internal()
+      .doc(
+        "Convert Count Distinct to a UDAF called count_distinct to " +
+          "prevent SparkPlanner converting it to Expand+Count. WARNING: " +
+          "When enabled, count distinct queries will fail to fallback!!!")
+      .booleanConf
+      .createWithDefault(false)
+
   val COLUMNAR_VELOX_BLOOM_FILTER_EXPECTED_NUM_ITEMS =
     
buildConf("spark.gluten.sql.columnar.backend.velox.bloomFilter.expectedNumItems")
       .internal()
diff --git 
a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala 
b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala
index c688c8abd..d2fc4f9ec 100644
--- 
a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala
+++ 
b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala
@@ -22,6 +22,7 @@ object ExpressionNames {
   final val SUM = "sum"
   final val AVG = "avg"
   final val COUNT = "count"
+  final val COUNT_DISTINCT = "count_distinct"
   final val MIN = "min"
   final val MAX = "max"
   final val MAX_BY = "max_by"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org
For additional commands, e-mail: commits-h...@gluten.apache.org

Reply via email to