http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceCombineFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceCombineFunction.scala index 8eac79d..381d443 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceCombineFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceCombineFunction.scala @@ -20,7 +20,8 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable import org.apache.flink.api.common.functions.CombineFunction -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.configuration.Configuration +import org.apache.flink.table.codegen.GeneratedAggregationsFunction import org.apache.flink.types.Row /** @@ -30,87 +31,62 @@ import org.apache.flink.types.Row * * It is used for sliding on batch for both time and count-windows. * - * @param aggregates aggregate functions. - * @param groupKeysMapping index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping index mapping between aggregate function list and aggregated value - * index in output Row. - * @param finalRowArity output row field count + * @param genPreAggregations Code-generated [[GeneratedAggregations]] for partial aggregation. + * @param genFinalAggregations Code-generated [[GeneratedAggregations]] for final aggregation. + * @param keysAndAggregatesArity The total arity of keys and aggregates * @param finalRowWindowStartPos relative window-start position to last field of output row * @param finalRowWindowEndPos relative window-end position to last field of output row * @param windowSize size of the window, used to determine window-end for output row */ class DataSetSlideWindowAggReduceCombineFunction( - aggregates: Array[AggregateFunction[_ <: Any]], - groupKeysMapping: Array[(Int, Int)], - aggregateMapping: Array[(Int, Int)], - finalRowArity: Int, + genPreAggregations: GeneratedAggregationsFunction, + genFinalAggregations: GeneratedAggregationsFunction, + keysAndAggregatesArity: Int, finalRowWindowStartPos: Option[Int], finalRowWindowEndPos: Option[Int], windowSize: Long) extends DataSetSlideWindowAggReduceGroupFunction( - aggregates, - groupKeysMapping, - aggregateMapping, - finalRowArity, + genFinalAggregations, + keysAndAggregatesArity, finalRowWindowStartPos, finalRowWindowEndPos, windowSize) with CombineFunction[Row, Row] { - private val intermediateRowArity: Int = groupKeysMapping.length + aggregateMapping.length + 1 - private val intermediateRow: Row = new Row(intermediateRowArity) + private val intermediateRow: Row = new Row(keysAndAggregatesArity + 1) - override def combine(records: Iterable[Row]): Row = { - - // reset first accumulator - var i = 0 - while (i < aggregates.length) { - aggregates(i).resetAccumulator(accumulatorList(i).get(0)) - i += 1 - } - - val iterator = records.iterator() - while (iterator.hasNext) { - val record = iterator.next() + protected var preAggfunction: GeneratedAggregations = _ - // accumulate - i = 0 - while (i < aggregates.length) { - // insert received accumulator into acc list - val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator] - accumulatorList(i).set(1, newAcc) - // merge acc list - val retAcc = aggregates(i).merge(accumulatorList(i)) - // insert result into acc list - accumulatorList(i).set(0, retAcc) - i += 1 - } + override def open(config: Configuration): Unit = { + super.open(config) - // check if this record is the last record - if (!iterator.hasNext) { - // set group keys - i = 0 - while (i < groupKeysMapping.length) { - intermediateRow.setField(i, record.getField(i)) - i += 1 - } + LOG.debug(s"Compiling AggregateHelper: $genPreAggregations.name \n\n " + + s"Code:\n$genPreAggregations.code") + val clazz = compile( + getClass.getClassLoader, + genPreAggregations.name, + genPreAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + preAggfunction = clazz.newInstance() + } - // set the partial accumulated result - i = 0 - while (i < aggregates.length) { - intermediateRow.setField(groupKeysMapping.length + i, accumulatorList(i).get(0)) - i += 1 - } + override def combine(records: Iterable[Row]): Row = { - intermediateRow.setField(windowStartPos, record.getField(windowStartPos)) + // reset accumulator + preAggfunction.resetAccumulator(accumulators) - return intermediateRow - } + val iterator = records.iterator() + var record: Row = null + while (iterator.hasNext) { + record = iterator.next() + preAggfunction.mergeAccumulatorsPair(accumulators, record) } + // set group keys and partial accumulated result + preAggfunction.setAggregationResults(accumulators, intermediateRow) + preAggfunction.setForwardedFields(record, intermediateRow) + + intermediateRow.setField(windowStartPos, record.getField(windowStartPos)) - // this code path should never be reached as we return before the loop finishes - // we need this to prevent a compiler error - throw new IllegalArgumentException("Group is empty. This should never happen.") + intermediateRow } }
http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala index d6bc006..a221c53 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideWindowAggReduceGroupFunction.scala @@ -18,13 +18,13 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable -import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.types.Row -import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.util.Collector +import org.slf4j.LoggerFactory /** * It wraps the aggregate logic inside of @@ -32,109 +32,72 @@ import org.apache.flink.util.{Collector, Preconditions} * * It is used for sliding on batch for both time and count-windows. * - * @param aggregates aggregate functions. - * @param groupKeysMapping index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping index mapping between aggregate function list and aggregated value - * index in output Row. - * @param finalRowArity output row field count + * @param genAggregations Code-generated [[GeneratedAggregations]] + * @param keysAndAggregatesArity The total arity of keys and aggregates * @param finalRowWindowStartPos relative window-start position to last field of output row * @param finalRowWindowEndPos relative window-end position to last field of output row * @param windowSize size of the window, used to determine window-end for output row */ class DataSetSlideWindowAggReduceGroupFunction( - aggregates: Array[AggregateFunction[_ <: Any]], - groupKeysMapping: Array[(Int, Int)], - aggregateMapping: Array[(Int, Int)], - finalRowArity: Int, + genAggregations: GeneratedAggregationsFunction, + keysAndAggregatesArity: Int, finalRowWindowStartPos: Option[Int], finalRowWindowEndPos: Option[Int], windowSize: Long) - extends RichGroupReduceFunction[Row, Row] { - - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(groupKeysMapping) + extends RichGroupReduceFunction[Row, Row] + with Compiler[GeneratedAggregations] { private var collector: TimeWindowPropertyCollector = _ + protected val windowStartPos: Int = keysAndAggregatesArity + private var output: Row = _ - private val accumulatorStartPos: Int = groupKeysMapping.length - protected val windowStartPos: Int = accumulatorStartPos + aggregates.length + protected var accumulators: Row = _ - val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) { - new JArrayList[Accumulator](2) - } + val LOG = LoggerFactory.getLogger(this.getClass) + protected var function: GeneratedAggregations = _ override def open(config: Configuration) { - output = new Row(finalRowArity) + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getClass.getClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + accumulators = function.createAccumulators() collector = new TimeWindowPropertyCollector(finalRowWindowStartPos, finalRowWindowEndPos) - - // init lists with two empty accumulators - var i = 0 - while (i < aggregates.length) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).add(accumulator) - accumulatorList(i).add(accumulator) - i += 1 - } } override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { - // reset first accumulator - var i = 0 - while (i < aggregates.length) { - aggregates(i).resetAccumulator(accumulatorList(i).get(0)) - i += 1 - } + // reset accumulator + function.resetAccumulator(accumulators) val iterator = records.iterator() + var record: Row = null while (iterator.hasNext) { - val record = iterator.next() - - // accumulate - i = 0 - while (i < aggregates.length) { - // insert received accumulator into acc list - val newAcc = record.getField(accumulatorStartPos + i).asInstanceOf[Accumulator] - accumulatorList(i).set(1, newAcc) - // merge acc list - val retAcc = aggregates(i).merge(accumulatorList(i)) - // insert result into acc list - accumulatorList(i).set(0, retAcc) - i += 1 - } + record = iterator.next() + function.mergeAccumulatorsPair(accumulators, record) + } - // check if this record is the last record - if (!iterator.hasNext) { - // set group keys value to final output - i = 0 - while (i < groupKeysMapping.length) { - val mapping = groupKeysMapping(i) - output.setField(mapping._1, record.getField(mapping._2)) - i += 1 - } + // set group keys value to final output + function.setForwardedFields(record, output) - // get final aggregate value and set to output. - i = 0 - while (i < aggregateMapping.length) { - val mapping = aggregateMapping(i) - val agg = aggregates(i) - val result = agg.getValue(accumulatorList(mapping._2).get(0)) - output.setField(mapping._1, result) - i += 1 - } + // get final aggregate value and set to output + function.setAggregationResults(accumulators, output) - // adds TimeWindow properties to output then emit output - if (finalRowWindowStartPos.isDefined || finalRowWindowEndPos.isDefined) { - collector.wrappedCollector = out - collector.windowStart = record.getField(windowStartPos).asInstanceOf[Long] - collector.windowEnd = collector.windowStart + windowSize + // adds TimeWindow properties to output then emit output + if (finalRowWindowStartPos.isDefined || finalRowWindowEndPos.isDefined) { + collector.wrappedCollector = out + collector.windowStart = record.getField(windowStartPos).asInstanceOf[Long] + collector.windowEnd = collector.windowStart + windowSize - collector.collect(output) - } else { - out.collect(output) - } - } + collector.collect(output) + } else { + out.collect(output) } } } http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala index 0a525f8..0e73f7b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleCountWindowAggReduceGroupFunction.scala @@ -18,106 +18,69 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable -import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.types.Row -import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.util.Collector +import org.slf4j.LoggerFactory /** * It wraps the aggregate logic inside of * [[org.apache.flink.api.java.operators.GroupReduceOperator]]. * It is only used for tumbling count-window on batch. * + * @param genAggregations Code-generated [[GeneratedAggregations]] * @param windowSize Tumble count window size - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping The index mapping between aggregate function list and aggregated value - * index in output Row. - * @param finalRowArity The output row field count */ class DataSetTumbleCountWindowAggReduceGroupFunction( - private val windowSize: Long, - private val aggregates: Array[AggregateFunction[_ <: Any]], - private val groupKeysMapping: Array[(Int, Int)], - private val aggregateMapping: Array[(Int, Int)], - private val finalRowArity: Int) - extends RichGroupReduceFunction[Row, Row] { - - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(groupKeysMapping) + private val genAggregations: GeneratedAggregationsFunction, + private val windowSize: Long) + extends RichGroupReduceFunction[Row, Row] + with Compiler[GeneratedAggregations] { private var output: Row = _ - private val accumStartPos: Int = groupKeysMapping.length + private var accumulators: Row = _ - val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) { - new JArrayList[Accumulator](2) - } + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ override def open(config: Configuration) { - output = new Row(finalRowArity) - - // init lists with two empty accumulators - for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).add(accumulator) - accumulatorList(i).add(accumulator) - } + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getClass.getClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + accumulators = function.createAccumulators() } override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { var count: Long = 0 val iterator = records.iterator() - var i = 0 while (iterator.hasNext) { if (count == 0) { - // reset first accumulator - i = 0 - while (i < aggregates.length) { - aggregates(i).resetAccumulator(accumulatorList(i).get(0)) - i += 1 - } + function.resetAccumulator(accumulators) } val record = iterator.next() count += 1 - i = 0 - while (i < aggregates.length) { - // insert received accumulator into acc list - val newAcc = record.getField(accumStartPos + i).asInstanceOf[Accumulator] - accumulatorList(i).set(1, newAcc) - // merge acc list - val retAcc = aggregates(i).merge(accumulatorList(i)) - // insert result into acc list - accumulatorList(i).set(0, retAcc) - i += 1 - } + accumulators = function.mergeAccumulatorsPair(accumulators, record) if (windowSize == count) { // set group keys value to final output. - i = 0 - while (i < groupKeysMapping.length) { - val (after, previous) = groupKeysMapping(i) - output.setField(after, record.getField(previous)) - i += 1 - } - - // merge the accumulators and then get value for the final output - i = 0 - while (i < aggregateMapping.length) { - val (after, previous) = aggregateMapping(i) - val agg = aggregates(previous) - output.setField(after, agg.getValue(accumulatorList(previous).get(0))) - i += 1 - } + function.setForwardedFields(record, output) + function.setAggregationResults(accumulators, output) // emit the output out.collect(output) count = 0 http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala index 904c76c..4a459b2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceCombineFunction.scala @@ -20,7 +20,8 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable import org.apache.flink.api.common.functions.CombineFunction -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.configuration.Configuration +import org.apache.flink.table.codegen.GeneratedAggregationsFunction import org.apache.flink.types.Row /** @@ -29,34 +30,45 @@ import org.apache.flink.types.Row * [[org.apache.flink.api.java.operators.GroupCombineOperator]]. * It is used for tumbling time-window on batch. * - * @param windowSize Tumbling time window size - * @param windowStartPos The relative window-start field position to the last field of output row - * @param windowEndPos The relative window-end field position to the last field of output row - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping The index mapping between aggregate function list and aggregated value - * index in output Row. - * @param finalRowArity The output row field count + * @param genPreAggregations Code-generated [[GeneratedAggregations]] for partial aggs. + * @param genFinalAggregations Code-generated [[GeneratedAggregations]] for final aggs. + * @param windowSize Tumbling time window size + * @param windowStartPos The relative window-start field position to the last field of + * output row + * @param windowEndPos The relative window-end field position to the last field of + * output row + * @param keysAndAggregatesArity The total arity of keys and aggregates */ class DataSetTumbleTimeWindowAggReduceCombineFunction( + genPreAggregations: GeneratedAggregationsFunction, + genFinalAggregations: GeneratedAggregationsFunction, windowSize: Long, windowStartPos: Option[Int], windowEndPos: Option[Int], - aggregates: Array[AggregateFunction[_ <: Any]], - groupKeysMapping: Array[(Int, Int)], - aggregateMapping: Array[(Int, Int)], - finalRowArity: Int) + keysAndAggregatesArity: Int) extends DataSetTumbleTimeWindowAggReduceGroupFunction( + genFinalAggregations, windowSize, windowStartPos, windowEndPos, - aggregates, - groupKeysMapping, - aggregateMapping, - finalRowArity) + keysAndAggregatesArity) with CombineFunction[Row, Row] { + protected var preAggfunction: GeneratedAggregations = _ + + override def open(config: Configuration): Unit = { + super.open(config) + + LOG.debug(s"Compiling AggregateHelper: $genPreAggregations.name \n\n " + + s"Code:\n$genPreAggregations.code") + val clazz = compile( + getClass.getClassLoader, + genPreAggregations.name, + genPreAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + preAggfunction = clazz.newInstance() + } + /** * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer, * @@ -69,47 +81,21 @@ class DataSetTumbleTimeWindowAggReduceCombineFunction( var last: Row = null val iterator = records.iterator() - // reset first accumulator in merge list - var i = 0 - while (i < aggregates.length) { - aggregates(i).resetAccumulator(accumulatorList(i).get(0)) - i += 1 - } + // reset accumulator + preAggfunction.resetAccumulator(accumulators) while (iterator.hasNext) { val record = iterator.next() - - i = 0 - while (i < aggregates.length) { - // insert received accumulator into acc list - val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator] - accumulatorList(i).set(1, newAcc) - // merge acc list - val retAcc = aggregates(i).merge(accumulatorList(i)) - // insert result into acc list - accumulatorList(i).set(0, retAcc) - i += 1 - } - + preAggfunction.mergeAccumulatorsPair(accumulators, record) last = record } - // set the partial merged result to the aggregateBuffer - i = 0 - while (i < aggregates.length) { - aggregateBuffer.setField(groupKeysMapping.length + i, accumulatorList(i).get(0)) - i += 1 - } - - // set group keys to aggregateBuffer. - i = 0 - while (i < groupKeysMapping.length) { - aggregateBuffer.setField(i, last.getField(i)) - i += 1 - } + // set group keys and partial merged result to aggregateBuffer + preAggfunction.setAggregationResults(accumulators, aggregateBuffer) + preAggfunction.setForwardedFields(last, aggregateBuffer) // set the rowtime attribute - val rowtimePos = groupKeysMapping.length + aggregates.length + val rowtimePos = keysAndAggregatesArity aggregateBuffer.setField(rowtimePos, last.getField(rowtimePos)) http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala index 99e2a0a..f4a1fc5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetTumbleTimeWindowAggReduceGroupFunction.scala @@ -18,65 +18,56 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable -import java.util.{ArrayList => JArrayList} import org.apache.flink.api.common.functions.RichGroupReduceFunction import org.apache.flink.configuration.Configuration -import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.types.Row -import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.util.Collector +import org.slf4j.LoggerFactory /** * It wraps the aggregate logic inside of * [[org.apache.flink.api.java.operators.GroupReduceOperator]]. It is used for tumbling time-window * on batch. * + * @param genAggregations Code-generated [[GeneratedAggregations]] * @param windowSize Tumbling time window size * @param windowStartPos The relative window-start field position to the last field of output row * @param windowEndPos The relative window-end field position to the last field of output row - * @param aggregates The aggregate functions. - * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row - * and output Row. - * @param aggregateMapping The index mapping between aggregate function list and aggregated value - * index in output Row. - * @param finalRowArity The output row field count + * @param keysAndAggregatesArity The total arity of keys and aggregates */ class DataSetTumbleTimeWindowAggReduceGroupFunction( + genAggregations: GeneratedAggregationsFunction, windowSize: Long, windowStartPos: Option[Int], windowEndPos: Option[Int], - aggregates: Array[AggregateFunction[_ <: Any]], - groupKeysMapping: Array[(Int, Int)], - aggregateMapping: Array[(Int, Int)], - finalRowArity: Int) - extends RichGroupReduceFunction[Row, Row] { - - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(groupKeysMapping) + keysAndAggregatesArity: Int) + extends RichGroupReduceFunction[Row, Row] + with Compiler[GeneratedAggregations] { private var collector: TimeWindowPropertyCollector = _ - protected var aggregateBuffer: Row = _ - private var output: Row = _ - private val accumStartPos: Int = groupKeysMapping.length - private val rowtimePos: Int = accumStartPos + aggregates.length - private val intermediateRowArity: Int = rowtimePos + 1 + protected var aggregateBuffer: Row = new Row(keysAndAggregatesArity + 1) + private var output: Row = _ + protected var accumulators: Row = _ - val accumulatorList: Array[JArrayList[Accumulator]] = Array.fill(aggregates.length) { - new JArrayList[Accumulator](2) - } + val LOG = LoggerFactory.getLogger(this.getClass) + protected var function: GeneratedAggregations = _ override def open(config: Configuration) { - aggregateBuffer = new Row(intermediateRowArity) - output = new Row(finalRowArity) + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getClass.getClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + accumulators = function.createAccumulators() collector = new TimeWindowPropertyCollector(windowStartPos, windowEndPos) - - // init lists with two empty accumulators - for (i <- aggregates.indices) { - val accumulator = aggregates(i).createAccumulator() - accumulatorList(i).add(accumulator) - accumulatorList(i).add(accumulator) - } } override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { @@ -84,51 +75,23 @@ class DataSetTumbleTimeWindowAggReduceGroupFunction( var last: Row = null val iterator = records.iterator() - // reset first accumulator in merge list - var i = 0 - while (i < aggregates.length) { - aggregates(i).resetAccumulator(accumulatorList(i).get(0)) - i += 1 - } + // reset accumulator + function.resetAccumulator(accumulators) while (iterator.hasNext) { val record = iterator.next() - - i = 0 - while (i < aggregates.length) { - // insert received accumulator into acc list - val newAcc = record.getField(groupKeysMapping.length + i).asInstanceOf[Accumulator] - accumulatorList(i).set(1, newAcc) - // merge acc list - val retAcc = aggregates(i).merge(accumulatorList(i)) - // insert result into acc list - accumulatorList(i).set(0, retAcc) - i += 1 - } - + function.mergeAccumulatorsPair(accumulators, record) last = record } // set group keys value to final output. - i = 0 - while (i < groupKeysMapping.length) { - val (after, previous) = groupKeysMapping(i) - output.setField(after, last.getField(previous)) - i += 1 - } + function.setForwardedFields(last, output) // get final aggregate value and set to output. - i = 0 - while (i < aggregateMapping.length) { - val (after, previous) = aggregateMapping(i) - val agg = aggregates(previous) - val result = agg.getValue(accumulatorList(previous).get(0)) - output.setField(after, result) - i += 1 - } + function.setAggregationResults(accumulators, output) // get window start timestamp - val startTs: Long = last.getField(rowtimePos).asInstanceOf[Long] + val startTs: Long = last.getField(keysAndAggregatesArity).asInstanceOf[Long] // set collector and window collector.wrappedCollector = out http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala index 5cc7ada..d49ed0e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetWindowAggMapFunction.scala @@ -25,58 +25,60 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable import org.apache.flink.configuration.Configuration import org.apache.flink.streaming.api.windowing.windows.TimeWindow -import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.types.Row -import org.apache.flink.util.Preconditions - +import org.slf4j.LoggerFactory /** * This map function only works for windows on batch tables. * It appends an (aligned) rowtime field to the end of the output row. + * + * @param genAggregations Code-generated [[GeneratedAggregations]] + * @param timeFieldPos Time field position in input row + * @param tumbleTimeWindowSize The size of tumble time window */ class DataSetWindowAggMapFunction( - private val aggregates: Array[AggregateFunction[_]], - private val aggFields: Array[Array[Int]], - private val groupingKeys: Array[Int], - private val timeFieldPos: Int, // time field position in input row + private val genAggregations: GeneratedAggregationsFunction, + private val timeFieldPos: Int, private val tumbleTimeWindowSize: Option[Long], @transient private val returnType: TypeInformation[Row]) - extends RichMapFunction[Row, Row] with ResultTypeQueryable[Row] { - - Preconditions.checkNotNull(aggregates) - Preconditions.checkNotNull(aggFields) - Preconditions.checkArgument(aggregates.length == aggFields.length) + extends RichMapFunction[Row, Row] + with ResultTypeQueryable[Row] + with Compiler[GeneratedAggregations] { + private var accs: Row = _ private var output: Row = _ - // add one more arity to store rowtime - private val partialRowLength = groupingKeys.length + aggregates.length + 1 - // rowtime index in the buffer output row - private val rowtimeIndex: Int = partialRowLength - 1 + + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ override def open(config: Configuration) { - output = new Row(partialRowLength) + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getRuntimeContext.getUserCodeClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + accs = function.createAccumulators() + output = function.createOutputRow() } override def map(input: Row): Row = { - var i = 0 - while (i < aggregates.length) { - val agg = aggregates(i) - val fieldValue = input.getField(aggFields(i)(0)) - val accumulator = agg.createAccumulator() - agg.accumulate(accumulator, fieldValue) - output.setField(groupingKeys.length + i, accumulator) - i += 1 - } + function.resetAccumulator(accs) - i = 0 - while (i < groupingKeys.length) { - output.setField(i, input.getField(groupingKeys(i))) - i += 1 - } + function.accumulate(accs, input) + + function.setAggregationResults(accs, output) + + function.setForwardedFields(input, output) val timeField = input.getField(timeFieldPos) val rowtime = getTimestamp(timeField) + val rowtimeIndex = output.getArity - 1 if (tumbleTimeWindowSize.isDefined) { // in case of tumble time window, align rowtime to window start to represent the window output.setField( http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala index 17a1128..bee39fa 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala @@ -27,7 +27,9 @@ import org.apache.flink.types.Row abstract class GeneratedAggregations extends Function { /** - * Calculate the results from accumulators, and set the results to the output + * Sets the results of the aggregations (partial or final) to the output row. + * Final results are computed with the aggregation function. + * Partial results are the accumulators themselves. * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results @@ -36,15 +38,22 @@ abstract class GeneratedAggregations extends Function { def setAggregationResults(accumulators: Row, output: Row) /** - * Copies forwarded fields from input row to output row. + * Copies forwarded fields, such as grouping keys, from input row to output row. * - * @param input input values bundled in a row - * @param output output results collected in a row + * @param input input values bundled in a row + * @param output output results collected in a row */ def setForwardedFields(input: Row, output: Row) /** - * Accumulate the input values to the accumulators + * Sets constant flags (boolean fields) to an output row. + * + * @param output The output row to which the constant flags are set. + */ + def setConstantFlags(output: Row) + + /** + * Accumulates the input values to the accumulators. * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results @@ -53,7 +62,7 @@ abstract class GeneratedAggregations extends Function { def accumulate(accumulators: Row, input: Row) /** - * Retract the input values from the accumulators + * Retracts the input values from the accumulators. * * @param accumulators the accumulators (saved in a row) which contains the current * aggregated results @@ -62,7 +71,7 @@ abstract class GeneratedAggregations extends Function { def retract(accumulators: Row, input: Row) /** - * Init the accumulators, and save them to a accumulators Row. + * Initializes the accumulators and save them to a accumulators row. * * @return a row of accumulators which contains the aggregated results */ @@ -76,7 +85,7 @@ abstract class GeneratedAggregations extends Function { def createOutputRow(): Row /** - * Merges two rows of accumulators into one row + * Merges two rows of accumulators into one row. * * @param a First row of accumulators * @param b The other row of accumulators @@ -84,4 +93,11 @@ abstract class GeneratedAggregations extends Function { */ def mergeAccumulatorsPair(a: Row, b: Row): Row + /** + * Resets all the accumulators. + * + * @param accumulators the accumulators (saved in a row) which contains the current + * aggregated results + */ + def resetAccumulator(accumulators: Row) } http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala index 5ac09b9..4838747 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala @@ -339,5 +339,6 @@ class AggregationsITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } - case class WC(word: String, frequency: Long) } + +case class WC(word: String, frequency: Long) http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala index adc84bf..16c493e 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala @@ -159,12 +159,18 @@ class BoundedProcessingOverRangeProcessFunctionTest { | return new org.apache.flink.types.Row(7); | } | - | //The test won't use this method + |/******* This test does not use the following methods *******/ | public org.apache.flink.types.Row mergeAccumulatorsPair( | org.apache.flink.types.Row a, | org.apache.flink.types.Row b) { | return null; | } + | + | public void resetAccumulator(org.apache.flink.types.Row accs) { + | } + | + | public void setConstantFlags(org.apache.flink.types.Row output) { + | } |} """.stripMargin