[FLINK-6242] [table] Add code generation for DataSet Aggregates This closes #3735.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/3b4542b8 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/3b4542b8 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/3b4542b8 Branch: refs/heads/master Commit: 3b4542b8f0981f01e42c861bccbc67c8b3a20fdd Parents: 4024aff Author: shaoxuan-wang <wshaox...@gmail.com> Authored: Tue Apr 18 21:45:49 2017 +0800 Committer: Fabian Hueske <fhue...@apache.org> Committed: Fri Apr 21 21:28:54 2017 +0200 ---------------------------------------------------------------------- .../flink/table/codegen/CodeGenerator.scala | 224 +++++++----- .../plan/nodes/dataset/DataSetAggregate.scala | 11 +- .../nodes/dataset/DataSetWindowAggregate.scala | 25 +- .../table/runtime/aggregate/AggregateUtil.scala | 359 +++++++++++-------- .../runtime/aggregate/DataSetAggFunction.scala | 109 ++---- .../aggregate/DataSetFinalAggFunction.scala | 121 ++----- .../aggregate/DataSetPreAggFunction.scala | 93 ++--- ...SetSessionWindowAggReduceGroupFunction.scala | 116 ++---- ...aSetSessionWindowAggregatePreProcessor.scala | 145 +++----- ...tSlideTimeWindowAggReduceGroupFunction.scala | 177 +++------ ...SetSlideWindowAggReduceCombineFunction.scala | 98 ++--- ...taSetSlideWindowAggReduceGroupFunction.scala | 123 +++---- ...umbleCountWindowAggReduceGroupFunction.scala | 89 ++--- ...mbleTimeWindowAggReduceCombineFunction.scala | 88 ++--- ...TumbleTimeWindowAggReduceGroupFunction.scala | 99 ++--- .../aggregate/DataSetWindowAggMapFunction.scala | 64 ++-- .../aggregate/GeneratedAggregations.scala | 32 +- .../scala/batch/table/AggregationsITCase.scala | 3 +- ...ProcessingOverRangeProcessFunctionTest.scala | 8 +- 19 files changed, 878 insertions(+), 1106 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index c6e3c9a..510a870 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -250,57 +250,88 @@ class CodeGenerator( * @param aggregates All aggregate functions * @param aggFields Indexes of the input fields for all aggregate functions * @param aggMapping The mapping of aggregates to output fields + * @param partialResults A flag defining whether final or partial results (accumulators) are set + * to the output row. * @param fwdMapping The mapping of input fields to output fields + * @param mergeMapping An optional mapping to specify the accumulators to merge. If not set, we + * assume that both rows have the accumulators at the same position. + * @param constantFlags An optional parameter to define where to set constant boolean flags in + * the output row. * @param outputArity The number of fields in the output row. * * @return A GeneratedAggregationsFunction */ def generateAggregations( - name: String, - generator: CodeGenerator, - inputType: RelDataType, - aggregates: Array[AggregateFunction[_ <: Any]], - aggFields: Array[Array[Int]], - aggMapping: Array[Int], - fwdMapping: Array[(Int, Int)], - outputArity: Int) + name: String, + generator: CodeGenerator, + inputType: RelDataType, + aggregates: Array[AggregateFunction[_ <: Any]], + aggFields: Array[Array[Int]], + aggMapping: Array[Int], + partialResults: Boolean, + fwdMapping: Array[Int], + mergeMapping: Option[Array[Int]], + constantFlags: Option[Array[(Int, Boolean)]], + outputArity: Int) : GeneratedAggregationsFunction = { - def genSetAggregationResults( - accTypes: Array[String], - aggs: Array[String], - aggMapping: Array[Int]): String = { + // get unique function name + val funcName = newName(name) + // register UDAGGs + val aggs = aggregates.map(a => generator.addReusableFunction(a)) + // get java types of accumulators + val accTypes = aggregates.map { a => + a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName + } + + // get java types of input fields + val javaTypes = inputType.getFieldList + .map(f => FlinkTypeFactory.toTypeInfo(f.getType)) + .map(t => t.getTypeClass.getCanonicalName) + // get parameter lists for aggregation functions + val parameters = aggFields.map {inFields => + val fields = for (f <- inFields) yield s"(${javaTypes(f)}) input.getField($f)" + fields.mkString(", ") + } + + def genSetAggregationResults: String = { val sig: String = j""" - | public void setAggregationResults( - | org.apache.flink.types.Row accs, - | org.apache.flink.types.Row output)""".stripMargin + | public final void setAggregationResults( + | org.apache.flink.types.Row accs, + | org.apache.flink.types.Row output)""".stripMargin val setAggs: String = { for (i <- aggs.indices) yield - j""" - | org.apache.flink.table.functions.AggregateFunction baseClass$i = - | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)}; - | - | output.setField( - | ${aggMapping(i)}, - | baseClass$i.getValue((${accTypes(i)}) accs.getField($i)));""".stripMargin + + if (partialResults) { + j""" + | output.setField( + | ${aggMapping(i)}, + | (${accTypes(i)}) accs.getField($i));""".stripMargin + } else { + j""" + | org.apache.flink.table.functions.AggregateFunction baseClass$i = + | (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)}; + | + | output.setField( + | ${aggMapping(i)}, + | baseClass$i.getValue((${accTypes(i)}) accs.getField($i)));""".stripMargin + } }.mkString("\n") - j"""$sig { + j""" + |$sig { |$setAggs | }""".stripMargin } - def genAccumulate( - accTypes: Array[String], - aggs: Array[String], - parameters: Array[String]): String = { + def genAccumulate: String = { val sig: String = j""" - | public void accumulate( + | public final void accumulate( | org.apache.flink.types.Row accs, | org.apache.flink.types.Row input)""".stripMargin @@ -317,14 +348,11 @@ class CodeGenerator( | }""".stripMargin } - def genRetract( - accTypes: Array[String], - aggs: Array[String], - parameters: Array[String]): String = { + def genRetract: String = { val sig: String = j""" - | public void retract( + | public final void retract( | org.apache.flink.types.Row accs, | org.apache.flink.types.Row input)""".stripMargin @@ -341,12 +369,11 @@ class CodeGenerator( | }""".stripMargin } - def genCreateAccumulators( - aggs: Array[String]): String = { + def genCreateAccumulators: String = { val sig: String = j""" - | public org.apache.flink.types.Row createAccumulators() + | public final org.apache.flink.types.Row createAccumulators() | """.stripMargin val init: String = j""" @@ -373,22 +400,24 @@ class CodeGenerator( | }""".stripMargin } - def genSetForwardedFields( - forwardMapping: Array[(Int, Int)]): String = { + def genSetForwardedFields: String = { val sig: String = j""" - | public void setForwardedFields( + | public final void setForwardedFields( | org.apache.flink.types.Row input, | org.apache.flink.types.Row output) | """.stripMargin + val forward: String = { - for (i <- forwardMapping.indices) yield - j""" - | output.setField( - | ${forwardMapping(i)._1}, - | input.getField(${forwardMapping(i)._2}));""" - .stripMargin + for (i <- fwdMapping.indices if fwdMapping(i) >= 0) yield + { + j""" + | output.setField( + | $i, + | input.getField(${fwdMapping(i)}));""" + .stripMargin + } }.mkString("\n") j"""$sig { @@ -396,20 +425,44 @@ class CodeGenerator( | }""".stripMargin } - def genCreateOutputRow(outputArity: Int): String = { + def genSetConstantFlags: String = { + + val sig: String = + j""" + | public final void setConstantFlags(org.apache.flink.types.Row output) + | """.stripMargin + + val setFlags: String = if (constantFlags.isDefined) { + { + for (cf <- constantFlags.get) yield { + j""" + | output.setField(${cf._1}, ${if (cf._2) "true" else "false"});""" + .stripMargin + } + }.mkString("\n") + } else { + "" + } + + j"""$sig { + |$setFlags + | }""".stripMargin + } + + def genCreateOutputRow: String = { j""" - | public org.apache.flink.types.Row createOutputRow() { + | public final org.apache.flink.types.Row createOutputRow() { | return new org.apache.flink.types.Row($outputArity); | }""".stripMargin } - def genMergeAccumulatorsPair( - accTypes: Array[String], - aggs: Array[String]): String = { + def genMergeAccumulatorsPair: String = { + + val mapping = mergeMapping.getOrElse(aggs.indices.toArray) val sig: String = j""" - | public org.apache.flink.types.Row mergeAccumulatorsPair( + | public final org.apache.flink.types.Row mergeAccumulatorsPair( | org.apache.flink.types.Row a, | org.apache.flink.types.Row b) """.stripMargin @@ -417,7 +470,7 @@ class CodeGenerator( for (i <- aggs.indices) yield j""" | ${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i); - | ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField($i); + | ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)}); | accList$i.set(0, aAcc$i); | accList$i.set(1, bAcc$i); | a.setField( @@ -430,75 +483,76 @@ class CodeGenerator( | return a; """.stripMargin - j"""$sig { + j""" + |$sig { |$merge |$ret | }""".stripMargin } - def genMergeList(accTypes: Array[String]): String = { + def genMergeList: String = { { for (i <- accTypes.indices) yield j""" - | java.util.ArrayList<${accTypes(i)}> accList$i; + | private final java.util.ArrayList<${accTypes(i)}> accList$i = + | new java.util.ArrayList<${accTypes(i)}>(2); """.stripMargin }.mkString("\n") } - def initMergeList( - accTypes: Array[String], - aggs: Array[String]): String = { + def initMergeList: String = { { for (i <- accTypes.indices) yield j""" - | accList$i = new java.util.ArrayList<${accTypes(i)}>(2); | accList$i.add(${aggs(i)}.createAccumulator()); | accList$i.add(${aggs(i)}.createAccumulator()); """.stripMargin }.mkString("\n") } - // get unique function name - val funcName = newName(name) - // register UDAGGs - val aggs = aggregates.map(a => generator.addReusableFunction(a)) - // get java types of accumulators - val accTypes = aggregates.map { a => - a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName - } + def genResetAccumulator: String = { - // get java types of input fields - val javaTypes = inputType.getFieldList - .map(f => FlinkTypeFactory.toTypeInfo(f.getType)) - .map(t => t.getTypeClass.getCanonicalName) - // get parameter lists for aggregation functions - val parameters = aggFields.map {inFields => - val fields = for (f <- inFields) yield s"(${javaTypes(f)}) input.getField($f)" - fields.mkString(", ") + val sig: String = + j""" + | public final void resetAccumulator( + | org.apache.flink.types.Row accs)""".stripMargin + + val reset: String = { + for (i <- aggs.indices) yield + j""" + | ${aggs(i)}.resetAccumulator( + | ((${accTypes(i)}) accs.getField($i)));""".stripMargin + }.mkString("\n") + + j"""$sig { + |$reset + | }""".stripMargin } var funcCode = j""" - |public class $funcName + |public final class $funcName | extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations { | | ${reuseMemberCode()} - | ${genMergeList(accTypes)} + | $genMergeList | public $funcName() throws Exception { | ${reuseInitCode()} - | ${initMergeList(accTypes, aggs)} + | $initMergeList | } | ${reuseConstructorCode(funcName)} | """.stripMargin - funcCode += genSetAggregationResults(accTypes, aggs, aggMapping) + "\n" - funcCode += genAccumulate(accTypes, aggs, parameters) + "\n" - funcCode += genRetract(accTypes, aggs, parameters) + "\n" - funcCode += genCreateAccumulators(aggs) + "\n" - funcCode += genSetForwardedFields(fwdMapping) + "\n" - funcCode += genCreateOutputRow(outputArity) + "\n" - funcCode += genMergeAccumulatorsPair(accTypes, aggs) + "\n" + funcCode += genSetAggregationResults + "\n" + funcCode += genAccumulate + "\n" + funcCode += genRetract + "\n" + funcCode += genCreateAccumulators + "\n" + funcCode += genSetForwardedFields + "\n" + funcCode += genSetConstantFlags + "\n" + funcCode += genCreateOutputRow + "\n" + funcCode += genMergeAccumulatorsPair + "\n" + funcCode += genResetAccumulator + "\n" funcCode += "}" GeneratedAggregationsFunction(funcName, funcCode) http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala index 5a4aa59..b92775c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala @@ -29,6 +29,7 @@ import org.apache.flink.api.java.DataSet import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.table.api.BatchTableEnvironment import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.plan.nodes.CommonAggregate import org.apache.flink.table.runtime.aggregate.{AggregateUtil, DataSetPreAggFunction} import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair @@ -89,19 +90,25 @@ class DataSetAggregate( override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = { + val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv) + + val generator = new CodeGenerator( + tableEnv.getConfig, + false, + inputDS.getType) + val ( preAgg: Option[DataSetPreAggFunction], preAggType: Option[TypeInformation[Row]], finalAgg: GroupReduceFunction[Row, Row] ) = AggregateUtil.createDataSetAggregateFunctions( + generator, namedAggregates, inputType, rowRelDataType, grouping, inGroupingSet) - val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv) - val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil) val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala index a94deb1..96c427e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala @@ -28,6 +28,7 @@ import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.table.api.BatchTableEnvironment import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.plan.logical._ import org.apache.flink.table.plan.nodes.CommonAggregate import org.apache.flink.table.runtime.aggregate.AggregateUtil.{CalcitePair, _} @@ -109,21 +110,28 @@ class DataSetWindowAggregate( val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv) + val generator = new CodeGenerator( + tableEnv.getConfig, + false, + inputDS.getType) + // whether identifiers are matched case-sensitively val caseSensitive = tableEnv.getFrameworkConfig.getParserConfig.caseSensitive() window match { case EventTimeTumblingGroupWindow(_, _, size) => createEventTimeTumblingWindowDataSet( + generator, inputDS, isTimeInterval(size.resultType), caseSensitive) case EventTimeSessionGroupWindow(_, _, gap) => - createEventTimeSessionWindowDataSet(inputDS, caseSensitive) + createEventTimeSessionWindowDataSet(generator, inputDS, caseSensitive) case EventTimeSlidingGroupWindow(_, _, size, slide) => createEventTimeSlidingWindowDataSet( + generator, inputDS, isTimeInterval(size.resultType), asLong(size), @@ -139,17 +147,20 @@ class DataSetWindowAggregate( } private def createEventTimeTumblingWindowDataSet( + generator: CodeGenerator, inputDS: DataSet[Row], isTimeWindow: Boolean, isParserCaseSensitive: Boolean): DataSet[Row] = { val mapFunction = createDataSetWindowPrepareMapFunction( + generator, window, namedAggregates, grouping, inputType, isParserCaseSensitive) val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction( + generator, window, namedAggregates, inputType, @@ -195,6 +206,7 @@ class DataSetWindowAggregate( } private[this] def createEventTimeSessionWindowDataSet( + generator: CodeGenerator, inputDS: DataSet[Row], isParserCaseSensitive: Boolean): DataSet[Row] = { @@ -203,6 +215,7 @@ class DataSetWindowAggregate( // create mapFunction for initializing the aggregations val mapFunction = createDataSetWindowPrepareMapFunction( + generator, window, namedAggregates, grouping, @@ -229,6 +242,7 @@ class DataSetWindowAggregate( if (groupingKeys.length > 0) { // create groupCombineFunction for combine the aggregations val combineGroupFunction = createDataSetWindowAggregationCombineFunction( + generator, window, namedAggregates, inputType, @@ -236,6 +250,7 @@ class DataSetWindowAggregate( // create groupReduceFunction for calculating the aggregations val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction( + generator, window, namedAggregates, inputType, @@ -257,6 +272,7 @@ class DataSetWindowAggregate( } else { // non-grouping window val mapPartitionFunction = createDataSetWindowAggregationMapPartitionFunction( + generator, window, namedAggregates, inputType, @@ -264,6 +280,7 @@ class DataSetWindowAggregate( // create groupReduceFunction for calculating the aggregations val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction( + generator, window, namedAggregates, inputType, @@ -288,6 +305,7 @@ class DataSetWindowAggregate( // create groupReduceFunction for calculating the aggregations val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction( + generator, window, namedAggregates, inputType, @@ -303,6 +321,7 @@ class DataSetWindowAggregate( } else { // non-grouping window val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction( + generator, window, namedAggregates, inputType, @@ -320,6 +339,7 @@ class DataSetWindowAggregate( } private def createEventTimeSlidingWindowDataSet( + generator: CodeGenerator, inputDS: DataSet[Row], isTimeWindow: Boolean, size: Long, @@ -330,6 +350,7 @@ class DataSetWindowAggregate( // create MapFunction for initializing the aggregations // it aligns the rowtime for pre-tumbling in case of a time-window for partial aggregates val mapFunction = createDataSetWindowPrepareMapFunction( + generator, window, namedAggregates, grouping, @@ -365,6 +386,7 @@ class DataSetWindowAggregate( // create GroupReduceFunction // for pre-tumbling and replicating/omitting the content for each pane val prepareReduceFunction = createDataSetSlideWindowPrepareGroupReduceFunction( + generator, window, namedAggregates, grouping, @@ -401,6 +423,7 @@ class DataSetWindowAggregate( // create GroupReduceFunction for final aggregation and conversion to output row val aggregateReduceFunction = createDataSetWindowAggregationGroupReduceFunction( + generator, window, namedAggregates, inputType, http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index da57153..2c503c6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -82,7 +82,7 @@ object AggregateUtil { val aggregationStateType: RowTypeInfo = createDataSetAggregateBufferDataType(Array(), aggregates, inputType) - val forwardMapping = (0 until inputType.getFieldCount).map(x => (x, x)).toArray + val forwardMapping = (0 until inputType.getFieldCount).toArray val aggMapping = aggregates.indices.map(x => x + inputType.getFieldCount).toArray val outputArity = inputType.getFieldCount + aggregates.length @@ -93,7 +93,10 @@ object AggregateUtil { aggregates, aggFields, aggMapping, + partialResults = false, forwardMapping, + None, + None, outputArity ) @@ -153,7 +156,7 @@ object AggregateUtil { val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates) val inputRowType = FlinkTypeFactory.toInternalRowTypeInfo(inputType).asInstanceOf[RowTypeInfo] - val forwardMapping = (0 until inputType.getFieldCount).map(x => (x, x)).toArray + val forwardMapping = (0 until inputType.getFieldCount).toArray val aggMapping = aggregates.indices.map(x => x + inputType.getFieldCount).toArray val outputArity = inputType.getFieldCount + aggregates.length @@ -164,7 +167,10 @@ object AggregateUtil { aggregates, aggFields, aggMapping, + partialResults = false, forwardMapping, + None, + None, outputArity ) @@ -225,6 +231,7 @@ object AggregateUtil { * NOTE: this function is only used for time based window on batch tables. */ def createDataSetWindowPrepareMapFunction( + generator: CodeGenerator, window: LogicalWindow, namedAggregates: Seq[CalcitePair[AggregateCall, String]], groupings: Array[Int], @@ -249,7 +256,7 @@ object AggregateUtil { val timeFieldPos = getTimeFieldPosition(time, inputType, isParserCaseSensitive) (timeFieldPos, Some(asLong(size))) - case EventTimeTumblingGroupWindow(_, time, size) => + case EventTimeTumblingGroupWindow(_, time, _) => val timeFieldPos = getTimeFieldPosition(time, inputType, isParserCaseSensitive) (timeFieldPos, None) @@ -272,10 +279,25 @@ object AggregateUtil { throw new UnsupportedOperationException(s"$window is currently not supported on batch") } - new DataSetWindowAggMapFunction( + val aggMapping = aggregates.indices.toArray.map(_ + groupings.length) + val outputArity = aggregates.length + groupings.length + 1 + + val genFunction = generator.generateAggregations( + "DataSetAggregatePrepareMapHelper", + generator, + inputType, aggregates, aggFieldIndexes, + aggMapping, + partialResults = true, groupings, + None, + None, + outputArity + ) + + new DataSetWindowAggMapFunction( + genFunction, timeFieldPos, tumbleTimeWindowSize, mapReturnType) @@ -309,6 +331,7 @@ object AggregateUtil { * NOTE: this function is only used for sliding windows with partial aggregates on batch tables. */ def createDataSetSlideWindowPrepareGroupReduceFunction( + generator: CodeGenerator, window: LogicalWindow, namedAggregates: Seq[CalcitePair[AggregateCall, String]], groupings: Array[Int], @@ -316,10 +339,10 @@ object AggregateUtil { isParserCaseSensitive: Boolean) : RichGroupReduceFunction[Row, Row] = { - val aggregates = transformToAggregateFunctions( + val (aggFieldIndexes, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false)._2 + needRetraction = false) val returnType: RowTypeInfo = createDataSetAggregateBufferDataType( groupings, @@ -327,13 +350,27 @@ object AggregateUtil { inputType, Some(Array(BasicTypeInfo.LONG_TYPE_INFO))) + val keysAndAggregatesArity = groupings.length + namedAggregates.length + window match { case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => // sliding time-window for partial aggregations - new DataSetSlideTimeWindowAggReduceGroupFunction( + val genFunction = generator.generateAggregations( + "DataSetAggregatePrepareMapHelper", + generator, + inputType, aggregates, - groupings.length, - returnType.getArity - 1, + aggFieldIndexes, + aggregates.indices.map(_ + groupings.length).toArray, + partialResults = true, + groupings, + Some(aggregates.indices.map(_ + groupings.length).toArray), + None, + keysAndAggregatesArity + 1 + ) + new DataSetSlideTimeWindowAggReduceGroupFunction( + genFunction, + keysAndAggregatesArity, asLong(size), asLong(slide), returnType) @@ -400,6 +437,7 @@ object AggregateUtil { * NOTE: this function is only used for window on batch tables. */ def createDataSetWindowAggregationGroupReduceFunction( + generator: CodeGenerator, window: LogicalWindow, namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, @@ -414,19 +452,37 @@ object AggregateUtil { inputType, needRetraction = false) - // the mapping relation between field index of intermediate aggregate Row and output Row. - val groupingOffsetMapping = getGroupKeysMapping(inputType, outputType, groupings) + val aggMapping = aggregates.indices.toArray.map(_ + groupings.length) + + val genPreAggFunction = generator.generateAggregations( + "GroupingWindowAggregateHelper", + generator, + inputType, + aggregates, + aggFieldIndexes, + aggMapping, + partialResults = true, + groupings, + Some(aggregates.indices.map(_ + groupings.length).toArray), + None, + outputType.getFieldCount + ) - // the mapping relation between aggregate function index in list and its corresponding - // field index in output Row. - val aggOffsetMapping = getAggregateMapping(namedAggregates, outputType) + val genFinalAggFunction = generator.generateAggregations( + "GroupingWindowAggregateHelper", + generator, + inputType, + aggregates, + aggFieldIndexes, + aggMapping, + partialResults = false, + groupings.indices.toArray, + Some(aggregates.indices.map(_ + groupings.length).toArray), + None, + outputType.getFieldCount + ) - if (groupingOffsetMapping.length != groupings.length || - aggOffsetMapping.length != namedAggregates.length) { - throw new TableException( - "Could not find output field in input data type " + - "or aggregate functions.") - } + val keysAndAggregatesArity = groupings.length + namedAggregates.length window match { case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => @@ -435,41 +491,33 @@ object AggregateUtil { if (doAllSupportPartialMerge(aggregates)) { // for incremental aggregations new DataSetTumbleTimeWindowAggReduceCombineFunction( + genPreAggFunction, + genFinalAggFunction, asLong(size), startPos, endPos, - aggregates, - groupingOffsetMapping, - aggOffsetMapping, - outputType.getFieldCount) + keysAndAggregatesArity) } else { // for non-incremental aggregations new DataSetTumbleTimeWindowAggReduceGroupFunction( + genFinalAggFunction, asLong(size), startPos, endPos, - aggregates, - groupingOffsetMapping, - aggOffsetMapping, outputType.getFieldCount) } case EventTimeTumblingGroupWindow(_, _, size) => // tumbling count window new DataSetTumbleCountWindowAggReduceGroupFunction( - asLong(size), - aggregates, - groupingOffsetMapping, - aggOffsetMapping, - outputType.getFieldCount) + genFinalAggFunction, + asLong(size)) case EventTimeSessionGroupWindow(_, _, gap) => val (startPos, endPos) = computeWindowStartEndPropertyPos(properties) new DataSetSessionWindowAggReduceGroupFunction( - aggregates, - groupingOffsetMapping, - aggOffsetMapping, - outputType.getFieldCount, + genFinalAggFunction, + keysAndAggregatesArity, startPos, endPos, asLong(gap), @@ -480,10 +528,9 @@ object AggregateUtil { if (doAllSupportPartialMerge(aggregates)) { // for partial aggregations new DataSetSlideWindowAggReduceCombineFunction( - aggregates, - groupingOffsetMapping, - aggOffsetMapping, - outputType.getFieldCount, + genPreAggFunction, + genFinalAggFunction, + keysAndAggregatesArity, startPos, endPos, asLong(size)) @@ -491,10 +538,8 @@ object AggregateUtil { else { // for non-partial aggregations new DataSetSlideWindowAggReduceGroupFunction( - aggregates, - groupingOffsetMapping, - aggOffsetMapping, - outputType.getFieldCount, + genFinalAggFunction, + keysAndAggregatesArity, startPos, endPos, asLong(size)) @@ -502,10 +547,8 @@ object AggregateUtil { case EventTimeSlidingGroupWindow(_, _, size, _) => new DataSetSlideWindowAggReduceGroupFunction( - aggregates, - groupingOffsetMapping, - aggOffsetMapping, - outputType.getFieldCount, + genFinalAggFunction, + keysAndAggregatesArity, None, None, asLong(size)) @@ -537,15 +580,20 @@ object AggregateUtil { * */ def createDataSetWindowAggregationMapPartitionFunction( + generator: CodeGenerator, window: LogicalWindow, namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, groupings: Array[Int]): MapPartitionFunction[Row, Row] = { - val aggregates = transformToAggregateFunctions( + val (aggFieldIndexes, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false)._2 + needRetraction = false) + + val aggMapping = aggregates.indices.map(_ + groupings.length).toArray + + val keysAndAggregatesArity = groupings.length + namedAggregates.length window match { case EventTimeSessionGroupWindow(_, _, gap) => @@ -556,9 +604,23 @@ object AggregateUtil { inputType, Option(Array(BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO))) - new DataSetSessionWindowAggregatePreProcessor( + val genFunction = generator.generateAggregations( + "GroupingWindowAggregateHelper", + generator, + inputType, aggregates, - groupings, + aggFieldIndexes, + aggMapping, + partialResults = true, + groupings.indices.toArray, + Some(aggregates.indices.map(_ + groupings.length).toArray), + None, + groupings.length + aggregates.length + 2 + ) + + new DataSetSessionWindowAggregatePreProcessor( + genFunction, + keysAndAggregatesArity, asLong(gap), combineReturnType) case _ => @@ -585,16 +647,21 @@ object AggregateUtil { * */ private[flink] def createDataSetWindowAggregationCombineFunction( + generator: CodeGenerator, window: LogicalWindow, namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, groupings: Array[Int]) : GroupCombineFunction[Row, Row] = { - val aggregates = transformToAggregateFunctions( + val (aggFieldIndexes, aggregates) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetraction = false)._2 + needRetraction = false) + + val aggMapping = aggregates.indices.map(_ + groupings.length).toArray + + val keysAndAggregatesArity = groupings.length + namedAggregates.length window match { @@ -606,9 +673,23 @@ object AggregateUtil { inputType, Option(Array(BasicTypeInfo.LONG_TYPE_INFO, BasicTypeInfo.LONG_TYPE_INFO))) - new DataSetSessionWindowAggregatePreProcessor( + val genFunction = generator.generateAggregations( + "GroupingWindowAggregateHelper", + generator, + inputType, aggregates, - groupings, + aggFieldIndexes, + aggMapping, + partialResults = true, + groupings.indices.toArray, + Some(aggregates.indices.map(_ + groupings.length).toArray), + None, + groupings.length + aggregates.length + 2 + ) + + new DataSetSessionWindowAggregatePreProcessor( + genFunction, + keysAndAggregatesArity, asLong(gap), combineReturnType) @@ -625,6 +706,7 @@ object AggregateUtil { * respective output type are generated as well. */ private[flink] def createDataSetAggregateFunctions( + generator: CodeGenerator, namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, outputType: RelDataType, @@ -645,51 +727,91 @@ object AggregateUtil { outputType ) - val groupingSetsMapping: Array[(Int, Int)] = if (inGroupingSet) { - getGroupingSetsIndicatorMapping(inputType, outputType) + val constantFlags: Option[Array[(Int, Boolean)]] = + if (inGroupingSet) { + + val groupingSetsMapping = getGroupingSetsIndicatorMapping(inputType, outputType) + val nonNullKeysFields = gkeyOutMapping.map(_._1) + val flags = for ((in, out) <- groupingSetsMapping) yield + (out, !nonNullKeysFields.contains(in)) + Some(flags) } else { - Array() + None } - if (doAllSupportPartialMerge(aggregates)) { + val aggOutFields = aggOutMapping.map(_._1) - // compute grouping key and aggregation positions - val gkeyInFields = gkeyOutMapping.map(_._2) - val gkeyOutFields = gkeyOutMapping.map(_._1) - val aggOutFields = aggOutMapping.map(_._1) + if (doAllSupportPartialMerge(aggregates)) { // compute preaggregation type - val preAggFieldTypes = gkeyInFields + val preAggFieldTypes = gkeyOutMapping.map(_._2) .map(inputType.getFieldList.get(_).getType) .map(FlinkTypeFactory.toTypeInfo) ++ createAccumulatorType(aggregates) val preAggRowType = new RowTypeInfo(preAggFieldTypes: _*) + val genPreAggFunction = generator.generateAggregations( + "DataSetAggregatePrepareMapHelper", + generator, + inputType, + aggregates, + aggInFields, + aggregates.indices.map(_ + groupings.length).toArray, + partialResults = true, + groupings, + None, + None, + groupings.length + aggregates.length + ) + + // compute mapping of forwarded grouping keys + val gkeyMapping: Array[Int] = if (gkeyOutMapping.nonEmpty) { + val gkeyOutFields = gkeyOutMapping.map(_._1) + val mapping = Array.fill[Int](gkeyOutFields.max + 1)(-1) + gkeyOutFields.zipWithIndex.foreach(m => mapping(m._1) = m._2) + mapping + } else { + new Array[Int](0) + } + + val genFinalAggFunction = generator.generateAggregations( + "DataSetAggregateFinalHelper", + generator, + inputType, + aggregates, + aggInFields, + aggOutFields, + partialResults = false, + gkeyMapping, + Some(aggregates.indices.map(_ + groupings.length).toArray), + constantFlags, + outputType.getFieldCount + ) + ( - Some(new DataSetPreAggFunction( - aggregates, - aggInFields, - gkeyInFields - )), + Some(new DataSetPreAggFunction(genPreAggFunction)), Some(preAggRowType), - new DataSetFinalAggFunction( - aggregates, - aggOutFields, - gkeyOutFields, - groupingSetsMapping, - outputType.getFieldCount) + new DataSetFinalAggFunction(genFinalAggFunction) ) } else { + val genFunction = generator.generateAggregations( + "DataSetAggregateHelper", + generator, + inputType, + aggregates, + aggInFields, + aggOutFields, + partialResults = false, + groupings, + None, + constantFlags, + outputType.getFieldCount + ) + ( None, None, - new DataSetAggFunction( - aggregates, - aggInFields, - aggOutMapping, - gkeyOutMapping, - groupingSetsMapping, - outputType.getFieldCount) + new DataSetAggFunction(genFunction) ) } @@ -768,15 +890,12 @@ object AggregateUtil { aggregates, aggFields, aggMapping, - Array(), - outputArity) - - val aggregateMapping = getAggregateMapping(namedAggregates, outputType) - - if (aggregateMapping.length != namedAggregates.length) { - throw new TableException( - "Could not find output field in input data type or aggregate functions.") - } + partialResults = false, + Array(), // no fields are forwarded + None, + None, + outputArity + ) val aggResultTypes = namedAggregates.map(a => FlinkTypeFactory.toTypeInfo(a.left.getType)) @@ -1169,62 +1288,6 @@ object AggregateUtil { new RowTypeInfo(aggTypes: _*) } - // Find the mapping between the index of aggregate list and aggregated value index in output Row. - private def getAggregateMapping( - namedAggregates: Seq[CalcitePair[AggregateCall, String]], - outputType: RelDataType): Array[(Int, Int)] = { - - // the mapping relation between aggregate function index in list and its corresponding - // field index in output Row. - var aggOffsetMapping = ArrayBuffer[(Int, Int)]() - - outputType.getFieldList.zipWithIndex.foreach { - case (outputFieldType, outputIndex) => - namedAggregates.zipWithIndex.foreach { - case (namedAggCall, aggregateIndex) => - if (namedAggCall.getValue.equals(outputFieldType.getName) && - namedAggCall.getKey.getType.equals(outputFieldType.getType)) { - aggOffsetMapping += ((outputIndex, aggregateIndex)) - } - } - } - - aggOffsetMapping.toArray - } - - // Find the mapping between the index of group key in intermediate aggregate Row and its index - // in output Row. - private def getGroupKeysMapping( - inputDatType: RelDataType, - outputType: RelDataType, - groupKeys: Array[Int]): Array[(Int, Int)] = { - - // the mapping relation between field index of intermediate aggregate Row and output Row. - var groupingOffsetMapping = ArrayBuffer[(Int, Int)]() - - outputType.getFieldList.zipWithIndex.foreach { - case (outputFieldType, outputIndex) => - inputDatType.getFieldList.zipWithIndex.foreach { - // find the field index in input data type. - case (inputFieldType, inputIndex) => - if (outputFieldType.getName.equals(inputFieldType.getName) && - outputFieldType.getType.equals(inputFieldType.getType)) { - // as aggregated field in output data type would not have a matched field in - // input data, so if inputIndex is not -1, it must be a group key. Then we can - // find the field index in buffer data by the group keys index mapping between - // input data and buffer data. - for (i <- groupKeys.indices) { - if (inputIndex == groupKeys(i)) { - groupingOffsetMapping += ((outputIndex, i)) - } - } - } - } - } - - groupingOffsetMapping.toArray - } - private def getTimeFieldPosition( timeField: Expression, inputType: RelDataType, http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala index 867943e..5f459f9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala @@ -21,101 +21,66 @@ import java.lang.Iterable import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.types.Row -import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.util.Collector +import org.slf4j.LoggerFactory /** * [[RichGroupReduceFunction]] to compute aggregates that do not support pre-aggregation for batch * (DataSet) queries. * - * @param aggregates The aggregate functions. - * @param aggInFields The positions of the aggregation input fields. - * @param gkeyOutMapping The mapping of group keys between input and output positions. - * @param aggOutMapping The mapping of aggregates to output positions. - * @param groupingSetsMapping The mapping of grouping set keys between input and output positions. - * @param finalRowArity The arity of the final resulting row. + * @param genAggregations Code-generated [[GeneratedAggregations]] */ class DataSetAggFunction( - private val aggregates: Array[AggregateFunction[_ <: Any]], - private val aggInFields: Array[Array[Int]], - private val aggOutMapping: Array[(Int, Int)], - private val gkeyOutMapping: Array[(Int, Int)], - private val groupingSetsMapping: Array[(Int, Int)], - private val finalRowArity: Int) - extends RichGroupReduceFunction[Row, Row] { - - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(aggInFields) - Preconditions.checkNotNull(aggOutMapping) - Preconditions.checkNotNull(gkeyOutMapping) - Preconditions.checkNotNull(groupingSetsMapping) + private val genAggregations: GeneratedAggregationsFunction) + extends RichGroupReduceFunction[Row, Row] + with Compiler[GeneratedAggregations] { private var output: Row = _ + private var accumulators: Row = _ - private var intermediateGKeys: Option[Array[Int]] = None - private var accumulators: Array[Accumulator] = _ + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ override def open(config: Configuration) { - accumulators = new Array(aggregates.length) - output = new Row(finalRowArity) - - if (!groupingSetsMapping.isEmpty) { - intermediateGKeys = Some(gkeyOutMapping.map(_._1)) - } + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getClass.getClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + accumulators = function.createAccumulators() } override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { - // create accumulators - var i = 0 - while (i < aggregates.length) { - accumulators(i) = aggregates(i).createAccumulator() - i += 1 - } + // reset accumulators + function.resetAccumulator(accumulators) val iterator = records.iterator() + var record: Row = null while (iterator.hasNext) { - val record = iterator.next() + record = iterator.next() // accumulate - i = 0 - while (i < aggregates.length) { - aggregates(i).accumulate(accumulators(i), record.getField(aggInFields(i)(0))) - i += 1 - } - - // check if this record is the last record - if (!iterator.hasNext) { - // set group keys value to final output - i = 0 - while (i < gkeyOutMapping.length) { - val (out, in) = gkeyOutMapping(i) - output.setField(out, record.getField(in)) - i += 1 - } - - // set agg results to output - i = 0 - while (i < aggOutMapping.length) { - val (out, in) = aggOutMapping(i) - output.setField(out, aggregates(in).getValue(accumulators(in))) - i += 1 - } - - // set grouping set flags to output - if (intermediateGKeys.isDefined) { - i = 0 - while (i < groupingSetsMapping.length) { - val (in, out) = groupingSetsMapping(i) - output.setField(out, !intermediateGKeys.get.contains(in)) - i += 1 - } - } - - out.collect(output) - } + function.accumulate(accumulators, record) } + + // set group keys value to final output + function.setForwardedFields(record, output) + + // set agg results to output + function.setAggregationResults(accumulators, output) + + // set grouping set flags to output + function.setConstantFlags(output) + + out.collect(output) } } http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala index e3db7a2..9b81992 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala @@ -19,60 +19,43 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable -import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.types.Row -import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.util.Collector +import org.slf4j.LoggerFactory /** * [[RichGroupReduceFunction]] to compute the final result of a pre-aggregated aggregation * for batch (DataSet) queries. * - * @param aggregates The aggregate functions. - * @param aggOutFields The positions of the aggregation results in the output - * @param gkeyOutFields The positions of the grouping keys in the output - * @param groupingSetsMapping The mapping of grouping set keys between input and output positions. - * @param finalRowArity The arity of the final resulting row + * @param genAggregations Code-generated [[GeneratedAggregations]] */ class DataSetFinalAggFunction( - private val aggregates: Array[AggregateFunction[_ <: Any]], - private val aggOutFields: Array[Int], - private val gkeyOutFields: Array[Int], - private val groupingSetsMapping: Array[(Int, Int)], - private val finalRowArity: Int) - extends RichGroupReduceFunction[Row, Row] { - - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(aggOutFields) - Preconditions.checkNotNull(gkeyOutFields) - Preconditions.checkNotNull(groupingSetsMapping) + private val genAggregations: GeneratedAggregationsFunction) + extends RichGroupReduceFunction[Row, Row] + with Compiler[GeneratedAggregations] { private var output: Row = _ + private var accumulators: Row = _ - private val intermediateGKeys: Option[Array[Int]] = if (!groupingSetsMapping.isEmpty) { - Some(gkeyOutFields) - } else { - None - } - - private val numAggs = aggregates.length - private val numGKeys = gkeyOutFields.length - - private val accumulators: Array[JArrayList[Accumulator]] = - Array.fill(numAggs)(new JArrayList[Accumulator](2)) + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ override def open(config: Configuration) { - output = new Row(finalRowArity) - - // init lists with two empty accumulators - for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulators(i).add(accumulator) - accumulators(i).add(accumulator) - } + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getClass.getClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + accumulators = function.createAccumulators() } override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { @@ -80,56 +63,24 @@ class DataSetFinalAggFunction( val iterator = records.iterator() // reset first accumulator - var i = 0 - while (i < aggregates.length) { - aggregates(i).resetAccumulator(accumulators(i).get(0)) - i += 1 - } + function.resetAccumulator(accumulators) + var record: Row = null while (iterator.hasNext) { - val record = iterator.next() - + record = iterator.next() // accumulate - i = 0 - while (i < aggregates.length) { - // insert received accumulator into acc list - val newAcc = record.getField(numGKeys + i).asInstanceOf[Accumulator] - accumulators(i).set(1, newAcc) - // merge acc list - val retAcc = aggregates(i).merge(accumulators(i)) - // insert result into acc list - accumulators(i).set(0, retAcc) - i += 1 - } - - // check if this record is the last record - if (!iterator.hasNext) { - // set group keys value to final output - i = 0 - while (i < gkeyOutFields.length) { - output.setField(gkeyOutFields(i), record.getField(i)) - i += 1 - } - - // get final aggregate value and set to output. - i = 0 - while (i < aggOutFields.length) { - output.setField(aggOutFields(i), aggregates(i).getValue(accumulators(i).get(0))) - i += 1 - } - - // set grouping set flags to output - if (intermediateGKeys.isDefined) { - i = 0 - while (i < groupingSetsMapping.length) { - val (in, out) = groupingSetsMapping(i) - output.setField(out, !intermediateGKeys.get.contains(in)) - i += 1 - } - } - - out.collect(output) - } + function.mergeAccumulatorsPair(accumulators, record) } + + // set group keys value to final output + function.setForwardedFields(record, output) + + // get final aggregate value and set to output. + function.setAggregationResults(accumulators, output) + + // set grouping set flags to output + function.setConstantFlags(output) + + out.collect(output) } } http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala index db49a53..8febe3e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala @@ -21,85 +21,64 @@ import java.lang.Iterable import org.apache.flink.api.common.functions._ import org.apache.flink.configuration.Configuration -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.types.Row -import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.util.Collector +import org.slf4j.LoggerFactory /** * [[GroupCombineFunction]] and [[MapPartitionFunction]] to compute pre-aggregates for batch * (DataSet) queries. * - * @param aggregates The aggregate functions. - * @param aggInFields The positions of the aggregation input fields. - * @param groupingKeys The positions of the grouping keys in the input. + * @param genAggregations Code-generated [[GeneratedAggregations]] */ -class DataSetPreAggFunction( - private val aggregates: Array[AggregateFunction[_ <: Any]], - private val aggInFields: Array[Array[Int]], - private val groupingKeys: Array[Int]) +class DataSetPreAggFunction(genAggregations: GeneratedAggregationsFunction) extends AbstractRichFunction with GroupCombineFunction[Row, Row] - with MapPartitionFunction[Row, Row] { - - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(aggInFields) - Preconditions.checkNotNull(groupingKeys) + with MapPartitionFunction[Row, Row] + with Compiler[GeneratedAggregations] { private var output: Row = _ - private var accumulators: Array[Accumulator] = _ + private var accumulators: Row = _ + + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ override def open(config: Configuration) { - accumulators = new Array(aggregates.length) - output = new Row(groupingKeys.length + aggregates.length) + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getClass.getClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + accumulators = function.createAccumulators() } override def combine(values: Iterable[Row], out: Collector[Row]): Unit = { - preaggregate(values, out) - } - - override def mapPartition(values: Iterable[Row], out: Collector[Row]): Unit = { - preaggregate(values, out) - } - - def preaggregate(records: Iterable[Row], out: Collector[Row]): Unit = { + // reset accumulators + function.resetAccumulator(accumulators) - // create accumulators - var i = 0 - while (i < aggregates.length) { - accumulators(i) = aggregates(i).createAccumulator() - i += 1 - } - - val iterator = records.iterator() + val iterator = values.iterator() + var record: Row = null while (iterator.hasNext) { - val record = iterator.next() - + record = iterator.next() // accumulate - i = 0 - while (i < aggregates.length) { - aggregates(i).accumulate(accumulators(i), record.getField(aggInFields(i)(0))) - i += 1 - } - // check if this record is the last record - if (!iterator.hasNext) { - // set group keys value to output - i = 0 - while (i < groupingKeys.length) { - output.setField(i, record.getField(groupingKeys(i))) - i += 1 - } + function.accumulate(accumulators, record) + } - // set agg results to output - i = 0 - while (i < accumulators.length) { - output.setField(groupingKeys.length + i, accumulators(i)) - i += 1 - } + // set group keys and accumulators to output + function.setAggregationResults(accumulators, output) + function.setForwardedFields(record, output) - out.collect(output) - } - } + out.collect(output) } + override def mapPartition(values: Iterable[Row], out: Collector[Row]): Unit = { + combine(values, out) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala index d108570..95699a2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala @@ -18,13 +18,13 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable -import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.types.Row import org.apache.flink.configuration.Configuration -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} -import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.util.Collector +import org.slf4j.LoggerFactory /** * It wraps the aggregate logic inside of @@ -40,53 +40,45 @@ import org.apache.flink.util.{Collector, Preconditions} * 2. when partial aggregate is supported, the input data structure of reduce is * |groupKey1|groupKey2|sum1|count1|sum2|count2|windowStart|windowEnd| * - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping The index mapping between aggregate function list and - * aggregated value index in output Row. - * @param finalRowArity The output row field count. + * @param genAggregations Code-generated [[GeneratedAggregations]] + * @param keysAndAggregatesArity The total arity of keys and aggregates * @param finalRowWindowStartPos The relative window-start field position. * @param finalRowWindowEndPos The relative window-end field position. * @param gap Session time window gap. */ class DataSetSessionWindowAggReduceGroupFunction( - aggregates: Array[AggregateFunction[_ <: Any]], - groupKeysMapping: Array[(Int, Int)], - aggregateMapping: Array[(Int, Int)], - finalRowArity: Int, + genAggregations: GeneratedAggregationsFunction, + keysAndAggregatesArity: Int, finalRowWindowStartPos: Option[Int], finalRowWindowEndPos: Option[Int], gap: Long, isInputCombined: Boolean) - extends RichGroupReduceFunction[Row, Row] { + extends RichGroupReduceFunction[Row, Row] + with Compiler[GeneratedAggregations] { - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(groupKeysMapping) + private var collector: TimeWindowPropertyCollector = _ + private val intermediateRowWindowStartPos = keysAndAggregatesArity + private val intermediateRowWindowEndPos = keysAndAggregatesArity + 1 - private var aggregateBuffer: Row = _ private var output: Row = _ - private var collector: TimeWindowPropertyCollector = _ - private val accumStartPos: Int = groupKeysMapping.length - private val intermediateRowArity: Int = accumStartPos + aggregates.length + 2 - private val intermediateRowWindowStartPos = intermediateRowArity - 2 - private val intermediateRowWindowEndPos = intermediateRowArity - 1 + private var accumulators: Row = _ - val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) { - new JArrayList[Accumulator](2) - } + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ override def open(config: Configuration) { - aggregateBuffer = new Row(intermediateRowArity) - output = new Row(finalRowArity) + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getClass.getClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + accumulators = function.createAccumulators() collector = new TimeWindowPropertyCollector(finalRowWindowStartPos, finalRowWindowEndPos) - - // init lists with two empty accumulators - for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).add(accumulator) - accumulatorList(i).add(accumulator) - } } /** @@ -105,13 +97,8 @@ class DataSetSessionWindowAggReduceGroupFunction( var windowEnd: java.lang.Long = null var currentRowTime: java.lang.Long = null - - // reset first accumulator in merge list - var i = 0 - while (i < aggregates.length) { - aggregates(i).resetAccumulator(accumulatorList(i).get(0)) - i += 1 - } + // reset accumulator + function.resetAccumulator(accumulators) val iterator = records.iterator() @@ -125,38 +112,18 @@ class DataSetSessionWindowAggReduceGroupFunction( // calculate the current window and open a new window if (null != windowEnd) { // evaluate and emit the current window's result. - doEvaluateAndCollect(out, accumulatorList, windowStart, windowEnd) - - // reset first accumulator in list - i = 0 - while (i < aggregates.length) { - aggregates(i).resetAccumulator(accumulatorList(i).get(0)) - i += 1 - } + doEvaluateAndCollect(out, windowStart, windowEnd) + // reset accumulator + function.resetAccumulator(accumulators) } else { - // set group keys value to final output. - i = 0 - while (i < groupKeysMapping.length) { - val (after, previous) = groupKeysMapping(i) - output.setField(after, record.getField(previous)) - i += 1 - } + // set keys to output + function.setForwardedFields(record, output) } windowStart = record.getField(intermediateRowWindowStartPos).asInstanceOf[Long] } - i = 0 - while (i < aggregates.length) { - // insert received accumulator into acc list - val newAcc = record.getField(accumStartPos + i).asInstanceOf[Accumulator] - accumulatorList(i).set(1, newAcc) - // merge acc list - val retAcc = aggregates(i).merge(accumulatorList(i)) - // insert result into acc list - accumulatorList(i).set(0, retAcc) - i += 1 - } + function.mergeAccumulatorsPair(accumulators, record) windowEnd = if (isInputCombined) { // partial aggregate is supported @@ -167,15 +134,13 @@ class DataSetSessionWindowAggReduceGroupFunction( } } // evaluate and emit the current window's result. - doEvaluateAndCollect(out, accumulatorList, windowStart, windowEnd) + doEvaluateAndCollect(out, windowStart, windowEnd) } /** * Evaluate and emit the data of the current window. * * @param out the collection of the aggregate results - * @param accumulatorList an array (indexed by aggregate index) of the accumulator lists for - * each aggregate * @param windowStart the window's start attribute value is the min (rowtime) of all rows * in the window. * @param windowEnd the window's end property value is max (rowtime) + gap for all rows @@ -183,18 +148,11 @@ class DataSetSessionWindowAggReduceGroupFunction( */ def doEvaluateAndCollect( out: Collector[Row], - accumulatorList: Array[JArrayList[Accumulator]], windowStart: Long, windowEnd: Long): Unit = { - // merge the accumulators and then get value for the final output - var i = 0 - while (i < aggregateMapping.length) { - val (after, previous) = aggregateMapping(i) - val agg = aggregates(previous) - output.setField(after, agg.getValue(accumulatorList(previous).get(0))) - i += 1 - } + // set value for the final output + function.setAggregationResults(accumulators, output) // adds TimeWindow properties to output then emit output if (finalRowWindowStartPos.isDefined || finalRowWindowEndPos.isDefined) { http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala index acd9e63..22a2682 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala @@ -18,55 +18,55 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable -import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.{AbstractRichFunction, GroupCombineFunction, MapPartitionFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.types.Row import org.apache.flink.configuration.Configuration -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} -import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} +import org.apache.flink.util.Collector +import org.slf4j.LoggerFactory /** * This wraps the aggregate logic inside of * [[org.apache.flink.api.java.operators.GroupCombineOperator]]. * - * @param aggregates The aggregate functions. - * @param groupingKeys The indexes of the grouping fields. + * @param genAggregations Code-generated [[GeneratedAggregations]] + * @param keysAndAggregatesArity The total arity of keys and aggregates * @param gap Session time window gap. * @param intermediateRowType Intermediate row data type. */ class DataSetSessionWindowAggregatePreProcessor( - aggregates: Array[AggregateFunction[_ <: Any]], - groupingKeys: Array[Int], + genAggregations: GeneratedAggregationsFunction, + keysAndAggregatesArity: Int, gap: Long, @transient intermediateRowType: TypeInformation[Row]) extends AbstractRichFunction with MapPartitionFunction[Row,Row] with GroupCombineFunction[Row,Row] - with ResultTypeQueryable[Row] { + with ResultTypeQueryable[Row] + with Compiler[GeneratedAggregations] { - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(groupingKeys) + private var output: Row = _ + private val rowTimeFieldPos = keysAndAggregatesArity + private var accumulators: Row = _ - private var aggregateBuffer: Row = _ - private val accumStartPos: Int = groupingKeys.length - private val rowTimeFieldPos = accumStartPos + aggregates.length - - val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) { - new JArrayList[Accumulator](2) - } + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ override def open(config: Configuration) { - aggregateBuffer = new Row(rowTimeFieldPos + 2) - - // init lists with two empty accumulators - for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).add(accumulator) - accumulatorList(i).add(accumulator) - } + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getClass.getClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + accumulators = function.createAccumulators() + output = function.createOutputRow() } /** @@ -79,43 +79,13 @@ class DataSetSessionWindowAggregatePreProcessor( * */ override def combine(records: Iterable[Row], out: Collector[Row]): Unit = { - preProcessing(records, out) - } - - /** - * Divide window based on the rowtime - * (current'rowtime - previousârowtime > gap), and then merge data (within a unified window) - * into an aggregate buffer. - * - * @param records Intermediate aggregate Rows. - * @return Pre partition intermediate aggregate Row. - * - */ - override def mapPartition(records: Iterable[Row], out: Collector[Row]): Unit = { - preProcessing(records, out) - } - - /** - * Intermediate aggregate Rows, divide window based on the rowtime - * (current'rowtime - previousârowtime > gap), and then merge data (within a unified window) - * into an aggregate buffer. - * - * @param records Intermediate aggregate Rows. - * @return PreProcessing intermediate aggregate Row. - * - */ - private def preProcessing(records: Iterable[Row], out: Collector[Row]): Unit = { var windowStart: java.lang.Long = null var windowEnd: java.lang.Long = null var currentRowTime: java.lang.Long = null - // reset first accumulator in merge list - var i = 0 - while (i < aggregates.length) { - aggregates(i).resetAccumulator(accumulatorList(i).get(0)) - i += 1 - } + // reset accumulator + function.resetAccumulator(accumulators) val iterator = records.iterator() @@ -128,51 +98,44 @@ class DataSetSessionWindowAggregatePreProcessor( // calculate the current window and open a new window. if (windowEnd != null) { // emit the current window's merged data - doCollect(out, accumulatorList, windowStart, windowEnd) - - // reset first value of accumulator list - i = 0 - while (i < aggregates.length) { - aggregates(i).resetAccumulator(accumulatorList(i).get(0)) - i += 1 - } + doCollect(out, windowStart, windowEnd) + + // reset accumulator + function.resetAccumulator(accumulators) } else { // set group keys to aggregateBuffer. - i = 0 - while (i < groupingKeys.length) { - aggregateBuffer.setField(i, record.getField(i)) - i += 1 - } + function.setForwardedFields(record, output) } windowStart = record.getField(rowTimeFieldPos).asInstanceOf[Long] } - i = 0 - while (i < aggregates.length) { - // insert received accumulator into acc list - val newAcc = record.getField(accumStartPos + i).asInstanceOf[Accumulator] - accumulatorList(i).set(1, newAcc) - // merge acc list - val retAcc = aggregates(i).merge(accumulatorList(i)) - // insert result into acc list - accumulatorList(i).set(0, retAcc) - i += 1 - } + function.mergeAccumulatorsPair(accumulators, record) // the current rowtime is the last rowtime of the next calculation. windowEnd = currentRowTime + gap } // emit the merged data of the current window. - doCollect(out, accumulatorList, windowStart, windowEnd) + doCollect(out, windowStart, windowEnd) + } + + /** + * Divide window based on the rowtime + * (current'rowtime - previousârowtime > gap), and then merge data (within a unified window) + * into an aggregate buffer. + * + * @param records Intermediate aggregate Rows. + * @return Pre partition intermediate aggregate Row. + * + */ + override def mapPartition(records: Iterable[Row], out: Collector[Row]): Unit = { + combine(records, out) } /** * Emit the merged data of the current window. * * @param out the collection of the aggregate results - * @param accumulatorList an array (indexed by aggregate index) of the accumulator lists for - * each aggregate * @param windowStart the window's start attribute value is the min (rowtime) * of all rows in the window. * @param windowEnd the window's end property value is max (rowtime) + gap @@ -180,24 +143,18 @@ class DataSetSessionWindowAggregatePreProcessor( */ def doCollect( out: Collector[Row], - accumulatorList: Array[JArrayList[Accumulator]], windowStart: Long, windowEnd: Long): Unit = { - // merge the accumulators into one accumulator - var i = 0 - while (i < aggregates.length) { - aggregateBuffer.setField(accumStartPos + i, accumulatorList(i).get(0)) - i += 1 - } + function.setAggregationResults(accumulators, output) // intermediate Row WindowStartPos is rowtime pos. - aggregateBuffer.setField(rowTimeFieldPos, windowStart) + output.setField(rowTimeFieldPos, windowStart) // intermediate Row WindowEndPos is rowtime pos + 1. - aggregateBuffer.setField(rowTimeFieldPos + 1, windowEnd) + output.setField(rowTimeFieldPos + 1, windowEnd) - out.collect(aggregateBuffer) + out.collect(output) } override def getProducedType: TypeInformation[Row] = { http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala index 422989b..b3a19a4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala @@ -18,16 +18,16 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable -import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.{CombineFunction, RichGroupReduceFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.windowing.windows.TimeWindow -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.types.Row -import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.util.Collector +import org.slf4j.LoggerFactory /** * It is used for sliding windows on batch for time-windows. It takes a prepared input row (with @@ -38,106 +38,78 @@ import org.apache.flink.util.{Collector, Preconditions} * it does no final aggregate evaluation. It also includes the logic of * [[DataSetSlideTimeWindowAggFlatMapFunction]]. * - * @param aggregates aggregate functions - * @param groupingKeysLength number of grouping keys - * @param timeFieldPos position of aligned time field + * @param genAggregations Code-generated [[GeneratedAggregations]] + * @param keysAndAggregatesArity The total arity of keys and aggregates * @param windowSize window size of the sliding window * @param windowSlide window slide of the sliding window * @param returnType return type of this function */ class DataSetSlideTimeWindowAggReduceGroupFunction( - private val aggregates: Array[AggregateFunction[_ <: Any]], - private val groupingKeysLength: Int, - private val timeFieldPos: Int, + private val genAggregations: GeneratedAggregationsFunction, + private val keysAndAggregatesArity: Int, private val windowSize: Long, private val windowSlide: Long, @transient private val returnType: TypeInformation[Row]) extends RichGroupReduceFunction[Row, Row] with CombineFunction[Row, Row] - with ResultTypeQueryable[Row] { + with ResultTypeQueryable[Row] + with Compiler[GeneratedAggregations] { - Preconditions.checkNotNull(aggregates) + private val timeFieldPos = returnType.getArity - 1 + private val intermediateWindowStartPos = keysAndAggregatesArity protected var intermediateRow: Row = _ - // add one field to store window start - protected val intermediateRowArity: Int = groupingKeysLength + aggregates.length + 1 - protected val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) { - new JArrayList[Accumulator](2) - } - private val intermediateWindowStartPos: Int = intermediateRowArity - 1 + private var accumulators: Row = _ + + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ override def open(config: Configuration) { - intermediateRow = new Row(intermediateRowArity) - - // init lists with two empty accumulators - var i = 0 - while (i < aggregates.length) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).add(accumulator) - accumulatorList(i).add(accumulator) - i += 1 - } + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getClass.getClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + accumulators = function.createAccumulators() + intermediateRow = function.createOutputRow() } override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { // reset first accumulator - var i = 0 - while (i < aggregates.length) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).set(0, accumulator) - i += 1 - } + function.resetAccumulator(accumulators) val iterator = records.iterator() + var record: Row = null while (iterator.hasNext) { - val record = iterator.next() + record = iterator.next() // accumulate - i = 0 - while (i < aggregates.length) { - // insert received accumulator into acc list - val newAcc = record.getField(groupingKeysLength + i).asInstanceOf[Accumulator] - accumulatorList(i).set(1, newAcc) - // merge acc list - val retAcc = aggregates(i).merge(accumulatorList(i)) - // insert result into acc list - accumulatorList(i).set(0, retAcc) - i += 1 - } + function.mergeAccumulatorsPair(accumulators, record) + } + + val windowStart = record.getField(timeFieldPos).asInstanceOf[Long] + + // adopted from SlidingEventTimeWindows.assignWindows + var start: Long = TimeWindow.getWindowStartWithOffset(windowStart, 0, windowSlide) + + // skip preparing output if it is not necessary + if (start > windowStart - windowSize) { + + // set group keys and partial accumulated result + function.setAggregationResults(accumulators, intermediateRow) + function.setForwardedFields(record, intermediateRow) - // trigger tumbling evaluation - if (!iterator.hasNext) { - val windowStart = record.getField(timeFieldPos).asInstanceOf[Long] - - // adopted from SlidingEventTimeWindows.assignWindows - var start: Long = TimeWindow.getWindowStartWithOffset(windowStart, 0, windowSlide) - - // skip preparing output if it is not necessary - if (start > windowStart - windowSize) { - - // set group keys - i = 0 - while (i < groupingKeysLength) { - intermediateRow.setField(i, record.getField(i)) - i += 1 - } - - // set accumulators - i = 0 - while (i < aggregates.length) { - intermediateRow.setField(groupingKeysLength + i, accumulatorList(i).get(0)) - i += 1 - } - - // adopted from SlidingEventTimeWindows.assignWindows - while (start > windowStart - windowSize) { - intermediateRow.setField(intermediateWindowStartPos, start) - out.collect(intermediateRow) - start -= windowSlide - } - } + // adopted from SlidingEventTimeWindows.assignWindows + while (start > windowStart - windowSize) { + intermediateRow.setField(intermediateWindowStartPos, start) + out.collect(intermediateRow) + start -= windowSlide } } } @@ -145,54 +117,21 @@ class DataSetSlideTimeWindowAggReduceGroupFunction( override def combine(records: Iterable[Row]): Row = { // reset first accumulator - var i = 0 - while (i < aggregates.length) { - aggregates(i).resetAccumulator(accumulatorList(i).get(0)) - i += 1 - } + function.resetAccumulator(accumulators) val iterator = records.iterator() + var record: Row = null while (iterator.hasNext) { - val record = iterator.next() - - i = 0 - while (i < aggregates.length) { - // insert received accumulator into acc list - val newAcc = record.getField(groupingKeysLength + i).asInstanceOf[Accumulator] - accumulatorList(i).set(1, newAcc) - // merge acc list - val retAcc = aggregates(i).merge(accumulatorList(i)) - // insert result into acc list - accumulatorList(i).set(0, retAcc) - i += 1 - } - - // check if this record is the last record - if (!iterator.hasNext) { - - // set group keys - i = 0 - while (i < groupingKeysLength) { - intermediateRow.setField(i, record.getField(i)) - i += 1 - } - - // set accumulators - i = 0 - while (i < aggregates.length) { - intermediateRow.setField(groupingKeysLength + i, accumulatorList(i).get(0)) - i += 1 - } - - intermediateRow.setField(timeFieldPos, record.getField(timeFieldPos)) - - return intermediateRow - } + record = iterator.next() + function.mergeAccumulatorsPair(accumulators, record) } + // set group keys and partial accumulated result + function.setAggregationResults(accumulators, intermediateRow) + function.setForwardedFields(record, intermediateRow) + + intermediateRow.setField(timeFieldPos, record.getField(timeFieldPos)) - // this code path should never be reached as we return before the loop finishes - // we need this to prevent a compiler error - throw new IllegalArgumentException("Group is empty. This should never happen.") + intermediateRow } override def getProducedType: TypeInformation[Row] = {