[FLINK-5990] [table] Add event-time OVER ROWS BETWEEN x PRECEDING aggregation 
to SQL.

This closes #3585.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/7a9d39fe
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/7a9d39fe
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/7a9d39fe

Branch: refs/heads/master
Commit: 7a9d39fe9f659d43bf4719a2981f6c4771ffbe48
Parents: 6949c8c
Author: 金竹 <jincheng.su...@alibaba-inc.com>
Authored: Sun Mar 19 23:31:00 2017 +0800
Committer: Fabian Hueske <fhue...@apache.org>
Committed: Fri Mar 24 20:19:17 2017 +0100

----------------------------------------------------------------------
 .../flink/table/plan/nodes/OverAggregate.scala  |  31 ++-
 .../datastream/DataStreamOverAggregate.scala    | 149 +++++++++---
 .../table/runtime/aggregate/AggregateUtil.scala |  48 +++-
 .../RowsClauseBoundedOverProcessFunction.scala  | 239 +++++++++++++++++++
 .../table/api/scala/stream/sql/SqlITCase.scala  | 139 ++++++++++-
 .../scala/stream/sql/WindowAggregateTest.scala  |  55 +++++
 6 files changed, 623 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
index 793ab23..91c8cef 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/OverAggregate.scala
@@ -18,12 +18,15 @@
 
 package org.apache.flink.table.plan.nodes
 
-import org.apache.calcite.rel.RelFieldCollation
+import org.apache.calcite.rel.{RelFieldCollation, RelNode}
 import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFieldImpl}
 import org.apache.calcite.rel.core.AggregateCall
 import org.apache.calcite.rel.core.Window.Group
+import org.apache.calcite.rel.core.Window
+import org.apache.calcite.rex.{RexInputRef}
 import org.apache.flink.table.runtime.aggregate.AggregateUtil._
 import org.apache.flink.table.functions.{ProcTimeType, RowTimeType}
+
 import scala.collection.JavaConverters._
 
 trait OverAggregate {
@@ -46,8 +49,16 @@ trait OverAggregate {
     orderingString
   }
 
-  private[flink] def windowRange(overWindow: Group): String = {
-    s"BETWEEN ${overWindow.lowerBound} AND ${overWindow.upperBound}"
+  private[flink] def windowRange(
+    logicWindow: Window,
+    overWindow: Group,
+    input: RelNode): String = {
+    if (overWindow.lowerBound.isPreceding && 
!overWindow.lowerBound.isUnbounded) {
+      s"BETWEEN ${getLowerBoundary(logicWindow, overWindow, input)} PRECEDING 
" +
+          s"AND ${overWindow.upperBound}"
+    } else {
+      s"BETWEEN ${overWindow.lowerBound} AND ${overWindow.upperBound}"
+    }
   }
 
   private[flink] def aggregationToString(
@@ -92,4 +103,18 @@ trait OverAggregate {
     }.mkString(", ")
   }
 
+  private[flink] def getLowerBoundary(
+    logicWindow: Window,
+    overWindow: Group,
+    input: RelNode): Long = {
+
+    val ref: RexInputRef = 
overWindow.lowerBound.getOffset.asInstanceOf[RexInputRef]
+    val lowerBoundIndex = input.getRowType.getFieldCount - ref.getIndex;
+    val lowerBound = logicWindow.constants.get(lowerBoundIndex).getValue2
+    lowerBound match {
+      case x: java.math.BigDecimal => 
x.asInstanceOf[java.math.BigDecimal].longValue()
+      case _ => lowerBound.asInstanceOf[Long]
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
index 34b3b0f..547c875 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
@@ -32,6 +32,7 @@ import org.apache.calcite.rel.core.Window
 import org.apache.calcite.rel.core.Window.Group
 import java.util.{List => JList}
 
+import org.apache.flink.api.java.functions.NullByteKeySelector
 import org.apache.flink.table.functions.{ProcTimeType, RowTimeType}
 import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
 
@@ -70,9 +71,9 @@ class DataStreamOverAggregate(
 
     super.explainTerms(pw)
       .itemIf("partitionBy", partitionToString(inputType, partitionKeys), 
partitionKeys.nonEmpty)
-        .item("orderBy",orderingToString(inputType, 
overWindow.orderKeys.getFieldCollations))
-      .itemIf("rows", windowRange(overWindow), overWindow.isRows)
-      .itemIf("range", windowRange(overWindow), !overWindow.isRows)
+      .item("orderBy",orderingToString(inputType, 
overWindow.orderKeys.getFieldCollations))
+      .itemIf("rows", windowRange(logicWindow, overWindow, getInput), 
overWindow.isRows)
+      .itemIf("range", windowRange(logicWindow, overWindow, getInput), 
!overWindow.isRows)
       .item(
         "select", aggregationToString(
           inputType,
@@ -99,20 +100,58 @@ class DataStreamOverAggregate(
       .getFieldList
       .get(overWindow.orderKeys.getFieldCollations.get(0).getFieldIndex)
       .getValue
-
     timeType match {
       case _: ProcTimeType =>
-        // both ROWS and RANGE clause with UNBOUNDED PRECEDING and CURRENT ROW 
condition.
-        if (overWindow.lowerBound.isUnbounded &&
-          overWindow.upperBound.isCurrentRow) {
+        // proc-time OVER window
+        if (overWindow.lowerBound.isUnbounded && 
overWindow.upperBound.isCurrentRow) {
+          // non-bounded OVER window
           createUnboundedAndCurrentRowProcessingTimeOverWindow(inputDS)
+        } else if (
+          overWindow.lowerBound.isPreceding && 
!overWindow.lowerBound.isUnbounded &&
+              overWindow.upperBound.isCurrentRow) {
+          // bounded OVER window
+          if (overWindow.isRows) {
+            // ROWS clause bounded OVER window
+            throw new TableException(
+              "ROWS clause bounded proc-time OVER window no supported yet.")
+          } else {
+            // RANGE clause bounded OVER window
+            throw new TableException(
+              "RANGE clause bounded proc-time OVER window no supported yet.")
+          }
         } else {
           throw new TableException(
-              "OVER window only support ProcessingTime UNBOUNDED PRECEDING and 
CURRENT ROW " +
-              "condition.")
+            "OVER window only support ProcessingTime UNBOUNDED PRECEDING and 
CURRENT ROW " +
+                "condition.")
         }
       case _: RowTimeType =>
-        throw new TableException("OVER Window of the EventTime type is not 
currently supported.")
+        // row-time OVER window
+        if (overWindow.lowerBound.isUnbounded && 
overWindow.upperBound.isCurrentRow) {
+          // non-bounded OVER window
+          if (overWindow.isRows) {
+            // ROWS clause unbounded OVER window
+            throw new TableException(
+              "ROWS clause unbounded row-time OVER window no supported yet.")
+          } else {
+            // RANGE clause unbounded OVER window
+            throw new TableException(
+              "RANGE clause unbounded row-time OVER window no supported yet.")
+          }
+        } else if (overWindow.lowerBound.isPreceding && 
!overWindow.lowerBound.isUnbounded &&
+            overWindow.upperBound.isCurrentRow) {
+          // bounded OVER window
+          if (overWindow.isRows) {
+            // ROWS clause bounded OVER window
+            createRowsClauseBoundedAndCurrentRowOverWindow(inputDS, true)
+          } else {
+            // RANGE clause bounded OVER window
+            throw new TableException(
+              "RANGE clause bounded row-time OVER window no supported yet.")
+          }
+        } else {
+          throw new TableException(
+            "row-time OVER window only support CURRENT ROW condition.")
+        }
       case _ =>
         throw new TableException(s"Unsupported time type {$timeType}")
     }
@@ -120,7 +159,7 @@ class DataStreamOverAggregate(
   }
 
   def createUnboundedAndCurrentRowProcessingTimeOverWindow(
-    inputDS: DataStream[Row]): DataStream[Row]  = {
+    inputDS: DataStream[Row]): DataStream[Row] = {
 
     val overWindow: Group = logicWindow.groups.get(0)
     val partitionKeys: Array[Int] = overWindow.keys.toArray
@@ -130,32 +169,78 @@ class DataStreamOverAggregate(
     val rowTypeInfo = 
FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
 
     val result: DataStream[Row] =
-        // partitioned aggregation
-        if (partitionKeys.nonEmpty) {
-          val processFunction = 
AggregateUtil.CreateUnboundedProcessingOverProcessFunction(
-            namedAggregates,
-            inputType)
+    // partitioned aggregation
+      if (partitionKeys.nonEmpty) {
+        val processFunction = 
AggregateUtil.createUnboundedProcessingOverProcessFunction(
+          namedAggregates,
+          inputType)
 
-          inputDS
+        inputDS
           .keyBy(partitionKeys: _*)
           .process(processFunction)
           .returns(rowTypeInfo)
           .name(aggOpName)
           .asInstanceOf[DataStream[Row]]
-        }
-        // non-partitioned aggregation
-        else {
-          val processFunction = 
AggregateUtil.CreateUnboundedProcessingOverProcessFunction(
-            namedAggregates,
-            inputType,
-            false)
-
-          inputDS
-            .process(processFunction).setParallelism(1).setMaxParallelism(1)
-            .returns(rowTypeInfo)
-            .name(aggOpName)
-            .asInstanceOf[DataStream[Row]]
-        }
+      }
+      // non-partitioned aggregation
+      else {
+        val processFunction = 
AggregateUtil.createUnboundedProcessingOverProcessFunction(
+          namedAggregates,
+          inputType,
+          false)
+
+        inputDS
+          .process(processFunction).setParallelism(1).setMaxParallelism(1)
+          .returns(rowTypeInfo)
+          .name(aggOpName)
+          .asInstanceOf[DataStream[Row]]
+      }
+    result
+  }
+
+  def createRowsClauseBoundedAndCurrentRowOverWindow(
+    inputDS: DataStream[Row],
+    isRowTimeType: Boolean = false): DataStream[Row] = {
+
+    val overWindow: Group = logicWindow.groups.get(0)
+    val partitionKeys: Array[Int] = overWindow.keys.toArray
+    val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = 
generateNamedAggregates
+    val inputFields = (0 until inputType.getFieldCount).toArray
+
+    val precedingOffset =
+      getLowerBoundary(logicWindow, overWindow, getInput()) + 1
+
+    // get the output types
+    val rowTypeInfo = 
FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
+
+    val processFunction = 
AggregateUtil.createRowsClauseBoundedOverProcessFunction(
+      namedAggregates,
+      inputType,
+      inputFields,
+      precedingOffset,
+      isRowTimeType
+    )
+    val result: DataStream[Row] =
+    // partitioned aggregation
+      if (partitionKeys.nonEmpty) {
+        inputDS
+          .keyBy(partitionKeys: _*)
+          .process(processFunction)
+          .returns(rowTypeInfo)
+          .name(aggOpName)
+          .asInstanceOf[DataStream[Row]]
+      }
+      // non-partitioned aggregation
+      else {
+        inputDS
+          .keyBy(new NullByteKeySelector[Row])
+          .process(processFunction)
+          .setParallelism(1)
+          .setMaxParallelism(1)
+          .returns(rowTypeInfo)
+          .name(aggOpName)
+          .asInstanceOf[DataStream[Row]]
+      }
     result
   }
 
@@ -180,7 +265,7 @@ class DataStreamOverAggregate(
       }
     }ORDER BY: ${orderingToString(inputType, 
overWindow.orderKeys.getFieldCollations)}, " +
       s"${if (overWindow.isRows) "ROWS" else "RANGE"}" +
-      s"${windowRange(overWindow)}, " +
+      s"${windowRange(logicWindow, overWindow, getInput)}, " +
       s"select: (${
         aggregationToString(
           inputType,

http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 9feec17..0084ee5 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -61,7 +61,7 @@ object AggregateUtil {
     * @param isPartitioned Flag to indicate whether the input is partitioned 
or not
     * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
     */
-  private[flink] def CreateUnboundedProcessingOverProcessFunction(
+  private[flink] def createUnboundedProcessingOverProcessFunction(
     namedAggregates: Seq[CalcitePair[AggregateCall, String]],
     inputType: RelDataType,
     isPartitioned: Boolean = true): ProcessFunction[Row, Row] = {
@@ -91,6 +91,52 @@ object AggregateUtil {
   }
 
   /**
+    * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] 
for ROWS clause
+    * bounded OVER window to evaluate final aggregate value.
+    *
+    * @param namedAggregates List of calls to aggregate functions and their 
output field names
+    * @param inputType       Input row type
+    * @param inputFields     All input fields
+    * @param precedingOffset the preceding offset
+    * @param isRowTimeType   It is a tag that indicates whether the time type 
is rowTimeType
+    * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
+    */
+  private[flink] def createRowsClauseBoundedOverProcessFunction(
+    namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+    inputType: RelDataType,
+    inputFields: Array[Int],
+    precedingOffset: Long,
+    isRowTimeType: Boolean): ProcessFunction[Row, Row] = {
+
+    val (aggFields, aggregates) =
+      transformToAggregateFunctions(
+        namedAggregates.map(_.getKey),
+        inputType,
+        needRetraction = true)
+
+    val aggregationStateType: RowTypeInfo =
+      createDataSetAggregateBufferDataType(Array(), aggregates, inputType)
+
+    val inputRowType: RowTypeInfo =
+      createDataSetAggregateBufferDataType(inputFields, Array(), inputType)
+
+      val processFunction = if (isRowTimeType) {
+        new RowsClauseBoundedOverProcessFunction(
+          aggregates,
+          aggFields,
+          inputType.getFieldCount,
+          aggregationStateType,
+          inputRowType,
+          precedingOffset
+        )
+      } else {
+        throw TableException(
+          "Bounded partitioned proc-time OVER aggregation is not supported 
yet.")
+      }
+      processFunction
+  }
+
+  /**
     * Create a [[org.apache.flink.api.common.functions.MapFunction]] that 
prepares for aggregates.
     * The output of the function contains the grouping keys and the timestamp 
and the intermediate
     * aggregate values of all aggregate function. The timestamp field is 
aligned to time window

http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala
new file mode 100644
index 0000000..1678d57
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import java.util
+import java.util.{List => JList}
+
+import org.apache.flink.api.common.state._
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo}
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.types.Row
+import org.apache.flink.util.{Collector, Preconditions}
+
+/**
+ * Process Function for ROWS clause event-time bounded OVER window
+ *
+ * @param aggregates           the list of all [[AggregateFunction]] used for 
this aggregation
+ * @param aggFields            the position (in the input Row) of the input 
value for each aggregate
+ * @param forwardedFieldCount  the count of forwarded fields.
+ * @param aggregationStateType the row type info of aggregation
+ * @param inputRowType         the row type info of input row
+ * @param precedingOffset      the preceding offset
+ */
+class RowsClauseBoundedOverProcessFunction(
+    private val aggregates: Array[AggregateFunction[_]],
+    private val aggFields: Array[Int],
+    private val forwardedFieldCount: Int,
+    private val aggregationStateType: RowTypeInfo,
+    private val inputRowType: RowTypeInfo,
+    private val precedingOffset: Long)
+  extends ProcessFunction[Row, Row] {
+
+  Preconditions.checkNotNull(aggregates)
+  Preconditions.checkNotNull(aggFields)
+  Preconditions.checkArgument(aggregates.length == aggFields.length)
+  Preconditions.checkNotNull(forwardedFieldCount)
+  Preconditions.checkNotNull(aggregationStateType)
+  Preconditions.checkNotNull(precedingOffset)
+
+  private var output: Row = _
+
+  // the state which keeps the last triggering timestamp
+  private var lastTriggeringTsState: ValueState[Long] = _
+
+  // the state which keeps the count of data
+  private var dataCountState: ValueState[Long] = _
+
+  // the state which used to materialize the accumulator for incremental 
calculation
+  private var accumulatorState: ValueState[Row] = _
+
+  // the state which keeps all the data that are not expired.
+  // The first element (as the mapState key) of the tuple is the time stamp. 
Per each time stamp,
+  // the second element of tuple is a list that contains the entire data of 
all the rows belonging
+  // to this time stamp.
+  private var dataState: MapState[Long, JList[Row]] = _
+
+  override def open(config: Configuration) {
+
+    output = new Row(forwardedFieldCount + aggregates.length)
+
+    val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] =
+      new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long])
+    lastTriggeringTsState = 
getRuntimeContext.getState(lastTriggeringTsDescriptor)
+
+    val dataCountStateDescriptor =
+      new ValueStateDescriptor[Long]("dataCountState", classOf[Long])
+    dataCountState = getRuntimeContext.getState(dataCountStateDescriptor)
+
+    val accumulatorStateDescriptor =
+      new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType)
+    accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor)
+
+    val keyTypeInformation: TypeInformation[Long] =
+      BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]
+    val valueTypeInformation: TypeInformation[JList[Row]] = new 
ListTypeInfo[Row](inputRowType)
+
+    val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
+      new MapStateDescriptor[Long, JList[Row]](
+        "dataState",
+        keyTypeInformation,
+        valueTypeInformation)
+
+    dataState = getRuntimeContext.getMapState(mapStateDescriptor)
+
+  }
+
+  override def processElement(
+    input: Row,
+    ctx: ProcessFunction[Row, Row]#Context,
+    out: Collector[Row]): Unit = {
+
+    // triggering timestamp for trigger calculation
+    val triggeringTs = ctx.timestamp
+
+    val lastTriggeringTs = lastTriggeringTsState.value
+    // check if the data is expired, if not, save the data and register event 
time timer
+
+    if (triggeringTs > lastTriggeringTs) {
+      val data = dataState.get(triggeringTs)
+      if (null != data) {
+        data.add(input)
+        dataState.put(triggeringTs, data)
+      } else {
+        val data = new util.ArrayList[Row]
+        data.add(input)
+        dataState.put(triggeringTs, data)
+        // register event time timer
+        ctx.timerService.registerEventTimeTimer(triggeringTs)
+      }
+    }
+  }
+
+  override def onTimer(
+    timestamp: Long,
+    ctx: ProcessFunction[Row, Row]#OnTimerContext,
+    out: Collector[Row]): Unit = {
+
+    // gets all window data from state for the calculation
+    val inputs: JList[Row] = dataState.get(timestamp)
+
+    if (null != inputs) {
+
+      var accumulators = accumulatorState.value
+      var dataCount = dataCountState.value
+
+      var retractList: JList[Row] = null
+      var retractTs: Long = Long.MaxValue
+      var retractCnt: Int = 0
+      var j = 0
+      var i = 0
+
+      while (j < inputs.size) {
+        val input = inputs.get(j)
+
+        // initialize when first run or failover recovery per key
+        if (null == accumulators) {
+          accumulators = new Row(aggregates.length)
+          i = 0
+          while (i < aggregates.length) {
+            accumulators.setField(i, aggregates(i).createAccumulator())
+            i += 1
+          }
+        }
+
+        var retractRow: Row = null
+
+        if (dataCount >= precedingOffset) {
+          if (null == retractList) {
+            // find the smallest timestamp
+            retractTs = Long.MaxValue
+            val dataTimestampIt = dataState.keys.iterator
+            while (dataTimestampIt.hasNext) {
+              val dataTs = dataTimestampIt.next
+              if (dataTs < retractTs) {
+                retractTs = dataTs
+              }
+            }
+            // get the oldest rows to retract them
+            retractList = dataState.get(retractTs)
+          }
+
+          retractRow = retractList.get(retractCnt)
+          retractCnt += 1
+
+          // remove retracted values from state
+          if (retractList.size == retractCnt) {
+            dataState.remove(retractTs)
+            retractList = null
+            retractCnt = 0
+          }
+        } else {
+          dataCount += 1
+        }
+
+        // copy forwarded fields to output row
+        i = 0
+        while (i < forwardedFieldCount) {
+          output.setField(i, input.getField(i))
+          i += 1
+        }
+
+        // retract old row from accumulators
+        if (null != retractRow) {
+          i = 0
+          while (i < aggregates.length) {
+            val accumulator = 
accumulators.getField(i).asInstanceOf[Accumulator]
+            aggregates(i).retract(accumulator, 
retractRow.getField(aggFields(i)))
+            i += 1
+          }
+        }
+
+        // accumulate current row and set aggregate in output row
+        i = 0
+        while (i < aggregates.length) {
+          val index = forwardedFieldCount + i
+          val accumulator = accumulators.getField(i).asInstanceOf[Accumulator]
+          aggregates(i).accumulate(accumulator, input.getField(aggFields(i)))
+          output.setField(index, aggregates(i).getValue(accumulator))
+          i += 1
+        }
+        j += 1
+
+        out.collect(output)
+      }
+
+      // update all states
+      if (dataState.contains(retractTs)) {
+        if (retractCnt > 0) {
+          retractList.subList(0, retractCnt).clear()
+          dataState.put(retractTs, retractList)
+        }
+      }
+      dataCountState.update(dataCount)
+      accumulatorState.update(accumulators)
+    }
+
+    lastTriggeringTsState.update(timestamp)
+  }
+}
+
+

http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
index d5a140a..19350a7 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
@@ -19,14 +19,18 @@
 package org.apache.flink.table.api.scala.stream.sql
 
 import org.apache.flink.api.scala._
+import org.apache.flink.streaming.api.functions.source.SourceFunction
+import 
org.apache.flink.table.api.scala.stream.sql.SqlITCase.EventTimeSourceFunction
 import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.streaming.api.watermark.Watermark
 import org.apache.flink.table.api.{TableEnvironment, TableException}
 import org.apache.flink.table.api.scala._
-import 
org.apache.flink.table.api.scala.stream.utils.{StreamingWithStateTestBase, 
StreamITCase,
-StreamTestData}
+import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, 
StreamTestData, StreamingWithStateTestBase}
 import org.apache.flink.types.Row
 import org.junit.Assert._
 import org.junit._
+import org.apache.flink.streaming.api.TimeCharacteristic
+import 
org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext
 
 import scala.collection.mutable
 
@@ -293,6 +297,120 @@ class SqlITCase extends StreamingWithStateTestBase {
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
+  @Test
+  def testBoundPartitionedEventTimeWindowWithRow(): Unit = {
+    val data = Seq(
+      Left((1L, (1L, 1, "Hello"))),
+      Left((2L, (2L, 2, "Hello"))),
+      Left((1L, (1L, 1, "Hello"))),
+      Left((2L, (2L, 2, "Hello"))),
+      Left((2L, (2L, 2, "Hello"))),
+      Left((1L, (1L, 1, "Hello"))),
+      Left((3L, (7L, 7, "Hello World"))),
+      Left((1L, (7L, 7, "Hello World"))),
+      Left((1L, (7L, 7, "Hello World"))),
+      Right(2L),
+      Left((3L, (3L, 3, "Hello"))),
+      Left((4L, (4L, 4, "Hello"))),
+      Left((5L, (5L, 5, "Hello"))),
+      Left((6L, (6L, 6, "Hello"))),
+      Left((20L, (20L, 20, "Hello World"))),
+      Right(6L),
+      Left((8L, (8L, 8, "Hello World"))),
+      Left((7L, (7L, 7, "Hello World"))),
+      Right(20L))
+
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t1 = env
+      .addSource[(Long, Int, String)](new EventTimeSourceFunction[(Long, Int, 
String)](data))
+      .toTable(tEnv).as('a, 'b, 'c)
+
+    tEnv.registerTable("T1", t1)
+
+    val sqlQuery = "SELECT " +
+      "c, a, " +
+      "count(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 2 
preceding AND CURRENT ROW)" +
+      ", sum(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 2 
preceding AND CURRENT ROW)" +
+      " from T1"
+
+    val result = tEnv.sql(sqlQuery).toDataStream[Row]
+    result.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList(
+      "Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3",
+      "Hello,2,3,4", "Hello,2,3,5","Hello,2,3,6",
+      "Hello,3,3,7", "Hello,4,3,9", "Hello,5,3,12",
+      "Hello,6,3,15",
+      "Hello World,7,1,7", "Hello World,7,2,14", "Hello World,7,3,21",
+      "Hello World,7,3,21", "Hello World,8,3,22", "Hello World,20,3,35")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  @Test
+  def testBoundNonPartitionedEventTimeWindowWithRow(): Unit = {
+
+    val data = Seq(
+      Left((2L, (2L, 2, "Hello"))),
+      Left((2L, (2L, 2, "Hello"))),
+      Left((1L, (1L, 1, "Hello"))),
+      Left((1L, (1L, 1, "Hello"))),
+      Left((2L, (2L, 2, "Hello"))),
+      Left((1L, (1L, 1, "Hello"))),
+      Left((20L, (20L, 20, "Hello World"))), // early row
+      Right(3L),
+      Left((2L, (2L, 2, "Hello"))), // late row
+      Left((3L, (3L, 3, "Hello"))),
+      Left((4L, (4L, 4, "Hello"))),
+      Left((5L, (5L, 5, "Hello"))),
+      Left((6L, (6L, 6, "Hello"))),
+      Left((7L, (7L, 7, "Hello World"))),
+      Right(7L),
+      Left((9L, (9L, 9, "Hello World"))),
+      Left((8L, (8L, 8, "Hello World"))),
+      Left((8L, (8L, 8, "Hello World"))),
+      Right(20L))
+
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+    env.setStateBackend(getStateBackend)
+    env.setParallelism(1)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t1 = env
+      .addSource[(Long, Int, String)](new EventTimeSourceFunction[(Long, Int, 
String)](data))
+      .toTable(tEnv).as('a, 'b, 'c)
+
+    tEnv.registerTable("T1", t1)
+
+    val sqlQuery = "SELECT " +
+      "c, a, " +
+      "count(a) OVER (ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT 
ROW)," +
+      "sum(a) OVER (ORDER BY RowTime() ROWS BETWEEN 2 preceding AND CURRENT 
ROW)" +
+      "from T1"
+
+    val result = tEnv.sql(sqlQuery).toDataStream[Row]
+    result.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList(
+      "Hello,1,1,1", "Hello,1,2,2", "Hello,1,3,3",
+      "Hello,2,3,4", "Hello,2,3,5", "Hello,2,3,6",
+      "Hello,3,3,7",
+      "Hello,4,3,9", "Hello,5,3,12",
+      "Hello,6,3,15", "Hello World,7,3,18",
+      "Hello World,8,3,21", "Hello World,8,3,23",
+      "Hello World,9,3,25",
+      "Hello World,20,3,37")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
   /**
     *  All aggregates must be computed on the same window.
     */
@@ -317,4 +435,21 @@ class SqlITCase extends StreamingWithStateTestBase {
     result.addSink(new StreamITCase.StringSink)
     env.execute()
   }
+
+}
+
+object SqlITCase {
+
+  class EventTimeSourceFunction[T](
+      dataWithTimestampList: Seq[Either[(Long, T), Long]]) extends 
SourceFunction[T] {
+    override def run(ctx: SourceContext[T]): Unit = {
+      dataWithTimestampList.foreach {
+        case Left(t) => ctx.collectWithTimestamp(t._2, t._1)
+        case Right(w) => ctx.emitWatermark(new Watermark(w))
+      }
+    }
+
+    override def cancel(): Unit = ???
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/7a9d39fe/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
index a25e59c..9a425b3 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala
@@ -239,4 +239,59 @@ class WindowAggregateTest extends TableTestBase {
       )
     streamUtil.verifySql(sql, expected)
   }
+
+  @Test
+  def testBoundPartitionedRowTimeWindowWithRow() = {
+    val sql = "SELECT " +
+        "c, " +
+        "count(a) OVER (PARTITION BY c ORDER BY RowTime() ROWS BETWEEN 5 
preceding AND " +
+        "CURRENT ROW) as cnt1 " +
+        "from MyTable"
+
+    val expected =
+      unaryNode(
+        "DataStreamCalc",
+        unaryNode(
+          "DataStreamOverAggregate",
+          unaryNode(
+            "DataStreamCalc",
+            streamTableNode(0),
+            term("select", "a", "c", "ROWTIME() AS $2")
+          ),
+          term("partitionBy", "c"),
+          term("orderBy", "ROWTIME"),
+          term("rows", "BETWEEN 5 PRECEDING AND CURRENT ROW"),
+          term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0")
+        ),
+        term("select", "c", "w0$o0 AS $1")
+      )
+    streamUtil.verifySql(sql, expected)
+  }
+
+  @Test
+  def testBoundNonPartitionedRowTimeWindowWithRow() = {
+    val sql = "SELECT " +
+        "c, " +
+        "count(a) OVER (ORDER BY RowTime() ROWS BETWEEN 5 preceding AND " +
+        "CURRENT ROW) as cnt1 " +
+        "from MyTable"
+
+    val expected =
+      unaryNode(
+        "DataStreamCalc",
+        unaryNode(
+          "DataStreamOverAggregate",
+          unaryNode(
+            "DataStreamCalc",
+            streamTableNode(0),
+            term("select", "a", "c", "ROWTIME() AS $2")
+          ),
+          term("orderBy", "ROWTIME"),
+          term("rows", "BETWEEN 5 PRECEDING AND CURRENT ROW"),
+          term("select", "a", "c", "ROWTIME", "COUNT(a) AS w0$o0")
+        ),
+        term("select", "c", "w0$o0 AS $1")
+      )
+    streamUtil.verifySql(sql, expected)
+  }
 }

Reply via email to