This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e7c26bb6 perf: Use DataFusion's `count_udaf` instead of `SUM(IF(expr 
IS NOT NULL, 1, 0))` (#2407)
0e7c26bb6 is described below

commit 0e7c26bb6d6535bbfaa3b2cf26f4bf5f95a95ed0
Author: Andy Grove <[email protected]>
AuthorDate: Tue Sep 16 17:38:45 2025 -0600

    perf: Use DataFusion's `count_udaf` instead of `SUM(IF(expr IS NOT NULL, 1, 
0))` (#2407)
---
 .gitignore                                         |  1 +
 native/core/src/execution/planner.rs               | 25 ++--------------
 .../sql/benchmark/CometAggregateBenchmark.scala    | 34 +++-------------------
 3 files changed, 7 insertions(+), 53 deletions(-)

diff --git a/.gitignore b/.gitignore
index 4157bf6f2..94877ced7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,3 +17,4 @@ dev/dist
 apache-rat-*.jar
 venv
 dev/release/comet-rm/workdir
+spark/benchmarks
diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index 6051a459e..0e832599d 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -30,6 +30,7 @@ use crate::{
 use arrow::compute::CastOptions;
 use arrow::datatypes::{DataType, Field, Schema, TimeUnit, 
DECIMAL128_MAX_PRECISION};
 use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, 
bit_or_udaf, bit_xor_udaf};
+use datafusion::functions_aggregate::count::count_udaf;
 use datafusion::functions_aggregate::min_max::max_udaf;
 use datafusion::functions_aggregate::min_max::min_udaf;
 use datafusion::functions_aggregate::sum::sum_udaf;
@@ -1904,35 +1905,13 @@ impl PhysicalPlanner {
         match spark_expr.expr_struct.as_ref().unwrap() {
             AggExprStruct::Count(expr) => {
                 assert!(!expr.children.is_empty());
-                // Using `count_udaf` from Comet is exceptionally slow for 
some reason, so
-                // as a workaround we translate it to `SUM(IF(expr IS NOT 
NULL, 1, 0))`
-                // https://github.com/apache/datafusion-comet/issues/744
-
                 let children = expr
                     .children
                     .iter()
                     .map(|child| self.create_expr(child, Arc::clone(&schema)))
                     .collect::<Result<Vec<_>, _>>()?;
 
-                // create `IS NOT NULL expr` and join them with `AND` if there 
are multiple
-                let not_null_expr: Arc<dyn PhysicalExpr> = 
children.iter().skip(1).fold(
-                    Arc::new(IsNotNullExpr::new(Arc::clone(&children[0]))) as 
Arc<dyn PhysicalExpr>,
-                    |acc, child| {
-                        Arc::new(BinaryExpr::new(
-                            acc,
-                            DataFusionOperator::And,
-                            Arc::new(IsNotNullExpr::new(Arc::clone(child))),
-                        ))
-                    },
-                );
-
-                let child = Arc::new(IfExpr::new(
-                    not_null_expr,
-                    Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
-                    Arc::new(Literal::new(ScalarValue::Int64(Some(0)))),
-                ));
-
-                AggregateExprBuilder::new(sum_udaf(), vec![child])
+                AggregateExprBuilder::new(count_udaf(), children)
                     .schema(schema)
                     .alias("count")
                     .with_ignore_nulls(false)
diff --git 
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala
 
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala
index 86b59050e..47fbe354f 100644
--- 
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala
+++ 
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala
@@ -64,13 +64,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
           spark.sql(query).noop()
         }
 
-        benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") 
{ _ =>
-          withSQLConf(CometConf.COMET_ENABLED.key -> "true") {
-            spark.sql(query).noop()
-          }
-        }
-
-        benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) 
($aggregateFunction)") { _ =>
+        benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
           withSQLConf(
             CometConf.COMET_ENABLED.key -> "true",
             CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -111,13 +105,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
           spark.sql(query).noop()
         }
 
-        benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") 
{ _ =>
-          withSQLConf(CometConf.COMET_ENABLED.key -> "true") {
-            spark.sql(query).noop()
-          }
-        }
-
-        benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) 
($aggregateFunction)") { _ =>
+        benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
           withSQLConf(
             CometConf.COMET_ENABLED.key -> "true",
             CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -153,15 +141,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
           spark.sql(query).noop()
         }
 
-        benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") 
{ _ =>
-          withSQLConf(
-            CometConf.COMET_ENABLED.key -> "true",
-            CometConf.COMET_MEMORY_OVERHEAD.key -> "1G") {
-            spark.sql(query).noop()
-          }
-        }
-
-        benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) 
($aggregateFunction)") { _ =>
+        benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
           withSQLConf(
             CometConf.COMET_ENABLED.key -> "true",
             CometConf.COMET_EXEC_ENABLED.key -> "true",
@@ -198,13 +178,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
           spark.sql(query).noop()
         }
 
-        benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") 
{ _ =>
-          withSQLConf(CometConf.COMET_ENABLED.key -> "true") {
-            spark.sql(query).noop()
-          }
-        }
-
-        benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) 
($aggregateFunction)") { _ =>
+        benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
           withSQLConf(
             CometConf.COMET_ENABLED.key -> "true",
             CometConf.COMET_EXEC_ENABLED.key -> "true") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to