This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d23e284 feat: correlation support (#456)
6d23e284 is described below

commit 6d23e2846c5ba0049bb7dc25955ab5d717953f9a
Author: Huaxin Gao <[email protected]>
AuthorDate: Thu May 23 08:20:27 2024 -0700

    feat: correlation support (#456)
    
    * feat: correlation support
    
    * fmt
    
    * remove un-used import
    
    * address comments
    
    * address comment
    
    ---------
    
    Co-authored-by: Huaxin Gao <[email protected]>
---
 .../datafusion/expressions/correlation.rs          | 256 +++++++++++++++++++++
 core/src/execution/datafusion/expressions/mod.rs   |   1 +
 core/src/execution/datafusion/planner.rs           |  13 ++
 core/src/execution/proto/expr.proto                |   8 +
 docs/source/user-guide/expressions.md              |   1 +
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  23 +-
 .../apache/comet/exec/CometAggregateSuite.scala    | 151 ++++++++++++
 7 files changed, 452 insertions(+), 1 deletion(-)

diff --git a/core/src/execution/datafusion/expressions/correlation.rs 
b/core/src/execution/datafusion/expressions/correlation.rs
new file mode 100644
index 00000000..c83341e5
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/correlation.rs
@@ -0,0 +1,256 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::compute::{and, filter, is_not_null};
+
+use std::{any::Any, sync::Arc};
+
+use crate::execution::datafusion::expressions::{
+    covariance::CovarianceAccumulator, stats::StatsType, 
stddev::StddevAccumulator,
+    utils::down_cast_any_ref,
+};
+use arrow::{
+    array::ArrayRef,
+    datatypes::{DataType, Field},
+};
+use datafusion::logical_expr::Accumulator;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, 
PhysicalExpr};
+
+/// CORR aggregate expression
+/// The implementation mostly is the same as the DataFusion's implementation. 
The reason
+/// we have our own implementation is that DataFusion has UInt64 for 
state_field `count`,
+/// while Spark has Double for count. Also we have added 
`null_on_divide_by_zero`
+/// to be consistent with Spark's implementation.
+#[derive(Debug)]
+pub struct Correlation {
+    name: String,
+    expr1: Arc<dyn PhysicalExpr>,
+    expr2: Arc<dyn PhysicalExpr>,
+    null_on_divide_by_zero: bool,
+}
+
+impl Correlation {
+    pub fn new(
+        expr1: Arc<dyn PhysicalExpr>,
+        expr2: Arc<dyn PhysicalExpr>,
+        name: impl Into<String>,
+        data_type: DataType,
+        null_on_divide_by_zero: bool,
+    ) -> Self {
+        // the result of correlation just support FLOAT64 data type.
+        assert!(matches!(data_type, DataType::Float64));
+        Self {
+            name: name.into(),
+            expr1,
+            expr2,
+            null_on_divide_by_zero,
+        }
+    }
+}
+
+impl AggregateExpr for Correlation {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn field(&self) -> Result<Field> {
+        Ok(Field::new(&self.name, DataType::Float64, true))
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(CorrelationAccumulator::try_new(
+            self.null_on_divide_by_zero,
+        )?))
+    }
+
+    fn state_fields(&self) -> Result<Vec<Field>> {
+        Ok(vec![
+            Field::new(
+                format_state_name(&self.name, "count"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "mean1"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "mean2"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "algo_const"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "m2_1"),
+                DataType::Float64,
+                true,
+            ),
+            Field::new(
+                format_state_name(&self.name, "m2_2"),
+                DataType::Float64,
+                true,
+            ),
+        ])
+    }
+
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.expr1.clone(), self.expr2.clone()]
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+}
+
+impl PartialEq<dyn Any> for Correlation {
+    fn eq(&self, other: &dyn Any) -> bool {
+        down_cast_any_ref(other)
+            .downcast_ref::<Self>()
+            .map(|x| {
+                self.name == x.name
+                    && self.expr1.eq(&x.expr1)
+                    && self.expr2.eq(&x.expr2)
+                    && self.null_on_divide_by_zero == x.null_on_divide_by_zero
+            })
+            .unwrap_or(false)
+    }
+}
+
+/// An accumulator to compute correlation
+#[derive(Debug)]
+pub struct CorrelationAccumulator {
+    covar: CovarianceAccumulator,
+    stddev1: StddevAccumulator,
+    stddev2: StddevAccumulator,
+    null_on_divide_by_zero: bool,
+}
+
+impl CorrelationAccumulator {
+    /// Creates a new `CorrelationAccumulator`
+    pub fn try_new(null_on_divide_by_zero: bool) -> Result<Self> {
+        Ok(Self {
+            covar: CovarianceAccumulator::try_new(StatsType::Population)?,
+            stddev1: StddevAccumulator::try_new(StatsType::Population, 
null_on_divide_by_zero)?,
+            stddev2: StddevAccumulator::try_new(StatsType::Population, 
null_on_divide_by_zero)?,
+            null_on_divide_by_zero,
+        })
+    }
+}
+
+impl Accumulator for CorrelationAccumulator {
+    fn state(&mut self) -> Result<Vec<ScalarValue>> {
+        Ok(vec![
+            ScalarValue::from(self.covar.get_count()),
+            ScalarValue::from(self.covar.get_mean1()),
+            ScalarValue::from(self.covar.get_mean2()),
+            ScalarValue::from(self.covar.get_algo_const()),
+            ScalarValue::from(self.stddev1.get_m2()),
+            ScalarValue::from(self.stddev2.get_m2()),
+        ])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let values = if values[0].null_count() != 0 || values[1].null_count() 
!= 0 {
+            let mask = and(&is_not_null(&values[0])?, 
&is_not_null(&values[1])?)?;
+            let values1 = filter(&values[0], &mask)?;
+            let values2 = filter(&values[1], &mask)?;
+
+            vec![values1, values2]
+        } else {
+            values.to_vec()
+        };
+
+        if !values[0].is_empty() && !values[1].is_empty() {
+            self.covar.update_batch(&values)?;
+            self.stddev1.update_batch(&values[0..1])?;
+            self.stddev2.update_batch(&values[1..2])?;
+        }
+
+        Ok(())
+    }
+
+    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let values = if values[0].null_count() != 0 || values[1].null_count() 
!= 0 {
+            let mask = and(&is_not_null(&values[0])?, 
&is_not_null(&values[1])?)?;
+            let values1 = filter(&values[0], &mask)?;
+            let values2 = filter(&values[1], &mask)?;
+
+            vec![values1, values2]
+        } else {
+            values.to_vec()
+        };
+
+        self.covar.retract_batch(&values)?;
+        self.stddev1.retract_batch(&values[0..1])?;
+        self.stddev2.retract_batch(&values[1..2])?;
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        let states_c = [
+            states[0].clone(),
+            states[1].clone(),
+            states[2].clone(),
+            states[3].clone(),
+        ];
+        let states_s1 = [states[0].clone(), states[1].clone(), 
states[4].clone()];
+        let states_s2 = [states[0].clone(), states[2].clone(), 
states[5].clone()];
+
+        if states[0].len() > 0 && states[1].len() > 0 && states[2].len() > 0 {
+            self.covar.merge_batch(&states_c)?;
+            self.stddev1.merge_batch(&states_s1)?;
+            self.stddev2.merge_batch(&states_s2)?;
+        }
+        Ok(())
+    }
+
+    fn evaluate(&mut self) -> Result<ScalarValue> {
+        let covar = self.covar.evaluate()?;
+        let stddev1 = self.stddev1.evaluate()?;
+        let stddev2 = self.stddev2.evaluate()?;
+
+        match (covar, stddev1, stddev2) {
+            (
+                ScalarValue::Float64(Some(c)),
+                ScalarValue::Float64(Some(s1)),
+                ScalarValue::Float64(Some(s2)),
+            ) if s1 != 0.0 && s2 != 0.0 => Ok(ScalarValue::Float64(Some(c / 
(s1 * s2)))),
+            _ if self.null_on_divide_by_zero => Ok(ScalarValue::Float64(None)),
+            _ => {
+                if self.covar.get_count() == 1.0 {
+                    return Ok(ScalarValue::Float64(Some(f64::NAN)));
+                }
+                Ok(ScalarValue::Float64(None))
+            }
+        }
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) + 
self.covar.size()
+            - std::mem::size_of_val(&self.stddev1)
+            + self.stddev1.size()
+            - std::mem::size_of_val(&self.stddev2)
+            + self.stddev2.size()
+    }
+}
diff --git a/core/src/execution/datafusion/expressions/mod.rs 
b/core/src/execution/datafusion/expressions/mod.rs
index 10cac169..9db4b65b 100644
--- a/core/src/execution/datafusion/expressions/mod.rs
+++ b/core/src/execution/datafusion/expressions/mod.rs
@@ -27,6 +27,7 @@ pub use normalize_nan::NormalizeNaNAndZero;
 pub mod avg;
 pub mod avg_decimal;
 pub mod bloom_filter_might_contain;
+pub mod correlation;
 pub mod covariance;
 pub mod stats;
 pub mod stddev;
diff --git a/core/src/execution/datafusion/planner.rs 
b/core/src/execution/datafusion/planner.rs
index 59818857..01d89238 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -67,6 +67,7 @@ use crate::{
                 bloom_filter_might_contain::BloomFilterMightContain,
                 cast::{Cast, EvalMode},
                 checkoverflow::CheckOverflow,
+                correlation::Correlation,
                 covariance::Covariance,
                 if_expr::IfExpr,
                 scalar_funcs::create_comet_physical_fun,
@@ -1310,6 +1311,18 @@ impl PhysicalPlanner {
                     ))),
                 }
             }
+            AggExprStruct::Correlation(expr) => {
+                let child1 = self.create_expr(expr.child1.as_ref().unwrap(), 
schema.clone())?;
+                let child2 = self.create_expr(expr.child2.as_ref().unwrap(), 
schema.clone())?;
+                let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+                Ok(Arc::new(Correlation::new(
+                    child1,
+                    child2,
+                    "correlation",
+                    datatype,
+                    expr.null_on_divide_by_zero,
+                )))
+            }
         }
     }
 
diff --git a/core/src/execution/proto/expr.proto 
b/core/src/execution/proto/expr.proto
index ee3de865..be85e8a9 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -96,6 +96,7 @@ message AggExpr {
     CovPopulation covPopulation = 13;
     Variance variance = 14;
     Stddev stddev = 15;
+    Correlation correlation = 16;
   }
 }
 
@@ -186,6 +187,13 @@ message Stddev {
   StatisticsType stats_type = 4;
 }
 
+message Correlation {
+  Expr child1 = 1;
+  Expr child2 = 2;
+  bool null_on_divide_by_zero = 3;
+  DataType datatype = 4;
+}
+
 message Literal {
   oneof value {
     bool bool_val = 1;
diff --git a/docs/source/user-guide/expressions.md 
b/docs/source/user-guide/expressions.md
index 38c86c72..521699d3 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -109,3 +109,4 @@ The following Spark expressions are currently available:
   - VarianceSamp
   - StddevPop
   - StddevSamp
+  - Corr
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index cf7c86a9..6333650d 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, 
Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, 
VarianceSamp}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, Corr, Count, CovPopulation, CovSample, 
Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, 
VarianceSamp}
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.optimizer.{BuildRight, 
NormalizeNaNAndZero}
 import org.apache.spark.sql.catalyst.plans._
@@ -547,6 +547,27 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           withInfo(aggExpr, child)
           None
         }
+      case corr @ Corr(child1, child2, nullOnDivideByZero) =>
+        val child1Expr = exprToProto(child1, inputs, binding)
+        val child2Expr = exprToProto(child2, inputs, binding)
+        val dataType = serializeDataType(corr.dataType)
+
+        if (child1Expr.isDefined && child2Expr.isDefined && 
dataType.isDefined) {
+          val corrBuilder = ExprOuterClass.Correlation.newBuilder()
+          corrBuilder.setChild1(child1Expr.get)
+          corrBuilder.setChild2(child2Expr.get)
+          corrBuilder.setNullOnDivideByZero(nullOnDivideByZero)
+          corrBuilder.setDatatype(dataType.get)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setCorrelation(corrBuilder)
+              .build())
+        } else {
+          withInfo(aggExpr, child1, child2)
+          None
+        }
       case fn =>
         val msg = s"unsupported Spark aggregate function: ${fn.prettyName}"
         emitWarning(msg)
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 310a24ee..d36534ee 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1212,6 +1212,157 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("correlation") {
+    withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
+      Seq(true, false).foreach { cometColumnShuffleEnabled =>
+        withSQLConf(
+          CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> 
cometColumnShuffleEnabled.toString) {
+          Seq(true, false).foreach { dictionary =>
+            withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+              Seq(true, false).foreach { nullOnDivideByZero =>
+                withSQLConf(
+                  "spark.sql.legacy.statisticalAggregate" -> 
nullOnDivideByZero.toString) {
+                  val table = "test"
+                  withTable(table) {
+                    sql(
+                      s"create table $table(col1 double, col2 double, col3 
double) using parquet")
+                    sql(s"insert into $table values(1, 4, 1), (2, 5, 1), (3, 
6, 2)")
+                    val expectedNumOfCometAggregates = 2
+
+                    checkSparkAnswerAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table",
+                      expectedNumOfCometAggregates)
+
+                    checkSparkAnswerAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+                      expectedNumOfCometAggregates)
+                  }
+
+                  withTable(table) {
+                    sql(
+                      s"create table $table(col1 double, col2 double, col3 
double) using parquet")
+                    sql(s"insert into $table values(1, 4, 3), (2, -5, 3), (3, 
6, 1)")
+                    val expectedNumOfCometAggregates = 2
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table",
+                      expectedNumOfCometAggregates)
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+                      expectedNumOfCometAggregates)
+                  }
+
+                  withTable(table) {
+                    sql(
+                      s"create table $table(col1 double, col2 double, col3 
double) using parquet")
+                    sql(s"insert into $table values(1.1, 4.1, 2.3), (2, 5, 
1.5), (3, 6, 2.3)")
+                    val expectedNumOfCometAggregates = 2
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table",
+                      expectedNumOfCometAggregates)
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+                      expectedNumOfCometAggregates)
+                  }
+
+                  withTable(table) {
+                    sql(
+                      s"create table $table(col1 double, col2 double, col3 
double) using parquet")
+                    sql(s"insert into $table values(1, 4, 1), (2, 5, 2), (3, 
6, 3), (1.1, 4.4, 1), (2.2, 5.5, 2), (3.3, 6.6, 3)")
+                    val expectedNumOfCometAggregates = 2
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table",
+                      expectedNumOfCometAggregates)
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+                      expectedNumOfCometAggregates)
+                  }
+
+                  withTable(table) {
+                    sql(s"create table $table(col1 int, col2 int, col3 int) 
using parquet")
+                    sql(s"insert into $table values(1, 4, 1), (2, 5, 2), (3, 
6, 3)")
+                    val expectedNumOfCometAggregates = 2
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table",
+                      expectedNumOfCometAggregates)
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+                      expectedNumOfCometAggregates)
+                  }
+
+                  withTable(table) {
+                    sql(s"create table $table(col1 int, col2 int, col3 int) 
using parquet")
+                    sql(
+                      s"insert into $table values(1, 4, 2), (null, null, 2), 
(3, 6, 1), (3, 3, 1)")
+                    val expectedNumOfCometAggregates = 2
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table",
+                      expectedNumOfCometAggregates)
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+                      expectedNumOfCometAggregates)
+                  }
+
+                  withTable(table) {
+                    sql(s"create table $table(col1 int, col2 int, col3 int) 
using parquet")
+                    sql(s"insert into $table values(1, 4, 1), (null, 5, 1), 
(2, 5, 2), (9, null, 2), (3, 6, 2)")
+                    val expectedNumOfCometAggregates = 2
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table",
+                      expectedNumOfCometAggregates)
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+                      expectedNumOfCometAggregates)
+                  }
+
+                  withTable(table) {
+                    sql(s"create table $table(col1 int, col2 int, col3 int) 
using parquet")
+                    sql(s"insert into $table values(null, null, 1), (1, 2, 1), 
(null, null, 2)")
+                    val expectedNumOfCometAggregates = 2
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table",
+                      expectedNumOfCometAggregates)
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+                      expectedNumOfCometAggregates)
+                  }
+
+                  withTable(table) {
+                    sql(s"create table $table(col1 int, col2 int, col3 int) 
using parquet")
+                    sql(
+                      s"insert into $table values(null, null, 1), (null, null, 
1), (null, null, 2)")
+                    val expectedNumOfCometAggregates = 2
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table",
+                      expectedNumOfCometAggregates)
+
+                    checkSparkAnswerWithTolAndNumOfAggregates(
+                      s"SELECT corr(col1, col2) FROM $table GROUP BY col3",
+                      expectedNumOfCometAggregates)
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
   protected def checkSparkAnswerAndNumOfAggregates(query: String, 
numAggregates: Int): Unit = {
     val df = sql(query)
     checkSparkAnswer(df)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to