Repository: spark
Updated Branches:
  refs/heads/master d6fd9b376 -> 63eee86cc


[SPARK-9297] [SQL] Add covar_pop and covar_samp

JIRA: https://issues.apache.org/jira/browse/SPARK-9297

Add two aggregation functions: covar_pop and covar_samp.

Author: Liang-Chi Hsieh <vii...@gmail.com>
Author: Liang-Chi Hsieh <vii...@appier.com>

Closes #10029 from viirya/covar-funcs.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/63eee86c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/63eee86c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/63eee86c

Branch: refs/heads/master
Commit: 63eee86cc652c108ca7712c8c0a73db1ca89ae90
Parents: d6fd9b3
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Wed Jan 13 10:26:55 2016 -0800
Committer: Davies Liu <davies....@gmail.com>
Committed: Wed Jan 13 10:26:55 2016 -0800

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   2 +
 .../expressions/aggregate/Covariance.scala      | 198 +++++++++++++++++++
 .../scala/org/apache/spark/sql/functions.scala  |  40 ++++
 .../hive/execution/AggregationQuerySuite.scala  |  32 +++
 4 files changed, 272 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/63eee86c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 5c2aa3c..d9009e3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -182,6 +182,8 @@ object FunctionRegistry {
     expression[Average]("avg"),
     expression[Corr]("corr"),
     expression[Count]("count"),
+    expression[CovPopulation]("covar_pop"),
+    expression[CovSample]("covar_samp"),
     expression[First]("first"),
     expression[First]("first_value"),
     expression[Last]("last"),

http://git-wip-us.apache.org/repos/asf/spark/blob/63eee86c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
new file mode 100644
index 0000000..f53b01b
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.types._
+
+/**
+ * Compute the covariance between two expressions.
+ * When applied on empty data (i.e., count is zero), it returns NULL.
+ *
+ */
+abstract class Covariance(left: Expression, right: Expression) extends 
ImperativeAggregate
+    with Serializable {
+  override def children: Seq[Expression] = Seq(left, right)
+
+  override def nullable: Boolean = true
+
+  override def dataType: DataType = DoubleType
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (left.dataType.isInstanceOf[DoubleType] && 
right.dataType.isInstanceOf[DoubleType]) {
+      TypeCheckResult.TypeCheckSuccess
+    } else {
+      TypeCheckResult.TypeCheckFailure(
+        s"covariance requires that both arguments are double type, " +
+          s"not (${left.dataType}, ${right.dataType}).")
+    }
+  }
+
+  override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
+
+  override def inputAggBufferAttributes: Seq[AttributeReference] = {
+    aggBufferAttributes.map(_.newInstance())
+  }
+
+  override val aggBufferAttributes: Seq[AttributeReference] = Seq(
+    AttributeReference("xAvg", DoubleType)(),
+    AttributeReference("yAvg", DoubleType)(),
+    AttributeReference("Ck", DoubleType)(),
+    AttributeReference("count", LongType)())
+
+  // Local cache of mutableAggBufferOffset(s) that will be used in update and 
merge
+  val xAvgOffset = mutableAggBufferOffset
+  val yAvgOffset = mutableAggBufferOffset + 1
+  val CkOffset = mutableAggBufferOffset + 2
+  val countOffset = mutableAggBufferOffset + 3
+
+  // Local cache of inputAggBufferOffset(s) that will be used in update and 
merge
+  val inputXAvgOffset = inputAggBufferOffset
+  val inputYAvgOffset = inputAggBufferOffset + 1
+  val inputCkOffset = inputAggBufferOffset + 2
+  val inputCountOffset = inputAggBufferOffset + 3
+
+  override def initialize(buffer: MutableRow): Unit = {
+    buffer.setDouble(xAvgOffset, 0.0)
+    buffer.setDouble(yAvgOffset, 0.0)
+    buffer.setDouble(CkOffset, 0.0)
+    buffer.setLong(countOffset, 0L)
+  }
+
+  override def update(buffer: MutableRow, input: InternalRow): Unit = {
+    val leftEval = left.eval(input)
+    val rightEval = right.eval(input)
+
+    if (leftEval != null && rightEval != null) {
+      val x = leftEval.asInstanceOf[Double]
+      val y = rightEval.asInstanceOf[Double]
+
+      var xAvg = buffer.getDouble(xAvgOffset)
+      var yAvg = buffer.getDouble(yAvgOffset)
+      var Ck = buffer.getDouble(CkOffset)
+      var count = buffer.getLong(countOffset)
+
+      val deltaX = x - xAvg
+      val deltaY = y - yAvg
+      count += 1
+      xAvg += deltaX / count
+      yAvg += deltaY / count
+      Ck += deltaX * (y - yAvg)
+
+      buffer.setDouble(xAvgOffset, xAvg)
+      buffer.setDouble(yAvgOffset, yAvg)
+      buffer.setDouble(CkOffset, Ck)
+      buffer.setLong(countOffset, count)
+    }
+  }
+
+  // Merge counters from other partitions. Formula can be found at:
+  // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+    val count2 = buffer2.getLong(inputCountOffset)
+
+    // We only go to merge two buffers if there is at least one record 
aggregated in buffer2.
+    // We don't need to check count in buffer1 because if count2 is more than 
zero, totalCount
+    // is more than zero too, then we won't get a divide by zero exception.
+    if (count2 > 0) {
+      var xAvg = buffer1.getDouble(xAvgOffset)
+      var yAvg = buffer1.getDouble(yAvgOffset)
+      var Ck = buffer1.getDouble(CkOffset)
+      var count = buffer1.getLong(countOffset)
+
+      val xAvg2 = buffer2.getDouble(inputXAvgOffset)
+      val yAvg2 = buffer2.getDouble(inputYAvgOffset)
+      val Ck2 = buffer2.getDouble(inputCkOffset)
+
+      val totalCount = count + count2
+      val deltaX = xAvg - xAvg2
+      val deltaY = yAvg - yAvg2
+      Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
+      xAvg = (xAvg * count + xAvg2 * count2) / totalCount
+      yAvg = (yAvg * count + yAvg2 * count2) / totalCount
+      count = totalCount
+
+      buffer1.setDouble(xAvgOffset, xAvg)
+      buffer1.setDouble(yAvgOffset, yAvg)
+      buffer1.setDouble(CkOffset, Ck)
+      buffer1.setLong(countOffset, count)
+    }
+  }
+}
+
+case class CovSample(
+    left: Expression,
+    right: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends Covariance(left, right) {
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def eval(buffer: InternalRow): Any = {
+    val count = buffer.getLong(countOffset)
+    if (count > 1) {
+      val Ck = buffer.getDouble(CkOffset)
+      val cov = Ck / (count - 1)
+      if (cov.isNaN) {
+        null
+      } else {
+        cov
+      }
+    } else {
+      null
+    }
+  }
+}
+
+case class CovPopulation(
+    left: Expression,
+    right: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends Covariance(left, right) {
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def eval(buffer: InternalRow): Any = {
+    val count = buffer.getLong(countOffset)
+    if (count > 0) {
+      val Ck = buffer.getDouble(CkOffset)
+      if (Ck.isNaN) {
+        null
+      } else {
+        Ck / count
+      }
+    } else {
+      null
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/63eee86c/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 592d79d..71fea27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -309,6 +309,46 @@ object functions extends LegacyFunctions {
     countDistinct(Column(columnName), columnNames.map(Column.apply) : _*)
 
   /**
+   * Aggregate function: returns the population covariance for two columns.
+   *
+   * @group agg_funcs
+   * @since 2.0.0
+   */
+  def covar_pop(column1: Column, column2: Column): Column = 
withAggregateFunction {
+    CovPopulation(column1.expr, column2.expr)
+  }
+
+  /**
+   * Aggregate function: returns the population covariance for two columns.
+   *
+   * @group agg_funcs
+   * @since 2.0.0
+   */
+  def covar_pop(columnName1: String, columnName2: String): Column = {
+    covar_pop(Column(columnName1), Column(columnName2))
+  }
+
+  /**
+   * Aggregate function: returns the sample covariance for two columns.
+   *
+   * @group agg_funcs
+   * @since 2.0.0
+   */
+  def covar_samp(column1: Column, column2: Column): Column = 
withAggregateFunction {
+    CovSample(column1.expr, column2.expr)
+  }
+
+  /**
+   * Aggregate function: returns the sample covariance for two columns.
+   *
+   * @group agg_funcs
+   * @since 2.0.0
+   */
+  def covar_samp(columnName1: String, columnName2: String): Column = {
+    covar_samp(Column(columnName1), Column(columnName2))
+  }
+
+  /**
    * Aggregate function: returns the first value in a group.
    *
    * @group agg_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/63eee86c/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 5550198..76b36aa 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -807,6 +807,38 @@ abstract class AggregationQuerySuite extends QueryTest 
with SQLTestUtils with Te
     assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
   }
 
+  test("covariance: covar_pop and covar_samp") {
+    // non-trivial example. To reproduce in python, use:
+    // >>> import numpy as np
+    // >>> a = np.array(range(20))
+    // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
+    // >>> np.cov(a, b, bias = 0)[0][1]
+    // 595.0
+    // >>> np.cov(a, b, bias = 1)[0][1]
+    // 565.25
+    val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", 
"b")
+    val cov_samp = df.groupBy().agg(covar_samp("a", 
"b")).collect()(0).getDouble(0)
+    assert(math.abs(cov_samp - 595.0) < 1e-12)
+
+    val cov_pop = df.groupBy().agg(covar_pop("a", 
"b")).collect()(0).getDouble(0)
+    assert(math.abs(cov_pop - 565.25) < 1e-12)
+
+    val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
+    val cov_samp2 = df2.groupBy().agg(covar_samp("a", 
"b")).collect()(0).getDouble(0)
+    assert(math.abs(cov_samp2 - 11564.0) < 1e-12)
+
+    val cov_pop2 = df2.groupBy().agg(covar_pop("a", 
"b")).collect()(0).getDouble(0)
+    assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12)
+
+    // one row test
+    val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b")
+    val cov_samp3 = df3.groupBy().agg(covar_samp("a", "b")).collect()(0).get(0)
+    assert(cov_samp3 == null)
+
+    val cov_pop3 = df3.groupBy().agg(covar_pop("a", 
"b")).collect()(0).getDouble(0)
+    assert(cov_pop3 == 0.0)
+  }
+
   test("no aggregation function (SPARK-11486)") {
     val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s")
       .groupBy("s").count()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to