[FLINK-4604] [table] Add support for standard deviation/variance This closes #3260.
Old PR: This closes #2762. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0af57fc1 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0af57fc1 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0af57fc1 Branch: refs/heads/master Commit: 0af57fc1800d7430843a1e14bb70168bbd750389 Parents: 6d0c4c3 Author: Anton Mushin <anton_mus...@epam.com> Authored: Fri Feb 3 14:06:49 2017 +0400 Committer: twalthr <twal...@apache.org> Committed: Wed May 3 14:43:06 2017 +0200 ---------------------------------------------------------------------- docs/dev/table_api.md | 47 ++++- .../flink/table/api/scala/expressionDsl.scala | 28 +++ .../table/expressions/ExpressionParser.scala | 43 +++- .../flink/table/expressions/aggregations.scala | 65 ++++++ .../aggfunctions/Sum0AggFunction.scala | 91 +++++++++ .../flink/table/plan/rules/FlinkRuleSets.scala | 3 + .../rules/dataSet/DataSetAggregateRule.scala | 8 +- .../DataSetAggregateWithNullValuesRule.scala | 9 +- .../flink/table/validate/FunctionCatalog.scala | 10 + .../table/api/java/batch/sql/SqlITCase.java | 53 ++++- .../scala/batch/sql/AggregationsITCase.scala | 200 +++++++++++++++++++ .../scala/batch/table/AggregationsITCase.scala | 74 +++++++ .../batch/utils/TableProgramsTestBase.scala | 5 + 13 files changed, 625 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/docs/dev/table_api.md ---------------------------------------------------------------------- diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md index 2b777c6..ec32e9b 100644 --- a/docs/dev/table_api.md +++ b/docs/dev/table_api.md @@ -1066,7 +1066,7 @@ dataType = "BYTE" | "SHORT" | "INT" | "LONG" | "FLOAT" | "DOUBLE" | "BOOLEAN" | as = composite , ".as(" , fieldReference , ")" ; -aggregation = composite , ( ".sum" | ".min" | ".max" | ".count" | ".avg" | ".start" | ".end" ) , [ "()" ] ; +aggregation = composite , ( ".sum" | ".min" | ".max" | ".count" | ".avg" | ".start" | ".end" | ".stddev_pop" | ".stddev_samp" | ".var_pop" | ".var_samp" ) , [ "()" ] ; if = composite , ".?(" , expression , "," , expression , ")" ; @@ -5233,7 +5233,7 @@ AVG(numeric) <p>Returns the average (arithmetic mean) of <i>numeric</i> across all input values.</p> </td> </tr> - + <tr> <td> {% highlight text %} @@ -5376,6 +5376,49 @@ ELEMENT(ARRAY) <p>Returns the sole element of an array with a single element. Returns <code>null</code> if the array is empty. Throws an exception if the array has more than one element.</p> </td> </tr> +<tr> + <td> + {% highlight text %} +STDDEV_POP(value) +{% endhighlight %} + </td> + <td> + <p>Returns the standard deviation of numeric <i>value</i></p> + </td> + </tr> + +<tr> + <td> + {% highlight text %} +STDDEV_SAMP(value) +{% endhighlight %} + </td> + <td> + <p>Returns the sample standard deviation of numeric <i>value</i></p> + </td> + </tr> + + <tr> + <td> + {% highlight text %} +VAR_POP(value) +{% endhighlight %} + </td> + <td> + <p>Returns the variance of numeric <i>value</i></p> + </td> + </tr> + + <tr> + <td> + {% highlight text %} +VAR_SAMP (value) +{% endhighlight %} + </td> + <td> + <p>Returns the sample variance of numeric <i>value</i></p> + </td> + </tr> </tbody> </table> http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala index a512098..b1d16b3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala @@ -165,6 +165,12 @@ trait ImplicitExpressionOperations { def sum = Sum(expr) /** + * Returns the sum of the values which go into it like [[Sum]]. + * It differs in that when no non null values are applied zero is returned instead of null. + */ + def sum0 = Sum0(expr) + + /** * Returns the minimum value of field across all input values. */ def min = Min(expr) @@ -185,6 +191,28 @@ trait ImplicitExpressionOperations { def avg = Avg(expr) /** + * Returns the population standard deviation of an expression. + * (the square root of [[VarPop]]) + */ + def stddev_pop = StddevPop(expr) + + /** + * Returns the sample standard deviation of an expression. + * (the square root of [[VarSamp]]). + */ + def stddev_samp = StddevSamp(expr) + + /** + * Returns the population standard variance of an expression. + */ + def var_pop = VarPop(expr) + + /** + * Returns the sample variance of a given expression. + */ + def var_samp = VarSamp(expr) + + /** * Converts a value to a given type. * * e.g. "42".cast(Types.INT) leads to 42. http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala index 113b85a..1356416 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala @@ -57,6 +57,11 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val SUM: Keyword = Keyword("sum") lazy val START: Keyword = Keyword("start") lazy val END: Keyword = Keyword("end") + lazy val SUM0: Keyword = Keyword("sum0") + lazy val STDDEV_POP: Keyword = Keyword("stddev_pop") + lazy val STDDEV_SAMP: Keyword = Keyword("stddev_samp") + lazy val VAR_POP: Keyword = Keyword("var_pop") + lazy val VAR_SAMP: Keyword = Keyword("var_samp") lazy val CAST: Keyword = Keyword("cast") lazy val NULL: Keyword = Keyword("Null") lazy val IF: Keyword = Keyword("?") @@ -95,7 +100,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val ASIN: Keyword = Keyword("asin") def functionIdent: ExpressionParser.Parser[String] = - not(ARRAY) ~ not(AS) ~ not(COUNT) ~ not(AVG) ~ not(MIN) ~ not(MAX) ~ + not(ARRAY) ~ not(AS) ~ not(COUNT) ~ not(AVG) ~ not(MIN) ~ not(MAX) ~ not(STDDEV_POP) ~ + not(STDDEV_SAMP) ~ not(VAR_SAMP) ~ not(VAR_POP) ~ not(SUM) ~ not(START) ~ not(END)~ not(CAST) ~ not(NULL) ~ not(IF) ~ not(CURRENT_ROW) ~ not(UNBOUNDED_ROW) ~ not(CURRENT_RANGE) ~ not(UNBOUNDED_RANGE) ~> super.ident @@ -205,6 +211,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val suffixSum: PackratParser[Expression] = composite <~ "." ~ SUM ~ opt("()") ^^ { e => Sum(e) } + lazy val suffixSum0: PackratParser[Expression] = + composite <~ "." ~ SUM0 ~ opt("()") ^^ { e => Sum0(e) } + lazy val suffixMin: PackratParser[Expression] = composite <~ "." ~ MIN ~ opt("()") ^^ { e => Min(e) } @@ -223,6 +232,17 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val suffixEnd: PackratParser[Expression] = composite <~ "." ~ END ~ opt("()") ^^ { e => WindowEnd(e) } + lazy val suffixStddevPop: PackratParser[Expression] = + composite <~ "." ~ STDDEV_POP ~ opt("()") ^^ { e => StddevPop(e) } + + lazy val suffixStddevSamp: PackratParser[Expression] = + composite <~ "." ~ STDDEV_SAMP ~ opt("()") ^^ { e => StddevSamp(e) } + + lazy val suffixVarSamp: PackratParser[Expression] = + composite <~ "." ~ VAR_SAMP ~ opt("()") ^^ { e => VarSamp(e) } + + lazy val suffixVarPop: PackratParser[Expression] = + composite <~ "." ~ VAR_POP ~ opt("()") ^^ { e => VarPop(e) } lazy val suffixCast: PackratParser[Expression] = composite ~ "." ~ CAST ~ "(" ~ dataType ~ ")" ^^ { case e ~ _ ~ _ ~ _ ~ dt ~ _ => Cast(e, dt) @@ -320,8 +340,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { composite <~ "." ~ ASIN ~ opt("()") ^^ { e => Asin(e) } lazy val suffixed: PackratParser[Expression] = - suffixTimeInterval | suffixRowInterval | suffixStart | suffixEnd | suffixAgg | + suffixTimeInterval | suffixRowInterval | suffixStart | suffixSum0 | suffixEnd | suffixAgg | suffixCast | suffixAs | suffixTrim | suffixTrimWithoutArgs | suffixIf | suffixAsc | + suffixStddevPop | suffixStddevSamp | suffixVarPop | suffixVarSamp | suffixDesc | suffixToDate | suffixToTimestamp | suffixToTime | suffixExtract | suffixFloor | suffixCeil | suffixGet | suffixFlattening | suffixAsin | suffixFunctionCall | suffixFunctionCallOneArg // function call must always be at the end @@ -334,6 +355,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val prefixSum: PackratParser[Expression] = SUM ~ "(" ~> expression <~ ")" ^^ { e => Sum(e) } + lazy val prefixSum0: PackratParser[Expression] = + SUM0 ~ "(" ~> expression <~ ")" ^^ { e => Sum0(e) } + lazy val prefixMin: PackratParser[Expression] = MIN ~ "(" ~> expression <~ ")" ^^ { e => Min(e) } @@ -352,6 +376,17 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val prefixEnd: PackratParser[Expression] = END ~ "(" ~> expression <~ ")" ^^ { e => WindowEnd(e) } + lazy val prefixStddevPop: PackratParser[Expression] = + STDDEV_POP ~ "(" ~> expression <~ ")" ^^ { e => StddevPop(e) } + + lazy val prefixStddevSamp: PackratParser[Expression] = + STDDEV_SAMP ~ "(" ~> expression <~ ")" ^^ { e => StddevSamp(e) } + + lazy val prefixVarSamp: PackratParser[Expression] = + VAR_SAMP ~ "(" ~> expression <~ ")" ^^ { e => VarSamp(e) } + + lazy val prefixVarPop: PackratParser[Expression] = + VAR_POP ~ "(" ~> expression <~ ")" ^^ { e => VarPop(e) } lazy val prefixCast: PackratParser[Expression] = CAST ~ "(" ~ expression ~ "," ~ dataType ~ ")" ^^ { case _ ~ _ ~ e ~ _ ~ dt ~ _ => Cast(e, dt) @@ -411,7 +446,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { ASIN ~ "(" ~> composite <~ ")" ^^ { e => Asin(e) } lazy val prefixed: PackratParser[Expression] = - prefixArray | prefixAgg | prefixStart | prefixEnd | prefixCast | prefixAs | prefixTrim | + prefixArray | prefixSum | prefixSum0 | prefixAgg | prefixStart | prefixEnd | + prefixCast | prefixAs | prefixTrim | prefixStddevPop | prefixStddevSamp | prefixVarSamp | + prefixVarPop | prefixTrimWithoutArgs | prefixIf | prefixExtract | prefixFloor | prefixCeil | prefixGet | prefixFlattening | prefixAsin | prefixFunctionCall | prefixFunctionCallOneArg // function call must always be at the end http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala index 4ef5209..82f4428 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala @@ -66,6 +66,19 @@ case class Sum(child: Expression) extends Aggregation { } } +case class Sum0(child: Expression) extends Aggregation { + override def toString = s"sum0($child)" + + override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.SUM0, false, null, name, child.toRexNode) + } + + override private[flink] def resultType = child.resultType + + override private[flink] def validateInput = + TypeCheckUtils.assertNumericExpr(child.resultType, "sum0") +} + case class Min(child: Expression) extends Aggregation { override def toString = s"min($child)" @@ -130,3 +143,55 @@ case class Avg(child: Expression) extends Aggregation { new SqlAvgAggFunction(AVG) } } + +case class StddevPop(child: Expression) extends Aggregation { + override def toString = s"stddev_pop($child)" + + override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.STDDEV_POP, false, null, name, child.toRexNode) + } + + override private[flink] def resultType = child.resultType + + override private[flink] def validateInput = + TypeCheckUtils.assertNumericExpr(child.resultType, "stddev_pop") +} + +case class StddevSamp(child: Expression) extends Aggregation { + override def toString = s"stddev_samp($child)" + + override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.STDDEV_SAMP, false, null, name, child.toRexNode) + } + + override private[flink] def resultType = child.resultType + + override private[flink] def validateInput = + TypeCheckUtils.assertNumericExpr(child.resultType, "stddev_samp") +} + +case class VarPop(child: Expression) extends Aggregation { + override def toString = s"var_pop($child)" + + override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.VAR_POP, false, null, name, child.toRexNode) + } + + override private[flink] def resultType = child.resultType + + override private[flink] def validateInput = + TypeCheckUtils.assertNumericExpr(child.resultType, "var_pop") +} + +case class VarSamp(child: Expression) extends Aggregation { + override def toString = s"var_samp($child)" + + override private[flink] def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + relBuilder.aggregateCall(SqlStdOperatorTable.VAR_SAMP, false, null, name, child.toRexNode) + } + + override private[flink] def resultType = child.resultType + + override private[flink] def validateInput = + TypeCheckUtils.assertNumericExpr(child.resultType, "var_samp") +} http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/Sum0AggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/Sum0AggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/Sum0AggFunction.scala new file mode 100644 index 0000000..6a24fbe --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/Sum0AggFunction.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions.aggfunctions + +import java.math.BigDecimal + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.table.functions.Accumulator + +abstract class Sum0AggFunction[T: Numeric] extends SumAggFunction[T] { + + override def getValue(accumulator: Accumulator): T = { + val a = accumulator.asInstanceOf[SumAccumulator[T]] + if (a.f1) { + a.f0 + } else { + 0.asInstanceOf[T] + } + } +} + +/** + * Built-in Byte Sum0 aggregate function + */ +class ByteSum0AggFunction extends Sum0AggFunction[Byte] { + override def getValueTypeInfo = BasicTypeInfo.BYTE_TYPE_INFO +} + +/** + * Built-in Short Sum0 aggregate function + */ +class ShortSum0AggFunction extends Sum0AggFunction[Short] { + override def getValueTypeInfo = BasicTypeInfo.SHORT_TYPE_INFO +} + +/** + * Built-in Int Sum0 aggregate function + */ +class IntSum0AggFunction extends Sum0AggFunction[Int] { + override def getValueTypeInfo = BasicTypeInfo.INT_TYPE_INFO +} + +/** + * Built-in Long Sum0 aggregate function + */ +class LongSum0AggFunction extends Sum0AggFunction[Long] { + override def getValueTypeInfo = BasicTypeInfo.LONG_TYPE_INFO +} + +/** + * Built-in Float Sum0 aggregate function + */ +class FloatSum0AggFunction extends Sum0AggFunction[Float] { + override def getValueTypeInfo = BasicTypeInfo.FLOAT_TYPE_INFO +} + +/** + * Built-in Double Sum0 aggregate function + */ +class DoubleSum0AggFunction extends Sum0AggFunction[Double] { + override def getValueTypeInfo = BasicTypeInfo.DOUBLE_TYPE_INFO +} + +/** + * Built-in Big Decimal Sum0 aggregate function + */ +class DecimalSum0AggFunction extends DecimalSumAggFunction { + + override def getValue(accumulator: Accumulator): BigDecimal = { + if (!accumulator.asInstanceOf[DecimalSumAccumulator].f1) { + 0.asInstanceOf[BigDecimal] + } else { + accumulator.asInstanceOf[DecimalSumAccumulator].f0 + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index 0bee4e5..6ebbae4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -74,6 +74,9 @@ object FlinkRuleSets { // expand distinct aggregate to normal aggregate with groupby AggregateExpandDistinctAggregatesRule.JOIN, + //aggregate reduce rule (deviation/variance functions) + AggregateReduceFunctionsRule.INSTANCE, + // remove unnecessary sort rule SortRemoveRule.INSTANCE, http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala index b4d5bc9..faaeb97 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala @@ -21,6 +21,7 @@ package org.apache.flink.table.plan.rules.dataSet import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.convert.ConverterRule +import org.apache.calcite.sql.SqlKind import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.dataset.{DataSetAggregate, DataSetUnion} import org.apache.flink.table.plan.nodes.logical.FlinkLogicalAggregate @@ -54,7 +55,12 @@ class DataSetAggregateRule // check if we have distinct aggregates val distinctAggs = agg.getAggCallList.exists(_.isDistinct) - !distinctAggs + val supported = agg.getAggCallList.map(_.getAggregation.getKind).forall { + case SqlKind.STDDEV_POP | SqlKind.STDDEV_SAMP | SqlKind.VAR_POP | SqlKind.VAR_SAMP => false + case _ => true + } + + !distinctAggs && supported } override def convert(rel: RelNode): RelNode = { http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala index d183e60..9636980 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala @@ -17,11 +17,13 @@ */ package org.apache.flink.table.plan.rules.dataSet + import com.google.common.collect.ImmutableList import org.apache.calcite.plan._ import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.convert.ConverterRule import org.apache.calcite.rex.RexLiteral +import org.apache.calcite.sql.SqlKind import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.dataset.DataSetAggregate import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalAggregate, FlinkLogicalUnion, FlinkLogicalValues} @@ -51,7 +53,12 @@ class DataSetAggregateWithNullValuesRule // check if we have distinct aggregates val distinctAggs = agg.getAggCallList.exists(_.isDistinct) - !distinctAggs + val supported = agg.getAggCallList.map(_.getAggregation.getKind).forall { + case SqlKind.STDDEV_POP | SqlKind.STDDEV_SAMP | SqlKind.VAR_POP | SqlKind.VAR_SAMP => false + case _ => true + } + + !distinctAggs && supported } override def convert(rel: RelNode): RelNode = { http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala index 729ad48..cb37ce4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala @@ -169,6 +169,11 @@ object FunctionCatalog { "max" -> classOf[Max], "min" -> classOf[Min], "sum" -> classOf[Sum], + "sum0" -> classOf[Sum0], + "stddev_pop" -> classOf[StddevPop], + "stddev_samp" -> classOf[StddevSamp], + "var_pop" -> classOf[VarPop], + "var_samp" -> classOf[VarSamp], // string functions "charLength" -> classOf[CharLength], @@ -305,10 +310,15 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable { SqlStdOperatorTable.GROUPING_ID, // AGGREGATE OPERATORS SqlStdOperatorTable.SUM, + SqlStdOperatorTable.SUM0, SqlStdOperatorTable.COUNT, SqlStdOperatorTable.MIN, SqlStdOperatorTable.MAX, SqlStdOperatorTable.AVG, + SqlStdOperatorTable.STDDEV_POP, + SqlStdOperatorTable.STDDEV_SAMP, + SqlStdOperatorTable.VAR_POP, + SqlStdOperatorTable.VAR_SAMP, // ARRAY OPERATORS SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, SqlStdOperatorTable.ITEM, http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/SqlITCase.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/SqlITCase.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/SqlITCase.java index 114226c..a5d2021 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/SqlITCase.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/SqlITCase.java @@ -20,6 +20,7 @@ package org.apache.flink.table.api.java.batch.sql; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.typeutils.MapTypeInfo; @@ -36,9 +37,12 @@ import org.apache.flink.test.javaApiOperators.util.CollectionDataSets; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import scala.collection.JavaConversions; +import scala.collection.mutable.Buffer; import java.util.ArrayList; import java.util.Collections; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -136,7 +140,7 @@ public class SqlITCase extends TableProgramsCollectionTestBase { DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds2 = CollectionDataSets.get5TupleDataSet(env); tableEnv.registerDataSet("t1", ds1, "a, b, c"); - tableEnv.registerDataSet("t2",ds2, "d, e, f, g, h"); + tableEnv.registerDataSet("t2", ds2, "d, e, f, g, h"); String sqlQuery = "SELECT c, g FROM t1, t2 WHERE b = e"; Table result = tableEnv.sql(sqlQuery); @@ -156,9 +160,7 @@ public class SqlITCase extends TableProgramsCollectionTestBase { rows.add(new Tuple2<>(1, Collections.singletonMap("foo", "bar"))); rows.add(new Tuple2<>(2, Collections.singletonMap("foo", "spam"))); - TypeInformation<Tuple2<Integer, Map<String, String>>> ty = new TupleTypeInfo<>( - BasicTypeInfo.INT_TYPE_INFO, - new MapTypeInfo<>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO)); + TypeInformation<Tuple2<Integer, Map<String, String>>> ty = new TupleTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO, new MapTypeInfo<>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO)); DataSet<Tuple2<Integer, Map<String, String>>> ds1 = env.fromCollection(rows, ty); tableEnv.registerDataSet("t1", ds1, "a, b"); @@ -171,4 +173,47 @@ public class SqlITCase extends TableProgramsCollectionTestBase { String expected = "bar\n" + "spam\n"; compareResultAsText(results, expected); } + + @Test + public void testDeviationAggregation() throws Exception { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config()); + + DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env); + tableEnv.registerDataSet("AggTable", ds, "x, y, z"); + + Buffer<String> columnForAgg = JavaConversions.asScalaBuffer(Arrays.asList("x, y".split(","))); + + String sqlQuery = getSelectQueryFromTemplate("AVG(?),STDDEV_POP(?),STDDEV_SAMP(?),VAR_POP(?),VAR_SAMP(?)", columnForAgg, "AggTable"); + Table result = tableEnv.sql(sqlQuery); + + String sqlQuery1 = getSelectQueryFromTemplate("SUM(?)/COUNT(?), " + + "SQRT( (SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?)), " + + "SQRT( (SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / CASE COUNT(?) WHEN 1 THEN NULL ELSE COUNT(?) - 1 END), " + + "(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?), " + + "(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / CASE COUNT(?) WHEN 1 THEN NULL ELSE COUNT(?) - 1 END", columnForAgg, "AggTable"); + + Table expected = tableEnv.sql(sqlQuery1); + + DataSet<Row> resultSet = tableEnv.toDataSet(result, Row.class); + List<Row> results = resultSet.collect(); + + DataSet<Row> expectedResultSet = tableEnv.toDataSet(expected, Row.class); + String expectedResults = expectedResultSet.map(new MapFunction<Row, Object>() { + @Override + public Object map(Row value) throws Exception { + StringBuilder stringBuffer = new StringBuilder(); + + int arityCount = value.getArity(); + + for (int i = 0; i < arityCount; i++) { + Object product = value.getField(i); + stringBuffer.append(Double.valueOf(product.toString()).intValue()).append(","); + } + return stringBuffer.substring(0, stringBuffer.length() - 1); + } + }).collect().get(0).toString(); + + compareResultAsText(results, expectedResults); + } } http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala index 600c15b..4fbd734 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala @@ -295,6 +295,206 @@ class AggregationsITCase( TestBaseUtils.compareResultAsText(results3.asJava, expected3) } + + + @Test + def testSqrtOfAggregatedSet(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements((1.0f, 1), (2.0f, 2)).toTable(tEnv) + + tEnv.registerTable("MyTable", ds) + + val sqlQuery = "SELECT " + + "SQRT((SUM(a * a) - SUM(a) * SUM(a) / COUNT(a)) / COUNT(a)) " + + "from (select _1 as a from MyTable)" + + val expected = "0.5" + val results = tEnv.sql(sqlQuery).toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testStddevPopAggregate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds) + val columns = Array("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("STDDEV_POP(?)")(columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + val expectedResult = "0,0,0,0,0.5,0.5" + TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult) + } + + @Test + def testStddevPopAggregateWithOtherAggreagteSUM0(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds) + val columns = Array("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("STDDEV_POP(?), " + + "$sum0(?), " + + "avg(?), " + + "max(?), " + + "min(?), " + + "count(?)" ) (columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + + val expectedResult = + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0.5,3.0,1.5,2.0,1.0,2," + + "0.5,3.0,1.5,2.0,1.0,2" + TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult) + } + + @Test + def testStddevPopAggregateWithOtherAggreagte(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds) + val columns = Array("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("STDDEV_POP(?), " + + "sum(?), " + + "avg(?), " + + "max(?), " + + "min(?), " + + "count(?)" )(columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + + val expectedResult = + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0,3,1,2,1,2," + + "0.5,3.0,1.5,2.0,1.0,2," + + "0.5,3.0,1.5,2.0,1.0,2" + TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult) + } + + @Test + def testStddevSampAggregate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds1) + val columns = Seq("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("STDDEV_SAMP(?)")(columns,"myTable") + val sqlExpectedQuery = getSelectQueryFromTemplate( + "SQRT((SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / " + + "CASE " + + "COUNT(?) WHEN 1 THEN NULL " + + "ELSE COUNT(?) - 1 " + + "END)")(columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + .head + .toString + .split(",").map(x=>"%.5f".format(x.toFloat)) + + val expectedResult = tEnv.sql(sqlExpectedQuery).toDataSet[Row].collect() + .head + .toString + .split(",").map(x=>"%.5f".format(x.toFloat)) + + Assert.assertEquals(expectedResult.mkString(","), actualResult.mkString(",")) + } + + @Test + def testVarPopAggregate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds) + val columns = Seq("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("var_pop(?)")(columns,"myTable") + val sqlExpectedQuery = getSelectQueryFromTemplate( + "(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / COUNT(?)")(columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + val expectedResult = tEnv.sql(sqlExpectedQuery) + .toDataSet[Row] + .collect().head + .toString + TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult) + } + + @Test + def testVarSampAggregate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements( + (1: Byte, 1 : Short, 1, 1L, 1F, 1D), + (2: Byte, 2 : Short, 2, 2L, 2F, 2D)).toTable(tEnv) + tEnv.registerTable("myTable", ds) + val columns = Seq("_1","_2","_3","_4","_5","_6") + + val sqlQuery = getSelectQueryFromTemplate("var_samp(?)")(columns,"myTable") + val sqlExpectedQuery = getSelectQueryFromTemplate( + "(SUM(? * ?) - SUM(?) * SUM(?) / COUNT(?)) / CASE COUNT(?) WHEN 1 THEN NULL " + + "ELSE COUNT(?) - 1 END")(columns,"myTable") + + val actualResult = tEnv.sql(sqlQuery).toDataSet[Row].collect() + val expectedResult = tEnv.sql(sqlExpectedQuery) + .toDataSet[Row] + .collect().head + .toString + TestBaseUtils.compareOrderedResultAsText(actualResult.asJava, expectedResult) + } + + @Test + def testSumNullElements(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val sqlQuery = getSelectQueryFromTemplate("$sum0(?)")( + Seq("_1","_2","_3","_4","_5","_6"), + "(select * from MyTable where _1 = 4)" + ) + + val ds = env.fromElements( + (1: Byte, 2L,1D,1F,1,1:Short ), + (2: Byte, 2L,1D,1F,1,1:Short )) + tEnv.registerDataSet("MyTable", ds) + + val result = tEnv.sql(sqlQuery) + + val expected = "null,null,null,null,null,null" + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + @Test def testTumbleWindowAggregate(): Unit = { http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala index 4838747..050f5f4 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala @@ -339,6 +339,80 @@ class AggregationsITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } + @Test + def testAnalyticAggregation(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv) + .select('_1.stddev_pop, '_1.stddev_samp, '_1.var_pop, '_1.var_samp) + val results = t.toDataSet[Row].collect() + val expected = "6,6,36,38" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testSQLStyleAnalyticAggregation(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .select( + """stddev_pop(a) as a1, a.stddev_pop as a2, + |stddev_samp (a) as b1, a.stddev_samp as b2, + |var_pop (a) as c1, a.var_pop as c2, + |var_samp (a) as d1, a.var_samp as d2 + """.stripMargin) + val expected = "6,6,6,6,36,36,38,38" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testWorkingAnalyticAggregationDataTypes(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val ds = env.fromElements( + (1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d), + (2: Byte, 2: Short, 2, 2L, 2.0f, 2.0d)).toTable(tEnv) + val res = ds.select('_1.stddev_pop, '_2.stddev_pop, '_3.stddev_pop, + '_4.stddev_pop, '_5.stddev_pop, '_6.stddev_pop, + '_1.stddev_samp, '_2.stddev_samp, '_3.stddev_samp, + '_4.stddev_samp, '_5.stddev_samp, '_6.stddev_samp, + '_1.var_pop, '_2.var_pop, '_3.var_pop, + '_4.var_pop, '_5.var_pop, '_6.var_pop, + '_1.var_samp, '_2.var_samp, '_3.var_samp, + '_4.var_samp, '_5.var_samp, '_6.var_samp) + val expected = + "0,0,0," + + "0,0.5,0.5," + + "1,1,1," + + "1,0.70710677,0.7071067811865476," + + "0,0,0," + + "0,0.25,0.25," + + "1,1,1," + + "1,0.5,0.5" + val results = res.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testPojoAnalyticAggregation(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val input = env.fromElements( + MyWC("hello", 1), + MyWC("hello", 8), + MyWC("ciao", 3), + MyWC("hola", 1), + MyWC("hola", 8)) + val expr = input.toTable(tEnv) + val result = expr + .groupBy('word) + .select('word, 'frequency.stddev_pop) + .toDataSet[MyWC] + val mappedResult = result.map(w => (w.word, w.frequency)).collect() + val expected = "(hola,3)\n(ciao,0)\n(hello,3)" + TestBaseUtils.compareResultAsText(mappedResult.asJava, expected) + } } case class WC(word: String, frequency: Long) http://git-wip-us.apache.org/repos/asf/flink/blob/0af57fc1/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/utils/TableProgramsTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/utils/TableProgramsTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/utils/TableProgramsTestBase.scala index cf9d947..ee8e1f5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/utils/TableProgramsTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/utils/TableProgramsTestBase.scala @@ -37,6 +37,11 @@ class TableProgramsTestBase( } conf } + + def getSelectQueryFromTemplate(selectBlock: String) + (columnsName: Seq[String], table :String): String = { + s"SELECT ${columnsName.map(x=>selectBlock.replace("?",x)).mkString(",")} FROM $table" + } } object TableProgramsTestBase {