[FLINK-6361] [table] Refactor the AggregateFunction interface and built-in aggregates.
This closes #3762. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/bc6409d6 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/bc6409d6 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/bc6409d6 Branch: refs/heads/master Commit: bc6409d624df54c2309c8bdb767f95de74ea1475 Parents: fe01892 Author: shaoxuan-wang <wshaox...@gmail.com> Authored: Tue Apr 25 00:28:37 2017 +0800 Committer: Fabian Hueske <fhue...@apache.org> Committed: Tue Apr 25 14:21:05 2017 +0200 ---------------------------------------------------------------------- .../org/apache/flink/table/api/Types.scala | 3 +- .../flink/table/codegen/CodeGenerator.scala | 60 +++--- .../table/functions/AggregateFunction.scala | 157 +++++++------- .../functions/aggfunctions/AvgAggFunction.scala | 206 ++++++++----------- .../aggfunctions/CountAggFunction.scala | 39 ++-- .../functions/aggfunctions/MaxAggFunction.scala | 48 ++--- .../MaxAggFunctionWithRetract.scala | 86 ++++---- .../functions/aggfunctions/MinAggFunction.scala | 48 ++--- .../MinAggFunctionWithRetract.scala | 86 ++++---- .../functions/aggfunctions/SumAggFunction.scala | 89 ++++---- .../SumWithRetractAggFunction.scala | 107 +++++----- .../table/runtime/aggregate/AggregateUtil.scala | 96 ++++++--- .../aggregate/GeneratedAggregations.scala | 32 +++ .../aggfunctions/AggFunctionTestBase.scala | 62 +++--- .../aggfunctions/AvgFunctionTest.scala | 23 ++- .../aggfunctions/CountAggFunctionTest.scala | 6 +- .../aggfunctions/MaxAggFunctionTest.scala | 44 ++-- .../MaxWithRetractAggFunctionTest.scala | 47 +++-- .../aggfunctions/MinAggFunctionTest.scala | 45 ++-- .../MinWithRetractAggFunctionTest.scala | 47 +++-- .../aggfunctions/SumAggFunctionTest.scala | 31 +-- .../SumWithRetractAggFunctionTest.scala | 31 ++- ...ProcessingOverRangeProcessFunctionTest.scala | 28 ++- 23 files changed, 772 insertions(+), 649 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala index 262a452..d82b990 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/Types.scala @@ -17,9 +17,8 @@ */ package org.apache.flink.table.api -import org.apache.flink.api.common.typeinfo.{Types, TypeInformation} +import org.apache.flink.api.common.typeinfo.{Types => JTypes, TypeInformation} import org.apache.flink.table.typeutils.TimeIntervalTypeInfo -import org.apache.flink.api.common.typeinfo.{Types => JTypes} /** * This class enumerates all supported types of the Table API. http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 510a870..298fb70 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -265,14 +265,16 @@ class CodeGenerator( name: String, generator: CodeGenerator, inputType: RelDataType, - aggregates: Array[AggregateFunction[_ <: Any]], + aggregates: Array[AggregateFunction[_ <: Any, _ <: Any]], aggFields: Array[Array[Int]], aggMapping: Array[Int], partialResults: Boolean, fwdMapping: Array[Int], mergeMapping: Option[Array[Int]], constantFlags: Option[Array[(Int, Boolean)]], - outputArity: Int) + outputArity: Int, + needRetract: Boolean, + needMerge: Boolean) : GeneratedAggregationsFunction = { // get unique function name @@ -364,9 +366,16 @@ class CodeGenerator( | ${parameters(i)});""".stripMargin }.mkString("\n") - j"""$sig { - |$retract - | }""".stripMargin + if (needRetract) { + j""" + |$sig { + |$retract + | }""".stripMargin + } else { + j""" + |$sig { + | }""".stripMargin + } } def genCreateAccumulators: String = { @@ -471,11 +480,9 @@ class CodeGenerator( j""" | ${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i); | ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)}); - | accList$i.set(0, aAcc$i); - | accList$i.set(1, bAcc$i); - | a.setField( - | $i, - | ${aggs(i)}.merge(accList$i)); + | accIt$i.setElement(bAcc$i); + | ${aggs(i)}.merge(aAcc$i, accIt$i); + | a.setField($i, aAcc$i); """.stripMargin }.mkString("\n") val ret: String = @@ -483,29 +490,27 @@ class CodeGenerator( | return a; """.stripMargin - j""" - |$sig { - |$merge - |$ret - | }""".stripMargin + if (needMerge) { + j""" + |$sig { + |$merge + |$ret + | }""".stripMargin + } else { + j""" + |$sig { + |$ret + | }""".stripMargin + } } def genMergeList: String = { { + val singleIterableClass = "org.apache.flink.table.runtime.aggregate.SingleElementIterable" for (i <- accTypes.indices) yield j""" - | private final java.util.ArrayList<${accTypes(i)}> accList$i = - | new java.util.ArrayList<${accTypes(i)}>(2); - """.stripMargin - }.mkString("\n") - } - - def initMergeList: String = { - { - for (i <- accTypes.indices) yield - j""" - | accList$i.add(${aggs(i)}.createAccumulator()); - | accList$i.add(${aggs(i)}.createAccumulator()); + | private final $singleIterableClass<${accTypes(i)}> accIt$i = + | new $singleIterableClass<${accTypes(i)}>(); """.stripMargin }.mkString("\n") } @@ -538,7 +543,6 @@ class CodeGenerator( | $genMergeList | public $funcName() throws Exception { | ${reuseInitCode()} - | $initMergeList | } | ${reuseConstructorCode(funcName)} | http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala index a67ccaa..7a74112 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala @@ -17,36 +17,100 @@ */ package org.apache.flink.table.functions -import java.util.{List => JList} - -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.table.api.TableException - /** * Base class for User-Defined Aggregates. * - * @tparam T the type of the aggregation result + * The behavior of an [[AggregateFunction]] can be defined by implementing a series of custom + * methods. An [[AggregateFunction]] needs at least three methods: + * - createAccumulator, + * - accumulate, and + * - getValue. + * + * There are a few other methods that can be optional to have: + * - retract, + * - merge, + * - resetAccumulator, and + * - getAccumulatorType. + * + * All these methods muse be declared publicly, not static and named exactly as the names + * mentioned above. The methods createAccumulator and getValue are defined in the + * [[AggregateFunction]] functions, while other methods are explained below. + * + * + * {{{ + * Processes the input values and update the provided accumulator instance. The method + * accumulate can be overloaded with different custom types and arguments. An AggregateFunction + * requires at least one accumulate() method. + * + * @param accumulator the accumulator which contains the current aggregated results + * @param [user defined inputs] the input value (usually obtained from a new arrived data). + * + * def accumulate(accumulator: ACC, [user defined inputs]): Unit + * }}} + * + * + * {{{ + * Retracts the input values from the accumulator instance. The current design assumes the + * inputs are the values that have been previously accumulated. The method retract can be + * overloaded with different custom types and arguments. This function must be implemented for + * datastream bounded over aggregate. + * + * @param accumulator the accumulator which contains the current aggregated results + * @param [user defined inputs] the input value (usually obtained from a new arrived data). + * + * def retract(accumulator: ACC, [user defined inputs]): Unit + * }}} + * + * + * {{{ + * Merges a group of accumulator instances into one accumulator instance. This function must be + * implemented for datastream session window grouping aggregate and dataset grouping aggregate. + * + * @param accumulator the accumulator which will keep the merged aggregate results. It should + * be noted that the accumulator may contain the previous aggregated + * results. Therefore user should not replace or clean this instance in the + * custom merge method. + * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be + * merged. + + * def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit + * }}} + * + * + * {{{ + * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for + * dataset grouping aggregate. + * + * @param accumulator the accumulator which needs to be reset + + * def resetAccumulator(accumulator: ACC): Unit + * }}} + * + * + * {{{ + * Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]] of the accumulator. This + * function is optional and can be implemented if the accumulator type cannot automatically + * inferred from the instance returned by createAccumulator method. + * + * @return the type information for the accumulator. + + * def getAccumulatorType: TypeInformation[_] + * }}} + * + * + * @tparam T the type of the aggregation result + * @tparam ACC base class for aggregate Accumulator. The accumulator is used to keep the aggregated + * values which are needed to compute an aggregation result. AggregateFunction + * represents its state using accumulator, thereby the state of the AggregateFunction + * must be put into the accumulator. */ -abstract class AggregateFunction[T] extends UserDefinedFunction { +abstract class AggregateFunction[T, ACC] extends UserDefinedFunction { /** * Creates and init the Accumulator for this [[AggregateFunction]]. * * @return the accumulator with the initial value */ - def createAccumulator(): Accumulator - - /** - * Retracts the input values from the accumulator instance. The current design assumes the - * inputs are the values that have been previously accumulated. - * - * @param accumulator the accumulator which contains the current - * aggregated results - * @param input the input value (usually obtained from a new arrived data) - */ - def retract(accumulator: Accumulator, input: Any): Unit = { - throw TableException("Retract is an optional method. There is no default implementation. You " + - "must implement one for yourself.") - } + def createAccumulator(): ACC /** * Called every time when an aggregation result should be materialized. @@ -58,54 +122,5 @@ abstract class AggregateFunction[T] extends UserDefinedFunction { * aggregated results * @return the aggregation result */ - def getValue(accumulator: Accumulator): T - - /** - * Processes the input values and update the provided accumulator instance. - * - * @param accumulator the accumulator which contains the current - * aggregated results - * @param input the input value (usually obtained from a new arrived data) - */ - def accumulate(accumulator: Accumulator, input: Any): Unit - - /** - * Merges a list of accumulator instances into one accumulator instance. - * - * IMPORTANT: You may only return a new accumulator instance or the first accumulator of the - * input list. If you return another instance, the result of the aggregation function might be - * incorrect. - * - * @param accumulators the [[java.util.List]] of accumulators that will be merged - * @return the resulting accumulator - */ - def merge(accumulators: JList[Accumulator]): Accumulator - - /** - * Resets the Accumulator for this [[AggregateFunction]]. - * - * @param accumulator the accumulator which needs to be reset - */ - def resetAccumulator(accumulator: Accumulator): Unit - - /** - * Returns the [[TypeInformation]] of the accumulator. - * This function is optional and can be implemented if the accumulator type cannot automatically - * inferred from the instance returned by [[createAccumulator()]]. - * - * @return The type information for the accumulator. - */ - def getAccumulatorType: TypeInformation[_] = null + def getValue(accumulator: ACC): T } - -/** - * Base class for aggregate Accumulator. The accumulator is used to keep the - * aggregated values which are needed to compute an aggregation result. - * The state of the function must be put into the accumulator. - * - * TODO: We have the plan to have the accumulator and return types of - * functions dynamically provided by the users. This needs the refactoring - * of the AggregateFunction interface with the code generation. We will remove - * the [[Accumulator]] once codeGen for UDAGG is completed (FLINK-5813). - */ -trait Accumulator http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala index 4837139..3f4e5db 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala @@ -18,15 +18,15 @@ package org.apache.flink.table.functions.aggfunctions import java.math.{BigDecimal, BigInteger} -import java.util.{List => JList} +import java.lang.{Iterable => JIterable} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} import org.apache.flink.api.java.typeutils.TupleTypeInfo -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.functions.AggregateFunction /** The initial accumulator for Integral Avg aggregate function */ -class IntegralAvgAccumulator extends JTuple2[Long, Long] with Accumulator { +class IntegralAvgAccumulator extends JTuple2[Long, Long] { f0 = 0L //sum f1 = 0L //count } @@ -36,57 +36,51 @@ class IntegralAvgAccumulator extends JTuple2[Long, Long] with Accumulator { * * @tparam T the type for the aggregation result */ -abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] { +abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T, IntegralAvgAccumulator] { - override def createAccumulator(): Accumulator = { + override def createAccumulator(): IntegralAvgAccumulator = { new IntegralAvgAccumulator } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: IntegralAvgAccumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[Number].longValue() - val accum = accumulator.asInstanceOf[IntegralAvgAccumulator] - accum.f0 += v - accum.f1 += 1L + acc.f0 += v + acc.f1 += 1L } } - override def retract(accumulator: Accumulator, value: Any): Unit = { + def retract(acc: IntegralAvgAccumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[Number].longValue() - val accum = accumulator.asInstanceOf[IntegralAvgAccumulator] - accum.f0 -= v - accum.f1 -= 1L + acc.f0 -= v + acc.f1 -= 1L } } - override def getValue(accumulator: Accumulator): T = { - val accum = accumulator.asInstanceOf[IntegralAvgAccumulator] - if (accum.f1 == 0) { + override def getValue(acc: IntegralAvgAccumulator): T = { + if (acc.f1 == 0) { null.asInstanceOf[T] } else { - resultTypeConvert(accum.f0 / accum.f1) + resultTypeConvert(acc.f0 / acc.f1) } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0).asInstanceOf[IntegralAvgAccumulator] - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[IntegralAvgAccumulator] - ret.f1 += a.f1 - ret.f0 += a.f0 - i += 1 + def merge(acc: IntegralAvgAccumulator, its: JIterable[IntegralAvgAccumulator]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() + acc.f1 += a.f1 + acc.f0 += a.f0 } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[IntegralAvgAccumulator].f0 = 0L - accumulator.asInstanceOf[IntegralAvgAccumulator].f1 = 0L + def resetAccumulator(acc: IntegralAvgAccumulator): Unit = { + acc.f0 = 0L + acc.f1 = 0L } - override def getAccumulatorType: TypeInformation[_] = { + def getAccumulatorType: TypeInformation[_] = { new TupleTypeInfo( new IntegralAvgAccumulator().getClass, BasicTypeInfo.LONG_TYPE_INFO, @@ -126,7 +120,7 @@ class IntAvgAggFunction extends IntegralAvgAggFunction[Int] { /** The initial accumulator for Big Integral Avg aggregate function */ class BigIntegralAvgAccumulator - extends JTuple2[BigInteger, Long] with Accumulator { + extends JTuple2[BigInteger, Long] { f0 = BigInteger.ZERO //sum f1 = 0L //count } @@ -136,57 +130,52 @@ class BigIntegralAvgAccumulator * * @tparam T the type for the aggregation result */ -abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] { +abstract class BigIntegralAvgAggFunction[T] + extends AggregateFunction[T, BigIntegralAvgAccumulator] { - override def createAccumulator(): Accumulator = { + override def createAccumulator(): BigIntegralAvgAccumulator = { new BigIntegralAvgAccumulator } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: BigIntegralAvgAccumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[Long] - val a = accumulator.asInstanceOf[BigIntegralAvgAccumulator] - a.f0 = a.f0.add(BigInteger.valueOf(v)) - a.f1 += 1L + acc.f0 = acc.f0.add(BigInteger.valueOf(v)) + acc.f1 += 1L } } - override def retract(accumulator: Accumulator, value: Any): Unit = { + def retract(acc: BigIntegralAvgAccumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[Long] - val a = accumulator.asInstanceOf[BigIntegralAvgAccumulator] - a.f0 = a.f0.subtract(BigInteger.valueOf(v)) - a.f1 -= 1L + acc.f0 = acc.f0.subtract(BigInteger.valueOf(v)) + acc.f1 -= 1L } } - override def getValue(accumulator: Accumulator): T = { - val a = accumulator.asInstanceOf[BigIntegralAvgAccumulator] - if (a.f1 == 0) { + override def getValue(acc: BigIntegralAvgAccumulator): T = { + if (acc.f1 == 0) { null.asInstanceOf[T] } else { - resultTypeConvert(a.f0.divide(BigInteger.valueOf(a.f1))) + resultTypeConvert(acc.f0.divide(BigInteger.valueOf(acc.f1))) } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0).asInstanceOf[BigIntegralAvgAccumulator] - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[BigIntegralAvgAccumulator] - ret.f1 += a.f1 - ret.f0 = ret.f0.add(a.f0) - i += 1 + def merge(acc: BigIntegralAvgAccumulator, its: JIterable[BigIntegralAvgAccumulator]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() + acc.f1 += a.f1 + acc.f0 = acc.f0.add(a.f0) } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[BigIntegralAvgAccumulator].f0 = BigInteger.ZERO - accumulator.asInstanceOf[BigIntegralAvgAccumulator].f1 = 0 + def resetAccumulator(acc: BigIntegralAvgAccumulator): Unit = { + acc.f0 = BigInteger.ZERO + acc.f1 = 0 } - override def getAccumulatorType: TypeInformation[_] = { + def getAccumulatorType: TypeInformation[_] = { new TupleTypeInfo( new BigIntegralAvgAccumulator().getClass, BasicTypeInfo.BIG_INT_TYPE_INFO, @@ -212,7 +201,7 @@ class LongAvgAggFunction extends BigIntegralAvgAggFunction[Long] { } /** The initial accumulator for Floating Avg aggregate function */ -class FloatingAvgAccumulator extends JTuple2[Double, Long] with Accumulator { +class FloatingAvgAccumulator extends JTuple2[Double, Long] { f0 = 0 //sum f1 = 0L //count } @@ -222,57 +211,51 @@ class FloatingAvgAccumulator extends JTuple2[Double, Long] with Accumulator { * * @tparam T the type for the aggregation result */ -abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] { +abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T, FloatingAvgAccumulator] { - override def createAccumulator(): Accumulator = { + override def createAccumulator(): FloatingAvgAccumulator = { new FloatingAvgAccumulator } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: FloatingAvgAccumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[Number].doubleValue() - val accum = accumulator.asInstanceOf[FloatingAvgAccumulator] - accum.f0 += v - accum.f1 += 1L + acc.f0 += v + acc.f1 += 1L } } - override def retract(accumulator: Accumulator, value: Any): Unit = { + def retract(acc: FloatingAvgAccumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[Number].doubleValue() - val accum = accumulator.asInstanceOf[FloatingAvgAccumulator] - accum.f0 -= v - accum.f1 -= 1L + acc.f0 -= v + acc.f1 -= 1L } } - override def getValue(accumulator: Accumulator): T = { - val accum = accumulator.asInstanceOf[FloatingAvgAccumulator] - if (accum.f1 == 0) { + override def getValue(acc: FloatingAvgAccumulator): T = { + if (acc.f1 == 0) { null.asInstanceOf[T] } else { - resultTypeConvert(accum.f0 / accum.f1) + resultTypeConvert(acc.f0 / acc.f1) } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0).asInstanceOf[FloatingAvgAccumulator] - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[FloatingAvgAccumulator] - ret.f1 += a.f1 - ret.f0 += a.f0 - i += 1 + def merge(acc: FloatingAvgAccumulator, its: JIterable[FloatingAvgAccumulator]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() + acc.f1 += a.f1 + acc.f0 += a.f0 } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[FloatingAvgAccumulator].f0 = 0 - accumulator.asInstanceOf[FloatingAvgAccumulator].f1 = 0L + def resetAccumulator(acc: FloatingAvgAccumulator): Unit = { + acc.f0 = 0 + acc.f1 = 0L } - override def getAccumulatorType: TypeInformation[_] = { + def getAccumulatorType: TypeInformation[_] = { new TupleTypeInfo( new FloatingAvgAccumulator().getClass, BasicTypeInfo.DOUBLE_TYPE_INFO, @@ -304,8 +287,7 @@ class DoubleAvgAggFunction extends FloatingAvgAggFunction[Double] { } /** The initial accumulator for Big Decimal Avg aggregate function */ -class DecimalAvgAccumulator - extends JTuple2[BigDecimal, Long] with Accumulator { +class DecimalAvgAccumulator extends JTuple2[BigDecimal, Long] { f0 = BigDecimal.ZERO //sum f1 = 0L //count } @@ -313,57 +295,51 @@ class DecimalAvgAccumulator /** * Base class for built-in Big Decimal Avg aggregate function */ -class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] { +class DecimalAvgAggFunction extends AggregateFunction[BigDecimal, DecimalAvgAccumulator] { - override def createAccumulator(): Accumulator = { + override def createAccumulator(): DecimalAvgAccumulator = { new DecimalAvgAccumulator } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: DecimalAvgAccumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[BigDecimal] - val accum = accumulator.asInstanceOf[DecimalAvgAccumulator] - accum.f0 = accum.f0.add(v) - accum.f1 += 1L + acc.f0 = acc.f0.add(v) + acc.f1 += 1L } } - override def retract(accumulator: Accumulator, value: Any): Unit = { + def retract(acc: DecimalAvgAccumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[BigDecimal] - val accum = accumulator.asInstanceOf[DecimalAvgAccumulator] - accum.f0 = accum.f0.subtract(v) - accum.f1 -= 1L + acc.f0 = acc.f0.subtract(v) + acc.f1 -= 1L } } - override def getValue(accumulator: Accumulator): BigDecimal = { - val a = accumulator.asInstanceOf[DecimalAvgAccumulator] - if (a.f1 == 0) { + override def getValue(acc: DecimalAvgAccumulator): BigDecimal = { + if (acc.f1 == 0) { null.asInstanceOf[BigDecimal] } else { - a.f0.divide(BigDecimal.valueOf(a.f1)) + acc.f0.divide(BigDecimal.valueOf(acc.f1)) } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0).asInstanceOf[DecimalAvgAccumulator] - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[DecimalAvgAccumulator] - ret.f0 = ret.f0.add(a.f0) - ret.f1 += a.f1 - i += 1 + def merge(acc: DecimalAvgAccumulator, its: JIterable[DecimalAvgAccumulator]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() + acc.f0 = acc.f0.add(a.f0) + acc.f1 += a.f1 } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[DecimalAvgAccumulator].f0 = BigDecimal.ZERO - accumulator.asInstanceOf[DecimalAvgAccumulator].f1 = 0L + def resetAccumulator(acc: DecimalAvgAccumulator): Unit = { + acc.f0 = BigDecimal.ZERO + acc.f1 = 0L } - override def getAccumulatorType: TypeInformation[_] = { + def getAccumulatorType: TypeInformation[_] = { new TupleTypeInfo( new DecimalAvgAccumulator().getClass, BasicTypeInfo.BIG_DEC_TYPE_INFO, http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala index 231337a..77341cd 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala @@ -17,58 +17,55 @@ */ package org.apache.flink.table.functions.aggfunctions -import java.util.{List => JList} +import java.lang.{Iterable => JIterable} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1} import org.apache.flink.api.java.typeutils.TupleTypeInfo -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.functions.AggregateFunction /** The initial accumulator for count aggregate function */ -class CountAccumulator extends JTuple1[Long] with Accumulator { +class CountAccumulator extends JTuple1[Long] { f0 = 0L //count } /** * built-in count aggregate function */ -class CountAggFunction extends AggregateFunction[Long] { +class CountAggFunction extends AggregateFunction[Long, CountAccumulator] { - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: CountAccumulator, value: Any): Unit = { if (value != null) { - accumulator.asInstanceOf[CountAccumulator].f0 += 1L + acc.f0 += 1L } } - override def retract(accumulator: Accumulator, value: Any): Unit = { + def retract(acc: CountAccumulator, value: Any): Unit = { if (value != null) { - accumulator.asInstanceOf[CountAccumulator].f0 -= 1L + acc.f0 -= 1L } } - override def getValue(accumulator: Accumulator): Long = { - accumulator.asInstanceOf[CountAccumulator].f0 + override def getValue(acc: CountAccumulator): Long = { + acc.f0 } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0).asInstanceOf[CountAccumulator] - var i: Int = 1 - while (i < accumulators.size()) { - ret.f0 += accumulators.get(i).asInstanceOf[CountAccumulator].f0 - i += 1 + def merge(acc: CountAccumulator, its: JIterable[CountAccumulator]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + acc.f0 += iter.next().f0 } - ret } - override def createAccumulator(): Accumulator = { + override def createAccumulator(): CountAccumulator = { new CountAccumulator } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[CountAccumulator].f0 = 0L + def resetAccumulator(acc: CountAccumulator): Unit = { + acc.f0 = 0L } - override def getAccumulatorType(): TypeInformation[_] = { + def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo((new CountAccumulator).getClass, BasicTypeInfo.LONG_TYPE_INFO) } } http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala index 2e666fa..96ee8d1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala @@ -18,69 +18,65 @@ package org.apache.flink.table.functions.aggfunctions import java.math.BigDecimal -import java.util.{List => JList} +import java.lang.{Iterable => JIterable} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} import org.apache.flink.api.java.typeutils.TupleTypeInfo -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.functions.AggregateFunction /** The initial accumulator for Max aggregate function */ -class MaxAccumulator[T] extends JTuple2[T, Boolean] with Accumulator +class MaxAccumulator[T] extends JTuple2[T, Boolean] /** * Base class for built-in Max aggregate function * * @tparam T the type for the aggregation result */ -abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] { +abstract class MaxAggFunction[T](implicit ord: Ordering[T]) + extends AggregateFunction[T, MaxAccumulator[T]] { - override def createAccumulator(): Accumulator = { + override def createAccumulator(): MaxAccumulator[T] = { val acc = new MaxAccumulator[T] acc.f0 = getInitValue acc.f1 = false acc } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: MaxAccumulator[T], value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[MaxAccumulator[T]] - if (!a.f1 || ord.compare(a.f0, v) < 0) { - a.f0 = v - a.f1 = true + if (!acc.f1 || ord.compare(acc.f0, v) < 0) { + acc.f0 = v + acc.f1 = true } } } - override def getValue(accumulator: Accumulator): T = { - val a = accumulator.asInstanceOf[MaxAccumulator[T]] - if (a.f1) { - a.f0 + override def getValue(acc: MaxAccumulator[T]): T = { + if (acc.f1) { + acc.f0 } else { null.asInstanceOf[T] } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0) - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[MaxAccumulator[T]] + def merge(acc: MaxAccumulator[T], its: JIterable[MaxAccumulator[T]]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() if (a.f1) { - accumulate(ret.asInstanceOf[MaxAccumulator[T]], a.f0) + accumulate(acc, a.f0) } - i += 1 } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[MaxAccumulator[T]].f0 = getInitValue - accumulator.asInstanceOf[MaxAccumulator[T]].f1 = false + def resetAccumulator(acc: MaxAccumulator[T]): Unit = { + acc.f0 = getInitValue + acc.f1 = false } - override def getAccumulatorType(): TypeInformation[_] = { + def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new MaxAccumulator[T].getClass, getValueTypeInfo, http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala index 14ceba2..6f18739 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala @@ -18,15 +18,16 @@ package org.apache.flink.table.functions.aggfunctions import java.math.BigDecimal -import java.util.{HashMap => JHashMap, List => JList} +import java.util.{HashMap => JHashMap} +import java.lang.{Iterable => JIterable} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} import org.apache.flink.api.java.typeutils.{MapTypeInfo, TupleTypeInfo} -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.functions.AggregateFunction /** The initial accumulator for Max with retraction aggregate function */ -class MaxWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] with Accumulator +class MaxWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] /** * Base class for built-in Max with retraction aggregate function @@ -34,110 +35,105 @@ class MaxWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] with Ac * @tparam T the type for the aggregation result */ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T]) - extends AggregateFunction[T] { + extends AggregateFunction[T, MaxWithRetractAccumulator[T]] { - override def createAccumulator(): Accumulator = { + override def createAccumulator(): MaxWithRetractAccumulator[T] = { val acc = new MaxWithRetractAccumulator[T] acc.f0 = getInitValue //max acc.f1 = new JHashMap[T, Long]() //store the count for each value acc } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: MaxWithRetractAccumulator[T], value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[MaxWithRetractAccumulator[T]] - if (a.f1.size() == 0 || (ord.compare(a.f0, v) < 0)) { - a.f0 = v + if (acc.f1.size() == 0 || (ord.compare(acc.f0, v) < 0)) { + acc.f0 = v } - if (!a.f1.containsKey(v)) { - a.f1.put(v, 1L) + if (!acc.f1.containsKey(v)) { + acc.f1.put(v, 1L) } else { - var count = a.f1.get(v) + var count = acc.f1.get(v) count += 1L - a.f1.put(v, count) + acc.f1.put(v, count) } } } - override def retract(accumulator: Accumulator, value: Any): Unit = { + def retract(acc: MaxWithRetractAccumulator[T], value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[MaxWithRetractAccumulator[T]] - var count = a.f1.get(v) + var count = acc.f1.get(v) count -= 1L if (count == 0) { //remove the key v from the map if the number of appearance of the value v is 0 - a.f1.remove(v) + acc.f1.remove(v) //if the total count is 0, we could just simply set the f0(max) to the initial value - if (a.f1.size() == 0) { - a.f0 = getInitValue + if (acc.f1.size() == 0) { + acc.f0 = getInitValue return } //if v is the current max value, we have to iterate the map to find the 2nd biggest // value to replace v as the max value - if (v == a.f0) { - val iterator = a.f1.keySet().iterator() + if (v == acc.f0) { + val iterator = acc.f1.keySet().iterator() var key = iterator.next() - a.f0 = key + acc.f0 = key while (iterator.hasNext()) { key = iterator.next() - if (ord.compare(a.f0, key) < 0) { - a.f0 = key + if (ord.compare(acc.f0, key) < 0) { + acc.f0 = key } } } } else { - a.f1.put(v, count) + acc.f1.put(v, count) } } } - override def getValue(accumulator: Accumulator): T = { - val a = accumulator.asInstanceOf[MaxWithRetractAccumulator[T]] - if (a.f1.size() != 0) { - a.f0 + override def getValue(acc: MaxWithRetractAccumulator[T]): T = { + if (acc.f1.size() != 0) { + acc.f0 } else { null.asInstanceOf[T] } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0).asInstanceOf[MaxWithRetractAccumulator[T]] - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[MaxWithRetractAccumulator[T]] + def merge(acc: MaxWithRetractAccumulator[T], + its: JIterable[MaxWithRetractAccumulator[T]]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() if (a.f1.size() != 0) { // set max element - if (ord.compare(ret.f0, a.f0) < 0) { - ret.f0 = a.f0 + if (ord.compare(acc.f0, a.f0) < 0) { + acc.f0 = a.f0 } // merge the count for each key val iterator = a.f1.keySet().iterator() while (iterator.hasNext()) { val key = iterator.next() - if (ret.f1.containsKey(key)) { - ret.f1.put(key, ret.f1.get(key) + a.f1.get(key)) + if (acc.f1.containsKey(key)) { + acc.f1.put(key, acc.f1.get(key) + a.f1.get(key)) } else { - ret.f1.put(key, a.f1.get(key)) + acc.f1.put(key, a.f1.get(key)) } } } - i += 1 } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[MaxWithRetractAccumulator[T]].f0 = getInitValue - accumulator.asInstanceOf[MaxWithRetractAccumulator[T]].f1.clear() + def resetAccumulator(acc: MaxWithRetractAccumulator[T]): Unit = { + acc.f0 = getInitValue + acc.f1.clear() } - override def getAccumulatorType(): TypeInformation[_] = { + def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new MaxWithRetractAccumulator[T].getClass, getValueTypeInfo, http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala index 75a8ebc..88d7afd 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala @@ -18,69 +18,65 @@ package org.apache.flink.table.functions.aggfunctions import java.math.BigDecimal -import java.util.{List => JList} +import java.lang.{Iterable => JIterable} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} import org.apache.flink.api.java.typeutils.TupleTypeInfo -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.functions.AggregateFunction /** The initial accumulator for Min aggregate function */ -class MinAccumulator[T] extends JTuple2[T, Boolean] with Accumulator +class MinAccumulator[T] extends JTuple2[T, Boolean] /** * Base class for built-in Min aggregate function * * @tparam T the type for the aggregation result */ -abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] { +abstract class MinAggFunction[T](implicit ord: Ordering[T]) + extends AggregateFunction[T, MinAccumulator[T]] { - override def createAccumulator(): Accumulator = { + override def createAccumulator(): MinAccumulator[T] = { val acc = new MinAccumulator[T] acc.f0 = getInitValue acc.f1 = false acc } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: MinAccumulator[T], value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[MinAccumulator[T]] - if (!a.f1 || ord.compare(a.f0, v) > 0) { - a.f0 = v - a.f1 = true + if (!acc.f1 || ord.compare(acc.f0, v) > 0) { + acc.f0 = v + acc.f1 = true } } } - override def getValue(accumulator: Accumulator): T = { - val a = accumulator.asInstanceOf[MinAccumulator[T]] - if (a.f1) { - a.f0 + override def getValue(acc: MinAccumulator[T]): T = { + if (acc.f1) { + acc.f0 } else { null.asInstanceOf[T] } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0) - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[MinAccumulator[T]] + def merge(acc: MinAccumulator[T], its: JIterable[MinAccumulator[T]]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() if (a.f1) { - accumulate(ret.asInstanceOf[MinAccumulator[T]], a.f0) + accumulate(acc, a.f0) } - i += 1 } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[MinAccumulator[T]].f0 = getInitValue - accumulator.asInstanceOf[MinAccumulator[T]].f1 = false + def resetAccumulator(acc: MinAccumulator[T]): Unit = { + acc.f0 = getInitValue + acc.f1 = false } - override def getAccumulatorType(): TypeInformation[_] = { + def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new MinAccumulator[T].getClass, getValueTypeInfo, http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala index 6f2c3a1..2d3348b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala @@ -18,15 +18,16 @@ package org.apache.flink.table.functions.aggfunctions import java.math.BigDecimal -import java.util.{HashMap => JHashMap, List => JList} +import java.util.{HashMap => JHashMap} +import java.lang.{Iterable => JIterable} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} import org.apache.flink.api.java.typeutils.{MapTypeInfo, TupleTypeInfo} -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.functions.AggregateFunction /** The initial accumulator for Min with retraction aggregate function */ -class MinWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] with Accumulator +class MinWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] /** * Base class for built-in Min with retraction aggregate function @@ -34,110 +35,105 @@ class MinWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] with Ac * @tparam T the type for the aggregation result */ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T]) - extends AggregateFunction[T] { + extends AggregateFunction[T, MinWithRetractAccumulator[T]] { - override def createAccumulator(): Accumulator = { + override def createAccumulator(): MinWithRetractAccumulator[T] = { val acc = new MinWithRetractAccumulator[T] acc.f0 = getInitValue //min acc.f1 = new JHashMap[T, Long]() //store the count for each value acc } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: MinWithRetractAccumulator[T], value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[MinWithRetractAccumulator[T]] - if (a.f1.size() == 0 || (ord.compare(a.f0, v) > 0)) { - a.f0 = v + if (acc.f1.size() == 0 || (ord.compare(acc.f0, v) > 0)) { + acc.f0 = v } - if (!a.f1.containsKey(v)) { - a.f1.put(v, 1L) + if (!acc.f1.containsKey(v)) { + acc.f1.put(v, 1L) } else { - var count = a.f1.get(v) + var count = acc.f1.get(v) count += 1L - a.f1.put(v, count) + acc.f1.put(v, count) } } } - override def retract(accumulator: Accumulator, value: Any): Unit = { + def retract(acc: MinWithRetractAccumulator[T], value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[MinWithRetractAccumulator[T]] - var count = a.f1.get(v) + var count = acc.f1.get(v) count -= 1L if (count == 0) { //remove the key v from the map if the number of appearance of the value v is 0 - a.f1.remove(v) + acc.f1.remove(v) //if the total count is 0, we could just simply set the f0(min) to the initial value - if (a.f1.size() == 0) { - a.f0 = getInitValue + if (acc.f1.size() == 0) { + acc.f0 = getInitValue return } //if v is the current min value, we have to iterate the map to find the 2nd smallest // value to replace v as the min value - if (v == a.f0) { - val iterator = a.f1.keySet().iterator() + if (v == acc.f0) { + val iterator = acc.f1.keySet().iterator() var key = iterator.next() - a.f0 = key + acc.f0 = key while (iterator.hasNext()) { key = iterator.next() - if (ord.compare(a.f0, key) > 0) { - a.f0 = key + if (ord.compare(acc.f0, key) > 0) { + acc.f0 = key } } } } else { - a.f1.put(v, count) + acc.f1.put(v, count) } } } - override def getValue(accumulator: Accumulator): T = { - val a = accumulator.asInstanceOf[MinWithRetractAccumulator[T]] - if (a.f1.size() != 0) { - a.f0 + override def getValue(acc: MinWithRetractAccumulator[T]): T = { + if (acc.f1.size() != 0) { + acc.f0 } else { null.asInstanceOf[T] } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0).asInstanceOf[MinWithRetractAccumulator[T]] - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[MinWithRetractAccumulator[T]] + def merge(acc: MinWithRetractAccumulator[T], + its: JIterable[MinWithRetractAccumulator[T]]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() if (a.f1.size() != 0) { // set min element - if (ord.compare(ret.f0, a.f0) > 0) { - ret.f0 = a.f0 + if (ord.compare(acc.f0, a.f0) > 0) { + acc.f0 = a.f0 } // merge the count for each key val iterator = a.f1.keySet().iterator() while (iterator.hasNext()) { val key = iterator.next() - if (ret.f1.containsKey(key)) { - ret.f1.put(key, ret.f1.get(key) + a.f1.get(key)) + if (acc.f1.containsKey(key)) { + acc.f1.put(key, acc.f1.get(key) + a.f1.get(key)) } else { - ret.f1.put(key, a.f1.get(key)) + acc.f1.put(key, a.f1.get(key)) } } } - i += 1 } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[MinWithRetractAccumulator[T]].f0 = getInitValue - accumulator.asInstanceOf[MinWithRetractAccumulator[T]].f1.clear() + def resetAccumulator(acc: MinWithRetractAccumulator[T]): Unit = { + acc.f0 = getInitValue + acc.f1.clear() } - override def getAccumulatorType(): TypeInformation[_] = { + def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( new MinWithRetractAccumulator[T].getClass, getValueTypeInfo, http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala index 8ee9862..55996ac 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala @@ -18,70 +18,65 @@ package org.apache.flink.table.functions.aggfunctions import java.math.BigDecimal -import java.util.{List => JList} +import java.lang.{Iterable => JIterable} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} import org.apache.flink.api.java.typeutils.TupleTypeInfo -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.functions.AggregateFunction /** The initial accumulator for Sum aggregate function */ -class SumAccumulator[T] extends JTuple2[T, Boolean] with Accumulator +class SumAccumulator[T] extends JTuple2[T, Boolean] /** * Base class for built-in Sum aggregate function * * @tparam T the type for the aggregation result */ -abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] { +abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T, SumAccumulator[T]] { private val numeric = implicitly[Numeric[T]] - override def createAccumulator(): Accumulator = { + override def createAccumulator(): SumAccumulator[T] = { val acc = new SumAccumulator[T]() acc.f0 = numeric.zero //sum acc.f1 = false acc } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(accumulator: SumAccumulator[T], value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[SumAccumulator[T]] - a.f0 = numeric.plus(v, a.f0) - a.f1 = true + accumulator.f0 = numeric.plus(v, accumulator.f0) + accumulator.f1 = true } } - override def getValue(accumulator: Accumulator): T = { - val a = accumulator.asInstanceOf[SumAccumulator[T]] - if (a.f1) { - a.f0 + override def getValue(accumulator: SumAccumulator[T]): T = { + if (accumulator.f1) { + accumulator.f0 } else { null.asInstanceOf[T] } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0).asInstanceOf[SumAccumulator[T]] - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[SumAccumulator[T]] + def merge(acc: SumAccumulator[T], its: JIterable[SumAccumulator[T]]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() if (a.f1) { - ret.f0 = numeric.plus(ret.f0, a.f0) - ret.f1 = true + acc.f0 = numeric.plus(acc.f0, a.f0) + acc.f1 = true } - i += 1 } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[SumAccumulator[T]].f0 = numeric.zero - accumulator.asInstanceOf[SumAccumulator[T]].f1 = false + def resetAccumulator(acc: SumAccumulator[T]): Unit = { + acc.f0 = numeric.zero + acc.f1 = false } - override def getAccumulatorType(): TypeInformation[_] = { + def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( (new SumAccumulator).getClass, getValueTypeInfo, @@ -134,7 +129,7 @@ class DoubleSumAggFunction extends SumAggFunction[Double] { } /** The initial accumulator for Big Decimal Sum aggregate function */ -class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] with Accumulator { +class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] { f0 = BigDecimal.ZERO f1 = false } @@ -142,49 +137,45 @@ class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] with Accumulato /** * Built-in Big Decimal Sum aggregate function */ -class DecimalSumAggFunction extends AggregateFunction[BigDecimal] { +class DecimalSumAggFunction extends AggregateFunction[BigDecimal, DecimalSumAccumulator] { - override def createAccumulator(): Accumulator = { + override def createAccumulator(): DecimalSumAccumulator = { new DecimalSumAccumulator } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: DecimalSumAccumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[BigDecimal] - val accum = accumulator.asInstanceOf[DecimalSumAccumulator] - accum.f0 = accum.f0.add(v) - accum.f1 = true + acc.f0 = acc.f0.add(v) + acc.f1 = true } } - override def getValue(accumulator: Accumulator): BigDecimal = { - if (!accumulator.asInstanceOf[DecimalSumAccumulator].f1) { + override def getValue(acc: DecimalSumAccumulator): BigDecimal = { + if (!acc.f1) { null.asInstanceOf[BigDecimal] } else { - accumulator.asInstanceOf[DecimalSumAccumulator].f0 + acc.f0 } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0).asInstanceOf[DecimalSumAccumulator] - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[DecimalSumAccumulator] + def merge(acc: DecimalSumAccumulator, its: JIterable[DecimalSumAccumulator]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() if (a.f1) { - ret.f0 = ret.f0.add(a.f0) - ret.f1 = true + acc.f0 = acc.f0.add(a.f0) + acc.f1 = true } - i += 1 } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[DecimalSumAccumulator].f0 = BigDecimal.ZERO - accumulator.asInstanceOf[DecimalSumAccumulator].f1 = false + def resetAccumulator(acc: DecimalSumAccumulator): Unit = { + acc.f0 = BigDecimal.ZERO + acc.f1 = false } - override def getAccumulatorType(): TypeInformation[_] = { + def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( (new DecimalSumAccumulator).getClass, BasicTypeInfo.BIG_DEC_TYPE_INFO, http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala index 928be11..7f68d11 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala @@ -18,77 +18,73 @@ package org.apache.flink.table.functions.aggfunctions import java.math.BigDecimal -import java.util.{List => JList} +import java.lang.{Iterable => JIterable} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} import org.apache.flink.api.java.typeutils.TupleTypeInfo -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.functions.AggregateFunction /** The initial accumulator for Sum with retract aggregate function */ -class SumWithRetractAccumulator[T] extends JTuple2[T, Long] with Accumulator +class SumWithRetractAccumulator[T] extends JTuple2[T, Long] /** * Base class for built-in Sum with retract aggregate function * * @tparam T the type for the aggregation result */ -abstract class SumWithRetractAggFunction[T: Numeric] extends AggregateFunction[T] { +abstract class SumWithRetractAggFunction[T: Numeric] + extends AggregateFunction[T, SumWithRetractAccumulator[T]] { private val numeric = implicitly[Numeric[T]] - override def createAccumulator(): Accumulator = { + override def createAccumulator(): SumWithRetractAccumulator[T] = { val acc = new SumWithRetractAccumulator[T]() acc.f0 = numeric.zero //sum acc.f1 = 0L //total count acc } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: SumWithRetractAccumulator[T], value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[SumWithRetractAccumulator[T]] - a.f0 = numeric.plus(a.f0, v) - a.f1 += 1 + acc.f0 = numeric.plus(acc.f0, v) + acc.f1 += 1 } } - override def retract(accumulator: Accumulator, value: Any): Unit = { + def retract(acc: SumWithRetractAccumulator[T], value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[T] - val a = accumulator.asInstanceOf[SumWithRetractAccumulator[T]] - a.f0 = numeric.minus(a.f0, v) - a.f1 -= 1 + acc.f0 = numeric.minus(acc.f0, v) + acc.f1 -= 1 } } - override def getValue(accumulator: Accumulator): T = { - val a = accumulator.asInstanceOf[SumWithRetractAccumulator[T]] - if (a.f1 > 0) { - a.f0 + override def getValue(acc: SumWithRetractAccumulator[T]): T = { + if (acc.f1 > 0) { + acc.f0 } else { null.asInstanceOf[T] } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0).asInstanceOf[SumWithRetractAccumulator[T]] - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[SumWithRetractAccumulator[T]] - ret.f0 = numeric.plus(ret.f0, a.f0) - ret.f1 += a.f1 - i += 1 + def merge(acc: SumWithRetractAccumulator[T], + its: JIterable[SumWithRetractAccumulator[T]]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() + acc.f0 = numeric.plus(acc.f0, a.f0) + acc.f1 += a.f1 } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[SumWithRetractAccumulator[T]].f0 = numeric.zero - accumulator.asInstanceOf[SumWithRetractAccumulator[T]].f1 = 0L + def resetAccumulator(acc: SumWithRetractAccumulator[T]): Unit = { + acc.f0 = numeric.zero + acc.f1 = 0L } - override def getAccumulatorType(): TypeInformation[_] = { + def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( (new SumWithRetractAccumulator).getClass, getValueTypeInfo, @@ -141,7 +137,7 @@ class DoubleSumWithRetractAggFunction extends SumWithRetractAggFunction[Double] } /** The initial accumulator for Big Decimal Sum with retract aggregate function */ -class DecimalSumWithRetractAccumulator extends JTuple2[BigDecimal, Long] with Accumulator { +class DecimalSumWithRetractAccumulator extends JTuple2[BigDecimal, Long] { f0 = BigDecimal.ZERO f1 = 0L } @@ -149,56 +145,53 @@ class DecimalSumWithRetractAccumulator extends JTuple2[BigDecimal, Long] with Ac /** * Built-in Big Decimal Sum with retract aggregate function */ -class DecimalSumWithRetractAggFunction extends AggregateFunction[BigDecimal] { +class DecimalSumWithRetractAggFunction + extends AggregateFunction[BigDecimal, DecimalSumWithRetractAccumulator] { - override def createAccumulator(): Accumulator = { + override def createAccumulator(): DecimalSumWithRetractAccumulator = { new DecimalSumWithRetractAccumulator } - override def accumulate(accumulator: Accumulator, value: Any): Unit = { + def accumulate(acc: DecimalSumWithRetractAccumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[BigDecimal] - val accum = accumulator.asInstanceOf[DecimalSumWithRetractAccumulator] - accum.f0 = accum.f0.add(v) - accum.f1 += 1L + acc.f0 = acc.f0.add(v) + acc.f1 += 1L } } - override def retract(accumulator: Accumulator, value: Any): Unit = { + def retract(acc: DecimalSumWithRetractAccumulator, value: Any): Unit = { if (value != null) { val v = value.asInstanceOf[BigDecimal] - val accum = accumulator.asInstanceOf[DecimalSumWithRetractAccumulator] - accum.f0 = accum.f0.subtract(v) - accum.f1 -= 1L + acc.f0 = acc.f0.subtract(v) + acc.f1 -= 1L } } - override def getValue(accumulator: Accumulator): BigDecimal = { - if (accumulator.asInstanceOf[DecimalSumWithRetractAccumulator].f1 == 0) { + override def getValue(acc: DecimalSumWithRetractAccumulator): BigDecimal = { + if (acc.f1 == 0) { null.asInstanceOf[BigDecimal] } else { - accumulator.asInstanceOf[DecimalSumWithRetractAccumulator].f0 + acc.f0 } } - override def merge(accumulators: JList[Accumulator]): Accumulator = { - val ret = accumulators.get(0).asInstanceOf[DecimalSumWithRetractAccumulator] - var i: Int = 1 - while (i < accumulators.size()) { - val a = accumulators.get(i).asInstanceOf[DecimalSumWithRetractAccumulator] - ret.f0 = ret.f0.add(a.f0) - ret.f1 += a.f1 - i += 1 + def merge(acc: DecimalSumWithRetractAccumulator, + its: JIterable[DecimalSumWithRetractAccumulator]): Unit = { + val iter = its.iterator() + while (iter.hasNext) { + val a = iter.next() + acc.f0 = acc.f0.add(a.f0) + acc.f1 += a.f1 } - ret } - override def resetAccumulator(accumulator: Accumulator): Unit = { - accumulator.asInstanceOf[DecimalSumWithRetractAccumulator].f0 = BigDecimal.ZERO - accumulator.asInstanceOf[DecimalSumWithRetractAccumulator].f1 = 0L + def resetAccumulator(acc: DecimalSumWithRetractAccumulator): Unit = { + acc.f0 = BigDecimal.ZERO + acc.f1 = 0L } - override def getAccumulatorType(): TypeInformation[_] = { + def getAccumulatorType(): TypeInformation[_] = { new TupleTypeInfo( (new DecimalSumWithRetractAccumulator).getClass, BasicTypeInfo.BIG_DEC_TYPE_INFO, http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 2c503c6..a82f383 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -17,6 +17,7 @@ */ package org.apache.flink.table.runtime.aggregate +import java.lang.reflect.Method import java.util import org.apache.calcite.rel.`type`._ @@ -73,11 +74,12 @@ object AggregateUtil { isPartitioned: Boolean, isRowsClause: Boolean): ProcessFunction[Row, Row] = { + val needRetract = false val (aggFields, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false) + needRetract) val aggregationStateType: RowTypeInfo = createDataSetAggregateBufferDataType(Array(), aggregates, inputType) @@ -97,7 +99,9 @@ object AggregateUtil { forwardMapping, None, None, - outputArity + outputArity, + needRetract, + needMerge = false ) if (isRowTimeType) { @@ -147,11 +151,12 @@ object AggregateUtil { isRowsClause: Boolean, isRowTimeType: Boolean): ProcessFunction[Row, Row] = { + val needRetract = true val (aggFields, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = true) + needRetract) val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates) val inputRowType = FlinkTypeFactory.toInternalRowTypeInfo(inputType).asInstanceOf[RowTypeInfo] @@ -171,7 +176,9 @@ object AggregateUtil { forwardMapping, None, None, - outputArity + outputArity, + needRetract, + needMerge = false ) if (isRowTimeType) { @@ -239,10 +246,11 @@ object AggregateUtil { isParserCaseSensitive: Boolean) : MapFunction[Row, Row] = { + val needRetract = false val (aggFieldIndexes, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false) + needRetract) val mapReturnType: RowTypeInfo = createDataSetAggregateBufferDataType( @@ -293,7 +301,9 @@ object AggregateUtil { groupings, None, None, - outputArity + outputArity, + needRetract, + needMerge = false ) new DataSetWindowAggMapFunction( @@ -339,10 +349,11 @@ object AggregateUtil { isParserCaseSensitive: Boolean) : RichGroupReduceFunction[Row, Row] = { + val needRetract = false val (aggFieldIndexes, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false) + needRetract) val returnType: RowTypeInfo = createDataSetAggregateBufferDataType( groupings, @@ -366,7 +377,9 @@ object AggregateUtil { groupings, Some(aggregates.indices.map(_ + groupings.length).toArray), None, - keysAndAggregatesArity + 1 + keysAndAggregatesArity + 1, + needRetract, + needMerge = true ) new DataSetSlideTimeWindowAggReduceGroupFunction( genFunction, @@ -447,10 +460,11 @@ object AggregateUtil { isInputCombined: Boolean = false) : RichGroupReduceFunction[Row, Row] = { + val needRetract = false val (aggFieldIndexes, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false) + needRetract) val aggMapping = aggregates.indices.toArray.map(_ + groupings.length) @@ -465,7 +479,9 @@ object AggregateUtil { groupings, Some(aggregates.indices.map(_ + groupings.length).toArray), None, - outputType.getFieldCount + outputType.getFieldCount, + needRetract, + needMerge = true ) val genFinalAggFunction = generator.generateAggregations( @@ -479,7 +495,9 @@ object AggregateUtil { groupings.indices.toArray, Some(aggregates.indices.map(_ + groupings.length).toArray), None, - outputType.getFieldCount + outputType.getFieldCount, + needRetract, + needMerge = true ) val keysAndAggregatesArity = groupings.length + namedAggregates.length @@ -586,10 +604,11 @@ object AggregateUtil { inputType: RelDataType, groupings: Array[Int]): MapPartitionFunction[Row, Row] = { + val needRetract = false val (aggFieldIndexes, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false) + needRetract) val aggMapping = aggregates.indices.map(_ + groupings.length).toArray @@ -615,7 +634,9 @@ object AggregateUtil { groupings.indices.toArray, Some(aggregates.indices.map(_ + groupings.length).toArray), None, - groupings.length + aggregates.length + 2 + groupings.length + aggregates.length + 2, + needRetract, + needMerge = true ) new DataSetSessionWindowAggregatePreProcessor( @@ -654,10 +675,11 @@ object AggregateUtil { groupings: Array[Int]) : GroupCombineFunction[Row, Row] = { + val needRetract = false val (aggFieldIndexes, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false) + needRetract) val aggMapping = aggregates.indices.map(_ + groupings.length).toArray @@ -684,7 +706,9 @@ object AggregateUtil { groupings.indices.toArray, Some(aggregates.indices.map(_ + groupings.length).toArray), None, - groupings.length + aggregates.length + 2 + groupings.length + aggregates.length + 2, + needRetract, + needMerge = true ) new DataSetSessionWindowAggregatePreProcessor( @@ -715,10 +739,11 @@ object AggregateUtil { Option[TypeInformation[Row]], RichGroupReduceFunction[Row, Row]) = { + val needRetract = false val (aggInFields, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false) + needRetract) val (gkeyOutMapping, aggOutMapping) = getOutputMappings( namedAggregates, @@ -760,7 +785,9 @@ object AggregateUtil { groupings, None, None, - groupings.length + aggregates.length + groupings.length + aggregates.length, + needRetract, + needMerge = false ) // compute mapping of forwarded grouping keys @@ -784,7 +811,9 @@ object AggregateUtil { gkeyMapping, Some(aggregates.indices.map(_ + groupings.length).toArray), constantFlags, - outputType.getFieldCount + outputType.getFieldCount, + needRetract, + needMerge = true ) ( @@ -805,7 +834,9 @@ object AggregateUtil { groupings, None, constantFlags, - outputType.getFieldCount + outputType.getFieldCount, + needRetract, + needMerge = false ) ( @@ -874,11 +905,12 @@ object AggregateUtil { outputType: RelDataType) : (DataStreamAggFunction[Row, Row, Row], RowTypeInfo, RowTypeInfo) = { + val needRetract = false val (aggFields, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false) + needRetract) val aggMapping = aggregates.indices.toArray val outputArity = aggregates.length @@ -894,7 +926,9 @@ object AggregateUtil { Array(), // no fields are forwarded None, None, - outputArity + outputArity, + needRetract, + needMerge = true ) val aggResultTypes = namedAggregates.map(a => FlinkTypeFactory.toTypeInfo(a.left.getType)) @@ -926,7 +960,7 @@ object AggregateUtil { * Return true if all aggregates can be partially merged. False otherwise. */ private[flink] def doAllSupportPartialMerge( - aggregateList: Array[TableAggregateFunction[_ <: Any]]): Boolean = { + aggregateList: Array[TableAggregateFunction[_ <: Any, _ <: Any]]): Boolean = { aggregateList.forall(ifMethodExistInFunction("merge", _)) } @@ -1033,11 +1067,11 @@ object AggregateUtil { aggregateCalls: Seq[AggregateCall], inputType: RelDataType, needRetraction: Boolean) - : (Array[Array[Int]], Array[TableAggregateFunction[_ <: Any]]) = { + : (Array[Array[Int]], Array[TableAggregateFunction[_ <: Any, _ <: Any]]) = { // store the aggregate fields of each aggregate function, by the same order of aggregates. val aggFieldIndexes = new Array[Array[Int]](aggregateCalls.size) - val aggregates = new Array[TableAggregateFunction[_ <: Any]](aggregateCalls.size) + val aggregates = new Array[TableAggregateFunction[_ <: Any, _ <: Any]](aggregateCalls.size) // create aggregate function instances by function type and aggregate field data type. aggregateCalls.zipWithIndex.foreach { case (aggregateCall, index) => @@ -1232,12 +1266,18 @@ object AggregateUtil { } private def createAccumulatorType( - aggregates: Array[TableAggregateFunction[_]]): Seq[TypeInformation[_]] = { + aggregates: Array[TableAggregateFunction[_, _]]): Seq[TypeInformation[_]] = { val aggTypes: Seq[TypeInformation[_]] = aggregates.map { agg => - val accType = agg.getAccumulatorType + val accType = try { + val method: Method = agg.getClass.getMethod("getAccumulatorType") + method.invoke(agg).asInstanceOf[TypeInformation[_]] + } catch { + case _: NoSuchMethodException => null + case ite: Throwable => throw new TableException("Unexpected exception:", ite) + } if (accType != null) { accType } else { @@ -1259,7 +1299,7 @@ object AggregateUtil { private def createDataSetAggregateBufferDataType( groupings: Array[Int], - aggregates: Array[TableAggregateFunction[_]], + aggregates: Array[TableAggregateFunction[_, _]], inputType: RelDataType, windowKeyTypes: Option[Array[TypeInformation[_]]] = None): RowTypeInfo = { @@ -1281,7 +1321,7 @@ object AggregateUtil { } private[flink] def createAccumulatorRowType( - aggregates: Array[TableAggregateFunction[_]]): RowTypeInfo = { + aggregates: Array[TableAggregateFunction[_, _]]): RowTypeInfo = { val aggTypes: Seq[TypeInformation[_]] = createAccumulatorType(aggregates) http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala index bee39fa..5f48e09 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala @@ -101,3 +101,35 @@ abstract class GeneratedAggregations extends Function { */ def resetAccumulator(accumulators: Row) } + +class SingleElementIterable[T] extends java.lang.Iterable[T] { + + class SingleElementIterator extends java.util.Iterator[T] { + + var element: T = _ + var newElement: Boolean = false + + override def hasNext: Boolean = newElement + + override def next(): T = { + if (newElement) { + newElement = false + element + } else { + throw new java.util.NoSuchElementException + } + } + + override def remove(): Unit = new java.lang.UnsupportedOperationException + } + + val it = new SingleElementIterator + + def setElement(element: T): Unit = it.element = element + + override def iterator(): java.util.Iterator[T] = { + it.newElement = true + it + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/bc6409d6/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala index cb1137f..39b9ec3 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala @@ -17,9 +17,10 @@ */ package org.apache.flink.table.functions.aggfunctions +import java.lang.reflect.Method import java.math.BigDecimal import java.util.{ArrayList => JArrayList, List => JList} -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.functions.AggregateFunction import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ import org.junit.Assert.assertEquals import org.junit.Test @@ -29,14 +30,18 @@ import org.junit.Test * * @tparam T the type for the aggregation result */ -abstract class AggFunctionTestBase[T] { +abstract class AggFunctionTestBase[T, ACC] { def inputValueSets: Seq[Seq[_]] def expectedResults: Seq[T] - def aggregator: AggregateFunction[T] + def aggregator: AggregateFunction[T, ACC] - def supportRetraction: Boolean = true + val accType = aggregator.getClass.getMethod("createAccumulator").getReturnType + + def accumulateFunc: Method = aggregator.getClass.getMethod("accumulate", accType, classOf[Any]) + + def retractFunc: Method = null @Test // test aggregate and retract functions without partial merge @@ -47,52 +52,55 @@ abstract class AggFunctionTestBase[T] { val result = aggregator.getValue(accumulator) validateResult[T](expected, result) - if (supportRetraction) { + if (ifMethodExistInFunction("retract", aggregator)) { retractVals(accumulator, vals) val expectedAccum = aggregator.createAccumulator() //The two accumulators should be exactly same - validateResult[Accumulator](expectedAccum, accumulator) + validateResult[ACC](expectedAccum, accumulator) } } } @Test - // test aggregate functions with partial merge def testAggregateWithMerge(): Unit = { if (ifMethodExistInFunction("merge", aggregator)) { + val mergeFunc = + aggregator.getClass.getMethod("merge", accType, classOf[java.lang.Iterable[ACC]]) // iterate over input sets for ((vals, expected) <- inputValueSets.zip(expectedResults)) { //equally split the vals sequence into two sequences val (firstVals, secondVals) = vals.splitAt(vals.length / 2) //1. verify merge with accumulate - val accumulators: JList[Accumulator] = new JArrayList[Accumulator]() - accumulators.add(accumulateVals(firstVals)) + val accumulators: JList[ACC] = new JArrayList[ACC]() accumulators.add(accumulateVals(secondVals)) - val accumulator = aggregator.merge(accumulators) - val result = aggregator.getValue(accumulator) + val acc = accumulateVals(firstVals) + + mergeFunc.invoke(aggregator, acc.asInstanceOf[Object], accumulators) + val result = aggregator.getValue(acc) validateResult[T](expected, result) //2. verify merge with accumulate & retract - if (supportRetraction) { - retractVals(accumulator, vals) + if (ifMethodExistInFunction("retract", aggregator)) { + retractVals(acc, vals) val expectedAccum = aggregator.createAccumulator() //The two accumulators should be exactly same - validateResult[Accumulator](expectedAccum, accumulator) + validateResult[ACC](expectedAccum, acc) } } // iterate over input sets for ((vals, expected) <- inputValueSets.zip(expectedResults)) { //3. test partial merge with an empty accumulator - val accumulators: JList[Accumulator] = new JArrayList[Accumulator]() - accumulators.add(accumulateVals(vals)) + val accumulators: JList[ACC] = new JArrayList[ACC]() accumulators.add(aggregator.createAccumulator()) - val accumulator = aggregator.merge(accumulators) - val result = aggregator.getValue(accumulator) + val acc = accumulateVals(vals) + + mergeFunc.invoke(aggregator, acc.asInstanceOf[Object], accumulators) + val result = aggregator.getValue(acc) validateResult[T](expected, result) } } @@ -103,13 +111,14 @@ abstract class AggFunctionTestBase[T] { def testResetAccumulator(): Unit = { if (ifMethodExistInFunction("resetAccumulator", aggregator)) { + val resetAccFunc = aggregator.getClass.getMethod("resetAccumulator", accType) // iterate over input sets for ((vals, expected) <- inputValueSets.zip(expectedResults)) { val accumulator = accumulateVals(vals) - aggregator.resetAccumulator(accumulator) + resetAccFunc.invoke(aggregator, accumulator.asInstanceOf[Object]) val expectedAccum = aggregator.createAccumulator() //The accumulator after reset should be exactly same as the new accumulator - validateResult[Accumulator](expectedAccum, accumulator) + validateResult[ACC](expectedAccum, accumulator) } } } @@ -130,13 +139,18 @@ abstract class AggFunctionTestBase[T] { } } - private def accumulateVals(vals: Seq[_]): Accumulator = { + private def accumulateVals(vals: Seq[_]): ACC = { val accumulator = aggregator.createAccumulator() - vals.foreach(v => aggregator.accumulate(accumulator, v)) + vals.foreach( + v => + accumulateFunc.invoke(aggregator, accumulator.asInstanceOf[Object], v.asInstanceOf[Object]) + ) accumulator } - private def retractVals(accumulator:Accumulator, vals: Seq[_]) = { - vals.foreach(v => aggregator.retract(accumulator, v)) + private def retractVals(accumulator:ACC, vals: Seq[_]) = { + vals.foreach( + v => retractFunc.invoke(aggregator, accumulator.asInstanceOf[Object], v.asInstanceOf[Object]) + ) } }