This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 6d23e284 feat: correlation support (#456)
6d23e284 is described below
commit 6d23e2846c5ba0049bb7dc25955ab5d717953f9a
Author: Huaxin Gao <[email protected]>
AuthorDate: Thu May 23 08:20:27 2024 -0700
feat: correlation support (#456)
* feat: correlation support
* fmt
* remove un-used import
* address comments
* address comment
---------
Co-authored-by: Huaxin Gao <[email protected]>
---
.../datafusion/expressions/correlation.rs | 256 +++++++++++++++++++++
core/src/execution/datafusion/expressions/mod.rs | 1 +
core/src/execution/datafusion/planner.rs | 13 ++
core/src/execution/proto/expr.proto | 8 +
docs/source/user-guide/expressions.md | 1 +
.../org/apache/comet/serde/QueryPlanSerde.scala | 23 +-
.../apache/comet/exec/CometAggregateSuite.scala | 151 ++++++++++++
7 files changed, 452 insertions(+), 1 deletion(-)
diff --git a/core/src/execution/datafusion/expressions/correlation.rs
b/core/src/execution/datafusion/expressions/correlation.rs
new file mode 100644
index 00000000..c83341e5
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/correlation.rs
@@ -0,0 +1,256 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::compute::{and, filter, is_not_null};
+
+use std::{any::Any, sync::Arc};
+
+use crate::execution::datafusion::expressions::{
+ covariance::CovarianceAccumulator, stats::StatsType,
stddev::StddevAccumulator,
+ utils::down_cast_any_ref,
+};
+use arrow::{
+ array::ArrayRef,
+ datatypes::{DataType, Field},
+};
+use datafusion::logical_expr::Accumulator;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr,
PhysicalExpr};
+
+/// CORR aggregate expression
+/// The implementation mostly is the same as the DataFusion's implementation.
The reason
+/// we have our own implementation is that DataFusion has UInt64 for
state_field `count`,
+/// while Spark has Double for count. Also we have added
`null_on_divide_by_zero`
+/// to be consistent with Spark's implementation.
+#[derive(Debug)]
+pub struct Correlation {
+ name: String,
+ expr1: Arc<dyn PhysicalExpr>,
+ expr2: Arc<dyn PhysicalExpr>,
+ null_on_divide_by_zero: bool,
+}
+
+impl Correlation {
+ pub fn new(
+ expr1: Arc<dyn PhysicalExpr>,
+ expr2: Arc<dyn PhysicalExpr>,
+ name: impl Into<String>,
+ data_type: DataType,
+ null_on_divide_by_zero: bool,
+ ) -> Self {
+ // the result of correlation just support FLOAT64 data type.
+ assert!(matches!(data_type, DataType::Float64));
+ Self {
+ name: name.into(),
+ expr1,
+ expr2,
+ null_on_divide_by_zero,
+ }
+ }
+}
+
+impl AggregateExpr for Correlation {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(&self.name, DataType::Float64, true))
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(CorrelationAccumulator::try_new(
+ self.null_on_divide_by_zero,
+ )?))
+ }
+
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![
+ Field::new(
+ format_state_name(&self.name, "count"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "mean1"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "mean2"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "algo_const"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "m2_1"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "m2_2"),
+ DataType::Float64,
+ true,
+ ),
+ ])
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr1.clone(), self.expr2.clone()]
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+}
+
+impl PartialEq<dyn Any> for Correlation {
+ fn eq(&self, other: &dyn Any) -> bool {
+ down_cast_any_ref(other)
+ .downcast_ref::<Self>()
+ .map(|x| {
+ self.name == x.name
+ && self.expr1.eq(&x.expr1)
+ && self.expr2.eq(&x.expr2)
+ && self.null_on_divide_by_zero == x.null_on_divide_by_zero
+ })
+ .unwrap_or(false)
+ }
+}
+
+/// An accumulator to compute correlation
+#[derive(Debug)]
+pub struct CorrelationAccumulator {
+ covar: CovarianceAccumulator,
+ stddev1: StddevAccumulator,
+ stddev2: StddevAccumulator,
+ null_on_divide_by_zero: bool,
+}
+
+impl CorrelationAccumulator {
+ /// Creates a new `CorrelationAccumulator`
+ pub fn try_new(null_on_divide_by_zero: bool) -> Result<Self> {
+ Ok(Self {
+ covar: CovarianceAccumulator::try_new(StatsType::Population)?,
+ stddev1: StddevAccumulator::try_new(StatsType::Population,
null_on_divide_by_zero)?,
+ stddev2: StddevAccumulator::try_new(StatsType::Population,
null_on_divide_by_zero)?,
+ null_on_divide_by_zero,
+ })
+ }
+}
+
+impl Accumulator for CorrelationAccumulator {
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![
+ ScalarValue::from(self.covar.get_count()),
+ ScalarValue::from(self.covar.get_mean1()),
+ ScalarValue::from(self.covar.get_mean2()),
+ ScalarValue::from(self.covar.get_algo_const()),
+ ScalarValue::from(self.stddev1.get_m2()),
+ ScalarValue::from(self.stddev2.get_m2()),
+ ])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values = if values[0].null_count() != 0 || values[1].null_count()
!= 0 {
+ let mask = and(&is_not_null(&values[0])?,
&is_not_null(&values[1])?)?;
+ let values1 = filter(&values[0], &mask)?;
+ let values2 = filter(&values[1], &mask)?;
+
+ vec![values1, values2]
+ } else {
+ values.to_vec()
+ };
+
+ if !values[0].is_empty() && !values[1].is_empty() {
+ self.covar.update_batch(&values)?;
+ self.stddev1.update_batch(&values[0..1])?;
+ self.stddev2.update_batch(&values[1..2])?;
+ }
+
+ Ok(())
+ }
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values = if values[0].null_count() != 0 || values[1].null_count()
!= 0 {
+ let mask = and(&is_not_null(&values[0])?,
&is_not_null(&values[1])?)?;
+ let values1 = filter(&values[0], &mask)?;
+ let values2 = filter(&values[1], &mask)?;
+
+ vec![values1, values2]
+ } else {
+ values.to_vec()
+ };
+
+ self.covar.retract_batch(&values)?;
+ self.stddev1.retract_batch(&values[0..1])?;
+ self.stddev2.retract_batch(&values[1..2])?;
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let states_c = [
+ states[0].clone(),
+ states[1].clone(),
+ states[2].clone(),
+ states[3].clone(),
+ ];
+ let states_s1 = [states[0].clone(), states[1].clone(),
states[4].clone()];
+ let states_s2 = [states[0].clone(), states[2].clone(),
states[5].clone()];
+
+ if states[0].len() > 0 && states[1].len() > 0 && states[2].len() > 0 {
+ self.covar.merge_batch(&states_c)?;
+ self.stddev1.merge_batch(&states_s1)?;
+ self.stddev2.merge_batch(&states_s2)?;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ let covar = self.covar.evaluate()?;
+ let stddev1 = self.stddev1.evaluate()?;
+ let stddev2 = self.stddev2.evaluate()?;
+
+ match (covar, stddev1, stddev2) {
+ (
+ ScalarValue::Float64(Some(c)),
+ ScalarValue::Float64(Some(s1)),
+ ScalarValue::Float64(Some(s2)),
+ ) if s1 != 0.0 && s2 != 0.0 => Ok(ScalarValue::Float64(Some(c /
(s1 * s2)))),
+ _ if self.null_on_divide_by_zero => Ok(ScalarValue::Float64(None)),
+ _ => {
+ if self.covar.get_count() == 1.0 {
+ return Ok(ScalarValue::Float64(Some(f64::NAN)));
+ }
+ Ok(ScalarValue::Float64(None))
+ }
+ }
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) +
self.covar.size()
+ - std::mem::size_of_val(&self.stddev1)
+ + self.stddev1.size()
+ - std::mem::size_of_val(&self.stddev2)
+ + self.stddev2.size()
+ }
+}
diff --git a/core/src/execution/datafusion/expressions/mod.rs
b/core/src/execution/datafusion/expressions/mod.rs
index 10cac169..9db4b65b 100644
--- a/core/src/execution/datafusion/expressions/mod.rs
+++ b/core/src/execution/datafusion/expressions/mod.rs
@@ -27,6 +27,7 @@ pub use normalize_nan::NormalizeNaNAndZero;
pub mod avg;
pub mod avg_decimal;
pub mod bloom_filter_might_contain;
+pub mod correlation;
pub mod covariance;
pub mod stats;
pub mod stddev;
diff --git a/core/src/execution/datafusion/planner.rs
b/core/src/execution/datafusion/planner.rs
index 59818857..01d89238 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -67,6 +67,7 @@ use crate::{
bloom_filter_might_contain::BloomFilterMightContain,
cast::{Cast, EvalMode},
checkoverflow::CheckOverflow,
+ correlation::Correlation,
covariance::Covariance,
if_expr::IfExpr,
scalar_funcs::create_comet_physical_fun,
@@ -1310,6 +1311,18 @@ impl PhysicalPlanner {
))),
}
}
+ AggExprStruct::Correlation(expr) => {
+ let child1 = self.create_expr(expr.child1.as_ref().unwrap(),
schema.clone())?;
+ let child2 = self.create_expr(expr.child2.as_ref().unwrap(),
schema.clone())?;
+ let datatype =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+ Ok(Arc::new(Correlation::new(
+ child1,
+ child2,
+ "correlation",
+ datatype,
+ expr.null_on_divide_by_zero,
+ )))
+ }
}
}
diff --git a/core/src/execution/proto/expr.proto
b/core/src/execution/proto/expr.proto
index ee3de865..be85e8a9 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -96,6 +96,7 @@ message AggExpr {
CovPopulation covPopulation = 13;
Variance variance = 14;
Stddev stddev = 15;
+ Correlation correlation = 16;
}
}
@@ -186,6 +187,13 @@ message Stddev {
StatisticsType stats_type = 4;
}
+message Correlation {
+ Expr child1 = 1;
+ Expr child2 = 2;
+ bool null_on_divide_by_zero = 3;
+ DataType datatype = 4;
+}
+
message Literal {
oneof value {
bool bool_val = 1;
diff --git a/docs/source/user-guide/expressions.md
b/docs/source/user-guide/expressions.md
index 38c86c72..521699d3 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -109,3 +109,4 @@ The following Spark expressions are currently available:
- VarianceSamp
- StddevPop
- StddevSamp
+ - Corr
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index cf7c86a9..6333650d 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample,
Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop,
VarianceSamp}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, Corr, Count, CovPopulation, CovSample,
Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop,
VarianceSamp}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.{BuildRight,
NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
@@ -547,6 +547,27 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
withInfo(aggExpr, child)
None
}
+ case corr @ Corr(child1, child2, nullOnDivideByZero) =>
+ val child1Expr = exprToProto(child1, inputs, binding)
+ val child2Expr = exprToProto(child2, inputs, binding)
+ val dataType = serializeDataType(corr.dataType)
+
+ if (child1Expr.isDefined && child2Expr.isDefined &&
dataType.isDefined) {
+ val corrBuilder = ExprOuterClass.Correlation.newBuilder()
+ corrBuilder.setChild1(child1Expr.get)
+ corrBuilder.setChild2(child2Expr.get)
+ corrBuilder.setNullOnDivideByZero(nullOnDivideByZero)
+ corrBuilder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setCorrelation(corrBuilder)
+ .build())
+ } else {
+ withInfo(aggExpr, child1, child2)
+ None
+ }
case fn =>
val msg = s"unsupported Spark aggregate function: ${fn.prettyName}"
emitWarning(msg)
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 310a24ee..d36534ee 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1212,6 +1212,157 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("correlation") {
+ withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
+ Seq(true, false).foreach { cometColumnShuffleEnabled =>
+ withSQLConf(
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key ->
cometColumnShuffleEnabled.toString) {
+ Seq(true, false).foreach { dictionary =>
+ withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+ Seq(true, false).foreach { nullOnDivideByZero =>
+ withSQLConf(
+ "spark.sql.legacy.statisticalAggregate" ->
nullOnDivideByZero.toString) {
+ val table = "test"
+ withTable(table) {
+ sql(
+ s"create table $table(col1 double, col2 double, col3
double) using parquet")
+ sql(s"insert into $table values(1, 4, 1), (2, 5, 1), (3,
6, 2)")
+ val expectedNumOfCometAggregates = 2
+
+ checkSparkAnswerAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table",
+ expectedNumOfCometAggregates)
+
+ checkSparkAnswerAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+ expectedNumOfCometAggregates)
+ }
+
+ withTable(table) {
+ sql(
+ s"create table $table(col1 double, col2 double, col3
double) using parquet")
+ sql(s"insert into $table values(1, 4, 3), (2, -5, 3), (3,
6, 1)")
+ val expectedNumOfCometAggregates = 2
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table",
+ expectedNumOfCometAggregates)
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+ expectedNumOfCometAggregates)
+ }
+
+ withTable(table) {
+ sql(
+ s"create table $table(col1 double, col2 double, col3
double) using parquet")
+ sql(s"insert into $table values(1.1, 4.1, 2.3), (2, 5,
1.5), (3, 6, 2.3)")
+ val expectedNumOfCometAggregates = 2
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table",
+ expectedNumOfCometAggregates)
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+ expectedNumOfCometAggregates)
+ }
+
+ withTable(table) {
+ sql(
+ s"create table $table(col1 double, col2 double, col3
double) using parquet")
+ sql(s"insert into $table values(1, 4, 1), (2, 5, 2), (3,
6, 3), (1.1, 4.4, 1), (2.2, 5.5, 2), (3.3, 6.6, 3)")
+ val expectedNumOfCometAggregates = 2
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table",
+ expectedNumOfCometAggregates)
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+ expectedNumOfCometAggregates)
+ }
+
+ withTable(table) {
+ sql(s"create table $table(col1 int, col2 int, col3 int)
using parquet")
+ sql(s"insert into $table values(1, 4, 1), (2, 5, 2), (3,
6, 3)")
+ val expectedNumOfCometAggregates = 2
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table",
+ expectedNumOfCometAggregates)
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+ expectedNumOfCometAggregates)
+ }
+
+ withTable(table) {
+ sql(s"create table $table(col1 int, col2 int, col3 int)
using parquet")
+ sql(
+ s"insert into $table values(1, 4, 2), (null, null, 2),
(3, 6, 1), (3, 3, 1)")
+ val expectedNumOfCometAggregates = 2
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table",
+ expectedNumOfCometAggregates)
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+ expectedNumOfCometAggregates)
+ }
+
+ withTable(table) {
+ sql(s"create table $table(col1 int, col2 int, col3 int)
using parquet")
+ sql(s"insert into $table values(1, 4, 1), (null, 5, 1),
(2, 5, 2), (9, null, 2), (3, 6, 2)")
+ val expectedNumOfCometAggregates = 2
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table",
+ expectedNumOfCometAggregates)
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+ expectedNumOfCometAggregates)
+ }
+
+ withTable(table) {
+ sql(s"create table $table(col1 int, col2 int, col3 int)
using parquet")
+ sql(s"insert into $table values(null, null, 1), (1, 2, 1),
(null, null, 2)")
+ val expectedNumOfCometAggregates = 2
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table",
+ expectedNumOfCometAggregates)
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+ expectedNumOfCometAggregates)
+ }
+
+ withTable(table) {
+ sql(s"create table $table(col1 int, col2 int, col3 int)
using parquet")
+ sql(
+ s"insert into $table values(null, null, 1), (null, null,
1), (null, null, 2)")
+ val expectedNumOfCometAggregates = 2
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table",
+ expectedNumOfCometAggregates)
+
+ checkSparkAnswerWithTolAndNumOfAggregates(
+ s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+ expectedNumOfCometAggregates)
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
protected def checkSparkAnswerAndNumOfAggregates(query: String,
numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]