Repository: spark Updated Branches: refs/heads/master e0fc9c7e5 -> 3bd6f5d2a
[SPARK-11490][SQL] variance should alias var_samp instead of var_pop. stddev is an alias for stddev_samp. variance should be consistent with stddev. Also took the chance to remove internal Stddev and Variance, and only kept StddevSamp/StddevPop and VarianceSamp/VariancePop. Author: Reynold Xin <r...@databricks.com> Closes #9449 from rxin/SPARK-11490. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3bd6f5d2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3bd6f5d2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3bd6f5d2 Branch: refs/heads/master Commit: 3bd6f5d2ae503468de0e218d51c331e249a862bb Parents: e0fc9c7 Author: Reynold Xin <r...@databricks.com> Authored: Wed Nov 4 09:34:52 2015 -0800 Committer: Yin Huai <yh...@databricks.com> Committed: Wed Nov 4 09:34:52 2015 -0800 ---------------------------------------------------------------------- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 2 - .../apache/spark/sql/catalyst/dsl/package.scala | 8 ---- .../expressions/aggregate/functions.scala | 29 ------------- .../catalyst/expressions/aggregate/utils.scala | 12 ------ .../sql/catalyst/expressions/aggregates.scala | 45 +++++--------------- .../scala/org/apache/spark/sql/DataFrame.scala | 2 +- .../org/apache/spark/sql/GroupedData.scala | 4 +- .../scala/org/apache/spark/sql/functions.scala | 9 ++-- .../spark/sql/DataFrameAggregateSuite.scala | 17 +++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 14 +++--- 11 files changed, 32 insertions(+), 114 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 24c1a7b..d4334d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -187,11 +187,11 @@ object FunctionRegistry { expression[Max]("max"), expression[Average]("mean"), expression[Min]("min"), - expression[Stddev]("stddev"), + expression[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), - expression[Variance]("variance"), + expression[VarianceSamp]("variance"), expression[VariancePop]("var_pop"), expression[VarianceSamp]("var_samp"), expression[Skewness]("skewness"), http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 3c67567..84e2b13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -297,10 +297,8 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) - case Variance(e @ StringType()) => Variance(Cast(e, DoubleType)) case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 787f67a..d8df664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -159,14 +159,6 @@ package object dsl { def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) - def stddev(e: Expression): Expression = Stddev(e) - def stddev_pop(e: Expression): Expression = StddevPop(e) - def stddev_samp(e: Expression): Expression = StddevSamp(e) - def variance(e: Expression): Expression = Variance(e) - def var_pop(e: Expression): Expression = VariancePop(e) - def var_samp(e: Expression): Expression = VarianceSamp(e) - def skewness(e: Expression): Expression = Skewness(e) - def kurtosis(e: Expression): Expression = Kurtosis(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index f2c3eca..10dc5e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -328,13 +328,6 @@ case class Min(child: Expression) extends DeclarativeAggregate { override val evaluateExpression = min } -// Compute the sample standard deviation of a column -case class Stddev(child: Expression) extends StddevAgg(child) { - - override def isSample: Boolean = true - override def prettyName: String = "stddev" -} - // Compute the population standard deviation of a column case class StddevPop(child: Expression) extends StddevAgg(child) { @@ -1274,28 +1267,6 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w } } -case class Variance(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "variance" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - - if (n == 0.0) Double.NaN else moments(2) / n - } -} - case class VarianceSamp(child: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 564174f..644c621 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -97,12 +97,6 @@ object Utils { mode = aggregate.Complete, isDistinct = false) - case expressions.Stddev(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Stddev(child), - mode = aggregate.Complete, - isDistinct = false) - case expressions.StddevPop(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.StddevPop(child), @@ -139,12 +133,6 @@ object Utils { mode = aggregate.Complete, isDistinct = false) - case expressions.Variance(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Variance(child), - mode = aggregate.Complete, - isDistinct = false) - case expressions.VariancePop(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.VariancePop(child), http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index bf59660..89d63ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -785,13 +785,6 @@ abstract class StddevAgg1(child: Expression) extends UnaryExpression with Partia } -// Compute the sample standard deviation of a column -case class Stddev(child: Expression) extends StddevAgg1(child) { - - override def toString: String = s"STDDEV($child)" - override def isSample: Boolean = true -} - // Compute the population standard deviation of a column case class StddevPop(child: Expression) extends StddevAgg1(child) { @@ -807,20 +800,21 @@ case class StddevSamp(child: Expression) extends StddevAgg1(child) { } case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = false - override def dataType: DataType = ArrayType(DoubleType) - override def toString: String = s"computePartialStddev($child)" - override def newInstance(): ComputePartialStdFunction = - new ComputePartialStdFunction(child, this) + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = false + override def dataType: DataType = ArrayType(DoubleType) + override def toString: String = s"computePartialStddev($child)" + override def newInstance(): ComputePartialStdFunction = + new ComputePartialStdFunction(child, this) } case class ComputePartialStdFunction ( expr: Expression, base: AggregateExpression1 -) extends AggregateFunction1 { + ) extends AggregateFunction1 { + def this() = this(null, null) // Required for serialization private val computeType = DoubleType @@ -1049,25 +1043,6 @@ case class Skewness(child: Expression) extends UnaryExpression with AggregateExp } // placeholder -case class Variance(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "variance" - - override def toString: String = s"VARIANCE($child)" -} - -// placeholder case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 { override def newInstance(): AggregateFunction1 = { http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index fc0ab63..5e9c7ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1383,7 +1383,7 @@ class DataFrame private[sql]( val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, - "stddev" -> Stddev, + "stddev" -> StddevSamp, "min" -> Min, "max" -> Max) http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index c2b2a40..7cf66b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -96,10 +96,10 @@ class GroupedData protected[sql]( case "avg" | "average" | "mean" => Average case "max" => Max case "min" => Min - case "stddev" | "std" => Stddev + case "stddev" | "std" => StddevSamp case "stddev_pop" => StddevPop case "stddev_samp" => StddevSamp - case "variance" => Variance + case "variance" => VarianceSamp case "var_pop" => VariancePop case "var_samp" => VarianceSamp case "sum" => Sum http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c8c5283..c70c965 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -329,13 +329,12 @@ object functions { def skewness(e: Column): Column = Skewness(e.expr) /** - * Aggregate function: returns the unbiased sample standard deviation of - * the expression in a group. + * Aggregate function: alias for [[stddev_samp]]. * * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = Stddev(e.expr) + def stddev(e: Column): Column = StddevSamp(e.expr) /** * Aggregate function: returns the unbiased sample standard deviation of @@ -388,12 +387,12 @@ object functions { def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) /** - * Aggregate function: returns the population variance of the values in a group. + * Aggregate function: alias for [[var_samp]]. * * @group agg_funcs * @since 1.6.0 */ - def variance(e: Column): Column = Variance(e.expr) + def variance(e: Column): Column = VarianceSamp(e.expr) /** * Aggregate function: returns the unbiased variance of the values in a group. http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 9b23977..b0e2ffa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -226,23 +226,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val absTol = 1e-8 val sparkVariance = testData2.agg(variance('a)) - val expectedVariance = Row(4.0 / 6.0) - checkAggregatesWithTol(sparkVariance, expectedVariance, absTol) + checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) val sparkVariancePop = testData2.agg(var_pop('a)) - checkAggregatesWithTol(sparkVariancePop, expectedVariance, absTol) + checkAggregatesWithTol(sparkVariancePop, Row(4.0 / 6.0), absTol) val sparkVarianceSamp = testData2.agg(var_samp('a)) - val expectedVarianceSamp = Row(4.0 / 5.0) - checkAggregatesWithTol(sparkVarianceSamp, expectedVarianceSamp, absTol) + checkAggregatesWithTol(sparkVarianceSamp, Row(4.0 / 5.0), absTol) val sparkSkewness = testData2.agg(skewness('a)) - val expectedSkewness = Row(0.0) - checkAggregatesWithTol(sparkSkewness, expectedSkewness, absTol) + checkAggregatesWithTol(sparkSkewness, Row(0.0), absTol) val sparkKurtosis = testData2.agg(kurtosis('a)) - val expectedKurtosis = Row(-1.5) - checkAggregatesWithTol(sparkKurtosis, expectedKurtosis, absTol) - + checkAggregatesWithTol(sparkKurtosis, Row(-1.5), absTol) } test("zero moments") { @@ -251,7 +246,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(variance('a)), - Row(0.0)) + Row(Double.NaN)) checkAnswer( emptyTableData.agg(var_samp('a)), http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6388a8b..5731a35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -536,7 +536,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), - Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, 1, 6, 3) + Row(0, -1.5, 1, 3, 2, 1.0, 1, 6, 3) ) } @@ -757,7 +757,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("variance") { val absTol = 1e-8 val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2") - val expectedAnswer = Row(4.0 / 6.0) + val expectedAnswer = Row(0.8) checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } @@ -784,16 +784,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("stddev agg") { checkAnswer( - sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0)))) } test("variance agg") { val absTol = 1e-8 - val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b)" + - "FROM testData2 GROUP BY a") - val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 4.0, 1.0 / 2.0, 1.0 / 4.0)) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + checkAggregatesWithTol( + sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)), + absTol) } test("skewness and kurtosis agg") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org