[FLINK-6257] [table] Consistent naming of ProcessFunction and methods for OVER windows.
- Add check for sort order of OVER windows. This closes #3681. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/07f1b035 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/07f1b035 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/07f1b035 Branch: refs/heads/table-retraction Commit: 07f1b035ffbf07d160503c48e2c58a464ec5d014 Parents: 5ff9c99 Author: sunjincheng121 <sunjincheng...@gmail.com> Authored: Thu Apr 6 10:33:30 2017 +0800 Committer: Fabian Hueske <fhue...@apache.org> Committed: Thu Apr 6 16:33:45 2017 +0200 ---------------------------------------------------------------------- .../datastream/DataStreamOverAggregate.scala | 196 +++++-------- .../table/runtime/aggregate/AggregateUtil.scala | 122 +++----- ...ndedProcessingOverRangeProcessFunction.scala | 183 ------------ ...oundedProcessingOverRowProcessFunction.scala | 179 ------------ .../aggregate/ProcTimeBoundedRangeOver.scala | 182 ++++++++++++ .../aggregate/ProcTimeBoundedRowsOver.scala | 179 ++++++++++++ .../ProcTimeUnboundedNonPartitionedOver.scala | 96 ++++++ .../ProcTimeUnboundedPartitionedOver.scala | 84 ++++++ .../RangeClauseBoundedOverProcessFunction.scala | 201 ------------- .../aggregate/RowTimeBoundedRangeOver.scala | 200 +++++++++++++ .../aggregate/RowTimeBoundedRowsOver.scala | 222 ++++++++++++++ .../aggregate/RowTimeUnboundedOver.scala | 292 +++++++++++++++++++ .../RowsClauseBoundedOverProcessFunction.scala | 222 -------------- .../UnboundedEventTimeOverProcessFunction.scala | 292 ------------------- ...rtitionedProcessingOverProcessFunction.scala | 96 ------ ...UnboundedProcessingOverProcessFunction.scala | 84 ------ ...ProcessingOverRangeProcessFunctionTest.scala | 2 +- 17 files changed, 1380 insertions(+), 1452 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala index 947775b..2224752 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala @@ -17,20 +17,21 @@ */ package org.apache.flink.table.plan.nodes.datastream +import java.util.{List => JList} + import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.core.{AggregateCall, Window} +import org.apache.calcite.rel.core.Window.Group import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.table.api.{StreamTableEnvironment, TableException} import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.runtime.aggregate._ import org.apache.flink.table.plan.nodes.OverAggregate +import org.apache.flink.table.runtime.aggregate._ import org.apache.flink.types.Row -import org.apache.calcite.rel.core.Window -import org.apache.calcite.rel.core.Window.Group -import java.util.{List => JList} import org.apache.flink.api.java.functions.NullByteKeySelector import org.apache.flink.table.codegen.CodeGenerator @@ -90,12 +91,20 @@ class DataStreamOverAggregate( val overWindow: org.apache.calcite.rel.core.Window.Group = logicWindow.groups.get(0) - val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) + val orderKeys = overWindow.orderKeys.getFieldCollations - if (overWindow.orderKeys.getFieldCollations.size() != 1) { + if (orderKeys.size() != 1) { throw new TableException( - "Unsupported use of OVER windows. The window may only be ordered by a single time column.") + "Unsupported use of OVER windows. The window can only be ordered by a single time column.") } + val orderKey = orderKeys.get(0) + + if (!orderKey.direction.equals(ASCENDING)) { + throw new TableException( + "Unsupported use of OVER windows. The window can only be ordered in ASCENDING mode.") + } + + val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) val generator = new CodeGenerator( tableEnv.getConfig, @@ -104,78 +113,69 @@ class DataStreamOverAggregate( val timeType = inputType .getFieldList - .get(overWindow.orderKeys.getFieldCollations.get(0).getFieldIndex) + .get(orderKey.getFieldIndex) .getValue + timeType match { case _: ProcTimeType => // proc-time OVER window if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) { - // unbounded preceding OVER window - createUnboundedAndCurrentRowProcessingTimeOverWindow( + // unbounded OVER window + createUnboundedAndCurrentRowOverWindow( generator, - inputDS) + inputDS, + isRowTimeType = false, + isRowsClause = overWindow.isRows) } else if ( overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded && - overWindow.upperBound.isCurrentRow) { + overWindow.upperBound.isCurrentRow) { // bounded OVER window - if (overWindow.isRows) { - // ROWS clause bounded OVER window - createBoundedAndCurrentRowOverWindow( - generator, - inputDS, - isRangeClause = false, - isRowTimeType = false) - } else { - // RANGE clause bounded OVER window - createBoundedAndCurrentRowOverWindow( - generator, - inputDS, - isRangeClause = true, - isRowTimeType = false) - } + createBoundedAndCurrentRowOverWindow( + generator, + inputDS, + isRowTimeType = false, + isRowsClause = overWindow.isRows + ) } else { throw new TableException( - "processing-time OVER RANGE FOLLOWING window is not supported yet.") + "OVER RANGE FOLLOWING windows are not supported yet.") } case _: RowTimeType => // row-time OVER window if (overWindow.lowerBound.isPreceding && - overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) { - // ROWS/RANGE clause unbounded OVER window - createUnboundedAndCurrentRowEventTimeOverWindow( + overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) { + // unbounded OVER window + createUnboundedAndCurrentRowOverWindow( generator, inputDS, - overWindow.isRows) + isRowTimeType = true, + isRowsClause = overWindow.isRows + ) } else if (overWindow.lowerBound.isPreceding && overWindow.upperBound.isCurrentRow) { // bounded OVER window - if (overWindow.isRows) { - // ROWS clause bounded OVER window - createBoundedAndCurrentRowOverWindow( - generator, - inputDS, - isRangeClause = false, - isRowTimeType = true) - } else { - // RANGE clause bounded OVER window - createBoundedAndCurrentRowOverWindow( - generator, - inputDS, - isRangeClause = true, - isRowTimeType = true) - } + createBoundedAndCurrentRowOverWindow( + generator, + inputDS, + isRowTimeType = true, + isRowsClause = overWindow.isRows + ) } else { throw new TableException( - "row-time OVER RANGE FOLLOWING window is not supported yet.") + "OVER RANGE FOLLOWING windows are not supported yet.") } case _ => - throw new TableException(s"Unsupported time type {$timeType}") + throw new TableException( + "Unsupported time type {$timeType}. " + + "OVER windows do only support RowTimeType and ProcTimeType.") } } - def createUnboundedAndCurrentRowProcessingTimeOverWindow( + def createUnboundedAndCurrentRowOverWindow( generator: CodeGenerator, - inputDS: DataStream[Row]): DataStream[Row] = { + inputDS: DataStream[Row], + isRowTimeType: Boolean, + isRowsClause: Boolean): DataStream[Row] = { val overWindow: Group = logicWindow.groups.get(0) val partitionKeys: Array[Int] = overWindow.keys.toArray @@ -184,14 +184,17 @@ class DataStreamOverAggregate( // get the output types val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] + val processFunction = AggregateUtil.createUnboundedOverProcessFunction( + generator, + namedAggregates, + inputType, + isRowTimeType, + partitionKeys.nonEmpty, + isRowsClause) + val result: DataStream[Row] = // partitioned aggregation if (partitionKeys.nonEmpty) { - val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction( - generator, - namedAggregates, - inputType) - inputDS .keyBy(partitionKeys: _*) .process(processFunction) @@ -201,17 +204,19 @@ class DataStreamOverAggregate( } // non-partitioned aggregation else { - val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction( - generator, - namedAggregates, - inputType, - isPartitioned = false) - - inputDS - .process(processFunction).setParallelism(1).setMaxParallelism(1) - .returns(rowTypeInfo) - .name(aggOpName) - .asInstanceOf[DataStream[Row]] + if (isRowTimeType) { + inputDS.keyBy(new NullByteKeySelector[Row]) + .process(processFunction).setParallelism(1).setMaxParallelism(1) + .returns(rowTypeInfo) + .name(aggOpName) + .asInstanceOf[DataStream[Row]] + } else { + inputDS + .process(processFunction).setParallelism(1).setMaxParallelism(1) + .returns(rowTypeInfo) + .name(aggOpName) + .asInstanceOf[DataStream[Row]] + } } result } @@ -219,15 +224,15 @@ class DataStreamOverAggregate( def createBoundedAndCurrentRowOverWindow( generator: CodeGenerator, inputDS: DataStream[Row], - isRangeClause: Boolean, - isRowTimeType: Boolean): DataStream[Row] = { + isRowTimeType: Boolean, + isRowsClause: Boolean): DataStream[Row] = { val overWindow: Group = logicWindow.groups.get(0) val partitionKeys: Array[Int] = overWindow.keys.toArray val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates val precedingOffset = - getLowerBoundary(logicWindow, overWindow, getInput()) + (if (isRangeClause) 0 else 1) + getLowerBoundary(logicWindow, overWindow, getInput()) + (if (isRowsClause) 1 else 0) // get the output types val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] @@ -237,8 +242,9 @@ class DataStreamOverAggregate( namedAggregates, inputType, precedingOffset, - isRangeClause, - isRowTimeType) + isRowsClause, + isRowTimeType + ) val result: DataStream[Row] = // partitioned aggregation if (partitionKeys.nonEmpty) { @@ -253,49 +259,7 @@ class DataStreamOverAggregate( else { inputDS .keyBy(new NullByteKeySelector[Row]) - .process(processFunction) - .setParallelism(1) - .setMaxParallelism(1) - .returns(rowTypeInfo) - .name(aggOpName) - .asInstanceOf[DataStream[Row]] - } - result - } - - def createUnboundedAndCurrentRowEventTimeOverWindow( - generator: CodeGenerator, - inputDS: DataStream[Row], - isRows: Boolean): DataStream[Row] = { - - val overWindow: Group = logicWindow.groups.get(0) - val partitionKeys: Array[Int] = overWindow.keys.toArray - val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates - - // get the output types - val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] - - val processFunction = AggregateUtil.createUnboundedEventTimeOverProcessFunction( - generator, - namedAggregates, - inputType, - isRows) - - val result: DataStream[Row] = - // partitioned aggregation - if (partitionKeys.nonEmpty) { - inputDS.keyBy(partitionKeys: _*) - .process(processFunction) - .returns(rowTypeInfo) - .name(aggOpName) - .asInstanceOf[DataStream[Row]] - } - // global non-partitioned aggregation - else { - inputDS.keyBy(new NullByteKeySelector[Row]) - .process(processFunction) - .setParallelism(1) - .setMaxParallelism(1) + .process(processFunction).setParallelism(1).setMaxParallelism(1) .returns(rowTypeInfo) .name(aggOpName) .asInstanceOf[DataStream[Row]] http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index fc03ac1..09d1a13 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -55,21 +55,23 @@ object AggregateUtil { type JavaList[T] = java.util.List[T] /** - * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] to evaluate final - * aggregate value. + * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for unbounded OVER + * window to evaluate final aggregate value. * * @param generator code generator instance * @param namedAggregates List of calls to aggregate functions and their output field names - * @param inputType Input row type - * @param isPartitioned Flag to indicate whether the input is partitioned or not - * - * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]] + * @param inputType Input row type + * @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType + * @param isPartitioned It is a tag that indicate whether the input is partitioned + * @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause */ - private[flink] def createUnboundedProcessingOverProcessFunction( - generator: CodeGenerator, - namedAggregates: Seq[CalcitePair[AggregateCall, String]], - inputType: RelDataType, - isPartitioned: Boolean = true): ProcessFunction[Row, Row] = { + private[flink] def createUnboundedOverProcessFunction( + generator: CodeGenerator, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + isRowTimeType: Boolean, + isPartitioned: Boolean, + isRowsClause: Boolean): ProcessFunction[Row, Row] = { val (aggFields, aggregates) = transformToAggregateFunctions( @@ -95,14 +97,30 @@ object AggregateUtil { outputArity ) - if (isPartitioned) { - new UnboundedProcessingOverProcessFunction( - genFunction, - aggregationStateType) + if (isRowTimeType) { + if (isRowsClause) { + // ROWS unbounded over process function + new RowTimeUnboundedRowsOver( + genFunction, + aggregationStateType, + FlinkTypeFactory.toInternalRowTypeInfo(inputType)) + } else { + // RANGE unbounded over process function + new RowTimeUnboundedRangeOver( + genFunction, + aggregationStateType, + FlinkTypeFactory.toInternalRowTypeInfo(inputType)) + } } else { - new UnboundedNonPartitionedProcessingOverProcessFunction( - genFunction, - aggregationStateType) + if (isPartitioned) { + new ProcTimeUnboundedPartitionedOver( + genFunction, + aggregationStateType) + } else { + new ProcTimeUnboundedNonPartitionedOver( + genFunction, + aggregationStateType) + } } } @@ -114,7 +132,7 @@ object AggregateUtil { * @param namedAggregates List of calls to aggregate functions and their output field names * @param inputType Input row type * @param precedingOffset the preceding offset - * @param isRangeClause It is a tag that indicates whether the OVER clause is rangeClause + * @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause * @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]] */ @@ -123,7 +141,7 @@ object AggregateUtil { namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, precedingOffset: Long, - isRangeClause: Boolean, + isRowsClause: Boolean, isRowTimeType: Boolean): ProcessFunction[Row, Row] = { val (aggFields, aggregates) = @@ -151,15 +169,15 @@ object AggregateUtil { ) if (isRowTimeType) { - if (isRangeClause) { - new RangeClauseBoundedOverProcessFunction( + if (isRowsClause) { + new RowTimeBoundedRowsOver( genFunction, aggregationStateType, inputRowType, precedingOffset ) } else { - new RowsClauseBoundedOverProcessFunction( + new RowTimeBoundedRangeOver( genFunction, aggregationStateType, inputRowType, @@ -167,14 +185,14 @@ object AggregateUtil { ) } } else { - if (isRangeClause) { - new BoundedProcessingOverRangeProcessFunction( + if (isRowsClause) { + new ProcTimeBoundedRowsOver( genFunction, precedingOffset, aggregationStateType, inputRowType) } else { - new BoundedProcessingOverRowProcessFunction( + new ProcTimeBoundedRangeOver( genFunction, precedingOffset, aggregationStateType, @@ -183,58 +201,6 @@ object AggregateUtil { } } - /** - * Create an [[ProcessFunction]] to evaluate final aggregate value. - * - * @param generator code generator instance - * @param namedAggregates List of calls to aggregate functions and their output field names - * @param inputType Input row type - * @param isRows Flag to indicate if whether this is a Row (true) or a Range (false) - * over window process - * @return [[UnboundedEventTimeOverProcessFunction]] - */ - private[flink] def createUnboundedEventTimeOverProcessFunction( - generator: CodeGenerator, - namedAggregates: Seq[CalcitePair[AggregateCall, String]], - inputType: RelDataType, - isRows: Boolean): UnboundedEventTimeOverProcessFunction = { - - val (aggFields, aggregates) = - transformToAggregateFunctions( - namedAggregates.map(_.getKey), - inputType, - needRetraction = false) - - val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates) - - val forwardMapping = (0 until inputType.getFieldCount).map(x => (x, x)).toArray - val aggMapping = aggregates.indices.map(x => x + inputType.getFieldCount).toArray - val outputArity = inputType.getFieldCount + aggregates.length - - val genFunction = generator.generateAggregations( - "UnboundedEventTimeOverAggregateHelper", - generator, - inputType, - aggregates, - aggFields, - aggMapping, - forwardMapping, - outputArity) - - if (isRows) { - // ROWS unbounded over process function - new UnboundedEventTimeRowsOverProcessFunction( - genFunction, - aggregationStateType, - FlinkTypeFactory.toInternalRowTypeInfo(inputType)) - } else { - // RANGE unbounded over process function - new UnboundedEventTimeRangeOverProcessFunction( - genFunction, - aggregationStateType, - FlinkTypeFactory.toInternalRowTypeInfo(inputType)) - } - } /** * Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates. http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunction.scala deleted file mode 100644 index 8f3aa3e..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunction.scala +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.runtime.aggregate - -import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.functions.ProcessFunction -import org.apache.flink.types.Row -import org.apache.flink.util.Collector -import org.apache.flink.api.common.state.ValueState -import org.apache.flink.api.common.state.ValueStateDescriptor -import org.apache.flink.api.common.state.MapState -import org.apache.flink.api.common.state.MapStateDescriptor -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.java.typeutils.ListTypeInfo -import java.util.{ArrayList, List => JList} - -import org.apache.flink.api.common.typeinfo.BasicTypeInfo -import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler} -import org.slf4j.LoggerFactory - -/** - * Process Function used for the aggregate in bounded proc-time OVER window - * [[org.apache.flink.streaming.api.datastream.DataStream]] - * - * @param genAggregations Generated aggregate helper function - * @param precedingTimeBoundary Is used to indicate the processing time boundaries - * @param aggregatesTypeInfo row type info of aggregation - * @param inputType row type info of input row - */ -class BoundedProcessingOverRangeProcessFunction( - genAggregations: GeneratedAggregationsFunction, - precedingTimeBoundary: Long, - aggregatesTypeInfo: RowTypeInfo, - inputType: TypeInformation[Row]) - extends ProcessFunction[Row, Row] - with Compiler[GeneratedAggregations] { - - private var output: Row = _ - private var accumulatorState: ValueState[Row] = _ - private var rowMapState: MapState[Long, JList[Row]] = _ - - val LOG = LoggerFactory.getLogger(this.getClass) - private var function: GeneratedAggregations = _ - - override def open(config: Configuration) { - LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + - s"Code:\n$genAggregations.code") - val clazz = compile( - getRuntimeContext.getUserCodeClassLoader, - genAggregations.name, - genAggregations.code) - LOG.debug("Instantiating AggregateHelper.") - function = clazz.newInstance() - output = function.createOutputRow() - - // We keep the elements received in a MapState indexed based on their ingestion time - val rowListTypeInfo: TypeInformation[JList[Row]] = - new ListTypeInfo[Row](inputType).asInstanceOf[TypeInformation[JList[Row]]] - val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] = - new MapStateDescriptor[Long, JList[Row]]("rowmapstate", - BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo) - rowMapState = getRuntimeContext.getMapState(mapStateDescriptor) - - val stateDescriptor: ValueStateDescriptor[Row] = - new ValueStateDescriptor[Row]("overState", aggregatesTypeInfo) - accumulatorState = getRuntimeContext.getState(stateDescriptor) - } - - override def processElement( - input: Row, - ctx: ProcessFunction[Row, Row]#Context, - out: Collector[Row]): Unit = { - - val currentTime = ctx.timerService.currentProcessingTime - // buffer the event incoming event - - // add current element to the window list of elements with corresponding timestamp - var rowList = rowMapState.get(currentTime) - // null value means that this si the first event received for this timestamp - if (rowList == null) { - rowList = new ArrayList[Row]() - // register timer to process event once the current millisecond passed - ctx.timerService.registerProcessingTimeTimer(currentTime + 1) - } - rowList.add(input) - rowMapState.put(currentTime, rowList) - - } - - override def onTimer( - timestamp: Long, - ctx: ProcessFunction[Row, Row]#OnTimerContext, - out: Collector[Row]): Unit = { - - // we consider the original timestamp of events that have registered this time trigger 1 ms ago - val currentTime = timestamp - 1 - var i = 0 - - // initialize the accumulators - var accumulators = accumulatorState.value() - - if (null == accumulators) { - accumulators = function.createAccumulators() - } - - // update the elements to be removed and retract them from aggregators - val limit = currentTime - precedingTimeBoundary - - // we iterate through all elements in the window buffer based on timestamp keys - // when we find timestamps that are out of interest, we retrieve corresponding elements - // and eliminate them. Multiple elements could have been received at the same timestamp - // the removal of old elements happens only once per proctime as onTimer is called only once - val iter = rowMapState.keys.iterator - val markToRemove = new ArrayList[Long]() - while (iter.hasNext) { - val elementKey = iter.next - if (elementKey < limit) { - // element key outside of window. Retract values - val elementsRemove = rowMapState.get(elementKey) - var iRemove = 0 - while (iRemove < elementsRemove.size()) { - val retractRow = elementsRemove.get(iRemove) - function.retract(accumulators, retractRow) - iRemove += 1 - } - // mark element for later removal not to modify the iterator over MapState - markToRemove.add(elementKey) - } - } - // need to remove in 2 steps not to have concurrent access errors via iterator to the MapState - i = 0 - while (i < markToRemove.size()) { - rowMapState.remove(markToRemove.get(i)) - i += 1 - } - - // get the list of elements of current proctime - val currentElements = rowMapState.get(currentTime) - // add current elements to aggregator. Multiple elements might have arrived in the same proctime - // the same accumulator value will be computed for all elements - var iElemenets = 0 - while (iElemenets < currentElements.size()) { - val input = currentElements.get(iElemenets) - function.accumulate(accumulators, input) - iElemenets += 1 - } - - // we need to build the output and emit for every event received at this proctime - iElemenets = 0 - while (iElemenets < currentElements.size()) { - val input = currentElements.get(iElemenets) - - // set the fields of the last event to carry on with the aggregates - function.setForwardedFields(input, output) - - // add the accumulators values to result - function.setAggregationResults(accumulators, output) - out.collect(output) - iElemenets += 1 - } - - // update the value of accumulators for future incremental computation - accumulatorState.update(accumulators) - - } - -} http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRowProcessFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRowProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRowProcessFunction.scala deleted file mode 100644 index d5ee4ae..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRowProcessFunction.scala +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.runtime.aggregate - -import java.util - -import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.functions.ProcessFunction -import org.apache.flink.types.Row -import org.apache.flink.util.{Collector, Preconditions} -import org.apache.flink.api.common.state.ValueStateDescriptor -import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.api.common.state.ValueState -import org.apache.flink.api.common.state.MapState -import org.apache.flink.api.common.state.MapStateDescriptor -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.java.typeutils.ListTypeInfo -import java.util.{List => JList} - -import org.apache.flink.api.common.typeinfo.BasicTypeInfo -import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler} -import org.slf4j.LoggerFactory - -/** - * Process Function for ROW clause processing-time bounded OVER window - * - * @param genAggregations Generated aggregate helper function - * @param precedingOffset preceding offset - * @param aggregatesTypeInfo row type info of aggregation - * @param inputType row type info of input row - */ -class BoundedProcessingOverRowProcessFunction( - genAggregations: GeneratedAggregationsFunction, - precedingOffset: Long, - aggregatesTypeInfo: RowTypeInfo, - inputType: TypeInformation[Row]) - extends ProcessFunction[Row, Row] - with Compiler[GeneratedAggregations] { - - Preconditions.checkArgument(precedingOffset > 0) - - private var accumulatorState: ValueState[Row] = _ - private var rowMapState: MapState[Long, JList[Row]] = _ - private var output: Row = _ - private var counterState: ValueState[Long] = _ - private var smallestTsState: ValueState[Long] = _ - - val LOG = LoggerFactory.getLogger(this.getClass) - private var function: GeneratedAggregations = _ - - override def open(config: Configuration) { - LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + - s"Code:\n$genAggregations.code") - val clazz = compile( - getRuntimeContext.getUserCodeClassLoader, - genAggregations.name, - genAggregations.code) - LOG.debug("Instantiating AggregateHelper.") - function = clazz.newInstance() - - output = function.createOutputRow() - // We keep the elements received in a Map state keyed - // by the ingestion time in the operator. - // we also keep counter of processed elements - // and timestamp of oldest element - val rowListTypeInfo: TypeInformation[JList[Row]] = - new ListTypeInfo[Row](inputType).asInstanceOf[TypeInformation[JList[Row]]] - - val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] = - new MapStateDescriptor[Long, JList[Row]]("windowBufferMapState", - BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo) - rowMapState = getRuntimeContext.getMapState(mapStateDescriptor) - - val aggregationStateDescriptor: ValueStateDescriptor[Row] = - new ValueStateDescriptor[Row]("aggregationState", aggregatesTypeInfo) - accumulatorState = getRuntimeContext.getState(aggregationStateDescriptor) - - val processedCountDescriptor : ValueStateDescriptor[Long] = - new ValueStateDescriptor[Long]("processedCountState", classOf[Long]) - counterState = getRuntimeContext.getState(processedCountDescriptor) - - val smallestTimestampDescriptor : ValueStateDescriptor[Long] = - new ValueStateDescriptor[Long]("smallestTSState", classOf[Long]) - smallestTsState = getRuntimeContext.getState(smallestTimestampDescriptor) - } - - override def processElement( - input: Row, - ctx: ProcessFunction[Row, Row]#Context, - out: Collector[Row]): Unit = { - - val currentTime = ctx.timerService.currentProcessingTime - - // initialize state for the processed element - var accumulators = accumulatorState.value - if (accumulators == null) { - accumulators = function.createAccumulators() - } - - // get smallest timestamp - var smallestTs = smallestTsState.value - if (smallestTs == 0L) { - smallestTs = currentTime - smallestTsState.update(smallestTs) - } - // get previous counter value - var counter = counterState.value - - if (counter == precedingOffset) { - val retractList = rowMapState.get(smallestTs) - - // get oldest element beyond buffer size - // and if oldest element exist, retract value - val retractRow = retractList.get(0) - function.retract(accumulators, retractRow) - retractList.remove(0) - - // if reference timestamp list not empty, keep the list - if (!retractList.isEmpty) { - rowMapState.put(smallestTs, retractList) - } // if smallest timestamp list is empty, remove and find new smallest - else { - rowMapState.remove(smallestTs) - val iter = rowMapState.keys.iterator - var currentTs: Long = 0L - var newSmallestTs: Long = Long.MaxValue - while (iter.hasNext) { - currentTs = iter.next - if (currentTs < newSmallestTs) { - newSmallestTs = currentTs - } - } - smallestTsState.update(newSmallestTs) - } - } // we update the counter only while buffer is getting filled - else { - counter += 1 - counterState.update(counter) - } - - // copy forwarded fields in output row - function.setForwardedFields(input, output) - - // accumulate current row and set aggregate in output row - function.accumulate(accumulators, input) - function.setAggregationResults(accumulators, output) - - // update map state, accumulator state, counter and timestamp - val currentTimeState = rowMapState.get(currentTime) - if (currentTimeState != null) { - currentTimeState.add(input) - rowMapState.put(currentTime, currentTimeState) - } else { // add new input - val newList = new util.ArrayList[Row] - newList.add(input) - rowMapState.put(currentTime, newList) - } - - accumulatorState.update(accumulators) - - out.collect(output) - } - -} http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala new file mode 100644 index 0000000..7f87e50 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.types.Row +import org.apache.flink.util.Collector +import org.apache.flink.api.common.state.ValueState +import org.apache.flink.api.common.state.ValueStateDescriptor +import org.apache.flink.api.common.state.MapState +import org.apache.flink.api.common.state.MapStateDescriptor +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.ListTypeInfo +import java.util.{ArrayList, List => JList} + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler} +import org.slf4j.LoggerFactory + +/** + * Process Function used for the aggregate in bounded proc-time OVER window + * [[org.apache.flink.streaming.api.datastream.DataStream]] + * + * @param genAggregations Generated aggregate helper function + * @param precedingTimeBoundary Is used to indicate the processing time boundaries + * @param aggregatesTypeInfo row type info of aggregation + * @param inputType row type info of input row + */ +class ProcTimeBoundedRangeOver( + genAggregations: GeneratedAggregationsFunction, + precedingTimeBoundary: Long, + aggregatesTypeInfo: RowTypeInfo, + inputType: TypeInformation[Row]) + extends ProcessFunction[Row, Row] + with Compiler[GeneratedAggregations] { + private var output: Row = _ + private var accumulatorState: ValueState[Row] = _ + private var rowMapState: MapState[Long, JList[Row]] = _ + + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ + + override def open(config: Configuration) { + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getRuntimeContext.getUserCodeClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + output = function.createOutputRow() + + // We keep the elements received in a MapState indexed based on their ingestion time + val rowListTypeInfo: TypeInformation[JList[Row]] = + new ListTypeInfo[Row](inputType).asInstanceOf[TypeInformation[JList[Row]]] + val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] = + new MapStateDescriptor[Long, JList[Row]]("rowmapstate", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo) + rowMapState = getRuntimeContext.getMapState(mapStateDescriptor) + + val stateDescriptor: ValueStateDescriptor[Row] = + new ValueStateDescriptor[Row]("overState", aggregatesTypeInfo) + accumulatorState = getRuntimeContext.getState(stateDescriptor) + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + val currentTime = ctx.timerService.currentProcessingTime + // buffer the event incoming event + + // add current element to the window list of elements with corresponding timestamp + var rowList = rowMapState.get(currentTime) + // null value means that this si the first event received for this timestamp + if (rowList == null) { + rowList = new ArrayList[Row]() + // register timer to process event once the current millisecond passed + ctx.timerService.registerProcessingTimeTimer(currentTime + 1) + } + rowList.add(input) + rowMapState.put(currentTime, rowList) + + } + + override def onTimer( + timestamp: Long, + ctx: ProcessFunction[Row, Row]#OnTimerContext, + out: Collector[Row]): Unit = { + + // we consider the original timestamp of events that have registered this time trigger 1 ms ago + val currentTime = timestamp - 1 + var i = 0 + + // initialize the accumulators + var accumulators = accumulatorState.value() + + if (null == accumulators) { + accumulators = function.createAccumulators() + } + + // update the elements to be removed and retract them from aggregators + val limit = currentTime - precedingTimeBoundary + + // we iterate through all elements in the window buffer based on timestamp keys + // when we find timestamps that are out of interest, we retrieve corresponding elements + // and eliminate them. Multiple elements could have been received at the same timestamp + // the removal of old elements happens only once per proctime as onTimer is called only once + val iter = rowMapState.keys.iterator + val markToRemove = new ArrayList[Long]() + while (iter.hasNext) { + val elementKey = iter.next + if (elementKey < limit) { + // element key outside of window. Retract values + val elementsRemove = rowMapState.get(elementKey) + var iRemove = 0 + while (iRemove < elementsRemove.size()) { + val retractRow = elementsRemove.get(iRemove) + function.retract(accumulators, retractRow) + iRemove += 1 + } + // mark element for later removal not to modify the iterator over MapState + markToRemove.add(elementKey) + } + } + // need to remove in 2 steps not to have concurrent access errors via iterator to the MapState + i = 0 + while (i < markToRemove.size()) { + rowMapState.remove(markToRemove.get(i)) + i += 1 + } + + // get the list of elements of current proctime + val currentElements = rowMapState.get(currentTime) + // add current elements to aggregator. Multiple elements might have arrived in the same proctime + // the same accumulator value will be computed for all elements + var iElemenets = 0 + while (iElemenets < currentElements.size()) { + val input = currentElements.get(iElemenets) + function.accumulate(accumulators, input) + iElemenets += 1 + } + + // we need to build the output and emit for every event received at this proctime + iElemenets = 0 + while (iElemenets < currentElements.size()) { + val input = currentElements.get(iElemenets) + + // set the fields of the last event to carry on with the aggregates + function.setForwardedFields(input, output) + + // add the accumulators values to result + function.setAggregationResults(accumulators, output) + out.collect(output) + iElemenets += 1 + } + + // update the value of accumulators for future incremental computation + accumulatorState.update(accumulators) + + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala new file mode 100644 index 0000000..31cfd73 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.util + +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.types.Row +import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.api.common.state.ValueStateDescriptor +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.common.state.ValueState +import org.apache.flink.api.common.state.MapState +import org.apache.flink.api.common.state.MapStateDescriptor +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.ListTypeInfo +import java.util.{List => JList} + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler} +import org.slf4j.LoggerFactory + +/** + * Process Function for ROW clause processing-time bounded OVER window + * + * @param genAggregations Generated aggregate helper function + * @param precedingOffset preceding offset + * @param aggregatesTypeInfo row type info of aggregation + * @param inputType row type info of input row + */ +class ProcTimeBoundedRowsOver( + genAggregations: GeneratedAggregationsFunction, + precedingOffset: Long, + aggregatesTypeInfo: RowTypeInfo, + inputType: TypeInformation[Row]) + extends ProcessFunction[Row, Row] + with Compiler[GeneratedAggregations] { + + Preconditions.checkArgument(precedingOffset > 0) + + private var accumulatorState: ValueState[Row] = _ + private var rowMapState: MapState[Long, JList[Row]] = _ + private var output: Row = _ + private var counterState: ValueState[Long] = _ + private var smallestTsState: ValueState[Long] = _ + + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ + + override def open(config: Configuration) { + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getRuntimeContext.getUserCodeClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + // We keep the elements received in a Map state keyed + // by the ingestion time in the operator. + // we also keep counter of processed elements + // and timestamp of oldest element + val rowListTypeInfo: TypeInformation[JList[Row]] = + new ListTypeInfo[Row](inputType).asInstanceOf[TypeInformation[JList[Row]]] + + val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] = + new MapStateDescriptor[Long, JList[Row]]("windowBufferMapState", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo) + rowMapState = getRuntimeContext.getMapState(mapStateDescriptor) + + val aggregationStateDescriptor: ValueStateDescriptor[Row] = + new ValueStateDescriptor[Row]("aggregationState", aggregatesTypeInfo) + accumulatorState = getRuntimeContext.getState(aggregationStateDescriptor) + + val processedCountDescriptor : ValueStateDescriptor[Long] = + new ValueStateDescriptor[Long]("processedCountState", classOf[Long]) + counterState = getRuntimeContext.getState(processedCountDescriptor) + + val smallestTimestampDescriptor : ValueStateDescriptor[Long] = + new ValueStateDescriptor[Long]("smallestTSState", classOf[Long]) + smallestTsState = getRuntimeContext.getState(smallestTimestampDescriptor) + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + val currentTime = ctx.timerService.currentProcessingTime + + // initialize state for the processed element + var accumulators = accumulatorState.value + if (accumulators == null) { + accumulators = function.createAccumulators() + } + + // get smallest timestamp + var smallestTs = smallestTsState.value + if (smallestTs == 0L) { + smallestTs = currentTime + smallestTsState.update(smallestTs) + } + // get previous counter value + var counter = counterState.value + + if (counter == precedingOffset) { + val retractList = rowMapState.get(smallestTs) + + // get oldest element beyond buffer size + // and if oldest element exist, retract value + val retractRow = retractList.get(0) + function.retract(accumulators, retractRow) + retractList.remove(0) + + // if reference timestamp list not empty, keep the list + if (!retractList.isEmpty) { + rowMapState.put(smallestTs, retractList) + } // if smallest timestamp list is empty, remove and find new smallest + else { + rowMapState.remove(smallestTs) + val iter = rowMapState.keys.iterator + var currentTs: Long = 0L + var newSmallestTs: Long = Long.MaxValue + while (iter.hasNext) { + currentTs = iter.next + if (currentTs < newSmallestTs) { + newSmallestTs = currentTs + } + } + smallestTsState.update(newSmallestTs) + } + } // we update the counter only while buffer is getting filled + else { + counter += 1 + counterState.update(counter) + } + + // copy forwarded fields in output row + function.setForwardedFields(input, output) + + // accumulate current row and set aggregate in output row + function.accumulate(accumulators, input) + function.setAggregationResults(accumulators, output) + + // update map state, accumulator state, counter and timestamp + val currentTimeState = rowMapState.get(currentTime) + if (currentTimeState != null) { + currentTimeState.add(input) + rowMapState.put(currentTime, currentTimeState) + } else { // add new input + val newList = new util.ArrayList[Row] + newList.add(input) + rowMapState.put(currentTime, newList) + } + + accumulatorState.update(accumulators) + + out.collect(output) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala new file mode 100644 index 0000000..6b9800b --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import org.apache.flink.api.common.state.{ListState, ListStateDescriptor} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.configuration.Configuration +import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext} +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.types.Row +import org.apache.flink.util.Collector +import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler} +import org.slf4j.LoggerFactory + +/** + * Process Function for non-partitioned processing-time unbounded OVER window + * + * @param genAggregations Generated aggregate helper function + * @param aggregationStateType row type info of aggregation + */ +class ProcTimeUnboundedNonPartitionedOver( + genAggregations: GeneratedAggregationsFunction, + aggregationStateType: RowTypeInfo) + extends ProcessFunction[Row, Row] + with CheckpointedFunction + with Compiler[GeneratedAggregations] { + + private var accumulators: Row = _ + private var output: Row = _ + private var state: ListState[Row] = _ + val LOG = LoggerFactory.getLogger(this.getClass) + + private var function: GeneratedAggregations = _ + + override def open(config: Configuration) { + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getRuntimeContext.getUserCodeClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + if (null == accumulators) { + val it = state.get().iterator() + if (it.hasNext) { + accumulators = it.next() + } else { + accumulators = function.createAccumulators() + } + } + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + function.setForwardedFields(input, output) + + function.accumulate(accumulators, input) + function.setAggregationResults(accumulators, output) + + out.collect(output) + } + + override def snapshotState(context: FunctionSnapshotContext): Unit = { + state.clear() + if (null != accumulators) { + state.add(accumulators) + } + } + + override def initializeState(context: FunctionInitializationContext): Unit = { + val accumulatorsDescriptor = new ListStateDescriptor[Row]("overState", aggregationStateType) + state = context.getOperatorStateStore.getOperatorState(accumulatorsDescriptor) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala new file mode 100644 index 0000000..9baa6a3 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.types.Row +import org.apache.flink.util.Collector +import org.apache.flink.api.common.state.ValueStateDescriptor +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.common.state.ValueState +import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler} +import org.slf4j.LoggerFactory + +/** + * Process Function for processing-time unbounded OVER window + * + * @param genAggregations Generated aggregate helper function + * @param aggregationStateType row type info of aggregation + */ +class ProcTimeUnboundedPartitionedOver( + genAggregations: GeneratedAggregationsFunction, + aggregationStateType: RowTypeInfo) + extends ProcessFunction[Row, Row] + with Compiler[GeneratedAggregations] { + + private var output: Row = _ + private var state: ValueState[Row] = _ + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ + + override def open(config: Configuration) { + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getRuntimeContext.getUserCodeClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + val stateDescriptor: ValueStateDescriptor[Row] = + new ValueStateDescriptor[Row]("overState", aggregationStateType) + state = getRuntimeContext.getState(stateDescriptor) + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + var accumulators = state.value() + + if (null == accumulators) { + accumulators = function.createAccumulators() + } + + function.setForwardedFields(input, output) + + function.accumulate(accumulators, input) + function.setAggregationResults(accumulators, output) + + state.update(accumulators) + + out.collect(output) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RangeClauseBoundedOverProcessFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RangeClauseBoundedOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RangeClauseBoundedOverProcessFunction.scala deleted file mode 100644 index 0f1ef49..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RangeClauseBoundedOverProcessFunction.scala +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.runtime.aggregate - -import java.util.{List => JList, ArrayList => JArrayList} - -import org.apache.flink.api.common.state._ -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo} -import org.apache.flink.configuration.Configuration -import org.apache.flink.streaming.api.functions.ProcessFunction -import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler} -import org.apache.flink.types.Row -import org.apache.flink.util.{Collector, Preconditions} -import org.slf4j.LoggerFactory - -/** - * Process Function for RANGE clause event-time bounded OVER window - * - * @param genAggregations Generated aggregate helper function - * @param aggregationStateType row type info of aggregation - * @param inputRowType row type info of input row - * @param precedingOffset preceding offset - */ -class RangeClauseBoundedOverProcessFunction( - genAggregations: GeneratedAggregationsFunction, - aggregationStateType: RowTypeInfo, - inputRowType: RowTypeInfo, - precedingOffset: Long) - extends ProcessFunction[Row, Row] - with Compiler[GeneratedAggregations] { - - Preconditions.checkNotNull(aggregationStateType) - Preconditions.checkNotNull(precedingOffset) - - private var output: Row = _ - - // the state which keeps the last triggering timestamp - private var lastTriggeringTsState: ValueState[Long] = _ - - // the state which used to materialize the accumulator for incremental calculation - private var accumulatorState: ValueState[Row] = _ - - // the state which keeps all the data that are not expired. - // The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp, - // the second element of tuple is a list that contains the entire data of all the rows belonging - // to this time stamp. - private var dataState: MapState[Long, JList[Row]] = _ - - val LOG = LoggerFactory.getLogger(this.getClass) - private var function: GeneratedAggregations = _ - - override def open(config: Configuration) { - LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + - s"Code:\n$genAggregations.code") - val clazz = compile( - getRuntimeContext.getUserCodeClassLoader, - genAggregations.name, - genAggregations.code) - LOG.debug("Instantiating AggregateHelper.") - function = clazz.newInstance() - - output = function.createOutputRow() - - val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] = - new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long]) - lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor) - - val accumulatorStateDescriptor = - new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType) - accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor) - - val keyTypeInformation: TypeInformation[Long] = - BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]] - val valueTypeInformation: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputRowType) - - val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] = - new MapStateDescriptor[Long, JList[Row]]( - "dataState", - keyTypeInformation, - valueTypeInformation) - - dataState = getRuntimeContext.getMapState(mapStateDescriptor) - } - - override def processElement( - input: Row, - ctx: ProcessFunction[Row, Row]#Context, - out: Collector[Row]): Unit = { - - // triggering timestamp for trigger calculation - val triggeringTs = ctx.timestamp - - val lastTriggeringTs = lastTriggeringTsState.value - - // check if the data is expired, if not, save the data and register event time timer - if (triggeringTs > lastTriggeringTs) { - val data = dataState.get(triggeringTs) - if (null != data) { - data.add(input) - dataState.put(triggeringTs, data) - } else { - val data = new JArrayList[Row] - data.add(input) - dataState.put(triggeringTs, data) - // register event time timer - ctx.timerService.registerEventTimeTimer(triggeringTs) - } - } - } - - override def onTimer( - timestamp: Long, - ctx: ProcessFunction[Row, Row]#OnTimerContext, - out: Collector[Row]): Unit = { - // gets all window data from state for the calculation - val inputs: JList[Row] = dataState.get(timestamp) - - if (null != inputs) { - - var accumulators = accumulatorState.value - var dataListIndex = 0 - var aggregatesIndex = 0 - - // initialize when first run or failover recovery per key - if (null == accumulators) { - accumulators = function.createAccumulators() - aggregatesIndex = 0 - } - - // keep up timestamps of retract data - val retractTsList: JList[Long] = new JArrayList[Long] - - // do retraction - val dataTimestampIt = dataState.keys.iterator - while (dataTimestampIt.hasNext) { - val dataTs: Long = dataTimestampIt.next() - val offset = timestamp - dataTs - if (offset > precedingOffset) { - val retractDataList = dataState.get(dataTs) - dataListIndex = 0 - while (dataListIndex < retractDataList.size()) { - val retractRow = retractDataList.get(dataListIndex) - function.retract(accumulators, retractRow) - dataListIndex += 1 - } - retractTsList.add(dataTs) - } - } - - // do accumulation - dataListIndex = 0 - while (dataListIndex < inputs.size()) { - val curRow = inputs.get(dataListIndex) - // accumulate current row - function.accumulate(accumulators, curRow) - dataListIndex += 1 - } - - // set aggregate in output row - function.setAggregationResults(accumulators, output) - - // copy forwarded fields to output row and emit output row - dataListIndex = 0 - while (dataListIndex < inputs.size()) { - aggregatesIndex = 0 - function.setForwardedFields(inputs.get(dataListIndex), output) - out.collect(output) - dataListIndex += 1 - } - - // remove the data that has been retracted - dataListIndex = 0 - while (dataListIndex < retractTsList.size) { - dataState.remove(retractTsList.get(dataListIndex)) - dataListIndex += 1 - } - - // update state - accumulatorState.update(accumulators) - lastTriggeringTsState.update(timestamp) - } - } -} - - http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala new file mode 100644 index 0000000..03ca02c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.util.{List => JList, ArrayList => JArrayList} + +import org.apache.flink.api.common.state._ +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo} +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler} +import org.apache.flink.types.Row +import org.apache.flink.util.{Collector, Preconditions} +import org.slf4j.LoggerFactory + +/** + * Process Function for RANGE clause event-time bounded OVER window + * + * @param genAggregations Generated aggregate helper function + * @param aggregationStateType row type info of aggregation + * @param inputRowType row type info of input row + * @param precedingOffset preceding offset + */ +class RowTimeBoundedRangeOver( + genAggregations: GeneratedAggregationsFunction, + aggregationStateType: RowTypeInfo, + inputRowType: RowTypeInfo, + precedingOffset: Long) + extends ProcessFunction[Row, Row] + with Compiler[GeneratedAggregations] { + Preconditions.checkNotNull(aggregationStateType) + Preconditions.checkNotNull(precedingOffset) + + private var output: Row = _ + + // the state which keeps the last triggering timestamp + private var lastTriggeringTsState: ValueState[Long] = _ + + // the state which used to materialize the accumulator for incremental calculation + private var accumulatorState: ValueState[Row] = _ + + // the state which keeps all the data that are not expired. + // The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp, + // the second element of tuple is a list that contains the entire data of all the rows belonging + // to this time stamp. + private var dataState: MapState[Long, JList[Row]] = _ + + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ + + override def open(config: Configuration) { + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getRuntimeContext.getUserCodeClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + + val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] = + new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long]) + lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor) + + val accumulatorStateDescriptor = + new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType) + accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor) + + val keyTypeInformation: TypeInformation[Long] = + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]] + val valueTypeInformation: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputRowType) + + val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] = + new MapStateDescriptor[Long, JList[Row]]( + "dataState", + keyTypeInformation, + valueTypeInformation) + + dataState = getRuntimeContext.getMapState(mapStateDescriptor) + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + // triggering timestamp for trigger calculation + val triggeringTs = ctx.timestamp + + val lastTriggeringTs = lastTriggeringTsState.value + + // check if the data is expired, if not, save the data and register event time timer + if (triggeringTs > lastTriggeringTs) { + val data = dataState.get(triggeringTs) + if (null != data) { + data.add(input) + dataState.put(triggeringTs, data) + } else { + val data = new JArrayList[Row] + data.add(input) + dataState.put(triggeringTs, data) + // register event time timer + ctx.timerService.registerEventTimeTimer(triggeringTs) + } + } + } + + override def onTimer( + timestamp: Long, + ctx: ProcessFunction[Row, Row]#OnTimerContext, + out: Collector[Row]): Unit = { + // gets all window data from state for the calculation + val inputs: JList[Row] = dataState.get(timestamp) + + if (null != inputs) { + + var accumulators = accumulatorState.value + var dataListIndex = 0 + var aggregatesIndex = 0 + + // initialize when first run or failover recovery per key + if (null == accumulators) { + accumulators = function.createAccumulators() + aggregatesIndex = 0 + } + + // keep up timestamps of retract data + val retractTsList: JList[Long] = new JArrayList[Long] + + // do retraction + val dataTimestampIt = dataState.keys.iterator + while (dataTimestampIt.hasNext) { + val dataTs: Long = dataTimestampIt.next() + val offset = timestamp - dataTs + if (offset > precedingOffset) { + val retractDataList = dataState.get(dataTs) + dataListIndex = 0 + while (dataListIndex < retractDataList.size()) { + val retractRow = retractDataList.get(dataListIndex) + function.retract(accumulators, retractRow) + dataListIndex += 1 + } + retractTsList.add(dataTs) + } + } + + // do accumulation + dataListIndex = 0 + while (dataListIndex < inputs.size()) { + val curRow = inputs.get(dataListIndex) + // accumulate current row + function.accumulate(accumulators, curRow) + dataListIndex += 1 + } + + // set aggregate in output row + function.setAggregationResults(accumulators, output) + + // copy forwarded fields to output row and emit output row + dataListIndex = 0 + while (dataListIndex < inputs.size()) { + aggregatesIndex = 0 + function.setForwardedFields(inputs.get(dataListIndex), output) + out.collect(output) + dataListIndex += 1 + } + + // remove the data that has been retracted + dataListIndex = 0 + while (dataListIndex < retractTsList.size) { + dataState.remove(retractTsList.get(dataListIndex)) + dataListIndex += 1 + } + + // update state + accumulatorState.update(accumulators) + lastTriggeringTsState.update(timestamp) + } + } +} + + http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala new file mode 100644 index 0000000..4a9a14c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.util +import java.util.{List => JList} + +import org.apache.flink.api.common.state._ +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo} +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.types.Row +import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler} +import org.slf4j.LoggerFactory + +/** + * Process Function for ROWS clause event-time bounded OVER window + * + * @param genAggregations Generated aggregate helper function + * @param aggregationStateType row type info of aggregation + * @param inputRowType row type info of input row + * @param precedingOffset preceding offset + */ +class RowTimeBoundedRowsOver( + genAggregations: GeneratedAggregationsFunction, + aggregationStateType: RowTypeInfo, + inputRowType: RowTypeInfo, + precedingOffset: Long) + extends ProcessFunction[Row, Row] + with Compiler[GeneratedAggregations] { + + Preconditions.checkNotNull(aggregationStateType) + Preconditions.checkNotNull(precedingOffset) + + private var output: Row = _ + + // the state which keeps the last triggering timestamp + private var lastTriggeringTsState: ValueState[Long] = _ + + // the state which keeps the count of data + private var dataCountState: ValueState[Long] = _ + + // the state which used to materialize the accumulator for incremental calculation + private var accumulatorState: ValueState[Row] = _ + + // the state which keeps all the data that are not expired. + // The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp, + // the second element of tuple is a list that contains the entire data of all the rows belonging + // to this time stamp. + private var dataState: MapState[Long, JList[Row]] = _ + + val LOG = LoggerFactory.getLogger(this.getClass) + private var function: GeneratedAggregations = _ + + override def open(config: Configuration) { + LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " + + s"Code:\n$genAggregations.code") + val clazz = compile( + getRuntimeContext.getUserCodeClassLoader, + genAggregations.name, + genAggregations.code) + LOG.debug("Instantiating AggregateHelper.") + function = clazz.newInstance() + + output = function.createOutputRow() + + val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] = + new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long]) + lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor) + + val dataCountStateDescriptor = + new ValueStateDescriptor[Long]("dataCountState", classOf[Long]) + dataCountState = getRuntimeContext.getState(dataCountStateDescriptor) + + val accumulatorStateDescriptor = + new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType) + accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor) + + val keyTypeInformation: TypeInformation[Long] = + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]] + val valueTypeInformation: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputRowType) + + val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] = + new MapStateDescriptor[Long, JList[Row]]( + "dataState", + keyTypeInformation, + valueTypeInformation) + + dataState = getRuntimeContext.getMapState(mapStateDescriptor) + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + // triggering timestamp for trigger calculation + val triggeringTs = ctx.timestamp + + val lastTriggeringTs = lastTriggeringTsState.value + // check if the data is expired, if not, save the data and register event time timer + + if (triggeringTs > lastTriggeringTs) { + val data = dataState.get(triggeringTs) + if (null != data) { + data.add(input) + dataState.put(triggeringTs, data) + } else { + val data = new util.ArrayList[Row] + data.add(input) + dataState.put(triggeringTs, data) + // register event time timer + ctx.timerService.registerEventTimeTimer(triggeringTs) + } + } + } + + override def onTimer( + timestamp: Long, + ctx: ProcessFunction[Row, Row]#OnTimerContext, + out: Collector[Row]): Unit = { + + // gets all window data from state for the calculation + val inputs: JList[Row] = dataState.get(timestamp) + + if (null != inputs) { + + var accumulators = accumulatorState.value + var dataCount = dataCountState.value + + var retractList: JList[Row] = null + var retractTs: Long = Long.MaxValue + var retractCnt: Int = 0 + var i = 0 + + while (i < inputs.size) { + val input = inputs.get(i) + + // initialize when first run or failover recovery per key + if (null == accumulators) { + accumulators = function.createAccumulators() + } + + var retractRow: Row = null + + if (dataCount >= precedingOffset) { + if (null == retractList) { + // find the smallest timestamp + retractTs = Long.MaxValue + val dataTimestampIt = dataState.keys.iterator + while (dataTimestampIt.hasNext) { + val dataTs = dataTimestampIt.next + if (dataTs < retractTs) { + retractTs = dataTs + } + } + // get the oldest rows to retract them + retractList = dataState.get(retractTs) + } + + retractRow = retractList.get(retractCnt) + retractCnt += 1 + + // remove retracted values from state + if (retractList.size == retractCnt) { + dataState.remove(retractTs) + retractList = null + retractCnt = 0 + } + } else { + dataCount += 1 + } + + // copy forwarded fields to output row + function.setForwardedFields(input, output) + + // retract old row from accumulators + if (null != retractRow) { + function.retract(accumulators, retractRow) + } + + // accumulate current row and set aggregate in output row + function.accumulate(accumulators, input) + function.setAggregationResults(accumulators, output) + i += 1 + + out.collect(output) + } + + // update all states + if (dataState.contains(retractTs)) { + if (retractCnt > 0) { + retractList.subList(0, retractCnt).clear() + dataState.put(retractTs, retractList) + } + } + dataCountState.update(dataCount) + accumulatorState.update(accumulators) + } + + lastTriggeringTsState.update(timestamp) + } +} + +