Repository: flink Updated Branches: refs/heads/master 976e03c1e -> fe2c61a28
[FLINK-5658] [table] Add event-time OVER ROWS/RANGE UNBOUNDED PRECEDING aggregation to SQL. This closes #3386. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/fe2c61a2 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/fe2c61a2 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/fe2c61a2 Branch: refs/heads/master Commit: fe2c61a28e6a5300b2cf4c1e50ee884b51ef42c9 Parents: 7a9d39f Author: hongyuhong 00223286 <hongyuh...@huawei.com> Authored: Fri Mar 24 09:31:59 2017 +0800 Committer: Fabian Hueske <fhue...@apache.org> Committed: Fri Mar 24 20:19:17 2017 +0100 ---------------------------------------------------------------------- .../datastream/DataStreamOverAggregate.scala | 62 +++-- .../table/runtime/aggregate/AggregateUtil.scala | 70 +++-- .../UnboundedEventTimeOverProcessFunction.scala | 224 ++++++++++++++++ .../table/api/scala/stream/sql/SqlITCase.scala | 263 ++++++++++++++++++- .../scala/stream/sql/WindowAggregateTest.scala | 64 ++++- 5 files changed, 634 insertions(+), 49 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/fe2c61a2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala index 547c875..3dd7ee2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala @@ -104,7 +104,7 @@ class DataStreamOverAggregate( case _: ProcTimeType => // proc-time OVER window if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) { - // non-bounded OVER window + // unbounded preceding OVER window createUnboundedAndCurrentRowProcessingTimeOverWindow(inputDS) } else if ( overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded && @@ -126,23 +126,15 @@ class DataStreamOverAggregate( } case _: RowTimeType => // row-time OVER window - if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) { - // non-bounded OVER window - if (overWindow.isRows) { - // ROWS clause unbounded OVER window - throw new TableException( - "ROWS clause unbounded row-time OVER window no supported yet.") - } else { - // RANGE clause unbounded OVER window - throw new TableException( - "RANGE clause unbounded row-time OVER window no supported yet.") - } - } else if (overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded && - overWindow.upperBound.isCurrentRow) { + if (overWindow.lowerBound.isPreceding && + overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) { + // unbounded preceding OVER window + createUnboundedAndCurrentRowEventTimeOverWindow(inputDS) + } else if (overWindow.lowerBound.isPreceding && overWindow.upperBound.isCurrentRow) { // bounded OVER window if (overWindow.isRows) { // ROWS clause bounded OVER window - createRowsClauseBoundedAndCurrentRowOverWindow(inputDS, true) + createRowsClauseBoundedAndCurrentRowOverWindow(inputDS, isRowTimeType = true) } else { // RANGE clause bounded OVER window throw new TableException( @@ -187,7 +179,7 @@ class DataStreamOverAggregate( val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction( namedAggregates, inputType, - false) + isPartitioned = false) inputDS .process(processFunction).setParallelism(1).setMaxParallelism(1) @@ -205,7 +197,6 @@ class DataStreamOverAggregate( val overWindow: Group = logicWindow.groups.get(0) val partitionKeys: Array[Int] = overWindow.keys.toArray val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates - val inputFields = (0 until inputType.getFieldCount).toArray val precedingOffset = getLowerBoundary(logicWindow, overWindow, getInput()) + 1 @@ -216,7 +207,6 @@ class DataStreamOverAggregate( val processFunction = AggregateUtil.createRowsClauseBoundedOverProcessFunction( namedAggregates, inputType, - inputFields, precedingOffset, isRowTimeType ) @@ -244,6 +234,42 @@ class DataStreamOverAggregate( result } + def createUnboundedAndCurrentRowEventTimeOverWindow( + inputDS: DataStream[Row]): DataStream[Row] = { + + val overWindow: Group = logicWindow.groups.get(0) + val partitionKeys: Array[Int] = overWindow.keys.toArray + val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates + + // get the output types + val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo] + + val processFunction = AggregateUtil.createUnboundedEventTimeOverProcessFunction( + namedAggregates, + inputType) + + val result: DataStream[Row] = + // partitioned aggregation + if (partitionKeys.nonEmpty) { + inputDS.keyBy(partitionKeys: _*) + .process(processFunction) + .returns(rowTypeInfo) + .name(aggOpName) + .asInstanceOf[DataStream[Row]] + } + // global non-partitioned aggregation + else { + inputDS.keyBy(new NullByteKeySelector[Row]) + .process(processFunction) + .setParallelism(1) + .setMaxParallelism(1) + .returns(rowTypeInfo) + .name(aggOpName) + .asInstanceOf[DataStream[Row]] + } + result + } + private def generateNamedAggregates: Seq[CalcitePair[AggregateCall, String]] = { val overWindow: Group = logicWindow.groups.get(0) http://git-wip-us.apache.org/repos/asf/flink/blob/fe2c61a2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 0084ee5..fdac692 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -104,7 +104,6 @@ object AggregateUtil { private[flink] def createRowsClauseBoundedOverProcessFunction( namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, - inputFields: Array[Int], precedingOffset: Long, isRowTimeType: Boolean): ProcessFunction[Row, Row] = { @@ -114,26 +113,49 @@ object AggregateUtil { inputType, needRetraction = true) - val aggregationStateType: RowTypeInfo = - createDataSetAggregateBufferDataType(Array(), aggregates, inputType) + val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates) + val inputRowType = FlinkTypeFactory.toInternalRowTypeInfo(inputType).asInstanceOf[RowTypeInfo] - val inputRowType: RowTypeInfo = - createDataSetAggregateBufferDataType(inputFields, Array(), inputType) + if (isRowTimeType) { + new RowsClauseBoundedOverProcessFunction( + aggregates, + aggFields, + inputType.getFieldCount, + aggregationStateType, + inputRowType, + precedingOffset + ) + } else { + throw TableException( + "Bounded partitioned proc-time OVER aggregation is not supported yet.") + } + } - val processFunction = if (isRowTimeType) { - new RowsClauseBoundedOverProcessFunction( - aggregates, - aggFields, - inputType.getFieldCount, - aggregationStateType, - inputRowType, - precedingOffset - ) - } else { - throw TableException( - "Bounded partitioned proc-time OVER aggregation is not supported yet.") - } - processFunction + /** + * Create an [[ProcessFunction]] to evaluate final aggregate value. + * + * @param namedAggregates List of calls to aggregate functions and their output field names + * @param inputType Input row type + * @return [[UnboundedEventTimeOverProcessFunction]] + */ + private[flink] def createUnboundedEventTimeOverProcessFunction( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType): UnboundedEventTimeOverProcessFunction = { + + val (aggFields, aggregates) = + transformToAggregateFunctions( + namedAggregates.map(_.getKey), + inputType, + needRetraction = false) + + val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates) + + new UnboundedEventTimeOverProcessFunction( + aggregates, + aggFields, + inputType.getFieldCount, + aggregationStateType, + FlinkTypeFactory.toInternalRowTypeInfo(inputType)) } /** @@ -595,7 +617,7 @@ object AggregateUtil { // compute preaggregation type val preAggFieldTypes = gkeyInFields .map(inputType.getFieldList.get(_).getType) - .map(FlinkTypeFactory.toTypeInfo) ++ createAccumulatorType(inputType, aggregates) + .map(FlinkTypeFactory.toTypeInfo) ++ createAccumulatorType(aggregates) val preAggRowType = new RowTypeInfo(preAggFieldTypes: _*) ( @@ -701,7 +723,7 @@ object AggregateUtil { val aggResultTypes = namedAggregates.map(a => FlinkTypeFactory.toTypeInfo(a.left.getType)) - val accumulatorRowType = createAccumulatorRowType(inputType, aggregates) + val accumulatorRowType = createAccumulatorRowType(aggregates) val aggResultRowType = new RowTypeInfo(aggResultTypes: _*) val aggFunction = new AggregateAggFunction(aggregates, aggFields) @@ -1029,7 +1051,6 @@ object AggregateUtil { } private def createAccumulatorType( - inputType: RelDataType, aggregates: Array[TableAggregateFunction[_]]): Seq[TypeInformation[_]] = { val aggTypes: Seq[TypeInformation[_]] = @@ -1068,7 +1089,7 @@ object AggregateUtil { .map(FlinkTypeFactory.toTypeInfo) // get all field data types of all intermediate aggregates - val aggTypes: Seq[TypeInformation[_]] = createAccumulatorType(inputType, aggregates) + val aggTypes: Seq[TypeInformation[_]] = createAccumulatorType(aggregates) // concat group key types, aggregation types, and window key types val allFieldTypes: Seq[TypeInformation[_]] = windowKeyTypes match { @@ -1079,10 +1100,9 @@ object AggregateUtil { } private def createAccumulatorRowType( - inputType: RelDataType, aggregates: Array[TableAggregateFunction[_]]): RowTypeInfo = { - val aggTypes: Seq[TypeInformation[_]] = createAccumulatorType(inputType, aggregates) + val aggTypes: Seq[TypeInformation[_]] = createAccumulatorType(aggregates) new RowTypeInfo(aggTypes: _*) } http://git-wip-us.apache.org/repos/asf/flink/blob/fe2c61a2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedEventTimeOverProcessFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedEventTimeOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedEventTimeOverProcessFunction.scala new file mode 100644 index 0000000..7616ede --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedEventTimeOverProcessFunction.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.util +import java.util.{List => JList} + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.configuration.Configuration +import org.apache.flink.types.Row +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.api.common.state._ +import org.apache.flink.api.java.typeutils.ListTypeInfo +import org.apache.flink.streaming.api.operators.TimestampedCollector +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} + + +/** + * A ProcessFunction to support unbounded event-time over-window + * + * @param aggregates the aggregate functions + * @param aggFields the filed index which the aggregate functions use + * @param forwardedFieldCount the input fields count + * @param intermediateType the intermediate row tye which the state saved + * @param inputType the input row tye which the state saved + * + */ +class UnboundedEventTimeOverProcessFunction( + private val aggregates: Array[AggregateFunction[_]], + private val aggFields: Array[Int], + private val forwardedFieldCount: Int, + private val intermediateType: TypeInformation[Row], + private val inputType: TypeInformation[Row]) + extends ProcessFunction[Row, Row]{ + + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(aggFields) + Preconditions.checkArgument(aggregates.length == aggFields.length) + + private var output: Row = _ + // state to hold the accumulators of the aggregations + private var accumulatorState: ValueState[Row] = _ + // state to hold rows until the next watermark arrives + private var rowMapState: MapState[Long, JList[Row]] = _ + // list to sort timestamps to access rows in timestamp order + private var sortedTimestamps: util.LinkedList[Long] = _ + + + override def open(config: Configuration) { + output = new Row(forwardedFieldCount + aggregates.length) + sortedTimestamps = new util.LinkedList[Long]() + + // initialize accumulator state + val accDescriptor: ValueStateDescriptor[Row] = + new ValueStateDescriptor[Row]("accumulatorstate", intermediateType) + accumulatorState = getRuntimeContext.getState[Row](accDescriptor) + + // initialize row state + val rowListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputType) + val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] = + new MapStateDescriptor[Long, JList[Row]]("rowmapstate", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo) + rowMapState = getRuntimeContext.getMapState(mapStateDescriptor) + } + + /** + * Puts an element from the input stream into state if it is not late. + * Registers a timer for the next watermark. + * + * @param input The input value. + * @param ctx The ctx to register timer or get current time + * @param out The collector for returning result values. + * + */ + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + val timestamp = ctx.timestamp() + val curWatermark = ctx.timerService().currentWatermark() + + // discard late record + if (timestamp >= curWatermark) { + // ensure every key just registers one timer + ctx.timerService.registerEventTimeTimer(curWatermark + 1) + + // put row into state + var rowList = rowMapState.get(timestamp) + if (rowList == null) { + rowList = new util.ArrayList[Row]() + } + rowList.add(input) + rowMapState.put(timestamp, rowList) + } + } + + /** + * Called when a watermark arrived. + * Sorts records according the timestamp, computes aggregates, and emits all records with + * timestamp smaller than the watermark in timestamp order. + * + * @param timestamp The timestamp of the firing timer. + * @param ctx The ctx to register timer or get current time + * @param out The collector for returning result values. + */ + override def onTimer( + timestamp: Long, + ctx: ProcessFunction[Row, Row]#OnTimerContext, + out: Collector[Row]): Unit = { + + Preconditions.checkArgument(out.isInstanceOf[TimestampedCollector[Row]]) + val collector = out.asInstanceOf[TimestampedCollector[Row]] + + val keyIterator = rowMapState.keys.iterator + if (keyIterator.hasNext) { + val curWatermark = ctx.timerService.currentWatermark + var existEarlyRecord: Boolean = false + var i = 0 + + // sort the record timestamps + do { + val recordTime = keyIterator.next + // only take timestamps smaller/equal to the watermark + if (recordTime <= curWatermark) { + insertToSortedList(recordTime) + } else { + existEarlyRecord = true + } + } while (keyIterator.hasNext) + + // get last accumulator + var lastAccumulator = accumulatorState.value + if (lastAccumulator == null) { + // initialize accumulator + lastAccumulator = new Row(aggregates.length) + while (i < aggregates.length) { + lastAccumulator.setField(i, aggregates(i).createAccumulator()) + i += 1 + } + } + + // emit the rows in order + while (!sortedTimestamps.isEmpty) { + val curTimestamp = sortedTimestamps.removeFirst() + val curRowList = rowMapState.get(curTimestamp) + collector.setAbsoluteTimestamp(curTimestamp) + + var j = 0 + while (j < curRowList.size) { + val curRow = curRowList.get(j) + i = 0 + + // copy forwarded fields to output row + while (i < forwardedFieldCount) { + output.setField(i, curRow.getField(i)) + i += 1 + } + + // update accumulators and copy aggregates to output row + i = 0 + while (i < aggregates.length) { + val index = forwardedFieldCount + i + val accumulator = lastAccumulator.getField(i).asInstanceOf[Accumulator] + aggregates(i).accumulate(accumulator, curRow.getField(aggFields(i))) + output.setField(index, aggregates(i).getValue(accumulator)) + i += 1 + } + // emit output row + collector.collect(output) + j += 1 + } + rowMapState.remove(curTimestamp) + } + + accumulatorState.update(lastAccumulator) + + // if are are rows with timestamp > watermark, register a timer for the next watermark + if (existEarlyRecord) { + ctx.timerService.registerEventTimeTimer(curWatermark + 1) + } + } + } + + /** + * Inserts timestamps in order into a linked list. + * + * If timestamps arrive in order (as in case of using the RocksDB state backend) this is just + * an append with O(1). + */ + private def insertToSortedList(recordTimeStamp: Long) = { + val listIterator = sortedTimestamps.listIterator(sortedTimestamps.size) + var continue = true + while (listIterator.hasPrevious && continue) { + val timestamp = listIterator.previous + if (recordTimeStamp >= timestamp) { + listIterator.next + listIterator.add(recordTimeStamp) + continue = false + } + } + + if (continue) { + sortedTimestamps.addFirst(recordTimeStamp) + } + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/fe2c61a2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala index 19350a7..34a68b2 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala @@ -19,8 +19,8 @@ package org.apache.flink.table.api.scala.stream.sql import org.apache.flink.api.scala._ -import org.apache.flink.streaming.api.functions.source.SourceFunction import org.apache.flink.table.api.scala.stream.sql.SqlITCase.EventTimeSourceFunction +import org.apache.flink.streaming.api.functions.source.SourceFunction import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.table.api.{TableEnvironment, TableException} @@ -436,6 +436,266 @@ class SqlITCase extends StreamingWithStateTestBase { env.execute() } + /** test sliding event-time unbounded window with partition by **/ + @Test + def testUnboundedEventTimeRowWindowWithPartition(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setStateBackend(getStateBackend) + StreamITCase.testResults = mutable.MutableList() + env.setParallelism(1) + + val sqlQuery = "SELECT a, b, c, " + + "SUM(b) over (" + + "partition by a order by rowtime() range between unbounded preceding and current row), " + + "count(b) over (" + + "partition by a order by rowtime() range between unbounded preceding and current row), " + + "avg(b) over (" + + "partition by a order by rowtime() range between unbounded preceding and current row), " + + "max(b) over (" + + "partition by a order by rowtime() range between unbounded preceding and current row), " + + "min(b) over (" + + "partition by a order by rowtime() range between unbounded preceding and current row) " + + "from T1" + + val data = Seq( + Left(14000005L, (1, 1L, "Hi")), + Left(14000000L, (2, 1L, "Hello")), + Left(14000002L, (3, 1L, "Hello")), + Left(14000003L, (1, 2L, "Hello")), + Left(14000004L, (1, 3L, "Hello world")), + Left(14000007L, (3, 2L, "Hello world")), + Left(14000008L, (2, 2L, "Hello world")), + Right(14000010L), + // the next 3 elements are late + Left(14000008L, (1, 4L, "Hello world")), + Left(14000008L, (2, 3L, "Hello world")), + Left(14000008L, (3, 3L, "Hello world")), + Left(14000012L, (1, 5L, "Hello world")), + Right(14000020L), + Left(14000021L, (1, 6L, "Hello world")), + // the next 3 elements are late + Left(14000019L, (1, 6L, "Hello world")), + Left(14000018L, (2, 4L, "Hello world")), + Left(14000018L, (3, 4L, "Hello world")), + Left(14000022L, (2, 5L, "Hello world")), + Left(14000022L, (3, 5L, "Hello world")), + Left(14000024L, (1, 7L, "Hello world")), + Left(14000023L, (1, 8L, "Hello world")), + Left(14000021L, (1, 9L, "Hello world")), + Right(14000030L) + ) + + val t1 = env.addSource(new EventTimeSourceFunction[(Int, Long, String)](data)) + .toTable(tEnv).as('a, 'b, 'c) + + tEnv.registerTable("T1", t1) + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,2,Hello,2,1,2,2,2", + "1,3,Hello world,5,2,2,3,2", + "1,1,Hi,6,3,2,3,1", + "2,1,Hello,1,1,1,1,1", + "2,2,Hello world,3,2,1,2,1", + "3,1,Hello,1,1,1,1,1", + "3,2,Hello world,3,2,1,2,1", + "1,5,Hello world,11,4,2,5,1", + "1,6,Hello world,17,5,3,6,1", + "1,9,Hello world,26,6,4,9,1", + "1,8,Hello world,34,7,4,9,1", + "1,7,Hello world,41,8,5,9,1", + "2,5,Hello world,8,3,2,5,1", + "3,5,Hello world,8,3,2,5,1" + ) + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + /** test sliding event-time unbounded window with partition by **/ + @Test + def testUnboundedEventTimeRowWindowWithPartitionMultiThread(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setStateBackend(getStateBackend) + StreamITCase.testResults = mutable.MutableList() + + val sqlQuery = "SELECT a, b, c, " + + "SUM(b) over (" + + "partition by a order by rowtime() range between unbounded preceding and current row), " + + "count(b) over (" + + "partition by a order by rowtime() range between unbounded preceding and current row), " + + "avg(b) over (" + + "partition by a order by rowtime() range between unbounded preceding and current row), " + + "max(b) over (" + + "partition by a order by rowtime() range between unbounded preceding and current row), " + + "min(b) over (" + + "partition by a order by rowtime() range between unbounded preceding and current row) " + + "from T1" + + val data = Seq( + Left(14000005L, (1, 1L, "Hi")), + Left(14000000L, (2, 1L, "Hello")), + Left(14000002L, (3, 1L, "Hello")), + Left(14000003L, (1, 2L, "Hello")), + Left(14000004L, (1, 3L, "Hello world")), + Left(14000007L, (3, 2L, "Hello world")), + Left(14000008L, (2, 2L, "Hello world")), + Right(14000010L), + Left(14000012L, (1, 5L, "Hello world")), + Right(14000020L), + Left(14000021L, (1, 6L, "Hello world")), + Left(14000023L, (2, 5L, "Hello world")), + Left(14000024L, (3, 5L, "Hello world")), + Left(14000026L, (1, 7L, "Hello world")), + Left(14000025L, (1, 8L, "Hello world")), + Left(14000022L, (1, 9L, "Hello world")), + Right(14000030L) + ) + + val t1 = env.addSource(new EventTimeSourceFunction[(Int, Long, String)](data)) + .toTable(tEnv).as('a, 'b, 'c) + + tEnv.registerTable("T1", t1) + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,2,Hello,2,1,2,2,2", + "1,3,Hello world,5,2,2,3,2", + "1,1,Hi,6,3,2,3,1", + "2,1,Hello,1,1,1,1,1", + "2,2,Hello world,3,2,1,2,1", + "3,1,Hello,1,1,1,1,1", + "3,2,Hello world,3,2,1,2,1", + "1,5,Hello world,11,4,2,5,1", + "1,6,Hello world,17,5,3,6,1", + "1,9,Hello world,26,6,4,9,1", + "1,8,Hello world,34,7,4,9,1", + "1,7,Hello world,41,8,5,9,1", + "2,5,Hello world,8,3,2,5,1", + "3,5,Hello world,8,3,2,5,1" + ) + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + /** test sliding event-time unbounded window without partitiion by **/ + @Test + def testUnboundedEventTimeRowWindowWithoutPartition(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setStateBackend(getStateBackend) + StreamITCase.testResults = mutable.MutableList() + env.setParallelism(1) + + val sqlQuery = "SELECT a, b, c, " + + "SUM(b) over (order by rowtime() range between unbounded preceding and current row), " + + "count(b) over (order by rowtime() range between unbounded preceding and current row), " + + "avg(b) over (order by rowtime() range between unbounded preceding and current row), " + + "max(b) over (order by rowtime() range between unbounded preceding and current row), " + + "min(b) over (order by rowtime() range between unbounded preceding and current row) " + + "from T1" + + val data = Seq( + Left(14000005L, (1, 1L, "Hi")), + Left(14000000L, (2, 2L, "Hello")), + Left(14000002L, (3, 5L, "Hello")), + Left(14000003L, (1, 3L, "Hello")), + Left(14000004L, (3, 7L, "Hello world")), + Left(14000007L, (4, 9L, "Hello world")), + Left(14000008L, (5, 8L, "Hello world")), + Right(14000010L), + // this element will be discard because it is late + Left(14000008L, (6, 8L, "Hello world")), + Right(14000020L), + Left(14000021L, (6, 8L, "Hello world")), + Right(14000030L) + ) + + val t1 = env.addSource(new EventTimeSourceFunction[(Int, Long, String)](data)) + .toTable(tEnv).as('a, 'b, 'c) + + tEnv.registerTable("T1", t1) + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "2,2,Hello,2,1,2,2,2", + "3,5,Hello,7,2,3,5,2", + "1,3,Hello,10,3,3,5,2", + "3,7,Hello world,17,4,4,7,2", + "1,1,Hi,18,5,3,7,1", + "4,9,Hello world,27,6,4,9,1", + "5,8,Hello world,35,7,5,9,1", + "6,8,Hello world,43,8,5,9,1") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + /** test sliding event-time unbounded window without partitiion by and arrive early **/ + @Test + def testUnboundedEventTimeRowWindowArriveEarly(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setStateBackend(getStateBackend) + StreamITCase.testResults = mutable.MutableList() + env.setParallelism(1) + + val sqlQuery = "SELECT a, b, c, " + + "SUM(b) over (order by rowtime() range between unbounded preceding and current row), " + + "count(b) over (order by rowtime() range between unbounded preceding and current row), " + + "avg(b) over (order by rowtime() range between unbounded preceding and current row), " + + "max(b) over (order by rowtime() range between unbounded preceding and current row), " + + "min(b) over (order by rowtime() range between unbounded preceding and current row) " + + "from T1" + + val data = Seq( + Left(14000005L, (1, 1L, "Hi")), + Left(14000000L, (2, 2L, "Hello")), + Left(14000002L, (3, 5L, "Hello")), + Left(14000003L, (1, 3L, "Hello")), + // next three elements are early + Left(14000012L, (3, 7L, "Hello world")), + Left(14000013L, (4, 9L, "Hello world")), + Left(14000014L, (5, 8L, "Hello world")), + Right(14000010L), + Left(14000011L, (6, 8L, "Hello world")), + // next element is early + Left(14000021L, (6, 8L, "Hello world")), + Right(14000020L), + Right(14000030L) + ) + + val t1 = env.addSource(new EventTimeSourceFunction[(Int, Long, String)](data)) + .toTable(tEnv).as('a, 'b, 'c) + + tEnv.registerTable("T1", t1) + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "2,2,Hello,2,1,2,2,2", + "3,5,Hello,7,2,3,5,2", + "1,3,Hello,10,3,3,5,2", + "1,1,Hi,11,4,2,5,1", + "6,8,Hello world,19,5,3,8,1", + "3,7,Hello world,26,6,4,8,1", + "4,9,Hello world,35,7,5,9,1", + "5,8,Hello world,43,8,5,9,1", + "6,8,Hello world,51,9,5,9,1") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } } object SqlITCase { @@ -451,5 +711,4 @@ object SqlITCase { override def cancel(): Unit = ??? } - } http://git-wip-us.apache.org/repos/asf/flink/blob/fe2c61a2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala index 9a425b3..7b8b2df 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala @@ -241,12 +241,67 @@ class WindowAggregateTest extends TableTestBase { } @Test + def testUnboundNonPartitionedEventTimeWindowWithRange() = { + val sql = "SELECT " + + "c, " + + "count(a) OVER (ORDER BY RowTime() RANGE UNBOUNDED preceding) as cnt1, " + + "sum(a) OVER (ORDER BY RowTime() RANGE UNBOUNDED preceding) as cnt2 " + + "from MyTable" + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamOverAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "c", "ROWTIME() AS $2") + ), + term("orderBy", "ROWTIME"), + term("range", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"), + term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0", "$SUM0(a) AS w0$o1") + ), + term("select", "c", "w0$o0 AS cnt1", "CASE(>(w0$o0, 0)", "CAST(w0$o1), null) AS cnt2") + ) + streamUtil.verifySql(sql, expected) + } + + @Test + def testUnboundPartitionedEventTimeWindowWithRange() = { + val sql = "SELECT " + + "c, " + + "count(a) OVER (PARTITION BY c ORDER BY RowTime() RANGE UNBOUNDED preceding) as cnt1, " + + "sum(a) OVER (PARTITION BY c ORDER BY RowTime() RANGE UNBOUNDED preceding) as cnt2 " + + "from MyTable" + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamOverAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "c", "ROWTIME() AS $2") + ), + term("partitionBy", "c"), + term("orderBy", "ROWTIME"), + term("range", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"), + term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0", "$SUM0(a) AS w0$o1") + ), + term("select", "c", "w0$o0 AS cnt1", "CASE(>(w0$o0, 0)", "CAST(w0$o1), null) AS cnt2") + ) + streamUtil.verifySql(sql, expected) + } + + @Test def testBoundPartitionedRowTimeWindowWithRow() = { val sql = "SELECT " + - "c, " + - "count(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 5 preceding AND " + - "CURRENT ROW) as cnt1 " + - "from MyTable" + "c, " + + "count(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 5 preceding AND " + + "CURRENT ROW) as cnt1 " + + "from MyTable" val expected = unaryNode( @@ -294,4 +349,5 @@ class WindowAggregateTest extends TableTestBase { ) streamUtil.verifySql(sql, expected) } + }