This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 49bf503 feat: Support Variance (#297)
49bf503 is described below
commit 49bf503780dfc93ede968b626214d53fb0953c34
Author: Huaxin Gao <[email protected]>
AuthorDate: Thu Apr 25 11:12:22 2024 -0700
feat: Support Variance (#297)
* feat: Support Variance
* Add StatisticsType in expr.poto
* add explainPlan info and fix fmt
* remove iunnecessary cast
* remove unused import
---------
Co-authored-by: Huaxin Gao <[email protected]>
---
EXPRESSIONS.md | 5 +-
core/src/execution/datafusion/expressions/mod.rs | 1 +
.../execution/datafusion/expressions/variance.rs | 256 +++++++++++++++++++++
core/src/execution/datafusion/planner.rs | 25 ++
core/src/execution/proto/expr.proto | 13 ++
.../org/apache/comet/serde/QueryPlanSerde.scala | 42 +++-
.../apache/comet/exec/CometAggregateSuite.scala | 52 +++++
7 files changed, 392 insertions(+), 2 deletions(-)
diff --git a/EXPRESSIONS.md b/EXPRESSIONS.md
index 45c3684..f0a2f69 100644
--- a/EXPRESSIONS.md
+++ b/EXPRESSIONS.md
@@ -103,4 +103,7 @@ The following Spark expressions are currently available:
+ BitXor
+ BoolAnd
+ BoolOr
- + Covariance
+ + CovPopulation
+ + CovSample
+ + VariancePop
+ + VarianceSamp
diff --git a/core/src/execution/datafusion/expressions/mod.rs
b/core/src/execution/datafusion/expressions/mod.rs
index 799790c..78763fc 100644
--- a/core/src/execution/datafusion/expressions/mod.rs
+++ b/core/src/execution/datafusion/expressions/mod.rs
@@ -34,3 +34,4 @@ pub mod subquery;
pub mod sum_decimal;
pub mod temporal;
mod utils;
+pub mod variance;
diff --git a/core/src/execution/datafusion/expressions/variance.rs
b/core/src/execution/datafusion/expressions/variance.rs
new file mode 100644
index 0000000..6aae01e
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/variance.rs
@@ -0,0 +1,256 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expressions that can evaluated at runtime during query
execution
+
+use std::{any::Any, sync::Arc};
+
+use crate::execution::datafusion::expressions::{stats::StatsType,
utils::down_cast_any_ref};
+use arrow::{
+ array::{ArrayRef, Float64Array},
+ datatypes::{DataType, Field},
+};
+use datafusion::logical_expr::Accumulator;
+use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue};
+use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr,
PhysicalExpr};
+
+/// VAR_SAMP and VAR_POP aggregate expression
+/// The implementation mostly is the same as the DataFusion's implementation.
The reason
+/// we have our own implementation is that DataFusion has UInt64 for
state_field `count`,
+/// while Spark has Double for count. Also we have added
`null_on_divide_by_zero`
+/// to be consistent with Spark's implementation.
+#[derive(Debug)]
+pub struct Variance {
+ name: String,
+ expr: Arc<dyn PhysicalExpr>,
+ stats_type: StatsType,
+ null_on_divide_by_zero: bool,
+}
+
+impl Variance {
+ /// Create a new VARIANCE aggregate function
+ pub fn new(
+ expr: Arc<dyn PhysicalExpr>,
+ name: impl Into<String>,
+ data_type: DataType,
+ stats_type: StatsType,
+ null_on_divide_by_zero: bool,
+ ) -> Self {
+ // the result of variance just support FLOAT64 data type.
+ assert!(matches!(data_type, DataType::Float64));
+ Self {
+ name: name.into(),
+ expr,
+ stats_type,
+ null_on_divide_by_zero,
+ }
+ }
+}
+
+impl AggregateExpr for Variance {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(&self.name, DataType::Float64, true))
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(VarianceAccumulator::try_new(
+ self.stats_type,
+ self.null_on_divide_by_zero,
+ )?))
+ }
+
+ fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(VarianceAccumulator::try_new(
+ self.stats_type,
+ self.null_on_divide_by_zero,
+ )?))
+ }
+
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![
+ Field::new(
+ format_state_name(&self.name, "count"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "mean"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(format_state_name(&self.name, "m2"), DataType::Float64,
true),
+ ])
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr.clone()]
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+}
+
+impl PartialEq<dyn Any> for Variance {
+ fn eq(&self, other: &dyn Any) -> bool {
+ down_cast_any_ref(other)
+ .downcast_ref::<Self>()
+ .map(|x| {
+ self.name == x.name && self.expr.eq(&x.expr) &&
self.stats_type == x.stats_type
+ })
+ .unwrap_or(false)
+ }
+}
+
+/// An accumulator to compute variance
+#[derive(Debug)]
+pub struct VarianceAccumulator {
+ m2: f64,
+ mean: f64,
+ count: f64,
+ stats_type: StatsType,
+ null_on_divide_by_zero: bool,
+}
+
+impl VarianceAccumulator {
+ /// Creates a new `VarianceAccumulator`
+ pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) ->
Result<Self> {
+ Ok(Self {
+ m2: 0_f64,
+ mean: 0_f64,
+ count: 0_f64,
+ stats_type: s_type,
+ null_on_divide_by_zero,
+ })
+ }
+
+ pub fn get_count(&self) -> f64 {
+ self.count
+ }
+
+ pub fn get_mean(&self) -> f64 {
+ self.mean
+ }
+
+ pub fn get_m2(&self) -> f64 {
+ self.m2
+ }
+}
+
+impl Accumulator for VarianceAccumulator {
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![
+ ScalarValue::from(self.count),
+ ScalarValue::from(self.mean),
+ ScalarValue::from(self.m2),
+ ])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
+
+ for value in arr {
+ let new_count = self.count + 1.0;
+ let delta1 = value - self.mean;
+ let new_mean = delta1 / new_count + self.mean;
+ let delta2 = value - new_mean;
+ let new_m2 = self.m2 + delta1 * delta2;
+
+ self.count += 1.0;
+ self.mean = new_mean;
+ self.m2 = new_m2;
+ }
+
+ Ok(())
+ }
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let arr = downcast_value!(&values[0], Float64Array).iter().flatten();
+
+ for value in arr {
+ let new_count = self.count - 1.0;
+ let delta1 = self.mean - value;
+ let new_mean = delta1 / new_count + self.mean;
+ let delta2 = new_mean - value;
+ let new_m2 = self.m2 - delta1 * delta2;
+
+ self.count -= 1.0;
+ self.mean = new_mean;
+ self.m2 = new_m2;
+ }
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let counts = downcast_value!(states[0], Float64Array);
+ let means = downcast_value!(states[1], Float64Array);
+ let m2s = downcast_value!(states[2], Float64Array);
+
+ for i in 0..counts.len() {
+ let c = counts.value(i);
+ if c == 0_f64 {
+ continue;
+ }
+ let new_count = self.count + c;
+ let new_mean = self.mean * self.count / new_count + means.value(i)
* c / new_count;
+ let delta = self.mean - means.value(i);
+ let new_m2 = self.m2 + m2s.value(i) + delta * delta * self.count *
c / new_count;
+
+ self.count = new_count;
+ self.mean = new_mean;
+ self.m2 = new_m2;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ let count = match self.stats_type {
+ StatsType::Population => self.count,
+ StatsType::Sample => {
+ if self.count > 0.0 {
+ self.count - 1.0
+ } else {
+ self.count
+ }
+ }
+ };
+
+ Ok(ScalarValue::Float64(match self.count {
+ count if count == 0.0 => None,
+ count if count == 1.0 => {
+ if let StatsType::Population = self.stats_type {
+ Some(0.0)
+ } else if self.null_on_divide_by_zero {
+ None
+ } else {
+ Some(f64::NAN)
+ }
+ }
+ _ => Some(self.m2 / count),
+ }))
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+}
diff --git a/core/src/execution/datafusion/planner.rs
b/core/src/execution/datafusion/planner.rs
index 5c379d4..7217479 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -75,6 +75,7 @@ use crate::{
subquery::Subquery,
sum_decimal::SumDecimal,
temporal::{DateTruncExec, HourExec, MinuteExec, SecondExec,
TimestampTruncExec},
+ variance::Variance,
NormalizeNaNAndZero,
},
operators::expand::CometExpandExec,
@@ -1235,6 +1236,30 @@ impl PhysicalPlanner {
StatsType::Population,
)))
}
+ AggExprStruct::Variance(expr) => {
+ let child = self.create_expr(expr.child.as_ref().unwrap(),
schema.clone())?;
+ let datatype =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+ match expr.stats_type {
+ 0 => Ok(Arc::new(Variance::new(
+ child,
+ "variance",
+ datatype,
+ StatsType::Sample,
+ expr.null_on_divide_by_zero,
+ ))),
+ 1 => Ok(Arc::new(Variance::new(
+ child,
+ "variance_pop",
+ datatype,
+ StatsType::Population,
+ expr.null_on_divide_by_zero,
+ ))),
+ stats_type => Err(ExecutionError::GeneralError(format!(
+ "Unknown StatisticsType {:?} for Variance",
+ stats_type
+ ))),
+ }
+ }
}
}
diff --git a/core/src/execution/proto/expr.proto
b/core/src/execution/proto/expr.proto
index afe75ec..042a981 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -94,9 +94,15 @@ message AggExpr {
BitXorAgg bitXorAgg = 11;
CovSample covSample = 12;
CovPopulation covPopulation = 13;
+ Variance variance = 14;
}
}
+enum StatisticsType {
+ SAMPLE = 0;
+ POPULATION = 1;
+}
+
message Count {
repeated Expr children = 1;
}
@@ -165,6 +171,13 @@ message CovPopulation {
DataType datatype = 4;
}
+message Variance {
+ Expr child = 1;
+ bool null_on_divide_by_zero = 2;
+ DataType datatype = 3;
+ StatisticsType stats_type = 4;
+}
+
message Literal {
oneof value {
bool bool_val = 1;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 9a12930..d08fb6b 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample,
Final, First, Last, Max, Min, Partial, Sum}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample,
Final, First, Last, Max, Min, Partial, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.{BuildRight,
NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
@@ -464,6 +464,46 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
} else {
None
}
+ case variance @ VarianceSamp(child, nullOnDivideByZero) =>
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(variance.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val varBuilder = ExprOuterClass.Variance.newBuilder()
+ varBuilder.setChild(childExpr.get)
+ varBuilder.setNullOnDivideByZero(nullOnDivideByZero)
+ varBuilder.setDatatype(dataType.get)
+ varBuilder.setStatsTypeValue(0)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setVariance(varBuilder)
+ .build())
+ } else {
+ withInfo(aggExpr, child)
+ None
+ }
+ case variancePop @ VariancePop(child, nullOnDivideByZero) =>
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(variancePop.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val varBuilder = ExprOuterClass.Variance.newBuilder()
+ varBuilder.setChild(childExpr.get)
+ varBuilder.setNullOnDivideByZero(nullOnDivideByZero)
+ varBuilder.setDatatype(dataType.get)
+ varBuilder.setStatsTypeValue(1)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setVariance(varBuilder)
+ .build())
+ } else {
+ withInfo(aggExpr, child)
+ None
+ }
case fn =>
val msg = s"unsupported Spark aggregate function: ${fn.prettyName}"
emitWarning(msg)
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index f6415cb..bd4042e 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1117,6 +1117,46 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("var_pop and var_samp") {
+ withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
+ Seq(true, false).foreach { cometColumnShuffleEnabled =>
+ withSQLConf(
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key ->
cometColumnShuffleEnabled.toString) {
+ Seq(true, false).foreach { dictionary =>
+ withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+ Seq(true, false).foreach { nullOnDivideByZero =>
+ withSQLConf(
+ "spark.sql.legacy.statisticalAggregate" ->
nullOnDivideByZero.toString) {
+ val table = "test"
+ withTable(table) {
+ sql(s"create table $table(col1 int, col2 int, col3 int,
col4 float, col5 double, col6 int) using parquet")
+ sql(s"insert into $table values(1, null, null, 1.1, 2.2,
1)," +
+ " (2, null, null, 3.4, 5.6, 1), (3, null, 4, 7.9, 2.4,
2)")
+ val expectedNumOfCometAggregates = 2
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ "SELECT var_samp(col1), var_samp(col2), var_samp(col3),
var_samp(col4), var_samp(col5) FROM test",
+ expectedNumOfCometAggregates)
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ "SELECT var_pop(col1), var_pop(col2), var_pop(col3),
var_pop(col4), var_samp(col5) FROM test",
+ expectedNumOfCometAggregates)
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT var_samp(col1), var_samp(col2), var_samp(col3),
var_samp(col4), var_samp(col5)" +
+ " FROM test GROUP BY col6",
+ expectedNumOfCometAggregates)
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT var_pop(col1), var_pop(col2), var_pop(col3),
var_pop(col4), var_samp(col5)" +
+ " FROM test GROUP BY col6",
+ expectedNumOfCometAggregates)
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
protected def checkSparkAnswerAndNumOfAggregates(query: String,
numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)
@@ -1126,6 +1166,18 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
s"Expected $numAggregates Comet aggregate operators, but found
$actualNumAggregates")
}
+ protected def checkSparkAnswerWithTolAndNumOfAggregates(
+ query: String,
+ numAggregates: Int,
+ absTol: Double = 1e-6): Unit = {
+ val df = sql(query)
+ checkSparkAnswerWithTol(df, absTol)
+ val actualNumAggregates = getNumCometHashAggregate(df)
+ assert(
+ actualNumAggregates == numAggregates,
+ s"Expected $numAggregates Comet aggregate operators, but found
$actualNumAggregates")
+ }
+
def getNumCometHashAggregate(df: DataFrame): Int = {
val sparkPlan = stripAQEPlan(df.queryExecution.executedPlan)
sparkPlan.collect { case s: CometHashAggregateExec => s }.size
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]