[FLINK-5990] [table] Add event-time OVER ROWS BETWEEN x PRECEDING aggregation to SQL.
This closes #3585. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/7a9d39fe Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/7a9d39fe Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/7a9d39fe Branch: refs/heads/master Commit: 7a9d39fe9f659d43bf4719a2981f6c4771ffbe48 Parents: 6949c8c Author: é竹 <jincheng.su...@alibaba-inc.com> Authored: Sun Mar 19 23:31:00 2017 +0800 Committer: Fabian Hueske <fhue...@apache.org> Committed: Fri Mar 24 20:19:17 2017 +0100 ---------------------------------------------------------------------- .../flink/table/plan/nodes/OverAggregate.scala | 31 ++- .../datastream/DataStreamOverAggregate.scala | 149 +++++++++--- .../table/runtime/aggregate/AggregateUtil.scala | 48 +++- .../RowsClauseBoundedOverProcessFunction.scala | 239 +++++++++++++++++++ .../table/api/scala/stream/sql/SqlITCase.scala | 139 ++++++++++- .../scala/stream/sql/WindowAggregateTest.scala | 55 +++++ 6 files changed, 623 insertions(+), 38 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala index 793ab23..91c8cef 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala @@ -18,12 +18,15 @@ package org.apache.flink.table.plan.nodes -import org.apache.calcite.rel.RelFieldCollation +import org.apache.calcite.rel.{RelFieldCollation, RelNode} import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFieldImpl} import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.core.Window.Group +import org.apache.calcite.rel.core.Window +import org.apache.calcite.rex.{RexInputRef} import org.apache.flink.table.runtime.aggregate.AggregateUtil._ import org.apache.flink.table.functions.{ProcTimeType, RowTimeType} + import scala.collection.JavaConverters._ trait OverAggregate { @@ -46,8 +49,16 @@ trait OverAggregate { orderingString } - private[flink] def windowRange(overWindow: Group): String = { - s"BETWEEN ${overWindow.lowerBound} AND ${overWindow.upperBound}" + private[flink] def windowRange( + logicWindow: Window, + overWindow: Group, + input: RelNode): String = { + if (overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded) { + s"BETWEEN ${getLowerBoundary(logicWindow, overWindow, input)} PRECEDING " + + s"AND ${overWindow.upperBound}" + } else { + s"BETWEEN ${overWindow.lowerBound} AND ${overWindow.upperBound}" + } } private[flink] def aggregationToString( @@ -92,4 +103,18 @@ trait OverAggregate { }.mkString(", ") } + private[flink] def getLowerBoundary( + logicWindow: Window, + overWindow: Group, + input: RelNode): Long = { + + val ref: RexInputRef = overWindow.lowerBound.getOffset.asInstanceOf[RexInputRef] + val lowerBoundIndex = input.getRowType.getFieldCount - ref.getIndex; + val lowerBound = logicWindow.constants.get(lowerBoundIndex).getValue2 + lowerBound match { + case x: java.math.BigDecimal => x.asInstanceOf[java.math.BigDecimal].longValue() + case _ => lowerBound.asInstanceOf[Long] + } + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala index 34b3b0f..547c875 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala @@ -32,6 +32,7 @@ import org.apache.calcite.rel.core.Window import org.apache.calcite.rel.core.Window.Group import java.util.{List => JList} +import org.apache.flink.api.java.functions.NullByteKeySelector import org.apache.flink.table.functions.{ProcTimeType, RowTimeType} import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair @@ -70,9 +71,9 @@ class DataStreamOverAggregate( super.explainTerms(pw) .itemIf("partitionBy", partitionToString(inputType, partitionKeys), partitionKeys.nonEmpty) - .item("orderBy",orderingToString(inputType, overWindow.orderKeys.getFieldCollations)) - .itemIf("rows", windowRange(overWindow), overWindow.isRows) - .itemIf("range", windowRange(overWindow), !overWindow.isRows) + .item("orderBy",orderingToString(inputType, overWindow.orderKeys.getFieldCollations)) + .itemIf("rows", windowRange(logicWindow, overWindow, getInput), overWindow.isRows) + .itemIf("range", windowRange(logicWindow, overWindow, getInput), !overWindow.isRows) .item( "select", aggregationToString( inputType, @@ -99,20 +100,58 @@ class DataStreamOverAggregate( .getFieldList .get(overWindow.orderKeys.getFieldCollations.get(0).getFieldIndex) .getValue - timeType match { case _: ProcTimeType => - // both ROWS and RANGE clause with UNBOUNDED PRECEDING and CURRENT ROW condition. - if (overWindow.lowerBound.isUnbounded && - overWindow.upperBound.isCurrentRow) { + // proc-time OVER window + if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) { + // non-bounded OVER window createUnboundedAndCurrentRowProcessingTimeOverWindow(inputDS) + } else if ( + overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded && + overWindow.upperBound.isCurrentRow) { + // bounded OVER window + if (overWindow.isRows) { + // ROWS clause bounded OVER window + throw new TableException( + "ROWS clause bounded proc-time OVER window no supported yet.") + } else { + // RANGE clause bounded OVER window + throw new TableException( + "RANGE clause bounded proc-time OVER window no supported yet.") + } } else { throw new TableException( - "OVER window only support ProcessingTime UNBOUNDED PRECEDING and CURRENT ROW " + - "condition.") + "OVER window only support ProcessingTime UNBOUNDED PRECEDING and CURRENT ROW " + + "condition.") } case _: RowTimeType => - throw new TableException("OVER Window of the EventTime type is not currently supported.") + // row-time OVER window + if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) { + // non-bounded OVER window + if (overWindow.isRows) { + // ROWS clause unbounded OVER window + throw new TableException( + "ROWS clause unbounded row-time OVER window no supported yet.") + } else { + // RANGE clause unbounded OVER window + throw new TableException( + "RANGE clause unbounded row-time OVER window no supported yet.") + } + } else if (overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded && + overWindow.upperBound.isCurrentRow) { + // bounded OVER window + if (overWindow.isRows) { + // ROWS clause bounded OVER window + createRowsClauseBoundedAndCurrentRowOverWindow(inputDS, true) + } else { + // RANGE clause bounded OVER window + throw new TableException( + "RANGE clause bounded row-time OVER window no supported yet.") + } + } else { + throw new TableException( + "row-time OVER window only support CURRENT ROW condition.") + } case _ => throw new TableException(s"Unsupported time type {$timeType}") } @@ -120,7 +159,7 @@ class DataStreamOverAggregate( } def createUnboundedAndCurrentRowProcessingTimeOverWindow( - inputDS: DataStream[Row]): DataStream[Row] = { + inputDS: DataStream[Row]): DataStream[Row] = { val overWindow: Group = logicWindow.groups.get(0) val partitionKeys: Array[Int] = overWindow.keys.toArray @@ -130,32 +169,78 @@ class DataStreamOverAggregate( val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] val result: DataStream[Row] = - // partitioned aggregation - if (partitionKeys.nonEmpty) { - val processFunction = AggregateUtil.CreateUnboundedProcessingOverProcessFunction( - namedAggregates, - inputType) + // partitioned aggregation + if (partitionKeys.nonEmpty) { + val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction( + namedAggregates, + inputType) - inputDS + inputDS .keyBy(partitionKeys: _*) .process(processFunction) .returns(rowTypeInfo) .name(aggOpName) .asInstanceOf[DataStream[Row]] - } - // non-partitioned aggregation - else { - val processFunction = AggregateUtil.CreateUnboundedProcessingOverProcessFunction( - namedAggregates, - inputType, - false) - - inputDS - .process(processFunction).setParallelism(1).setMaxParallelism(1) - .returns(rowTypeInfo) - .name(aggOpName) - .asInstanceOf[DataStream[Row]] - } + } + // non-partitioned aggregation + else { + val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction( + namedAggregates, + inputType, + false) + + inputDS + .process(processFunction).setParallelism(1).setMaxParallelism(1) + .returns(rowTypeInfo) + .name(aggOpName) + .asInstanceOf[DataStream[Row]] + } + result + } + + def createRowsClauseBoundedAndCurrentRowOverWindow( + inputDS: DataStream[Row], + isRowTimeType: Boolean = false): DataStream[Row] = { + + val overWindow: Group = logicWindow.groups.get(0) + val partitionKeys: Array[Int] = overWindow.keys.toArray + val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates + val inputFields = (0 until inputType.getFieldCount).toArray + + val precedingOffset = + getLowerBoundary(logicWindow, overWindow, getInput()) + 1 + + // get the output types + val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] + + val processFunction = AggregateUtil.createRowsClauseBoundedOverProcessFunction( + namedAggregates, + inputType, + inputFields, + precedingOffset, + isRowTimeType + ) + val result: DataStream[Row] = + // partitioned aggregation + if (partitionKeys.nonEmpty) { + inputDS + .keyBy(partitionKeys: _*) + .process(processFunction) + .returns(rowTypeInfo) + .name(aggOpName) + .asInstanceOf[DataStream[Row]] + } + // non-partitioned aggregation + else { + inputDS + .keyBy(new NullByteKeySelector[Row]) + .process(processFunction) + .setParallelism(1) + .setMaxParallelism(1) + .returns(rowTypeInfo) + .name(aggOpName) + .asInstanceOf[DataStream[Row]] + } result } @@ -180,7 +265,7 @@ class DataStreamOverAggregate( } }ORDER BY: ${orderingToString(inputType, overWindow.orderKeys.getFieldCollations)}, " + s"${if (overWindow.isRows) "ROWS" else "RANGE"}" + - s"${windowRange(overWindow)}, " + + s"${windowRange(logicWindow, overWindow, getInput)}, " + s"select: (${ aggregationToString( inputType, http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 9feec17..0084ee5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -61,7 +61,7 @@ object AggregateUtil { * @param isPartitioned Flag to indicate whether the input is partitioned or not * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]] */ - private[flink] def CreateUnboundedProcessingOverProcessFunction( + private[flink] def createUnboundedProcessingOverProcessFunction( namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, isPartitioned: Boolean = true): ProcessFunction[Row, Row] = { @@ -91,6 +91,52 @@ object AggregateUtil { } /** + * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for ROWS clause + * bounded OVER window to evaluate final aggregate value. + * + * @param namedAggregates List of calls to aggregate functions and their output field names + * @param inputType Input row type + * @param inputFields All input fields + * @param precedingOffset the preceding offset + * @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType + * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]] + */ + private[flink] def createRowsClauseBoundedOverProcessFunction( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + inputFields: Array[Int], + precedingOffset: Long, + isRowTimeType: Boolean): ProcessFunction[Row, Row] = { + + val (aggFields, aggregates) = + transformToAggregateFunctions( + namedAggregates.map(_.getKey), + inputType, + needRetraction = true) + + val aggregationStateType: RowTypeInfo = + createDataSetAggregateBufferDataType(Array(), aggregates, inputType) + + val inputRowType: RowTypeInfo = + createDataSetAggregateBufferDataType(inputFields, Array(), inputType) + + val processFunction = if (isRowTimeType) { + new RowsClauseBoundedOverProcessFunction( + aggregates, + aggFields, + inputType.getFieldCount, + aggregationStateType, + inputRowType, + precedingOffset + ) + } else { + throw TableException( + "Bounded partitioned proc-time OVER aggregation is not supported yet.") + } + processFunction + } + + /** * Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates. * The output of the function contains the grouping keys and the timestamp and the intermediate * aggregate values of all aggregate function. The timestamp field is aligned to time window http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala new file mode 100644 index 0000000..1678d57 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.util +import java.util.{List => JList} + +import org.apache.flink.api.common.state._ +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo} +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.types.Row +import org.apache.flink.util.{Collector, Preconditions} + +/** + * Process Function for ROWS clause event-time bounded OVER window + * + * @param aggregates the list of all [[AggregateFunction]] used for this aggregation + * @param aggFields the position (in the input Row) of the input value for each aggregate + * @param forwardedFieldCount the count of forwarded fields. + * @param aggregationStateType the row type info of aggregation + * @param inputRowType the row type info of input row + * @param precedingOffset the preceding offset + */ +class RowsClauseBoundedOverProcessFunction( + private val aggregates: Array[AggregateFunction[_]], + private val aggFields: Array[Int], + private val forwardedFieldCount: Int, + private val aggregationStateType: RowTypeInfo, + private val inputRowType: RowTypeInfo, + private val precedingOffset: Long) + extends ProcessFunction[Row, Row] { + + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(aggFields) + Preconditions.checkArgument(aggregates.length == aggFields.length) + Preconditions.checkNotNull(forwardedFieldCount) + Preconditions.checkNotNull(aggregationStateType) + Preconditions.checkNotNull(precedingOffset) + + private var output: Row = _ + + // the state which keeps the last triggering timestamp + private var lastTriggeringTsState: ValueState[Long] = _ + + // the state which keeps the count of data + private var dataCountState: ValueState[Long] = _ + + // the state which used to materialize the accumulator for incremental calculation + private var accumulatorState: ValueState[Row] = _ + + // the state which keeps all the data that are not expired. + // The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp, + // the second element of tuple is a list that contains the entire data of all the rows belonging + // to this time stamp. + private var dataState: MapState[Long, JList[Row]] = _ + + override def open(config: Configuration) { + + output = new Row(forwardedFieldCount + aggregates.length) + + val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] = + new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long]) + lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor) + + val dataCountStateDescriptor = + new ValueStateDescriptor[Long]("dataCountState", classOf[Long]) + dataCountState = getRuntimeContext.getState(dataCountStateDescriptor) + + val accumulatorStateDescriptor = + new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType) + accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor) + + val keyTypeInformation: TypeInformation[Long] = + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]] + val valueTypeInformation: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputRowType) + + val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] = + new MapStateDescriptor[Long, JList[Row]]( + "dataState", + keyTypeInformation, + valueTypeInformation) + + dataState = getRuntimeContext.getMapState(mapStateDescriptor) + + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + // triggering timestamp for trigger calculation + val triggeringTs = ctx.timestamp + + val lastTriggeringTs = lastTriggeringTsState.value + // check if the data is expired, if not, save the data and register event time timer + + if (triggeringTs > lastTriggeringTs) { + val data = dataState.get(triggeringTs) + if (null != data) { + data.add(input) + dataState.put(triggeringTs, data) + } else { + val data = new util.ArrayList[Row] + data.add(input) + dataState.put(triggeringTs, data) + // register event time timer + ctx.timerService.registerEventTimeTimer(triggeringTs) + } + } + } + + override def onTimer( + timestamp: Long, + ctx: ProcessFunction[Row, Row]#OnTimerContext, + out: Collector[Row]): Unit = { + + // gets all window data from state for the calculation + val inputs: JList[Row] = dataState.get(timestamp) + + if (null != inputs) { + + var accumulators = accumulatorState.value + var dataCount = dataCountState.value + + var retractList: JList[Row] = null + var retractTs: Long = Long.MaxValue + var retractCnt: Int = 0 + var j = 0 + var i = 0 + + while (j < inputs.size) { + val input = inputs.get(j) + + // initialize when first run or failover recovery per key + if (null == accumulators) { + accumulators = new Row(aggregates.length) + i = 0 + while (i < aggregates.length) { + accumulators.setField(i, aggregates(i).createAccumulator()) + i += 1 + } + } + + var retractRow: Row = null + + if (dataCount >= precedingOffset) { + if (null == retractList) { + // find the smallest timestamp + retractTs = Long.MaxValue + val dataTimestampIt = dataState.keys.iterator + while (dataTimestampIt.hasNext) { + val dataTs = dataTimestampIt.next + if (dataTs < retractTs) { + retractTs = dataTs + } + } + // get the oldest rows to retract them + retractList = dataState.get(retractTs) + } + + retractRow = retractList.get(retractCnt) + retractCnt += 1 + + // remove retracted values from state + if (retractList.size == retractCnt) { + dataState.remove(retractTs) + retractList = null + retractCnt = 0 + } + } else { + dataCount += 1 + } + + // copy forwarded fields to output row + i = 0 + while (i < forwardedFieldCount) { + output.setField(i, input.getField(i)) + i += 1 + } + + // retract old row from accumulators + if (null != retractRow) { + i = 0 + while (i < aggregates.length) { + val accumulator = accumulators.getField(i).asInstanceOf[Accumulator] + aggregates(i).retract(accumulator, retractRow.getField(aggFields(i))) + i += 1 + } + } + + // accumulate current row and set aggregate in output row + i = 0 + while (i < aggregates.length) { + val index = forwardedFieldCount + i + val accumulator = accumulators.getField(i).asInstanceOf[Accumulator] + aggregates(i).accumulate(accumulator, input.getField(aggFields(i))) + output.setField(index, aggregates(i).getValue(accumulator)) + i += 1 + } + j += 1 + + out.collect(output) + } + + // update all states + if (dataState.contains(retractTs)) { + if (retractCnt > 0) { + retractList.subList(0, retractCnt).clear() + dataState.put(retractTs, retractList) + } + } + dataCountState.update(dataCount) + accumulatorState.update(accumulators) + } + + lastTriggeringTsState.update(timestamp) + } +} + + http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala index d5a140a..19350a7 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala @@ -19,14 +19,18 @@ package org.apache.flink.table.api.scala.stream.sql import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.functions.source.SourceFunction +import org.apache.flink.table.api.scala.stream.sql.SqlITCase.EventTimeSourceFunction import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.table.api.{TableEnvironment, TableException} import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.scala.stream.utils.{StreamingWithStateTestBase, StreamITCase, -StreamTestData} +import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase} import org.apache.flink.types.Row import org.junit.Assert._ import org.junit._ +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext import scala.collection.mutable @@ -293,6 +297,120 @@ class SqlITCase extends StreamingWithStateTestBase { assertEquals(expected.sorted, StreamITCase.testResults.sorted) } + @Test + def testBoundPartitionedEventTimeWindowWithRow(): Unit = { + val data = Seq( + Left((1L, (1L, 1, "Hello"))), + Left((2L, (2L, 2, "Hello"))), + Left((1L, (1L, 1, "Hello"))), + Left((2L, (2L, 2, "Hello"))), + Left((2L, (2L, 2, "Hello"))), + Left((1L, (1L, 1, "Hello"))), + Left((3L, (7L, 7, "Hello World"))), + Left((1L, (7L, 7, "Hello World"))), + Left((1L, (7L, 7, "Hello World"))), + Right(2L), + Left((3L, (3L, 3, "Hello"))), + Left((4L, (4L, 4, "Hello"))), + Left((5L, (5L, 5, "Hello"))), + Left((6L, (6L, 6, "Hello"))), + Left((20L, (20L, 20, "Hello World"))), + Right(6L), + Left((8L, (8L, 8, "Hello World"))), + Left((7L, (7L, 7, "Hello World"))), + Right(20L)) + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t1 = env + .addSource[(Long, Int, String)](new EventTimeSourceFunction[(Long, Int, String)](data)) + .toTable(tEnv).as('a, 'b, 'c) + + tEnv.registerTable("T1", t1) + + val sqlQuery = "SELECT " + + "c, a, " + + "count(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)" + + ", sum(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)" + + " from T1" + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3", + "Hello,2,3,4", "Hello,2,3,5","Hello,2,3,6", + "Hello,3,3,7", "Hello,4,3,9", "Hello,5,3,12", + "Hello,6,3,15", + "Hello World,7,1,7", "Hello World,7,2,14", "Hello World,7,3,21", + "Hello World,7,3,21", "Hello World,8,3,22", "Hello World,20,3,35") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testBoundNonPartitionedEventTimeWindowWithRow(): Unit = { + + val data = Seq( + Left((2L, (2L, 2, "Hello"))), + Left((2L, (2L, 2, "Hello"))), + Left((1L, (1L, 1, "Hello"))), + Left((1L, (1L, 1, "Hello"))), + Left((2L, (2L, 2, "Hello"))), + Left((1L, (1L, 1, "Hello"))), + Left((20L, (20L, 20, "Hello World"))), // early row + Right(3L), + Left((2L, (2L, 2, "Hello"))), // late row + Left((3L, (3L, 3, "Hello"))), + Left((4L, (4L, 4, "Hello"))), + Left((5L, (5L, 5, "Hello"))), + Left((6L, (6L, 6, "Hello"))), + Left((7L, (7L, 7, "Hello World"))), + Right(7L), + Left((9L, (9L, 9, "Hello World"))), + Left((8L, (8L, 8, "Hello World"))), + Left((8L, (8L, 8, "Hello World"))), + Right(20L)) + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setStateBackend(getStateBackend) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t1 = env + .addSource[(Long, Int, String)](new EventTimeSourceFunction[(Long, Int, String)](data)) + .toTable(tEnv).as('a, 'b, 'c) + + tEnv.registerTable("T1", t1) + + val sqlQuery = "SELECT " + + "c, a, " + + "count(a) OVER (ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)," + + "sum(a) OVER (ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT ROW)" + + "from T1" + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3", + "Hello,2,3,4", "Hello,2,3,5", "Hello,2,3,6", + "Hello,3,3,7", + "Hello,4,3,9", "Hello,5,3,12", + "Hello,6,3,15", "Hello World,7,3,18", + "Hello World,8,3,21", "Hello World,8,3,23", + "Hello World,9,3,25", + "Hello World,20,3,37") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + /** * All aggregates must be computed on the same window. */ @@ -317,4 +435,21 @@ class SqlITCase extends StreamingWithStateTestBase { result.addSink(new StreamITCase.StringSink) env.execute() } + +} + +object SqlITCase { + + class EventTimeSourceFunction[T]( + dataWithTimestampList: Seq[Either[(Long, T), Long]]) extends SourceFunction[T] { + override def run(ctx: SourceContext[T]): Unit = { + dataWithTimestampList.foreach { + case Left(t) => ctx.collectWithTimestamp(t._2, t._1) + case Right(w) => ctx.emitWatermark(new Watermark(w)) + } + } + + override def cancel(): Unit = ??? + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala index a25e59c..9a425b3 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala @@ -239,4 +239,59 @@ class WindowAggregateTest extends TableTestBase { ) streamUtil.verifySql(sql, expected) } + + @Test + def testBoundPartitionedRowTimeWindowWithRow() = { + val sql = "SELECT " + + "c, " + + "count(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 5 preceding AND " + + "CURRENT ROW) as cnt1 " + + "from MyTable" + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamOverAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "c", "ROWTIME() AS $2") + ), + term("partitionBy", "c"), + term("orderBy", "ROWTIME"), + term("rows", "BETWEEN 5 PRECEDING AND CURRENT ROW"), + term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0") + ), + term("select", "c", "w0$o0 AS $1") + ) + streamUtil.verifySql(sql, expected) + } + + @Test + def testBoundNonPartitionedRowTimeWindowWithRow() = { + val sql = "SELECT " + + "c, " + + "count(a) OVER (ORDER BY RowTime() ROWS BETWEEN 5 preceding AND " + + "CURRENT ROW) as cnt1 " + + "from MyTable" + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamOverAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "c", "ROWTIME() AS $2") + ), + term("orderBy", "ROWTIME"), + term("rows", "BETWEEN 5 PRECEDING AND CURRENT ROW"), + term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0") + ), + term("select", "c", "w0$o0 AS $1") + ) + streamUtil.verifySql(sql, expected) + } }