This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 49bf503  feat: Support Variance (#297)
49bf503 is described below

commit 49bf503780dfc93ede968b626214d53fb0953c34
Author: Huaxin Gao <[email protected]>
AuthorDate: Thu Apr 25 11:12:22 2024 -0700

    feat: Support Variance (#297)
    
    * feat: Support Variance
    
    * Add StatisticsType in expr.poto
    
    * add explainPlan info and fix fmt
    
    * remove iunnecessary cast
    
    * remove unused import
    
    ---------
    
    Co-authored-by: Huaxin Gao <[email protected]>
---
 EXPRESSIONS.md                                     |   5 +-
 core/src/execution/datafusion/expressions/mod.rs   |   1 +
 .../execution/datafusion/expressions/variance.rs   | 256 +++++++++++++++++++++
 core/src/execution/datafusion/planner.rs           |  25 ++
 core/src/execution/proto/expr.proto                |  13 ++
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  42 +++-
 .../apache/comet/exec/CometAggregateSuite.scala    |  52 +++++
 7 files changed, 392 insertions(+), 2 deletions(-)

diff --git a/EXPRESSIONS.md b/EXPRESSIONS.md
index 45c3684..f0a2f69 100644
--- a/EXPRESSIONS.md
+++ b/EXPRESSIONS.md
@@ -103,4 +103,7 @@ The following Spark expressions are currently available:
     + BitXor
     + BoolAnd
     + BoolOr
-    + Covariance
+    + CovPopulation
+    + CovSample
+    + VariancePop
+    + VarianceSamp
diff --git a/core/src/execution/datafusion/expressions/mod.rs 
b/core/src/execution/datafusion/expressions/mod.rs
index 799790c..78763fc 100644
--- a/core/src/execution/datafusion/expressions/mod.rs
+++ b/core/src/execution/datafusion/expressions/mod.rs
@@ -34,3 +34,4 @@ pub mod subquery;
 pub mod sum_decimal;
 pub mod temporal;
 mod utils;
+pub mod variance;
diff --git a/core/src/execution/datafusion/expressions/variance.rs 
b/core/src/execution/datafusion/expressions/variance.rs
new file mode 100644
index 0000000..6aae01e
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/variance.rs
@@ -0,0 +1,256 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expressions that can evaluated at runtime during query 
execution
+
+use std::{any::Any, sync::Arc};
+
+use crate::execution::datafusion::expressions::{stats::StatsType, 
utils::down_cast_any_ref};
+use arrow::{
+    array::{ArrayRef, Float64Array},
+    datatypes::{DataType, Field},
+};
+use datafusion::logical_expr::Accumulator;
+use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue};
+use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, 
PhysicalExpr};
+
+/// VAR_SAMP and VAR_POP aggregate expression
+/// The implementation mostly is the same as the DataFusion's implementation. 
The reason
+/// we have our own implementation is that DataFusion has UInt64 for 
state_field `count`,
+/// while Spark has Double for count. Also we have added 
`null_on_divide_by_zero`
+/// to be consistent with Spark's implementation.
+#[derive(Debug)]
+pub struct Variance {
+    name: String,
+    expr: Arc<dyn PhysicalExpr>,
+    stats_type: StatsType,
+    null_on_divide_by_zero: bool,
+}
+
+impl Variance {
+    /// Create a new VARIANCE aggregate function
+    pub fn new(
+        expr: Arc<dyn PhysicalExpr>,
+        name: impl Into<String>,
+        data_type: DataType,
+        stats_type: StatsType,
+        null_on_divide_by_zero: bool,
+    ) -> Self {
+        // the result of variance just support FLOAT64 data type.
+        assert!(matches!(data_type, DataType::Float64));
+        Self {
+            name: name.into(),
+            expr,
+            stats_type,
+            null_on_divide_by_zero,
+        }
+    }
+}
+
+impl AggregateExpr for Variance {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn field(&self) -> Result<Field> {
+        Ok(Field::new(&self.name, DataType::Float64, true))
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(VarianceAccumulator::try_new(
+            self.stats_type,
+            self.null_on_divide_by_zero,
+        )?))
+    }
+
+    fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(VarianceAccumulator::try_new(
+            self.stats_type,
+            self.null_on_divide_by_zero,
+        )?))
+    }
+
+    fn state_fields(&self) -> Result<Vec<Field>> {
+        Ok(vec![
+            Field::new(
+                format_state_name(&self.name, "count"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "mean"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(format_state_name(&self.name, "m2"), DataType::Float64, 
true),
+        ])
+    }
+
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.expr.clone()]
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+}
+
+impl PartialEq<dyn Any> for Variance {
+    fn eq(&self, other: &dyn Any) -> bool {
+        down_cast_any_ref(other)
+            .downcast_ref::<Self>()
+            .map(|x| {
+                self.name == x.name && self.expr.eq(&x.expr) && 
self.stats_type == x.stats_type
+            })
+            .unwrap_or(false)
+    }
+}
+
+/// An accumulator to compute variance
+#[derive(Debug)]
+pub struct VarianceAccumulator {
+    m2: f64,
+    mean: f64,
+    count: f64,
+    stats_type: StatsType,
+    null_on_divide_by_zero: bool,
+}
+
+impl VarianceAccumulator {
+    /// Creates a new `VarianceAccumulator`
+    pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> 
Result<Self> {
+        Ok(Self {
+            m2: 0_f64,
+            mean: 0_f64,
+            count: 0_f64,
+            stats_type: s_type,
+            null_on_divide_by_zero,
+        })
+    }
+
+    pub fn get_count(&self) -> f64 {
+        self.count
+    }
+
+    pub fn get_mean(&self) -> f64 {
+        self.mean
+    }
+
+    pub fn get_m2(&self) -> f64 {
+        self.m2
+    }
+}
+
+impl Accumulator for VarianceAccumulator {
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        Ok(vec![
+            ScalarValue::from(self.count),
+            ScalarValue::from(self.mean),
+            ScalarValue::from(self.m2),
+        ])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
+
+        for value in arr {
+            let new_count = self.count + 1.0;
+            let delta1 = value - self.mean;
+            let new_mean = delta1 / new_count + self.mean;
+            let delta2 = value - new_mean;
+            let new_m2 = self.m2 + delta1 * delta2;
+
+            self.count += 1.0;
+            self.mean = new_mean;
+            self.m2 = new_m2;
+        }
+
+        Ok(())
+    }
+
+    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
+
+        for value in arr {
+            let new_count = self.count - 1.0;
+            let delta1 = self.mean - value;
+            let new_mean = delta1 / new_count + self.mean;
+            let delta2 = new_mean - value;
+            let new_m2 = self.m2 - delta1 * delta2;
+
+            self.count -= 1.0;
+            self.mean = new_mean;
+            self.m2 = new_m2;
+        }
+
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        let counts = downcast_value!(states[0], Float64Array);
+        let means = downcast_value!(states[1], Float64Array);
+        let m2s = downcast_value!(states[2], Float64Array);
+
+        for i in 0..counts.len() {
+            let c = counts.value(i);
+            if c == 0_f64 {
+                continue;
+            }
+            let new_count = self.count + c;
+            let new_mean = self.mean * self.count / new_count + means.value(i) 
* c / new_count;
+            let delta = self.mean - means.value(i);
+            let new_m2 = self.m2 + m2s.value(i) + delta * delta * self.count * 
c / new_count;
+
+            self.count = new_count;
+            self.mean = new_mean;
+            self.m2 = new_m2;
+        }
+        Ok(())
+    }
+
+    fn evaluate(&mut self) -> Result<ScalarValue> {
+        let count = match self.stats_type {
+            StatsType::Population => self.count,
+            StatsType::Sample => {
+                if self.count > 0.0 {
+                    self.count - 1.0
+                } else {
+                    self.count
+                }
+            }
+        };
+
+        Ok(ScalarValue::Float64(match self.count {
+            count if count == 0.0 => None,
+            count if count == 1.0 => {
+                if let StatsType::Population = self.stats_type {
+                    Some(0.0)
+                } else if self.null_on_divide_by_zero {
+                    None
+                } else {
+                    Some(f64::NAN)
+                }
+            }
+            _ => Some(self.m2 / count),
+        }))
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+    }
+}
diff --git a/core/src/execution/datafusion/planner.rs 
b/core/src/execution/datafusion/planner.rs
index 5c379d4..7217479 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -75,6 +75,7 @@ use crate::{
                 subquery::Subquery,
                 sum_decimal::SumDecimal,
                 temporal::{DateTruncExec, HourExec, MinuteExec, SecondExec, 
TimestampTruncExec},
+                variance::Variance,
                 NormalizeNaNAndZero,
             },
             operators::expand::CometExpandExec,
@@ -1235,6 +1236,30 @@ impl PhysicalPlanner {
                     StatsType::Population,
                 )))
             }
+            AggExprStruct::Variance(expr) => {
+                let child = self.create_expr(expr.child.as_ref().unwrap(), 
schema.clone())?;
+                let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+                match expr.stats_type {
+                    0 => Ok(Arc::new(Variance::new(
+                        child,
+                        "variance",
+                        datatype,
+                        StatsType::Sample,
+                        expr.null_on_divide_by_zero,
+                    ))),
+                    1 => Ok(Arc::new(Variance::new(
+                        child,
+                        "variance_pop",
+                        datatype,
+                        StatsType::Population,
+                        expr.null_on_divide_by_zero,
+                    ))),
+                    stats_type => Err(ExecutionError::GeneralError(format!(
+                        "Unknown StatisticsType {:?} for Variance",
+                        stats_type
+                    ))),
+                }
+            }
         }
     }
 
diff --git a/core/src/execution/proto/expr.proto 
b/core/src/execution/proto/expr.proto
index afe75ec..042a981 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -94,9 +94,15 @@ message AggExpr {
     BitXorAgg bitXorAgg = 11;
     CovSample covSample = 12;
     CovPopulation covPopulation = 13;
+    Variance variance = 14;
   }
 }
 
+enum StatisticsType {
+  SAMPLE = 0;
+  POPULATION = 1;
+}
+
 message Count {
    repeated Expr children = 1;
 }
@@ -165,6 +171,13 @@ message CovPopulation {
   DataType datatype = 4;
 }
 
+message Variance {
+  Expr child = 1;
+  bool null_on_divide_by_zero = 2;
+  DataType datatype = 3;
+  StatisticsType stats_type = 4;
+}
+
 message Literal {
   oneof value {
     bool bool_val = 1;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 9a12930..d08fb6b 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, 
Final, First, Last, Max, Min, Partial, Sum}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, 
Final, First, Last, Max, Min, Partial, Sum, VariancePop, VarianceSamp}
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.optimizer.{BuildRight, 
NormalizeNaNAndZero}
 import org.apache.spark.sql.catalyst.plans._
@@ -464,6 +464,46 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
         } else {
           None
         }
+      case variance @ VarianceSamp(child, nullOnDivideByZero) =>
+        val childExpr = exprToProto(child, inputs, binding)
+        val dataType = serializeDataType(variance.dataType)
+
+        if (childExpr.isDefined && dataType.isDefined) {
+          val varBuilder = ExprOuterClass.Variance.newBuilder()
+          varBuilder.setChild(childExpr.get)
+          varBuilder.setNullOnDivideByZero(nullOnDivideByZero)
+          varBuilder.setDatatype(dataType.get)
+          varBuilder.setStatsTypeValue(0)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setVariance(varBuilder)
+              .build())
+        } else {
+          withInfo(aggExpr, child)
+          None
+        }
+      case variancePop @ VariancePop(child, nullOnDivideByZero) =>
+        val childExpr = exprToProto(child, inputs, binding)
+        val dataType = serializeDataType(variancePop.dataType)
+
+        if (childExpr.isDefined && dataType.isDefined) {
+          val varBuilder = ExprOuterClass.Variance.newBuilder()
+          varBuilder.setChild(childExpr.get)
+          varBuilder.setNullOnDivideByZero(nullOnDivideByZero)
+          varBuilder.setDatatype(dataType.get)
+          varBuilder.setStatsTypeValue(1)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setVariance(varBuilder)
+              .build())
+        } else {
+          withInfo(aggExpr, child)
+          None
+        }
       case fn =>
         val msg = s"unsupported Spark aggregate function: ${fn.prettyName}"
         emitWarning(msg)
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index f6415cb..bd4042e 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1117,6 +1117,46 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("var_pop and var_samp") {
+    withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
+      Seq(true, false).foreach { cometColumnShuffleEnabled =>
+        withSQLConf(
+          CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> 
cometColumnShuffleEnabled.toString) {
+          Seq(true, false).foreach { dictionary =>
+            withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+              Seq(true, false).foreach { nullOnDivideByZero =>
+                withSQLConf(
+                  "spark.sql.legacy.statisticalAggregate" -> 
nullOnDivideByZero.toString) {
+                  val table = "test"
+                  withTable(table) {
+                    sql(s"create table $table(col1 int, col2 int, col3 int, 
col4 float, col5 double, col6 int) using parquet")
+                    sql(s"insert into $table values(1, null, null, 1.1, 2.2, 
1)," +
+                      " (2, null, null, 3.4, 5.6, 1), (3, null, 4, 7.9, 2.4, 
2)")
+                    val expectedNumOfCometAggregates = 2
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      "SELECT var_samp(col1), var_samp(col2), var_samp(col3), 
var_samp(col4), var_samp(col5) FROM test",
+                      expectedNumOfCometAggregates)
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      "SELECT var_pop(col1), var_pop(col2), var_pop(col3), 
var_pop(col4), var_samp(col5) FROM test",
+                      expectedNumOfCometAggregates)
+                    checkSparkAnswerAndNumOfAggregates(
+                      "SELECT var_samp(col1), var_samp(col2), var_samp(col3), 
var_samp(col4), var_samp(col5)" +
+                        " FROM test GROUP BY col6",
+                      expectedNumOfCometAggregates)
+                    checkSparkAnswerAndNumOfAggregates(
+                      "SELECT var_pop(col1), var_pop(col2), var_pop(col3), 
var_pop(col4), var_samp(col5)" +
+                        " FROM test GROUP BY col6",
+                      expectedNumOfCometAggregates)
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
   protected def checkSparkAnswerAndNumOfAggregates(query: String, 
numAggregates: Int): Unit = {
     val df = sql(query)
     checkSparkAnswer(df)
@@ -1126,6 +1166,18 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       s"Expected $numAggregates Comet aggregate operators, but found 
$actualNumAggregates")
   }
 
+  protected def checkSparkAnswerWithTolAndNumOfAggregates(
+      query: String,
+      numAggregates: Int,
+      absTol: Double = 1e-6): Unit = {
+    val df = sql(query)
+    checkSparkAnswerWithTol(df, absTol)
+    val actualNumAggregates = getNumCometHashAggregate(df)
+    assert(
+      actualNumAggregates == numAggregates,
+      s"Expected $numAggregates Comet aggregate operators, but found 
$actualNumAggregates")
+  }
+
   def getNumCometHashAggregate(df: DataFrame): Int = {
     val sparkPlan = stripAQEPlan(df.queryExecution.executedPlan)
     sparkPlan.collect { case s: CometHashAggregateExec => s }.size


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to