Repository: spark
Updated Branches:
  refs/heads/master f5a9526fe -> d292f7483


[SPARK-11420] Updating Stddev support via Imperative Aggregate

switched stddev support from DeclarativeAggregate to ImperativeAggregate.

Author: JihongMa <linlin200...@gmail.com>

Closes #9380 from JihongMA/SPARK-11420.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d292f748
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d292f748
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d292f748

Branch: refs/heads/master
Commit: d292f74831de7e69c852ed26d9c15df85b4fb568
Parents: f5a9526
Author: JihongMa <linlin200...@gmail.com>
Authored: Thu Nov 12 13:47:34 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Thu Nov 12 13:47:34 2015 -0800

----------------------------------------------------------------------
 R/pkg/inst/tests/test_sparkSQL.R                |   4 +-
 python/pyspark/sql/dataframe.py                 |   2 +-
 .../catalyst/analysis/HiveTypeCoercion.scala    |   6 +-
 .../expressions/aggregate/Kurtosis.scala        |   4 +-
 .../expressions/aggregate/Skewness.scala        |   4 +-
 .../catalyst/expressions/aggregate/Stddev.scala | 128 +++++--------------
 .../scala/org/apache/spark/sql/functions.scala  |   2 +-
 .../spark/sql/DataFrameAggregateSuite.scala     |   4 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   |   2 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  11 +-
 10 files changed, 52 insertions(+), 115 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d292f748/R/pkg/inst/tests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 9e453a1..af024e6 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -1007,7 +1007,7 @@ test_that("group by, agg functions", {
   df3 <- agg(gd, age = "stddev")
   expect_is(df3, "DataFrame")
   df3_local <- collect(df3)
-  expect_equal(0, df3_local[df3_local$name == "Andy",][1, 2])
+  expect_true(is.nan(df3_local[df3_local$name == "Andy",][1, 2]))
 
   df4 <- agg(gd, sumAge = sum(df$age))
   expect_is(df4, "DataFrame")
@@ -1038,7 +1038,7 @@ test_that("group by, agg functions", {
   df7 <- agg(gd2, value = "stddev")
   df7_local <- collect(df7)
   expect_true(abs(df7_local[df7_local$name == "ID1",][1, 2] - 6.928203) < 1e-6)
-  expect_equal(0, df7_local[df7_local$name == "ID2",][1, 2])
+  expect_true(is.nan(df7_local[df7_local$name == "ID2",][1, 2]))
 
   mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}",
                   "{\"name\":\"Andy\", \"age\":30}",

http://git-wip-us.apache.org/repos/asf/spark/blob/d292f748/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 0dd75ba..ad6ad02 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -761,7 +761,7 @@ class DataFrame(object):
         +-------+------------------+-----+
         |  count|                 2|    2|
         |   mean|               3.5| null|
-        | stddev|2.1213203435596424| null|
+        | stddev|2.1213203435596424|  NaN|
         |    min|                 2|Alice|
         |    max|                 5|  Bob|
         +-------+------------------+-----+

http://git-wip-us.apache.org/repos/asf/spark/blob/d292f748/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index bf2bff0..92188ee 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -297,8 +297,10 @@ object HiveTypeCoercion {
 
       case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
       case Average(e @ StringType()) => Average(Cast(e, DoubleType))
-      case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
-      case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
+      case StddevPop(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
+        StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
+      case StddevSamp(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
+        StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
       case VariancePop(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
         VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
       case VarianceSamp(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/d292f748/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
index bae78d9..8fa3aac 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
@@ -42,9 +42,11 @@ case class Kurtosis(child: Expression,
       s"$prettyName requires ${momentOrder + 1} central moments, received: 
${moments.length}")
     val m2 = moments(2)
     val m4 = moments(4)
+
     if (n == 0.0 || m2 == 0.0) {
       Double.NaN
-    } else {
+    }
+    else {
       n * m4 / (m2 * m2) - 3.0
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d292f748/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
index c593074..e1c01a5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
@@ -41,9 +41,11 @@ case class Skewness(child: Expression,
       s"$prettyName requires ${momentOrder + 1} central moments, received: 
${moments.length}")
     val m2 = moments(2)
     val m3 = moments(3)
+
     if (n == 0.0 || m2 == 0.0) {
       Double.NaN
-    } else {
+    }
+    else {
       math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d292f748/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
index 2748009..05dd5e3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
@@ -17,117 +17,55 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.TypeUtils
-import org.apache.spark.sql.types._
 
+case class StddevSamp(child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends CentralMomentAgg(child) {
 
-// Compute the population standard deviation of a column
-case class StddevPop(child: Expression) extends StddevAgg(child) {
-  override def isSample: Boolean = false
-  override def prettyName: String = "stddev_pop"
-}
-
-
-// Compute the sample standard deviation of a column
-case class StddevSamp(child: Expression) extends StddevAgg(child) {
-  override def isSample: Boolean = true
-  override def prettyName: String = "stddev_samp"
-}
-
-
-// Compute standard deviation based on online algorithm specified here:
-// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
+  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, 
inputAggBufferOffset = 0)
 
-  def isSample: Boolean
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 
-  override def children: Seq[Expression] = child :: Nil
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  override def nullable: Boolean = true
-
-  override def dataType: DataType = resultType
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+  override def prettyName: String = "stddev_samp"
 
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
+  override protected val momentOrder = 2
 
-  private lazy val resultType = DoubleType
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Double = {
+    require(moments.length == momentOrder + 1,
+      s"$prettyName requires ${momentOrder + 1} central moment, received: 
${moments.length}")
 
-  private lazy val count = AttributeReference("count", resultType)()
-  private lazy val avg = AttributeReference("avg", resultType)()
-  private lazy val mk = AttributeReference("mk", resultType)()
+    if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0))
+  }
+}
 
-  override lazy val aggBufferAttributes = count :: avg :: mk :: Nil
+case class StddevPop(
+    child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends CentralMomentAgg(child) {
 
-  override lazy val initialValues: Seq[Expression] = Seq(
-    /* count = */ Cast(Literal(0), resultType),
-    /* avg = */ Cast(Literal(0), resultType),
-    /* mk = */ Cast(Literal(0), resultType)
-  )
+  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, 
inputAggBufferOffset = 0)
 
-  override lazy val updateExpressions: Seq[Expression] = {
-    val value = Cast(child, resultType)
-    val newCount = count + Cast(Literal(1), resultType)
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 
-    // update average
-    // avg = avg + (value - avg)/count
-    val newAvg = avg + (value - avg) / newCount
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-    // update sum ofference from mean
-    // Mk = Mk + (value - preAvg) * (value - updatedAvg)
-    val newMk = mk + (value - avg) * (value - newAvg)
+  override def prettyName: String = "stddev_pop"
 
-    Seq(
-      /* count = */ If(IsNull(child), count, newCount),
-      /* avg = */ If(IsNull(child), avg, newAvg),
-      /* mk = */ If(IsNull(child), mk, newMk)
-    )
-  }
+  override protected val momentOrder = 2
 
-  override lazy val mergeExpressions: Seq[Expression] = {
-
-    // count merge
-    val newCount = count.left + count.right
-
-    // average merge
-    val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / 
newCount
-
-    // update sum of square differences
-    val newMk = {
-      val avgDelta = avg.right - avg.left
-      val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / 
newCount
-      mk.left + mk.right + mkDelta
-    }
-
-    Seq(
-      /* count = */ If(IsNull(count.left), count.right,
-                       If(IsNull(count.right), count.left, newCount)),
-      /* avg = */ If(IsNull(avg.left), avg.right,
-                     If(IsNull(avg.right), avg.left, newAvg)),
-      /* mk = */ If(IsNull(mk.left), mk.right,
-                    If(IsNull(mk.right), mk.left, newMk))
-    )
-  }
+  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): 
Double = {
+    require(moments.length == momentOrder + 1,
+      s"$prettyName requires ${momentOrder + 1} central moment, received: 
${moments.length}")
 
-  override lazy val evaluateExpression: Expression = {
-    // when count == 0, return null
-    // when count == 1, return 0
-    // when count >1
-    // stddev_samp = sqrt (mk/(count -1))
-    // stddev_pop = sqrt (mk/count)
-    val varCol =
-      if (isSample) {
-        mk / Cast(count - Cast(Literal(1), resultType), resultType)
-      } else {
-        mk / count
-      }
-
-    If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), 
resultType),
-      If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), 
resultType),
-        Cast(Sqrt(varCol), resultType)))
+    if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d292f748/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index b6330e2..53cc6e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -397,7 +397,7 @@ object functions extends LegacyFunctions {
   def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) }
 
   /**
-   * Aggregate function: returns the unbiased sample standard deviation of
+   * Aggregate function: returns the sample standard deviation of
    * the expression in a group.
    *
    * @group agg_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/d292f748/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index eb1ee26..432e8d1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -195,7 +195,7 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
   }
 
   test("stddev") {
-    val testData2ADev = math.sqrt(4 / 5.0)
+    val testData2ADev = math.sqrt(4.0 / 5.0)
     checkAnswer(
       testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
       Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev))
@@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
     val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
     checkAnswer(
     emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
-    Row(null, null, null))
+    Row(Double.NaN, Double.NaN, Double.NaN))
   }
 
   test("zero sum") {

http://git-wip-us.apache.org/repos/asf/spark/blob/d292f748/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index e4f23fe..35cdab5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     val emptyDescribeResult = Seq(
       Row("count", "0", "0"),
       Row("mean", null, null),
-      Row("stddev", null, null),
+      Row("stddev", "NaN", "NaN"),
       Row("min", null, null),
       Row("max", null, null))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d292f748/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 52a561d..167aea8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -314,13 +314,6 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
       testCodeGen(
         "SELECT min(key) FROM testData3x",
         Row(1) :: Nil)
-      // STDDEV
-      testCodeGen(
-        "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a",
-        (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25))))
-      testCodeGen(
-        "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2",
-        Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil)
       // Some combinations.
       testCodeGen(
         """
@@ -341,8 +334,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext 
{
         Row(100, 1, 50.5, 300, 100) :: Nil)
       // Aggregate with Code generation handling all null values
       testCodeGen(
-        "SELECT  sum('a'), avg('a'), stddev('a'), count(null) FROM testData",
-        Row(null, null, null, 0) :: Nil)
+        "SELECT  sum('a'), avg('a'), count(null) FROM testData",
+        Row(null, null, 0) :: Nil)
     } finally {
       sqlContext.dropTempTable("testData3x")
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to