This is an automated email from the ASF dual-hosted git repository. mahongbin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push: new b07b36960 [Gluten-4706] Add a mode to execute count distinct directly instead of Expand+Count (#4708) b07b36960 is described below commit b07b36960238db62647b6aeefc930b7fc3f271c7 Author: Hongbin Ma <mahong...@apache.org> AuthorDate: Tue Mar 12 16:38:47 2024 +0800 [Gluten-4706] Add a mode to execute count distinct directly instead of Expand+Count (#4708) --- .../clickhouse/CHSparkPlanExecApi.scala | 6 +- .../execution/CHHashAggregateExecTransformer.scala | 6 ++ .../execution/GlutenClickHouseHiveTableSuite.scala | 71 +++++++------ .../GlutenClickHouseNativeWriteTableSuite.scala | 38 +------ ...lutenClickHouseWholeStageTransformerSuite.scala | 17 ++- .../GlutenClickhouseCountDistinctSuite.scala | 118 +++++++++++++++++++++ .../extension/CustomAggExpressionTransformer.scala | 2 +- .../AggregateFunctionPartialMerge.h | 4 +- .../local-engine/Parser/AggregateFunctionParser.h | 14 ++- .../CommonAggregateFunctionParser.cpp | 2 +- .../expression/ExpressionMappings.scala | 1 + .../extension/CountDistinctWithoutExpand.scala | 49 +++++++++ .../expressions/aggregate/CountDistinct.scala | 59 +++++++++++ .../main/scala/io/glutenproject/GlutenConfig.scala | 13 +++ .../glutenproject/expression/ExpressionNames.scala | 1 + 15 files changed, 322 insertions(+), 79 deletions(-) diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 43d8ed1bc..f32116728 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -22,6 +22,7 @@ import io.glutenproject.execution._ import io.glutenproject.expression._ import io.glutenproject.expression.ConverterUtils.FunctionConfig import io.glutenproject.extension.{FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage} +import io.glutenproject.extension.CountDistinctWithoutExpand import io.glutenproject.extension.columnar.AddTransformHintRule import io.glutenproject.extension.columnar.MiscColumnarRules.TransformPreOverrides import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode} @@ -503,7 +504,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { * @return */ override def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]] = { - List(spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf)) + List( + spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf), + _ => CountDistinctWithoutExpand + ) } /** diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala index c6b99a6bc..85b967446 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala @@ -358,6 +358,12 @@ case class CHHashAggregateExecTransformer( (makeStructType(fields), attr.nullable) case expr if "bloom_filter_agg".equals(expr.prettyName) => (makeStructTypeSingleOne(expr.children.head.dataType, attr.nullable), attr.nullable) + case cd: CountDistinct => + var fields = Seq[(DataType, Boolean)]() + for (child <- cd.children) { + fields = fields :+ (child.dataType, child.nullable) + } + (makeStructType(fields), false) case _ => (makeStructTypeSingleOne(attr.dataType, attr.nullable), attr.nullable) } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala index 938f8e6d1..b40a0fe0d 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala @@ -19,8 +19,10 @@ package io.glutenproject.execution import io.glutenproject.GlutenConfig import io.glutenproject.utils.UTSystemParameters -import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf} -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.SparkConf +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.hive.HiveTableScanExecTransformer @@ -29,11 +31,11 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.hadoop.fs.Path import java.io.{File, PrintWriter} -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.reflect.ClassTag -case class AllDataTypesWithComplextType( +case class AllDataTypesWithComplexType( string_field: String = null, int_field: java.lang.Integer = null, long_field: java.lang.Long = null, @@ -51,6 +53,36 @@ case class AllDataTypesWithComplextType( mapValueContainsNull: Map[Int, Option[Long]] = null ) +object AllDataTypesWithComplexType { + def genTestData(): Seq[AllDataTypesWithComplexType] = { + (0 to 199).map { + i => + if (i % 100 == 1) { + AllDataTypesWithComplexType() + } else { + AllDataTypesWithComplexType( + s"$i", + i, + i.toLong, + i.toFloat, + i.toDouble, + i.toShort, + i.toByte, + i % 2 == 0, + new java.math.BigDecimal(i + ".56"), + Date.valueOf(new Date(System.currentTimeMillis()).toLocalDate.plusDays(i % 10)), + Timestamp.valueOf( + new Timestamp(System.currentTimeMillis()).toLocalDateTime.plusDays(i % 10)), + Seq.apply(i + 1, i + 2, i + 3), + Seq.apply(Option.apply(i + 1), Option.empty, Option.apply(i + 3)), + Map.apply((i + 1, i + 2), (i + 3, i + 4)), + Map.empty + ) + } + } + } +} + class GlutenClickHouseHiveTableSuite extends GlutenClickHouseWholeStageTransformerSuite with AdaptiveSparkPlanHelper { @@ -170,38 +202,13 @@ class GlutenClickHouseHiveTableSuite "map_field map<int, long>," + "map_field_with_null map<int, long>) stored as %s".format(fileFormat) - def genTestData(): Seq[AllDataTypesWithComplextType] = { - (0 to 199).map { - i => - if (i % 100 == 1) { - AllDataTypesWithComplextType() - } else { - AllDataTypesWithComplextType( - s"$i", - i, - i.toLong, - i.toFloat, - i.toDouble, - i.toShort, - i.toByte, - i % 2 == 0, - new java.math.BigDecimal(i + ".56"), - new java.sql.Date(System.currentTimeMillis()), - new Timestamp(System.currentTimeMillis()), - Seq.apply(i + 1, i + 2, i + 3), - Seq.apply(Option.apply(i + 1), Option.empty, Option.apply(i + 3)), - Map.apply((i + 1, i + 2), (i + 3, i + 4)), - Map.empty - ) - } - } - } - protected def initializeTable( table_name: String, table_create_sql: String, partitions: Seq[String]): Unit = { - spark.createDataFrame(genTestData()).createOrReplaceTempView("tmp_t") + spark + .createDataFrame(AllDataTypesWithComplexType.genTestData()) + .createOrReplaceTempView("tmp_t") val truncate_sql = "truncate table %s".format(table_name) val drop_sql = "drop table if exists %s".format(table_name) spark.sql(drop_sql) diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala index 5e9c2d1fc..724fb2721 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala @@ -17,6 +17,7 @@ package io.glutenproject.execution import io.glutenproject.GlutenConfig +import io.glutenproject.execution.AllDataTypesWithComplexType.genTestData import io.glutenproject.utils.UTSystemParameters import org.apache.spark.SparkConf @@ -27,8 +28,6 @@ import org.apache.spark.sql.test.SharedSparkSession import org.scalatest.BeforeAndAfterAll -import java.sql.{Date, Timestamp} - class GlutenClickHouseNativeWriteTableSuite extends GlutenClickHouseWholeStageTransformerSuite with AdaptiveSparkPlanHelper @@ -90,41 +89,6 @@ class GlutenClickHouseNativeWriteTableSuite private val table_name_vanilla_template = "hive_%s_test_written_by_vanilla" private val formats = Array("orc", "parquet") - def genTestData(): Seq[AllDataTypesWithComplextType] = { - (0 to 199).map { - i => - if (i % 100 == 1) { - AllDataTypesWithComplextType() - } else { - AllDataTypesWithComplextType( - s"$i", - i, - i.toLong, - i.toFloat, - i.toDouble, - i.toShort, - i.toByte, - i % 2 == 0, - new java.math.BigDecimal(i + ".56"), - Date.valueOf(new Date(System.currentTimeMillis()).toLocalDate.plusDays(i % 10)), - Timestamp.valueOf( - new Timestamp(System.currentTimeMillis()).toLocalDateTime.plusDays(i % 10)), - Seq.apply(i + 1, i + 2, i + 3), - Seq.apply(Option.apply(i + 1), Option.empty, Option.apply(i + 3)), - Map.apply((i + 1, i + 2), (i + 3, i + 4)), - Map.empty - ) - } - } - } - - protected def initializeTable(table_name: String, table_create_sql: String): Unit = { - spark.createDataFrame(genTestData()).createOrReplaceTempView("tmp_t") - spark.sql(s"drop table IF EXISTS $table_name") - spark.sql(table_create_sql) - spark.sql("insert into %s select * from tmp_t".format(table_name)) - } - override protected def afterAll(): Unit = { DeltaLog.clearCache() diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala index 28da546fc..1890d00e3 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala @@ -16,6 +16,12 @@ */ package io.glutenproject.execution +import io.glutenproject.GlutenConfig +import io.glutenproject.utils.UTSystemParameters + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig + import org.apache.commons.io.FileUtils import java.io.File @@ -37,6 +43,15 @@ class GlutenClickHouseWholeStageTransformerSuite extends WholeStageTransformerSu } } + override protected def sparkConf: SparkConf = + super.sparkConf + .set(GlutenConfig.GLUTEN_LIB_PATH, UTSystemParameters.getClickHouseLibPath()) + .set( + "spark.gluten.sql.columnar.backend.ch.use.v2", + ClickHouseConfig.DEFAULT_USE_DATASOURCE_V2) + .set("spark.gluten.sql.enable.native.validation", "false") + .set("spark.sql.warehouse.dir", warehouse) + override def beforeAll(): Unit = { // prepare working paths val basePathDir = new File(basePath) @@ -56,6 +71,6 @@ class GlutenClickHouseWholeStageTransformerSuite extends WholeStageTransformerSu protected val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db" override protected val backend: String = "ch" - override protected val resourcePath: String = "" + final override protected val resourcePath: String = "" // ch not need this override protected val fileFormat: String = "parquet" } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseCountDistinctSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseCountDistinctSuite.scala new file mode 100644 index 000000000..4ba9882b6 --- /dev/null +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseCountDistinctSuite.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.execution + +import io.glutenproject.execution.AllDataTypesWithComplexType.genTestData + +import org.apache.spark.SparkConf +class GlutenClickhouseCountDistinctSuite extends GlutenClickHouseWholeStageTransformerSuite { + + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.gluten.sql.countDistinctWithoutExpand", "true") + .set("spark.sql.adaptive.enabled", "false") + } + + test("check count distinct correctness") { + // simple case + var sql = "select count(distinct(a)) from values (1,1,1), (2,2,2) as data(a,b,c)" + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + // with null + sql = "select count(distinct(a)) from " + + "values (1,1,1), (2,2,2), (1,3,3), (null,4,4), (null,5,5) as data(a,b,c)" + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + // three CD + sql = "select count(distinct(b)), count(distinct(a)),count(distinct c) from " + + "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2) as data(a,b,c)" + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + // count distinct with multiple args + sql = "select count(distinct(a,b)), count(distinct(a,b,c)) from " + + "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2) as data(a,b,c)" + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + } + + test("check count distinct execution plan") { + val sql = + "select count(distinct(b)), count(distinct a, b) from " + + "values (0, null,1), (1, 1,1), (2, 2,1), (1, 2,1) ,(2,2,2) as data(a,b,c) group by c" + + val df = spark.sql(sql) + WholeStageTransformerSuite.checkFallBack(df) + + val planExecs = df.queryExecution.executedPlan.collect { + case aggTransformer: HashAggregateExecBaseTransformer => aggTransformer + } + + planExecs.head.aggregateExpressions.foreach { + expr => assert(expr.toString().startsWith("countdistinct")) + } + planExecs(1).aggregateExpressions.foreach { + expr => assert(expr.toString().startsWith("partial_countdistinct")) + } + } + + test("check all data types") { + spark.createDataFrame(genTestData()).createOrReplaceTempView("all_data_types") + + // Vanilla does not support map + for ( + field <- AllDataTypesWithComplexType().getClass.getDeclaredFields.filterNot( + p => p.getName.startsWith("map")) + ) { + val sql = s"select count(distinct(${field.getName})) from all_data_types" + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + spark.sql(sql).show + } + + // just test success run + for ( + field <- AllDataTypesWithComplexType().getClass.getDeclaredFields.filter( + p => p.getName.startsWith("map")) + ) { + val sql = s"select count(distinct(${field.getName})) from all_data_types" + spark.sql(sql).show + } + } + + test("check count distinct with agg fallback") { + // skewness agg is not supported, will cause fallback + val sql = "select count(distinct(a,b)) , skewness(b) from " + + "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2),(3,3,3) as data(a,b,c)" + assertThrows[UnsupportedOperationException] { + spark.sql(sql).show + } + } + + test("check count distinct with expr fallback") { + // try_add is not supported, will cause fallback after a project operator + val sql = s""" + select count(distinct(a,b)) , try_add(c,b) from + values (0, null,1), (0,null,2), (1, 1,4) as data(a,b,c) group by try_add(c,b) + """; + val df = spark.sql(sql) + WholeStageTransformerSuite.checkFallBack(df, noFallback = false) + } + + test("check count distinct with filter") { + val sql = "select count(distinct(a,b)) FILTER (where c <3) from " + + "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2),(3,3,3) as data(a,b,c)" + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + } +} diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala index 76ec0d7ad..9fbc2aa3b 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala @@ -16,7 +16,7 @@ */ package io.glutenproject.execution.extension -import io.glutenproject.expression._ +import io.glutenproject.expression.Sig import io.glutenproject.extension.ExpressionExtensionTrait import org.apache.spark.sql.catalyst.expressions._ diff --git a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h index bf2becf32..f77c253f5 100644 --- a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h +++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h @@ -50,7 +50,8 @@ private: public: AggregateFunctionPartialMerge(const AggregateFunctionPtr & nested_, const DataTypePtr & argument, const Array & params_) - : IAggregateFunctionHelper<AggregateFunctionPartialMerge>({argument}, params_, createResultType(nested_)), nested_func(nested_) + : IAggregateFunctionHelper<AggregateFunctionPartialMerge>({argument}, params_, createResultType(nested_)) + , nested_func(nested_) { const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(argument.get()); @@ -115,5 +116,4 @@ public: AggregateFunctionPtr getNestedFunction() const override { return nested_func; } }; - } diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h index a9840eeef..e2444361f 100644 --- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h +++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h @@ -29,7 +29,6 @@ namespace local_engine { - class AggregateFunctionParser { public: @@ -78,7 +77,11 @@ public: } }; - AggregateFunctionParser(SerializedPlanParser * plan_parser_) : plan_parser(plan_parser_) { } + AggregateFunctionParser(SerializedPlanParser * plan_parser_) + : plan_parser(plan_parser_) + { + } + virtual ~AggregateFunctionParser() = default; virtual String getName() const = 0; @@ -93,6 +96,7 @@ public: /// Do some preprojections for the function arguments, and return the necessary arguments for the CH function. virtual DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(const CommonFunctionInfo & func_info, const String & ch_func_name, DB::ActionsDAGPtr & actions_dag) const; + DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(const CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const { return parseFunctionArguments(func_info, getCHFunctionName(func_info), actions_dag); @@ -106,7 +110,9 @@ public: /// Make a postprojection for the function result. virtual const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( - const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag, bool withNullability) const; + const CommonFunctionInfo & func_info, + const DB::ActionsDAG::Node * func_node, + DB::ActionsDAGPtr & actions_dag, bool withNullability) const; /// Parameters are only used in aggregate functions at present. e.g. percentiles(0.5)(x). /// 0.5 is the parameter of percentiles function. @@ -159,6 +165,7 @@ protected: SerializedPlanParser * plan_parser; Poco::Logger * logger = &Poco::Logger::get("AggregateFunctionParserFactory"); }; + using AggregateFunctionParserPtr = std::shared_ptr<AggregateFunctionParser>; using AggregateFunctionParserCreator = std::function<AggregateFunctionParserPtr(SerializedPlanParser *)>; @@ -200,5 +207,4 @@ struct AggregateFunctionParserRegister { AggregateFunctionParserRegister() { AggregateFunctionParserFactory::instance().registerAggregateFunctionParser<Parser>(Parser::name); } }; - } diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp index fc334a214..afe65f793 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp @@ -47,5 +47,5 @@ REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(RowNumber, row_number, row_number) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Ntile, ntile, ntile) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(PercentRank, percent_rank, percent_rank) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CumeDist, cume_dist, cume_dist) - +REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CountDistinct, count_distinct, uniqExact) } diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala index b6a874756..a063c244b 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala @@ -251,6 +251,7 @@ object ExpressionMappings { Sig[Sum](SUM), Sig[Average](AVG), Sig[Count](COUNT), + Sig[CountDistinct](COUNT_DISTINCT), Sig[Min](MIN), Sig[Max](MAX), Sig[MaxBy](MAX_BY), diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/CountDistinctWithoutExpand.scala b/gluten-core/src/main/scala/io/glutenproject/extension/CountDistinctWithoutExpand.scala new file mode 100644 index 000000000..c7f57030f --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/extension/CountDistinctWithoutExpand.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.extension + +import io.glutenproject.GlutenConfig + +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, CountDistinct} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE_EXPRESSION + +/** + * By converting Count(with isDistinct=true) to a UDAF called CountDistinct, we can avoid the Expand + * operator in the physical plan. + * + * This rule takes no effect unless spark.gluten.enabled and + * spark.gluten.sql.countDistinctWithoutExpand are both true + */ +object CountDistinctWithoutExpand extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + if ( + GlutenConfig.getConf.enableGluten && GlutenConfig.getConf.enableCountDistinctWithoutExpand + ) { + plan.transformAllExpressionsWithPruning(_.containsPattern(AGGREGATE_EXPRESSION)) { + case ae: AggregateExpression if ae.isDistinct && ae.aggregateFunction.isInstanceOf[Count] => + ae.copy( + aggregateFunction = + CountDistinct.apply(ae.aggregateFunction.asInstanceOf[Count].children), + isDistinct = false) + } + } else { + plan + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountDistinct.scala b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountDistinct.scala new file mode 100644 index 000000000..54918fcc7 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountDistinct.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * By default, spark execute count distinct as Expand + Count. This works reliably but may be + * slower. We allow user to inject a custom count distinct function to speed up the execution. Check + * the optimizer rule at CountDistinctWithoutExpand + */ +case class CountDistinct(children: Seq[Expression]) extends DeclarativeAggregate { + + override def nullable: Boolean = false + + // Return data type. + override def dataType: DataType = LongType + + protected lazy val cd = AttributeReference("count_distinct", LongType, nullable = false)() + + override lazy val aggBufferAttributes = cd :: Nil + + override lazy val initialValues = + throw new UnsupportedOperationException( + "count distinct does not have non-columnar implementation") + + override lazy val mergeExpressions = + throw new UnsupportedOperationException( + "count distinct does not have non-columnar implementation") + + override lazy val evaluateExpression = + throw new UnsupportedOperationException( + "count distinct does not have non-columnar implementation") + + override def defaultResult: Option[Literal] = Option(Literal(0L)) + + override lazy val updateExpressions = + throw new UnsupportedOperationException( + "count distinct does not have non-columnar implementation") + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CountDistinct = + copy(children = newChildren) +} diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala index 87f5fc4f0..8a8bd4152 100644 --- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala +++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala @@ -94,6 +94,9 @@ class GlutenConfig(conf: SQLConf) extends Logging { def enableCommonSubexpressionEliminate: Boolean = conf.getConf(ENABLE_COMMON_SUBEXPRESSION_ELIMINATE) + def enableCountDistinctWithoutExpand: Boolean = + conf.getConf(ENABLE_COUNT_DISTINCT_WITHOUT_EXPAND) + def veloxOrcScanEnabled: Boolean = conf.getConf(VELOX_ORC_SCAN_ENABLED) @@ -1537,6 +1540,16 @@ object GlutenConfig { .booleanConf .createWithDefault(true) + val ENABLE_COUNT_DISTINCT_WITHOUT_EXPAND = + buildConf("spark.gluten.sql.countDistinctWithoutExpand") + .internal() + .doc( + "Convert Count Distinct to a UDAF called count_distinct to " + + "prevent SparkPlanner converting it to Expand+Count. WARNING: " + + "When enabled, count distinct queries will fail to fallback!!!") + .booleanConf + .createWithDefault(false) + val COLUMNAR_VELOX_BLOOM_FILTER_EXPECTED_NUM_ITEMS = buildConf("spark.gluten.sql.columnar.backend.velox.bloomFilter.expectedNumItems") .internal() diff --git a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala index c688c8abd..d2fc4f9ec 100644 --- a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala @@ -22,6 +22,7 @@ object ExpressionNames { final val SUM = "sum" final val AVG = "avg" final val COUNT = "count" + final val COUNT_DISTINCT = "count_distinct" final val MIN = "min" final val MAX = "max" final val MAX_BY = "max_by" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org For additional commands, e-mail: commits-h...@gluten.apache.org