This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch FLINK-19449
in repository https://gitbox.apache.org/repos/asf/flink.git

commit d9012f8732101a8f5e341367653e10ce9d492605
Author: JingsongLi <lzljs3620...@aliyun.com>
AuthorDate: Fri Apr 23 17:02:56 2021 +0800

    [FLINK-19449][table-planner] LEAD/LAG cannot work correctly in streaming 
mode
---
 docs/data/sql_functions.yml                        |   6 +-
 .../functions/aggfunctions/LagAggFunction.java     | 159 +++++++++++++++
 .../stream/StreamExecGlobalWindowAggregate.java    |   4 +-
 .../stream/StreamExecLocalWindowAggregate.java     |   2 +-
 .../exec/stream/StreamExecWindowAggregate.java     |   2 +-
 .../plan/metadata/FlinkRelMdColumnInterval.scala   |  25 ++-
 .../StreamPhysicalGlobalWindowAggregate.scala      |   2 +-
 .../StreamPhysicalLocalWindowAggregate.scala       |   2 +-
 .../stream/StreamPhysicalWindowAggregate.scala     |   2 +-
 .../planner/plan/utils/AggFunctionFactory.scala    |  28 ++-
 .../table/planner/plan/utils/AggregateUtil.scala   |  27 ++-
 .../functions/aggfunctions/LagAggFunctionTest.java |  62 ++++++
 .../plan/metadata/FlinkRelMdHandlerTestBase.scala  |   9 +-
 .../runtime/stream/sql/OverAggregateITCase.scala   |  68 +++++++
 .../runtime/typeutils/LinkedListSerializer.java    | 213 +++++++++++++++++++++
 15 files changed, 576 insertions(+), 35 deletions(-)

diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml
index 6b7caa15..51df9d1 100644
--- a/docs/data/sql_functions.yml
+++ b/docs/data/sql_functions.yml
@@ -674,10 +674,10 @@ aggregate:
   - sql: ROW_NUMER()
     description: Assigns a unique, sequential number to each row, starting 
with one, according to the ordering of rows within the window partition. 
ROW_NUMBER and RANK are similar. ROW_NUMBER numbers all rows sequentially (for 
example 1, 2, 3, 4, 5). RANK provides the same numeric value for ties (for 
example 1, 2, 2, 4, 5).
   - sql: LEAD(expression [, offset] [, default])
-    description: Returns the value of expression at the offsetth row before 
the current row in the window. The default value of offset is 1 and the default 
value of default is NULL.
-  - sql: LAG(expression [, offset] [, default])
     description: Returns the value of expression at the offsetth row after the 
current row in the window. The default value of offset is 1 and the default 
value of default is NULL.
-  - sql: FIRST_VALUE(expression) 
+  - sql: LAG(expression [, offset] [, default])
+    description: Returns the value of expression at the offsetth row before 
the current row in the window. The default value of offset is 1 and the default 
value of default is NULL.
+  - sql: FIRST_VALUE(expression)
     description: Returns the first value in an ordered set of values.
   - sql: LAST_VALUE(expression)
     description: Returns the last value in an ordered set of values.
diff --git 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java
 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java
new file mode 100644
index 0000000..3333865
--- /dev/null
+++ 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.aggfunctions;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.functions.AggregateFunction;
+import 
org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
+import org.apache.flink.table.runtime.typeutils.InternalSerializers;
+import org.apache.flink.table.runtime.typeutils.LinkedListSerializer;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.utils.DataTypeUtils;
+
+import java.util.Arrays;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Objects;
+
+/** Lag {@link AggregateFunction}. */
+public class LagAggFunction<T> extends BuiltInAggregateFunction<T, 
LagAggFunction.LagAcc<T>> {
+
+    private final transient DataType[] valueDataTypes;
+
+    @SuppressWarnings("unchecked")
+    public LagAggFunction(LogicalType[] valueTypes) {
+        this.valueDataTypes =
+                Arrays.stream(valueTypes)
+                        .map(DataTypeUtils::toInternalDataType)
+                        .toArray(DataType[]::new);
+        if (valueDataTypes.length == 3
+                && valueDataTypes[2].getLogicalType().getTypeRoot() != 
LogicalTypeRoot.NULL) {
+            if (valueDataTypes[0].getConversionClass() != 
valueDataTypes[2].getConversionClass()) {
+                throw new TableException(
+                        String.format(
+                                "Please explicitly cast default value %s to 
%s.",
+                                valueDataTypes[2], valueDataTypes[1]));
+            }
+        }
+    }
+
+    // 
--------------------------------------------------------------------------------------------
+    // Planning
+    // 
--------------------------------------------------------------------------------------------
+
+    @Override
+    public List<DataType> getArgumentDataTypes() {
+        return Arrays.asList(valueDataTypes);
+    }
+
+    @Override
+    public DataType getAccumulatorDataType() {
+        return DataTypes.STRUCTURED(
+                LagAcc.class,
+                DataTypes.FIELD("offset", DataTypes.INT()),
+                DataTypes.FIELD("defaultValue", valueDataTypes[0]),
+                DataTypes.FIELD("buffer", getLinkedListType()));
+    }
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    private DataType getLinkedListType() {
+        TypeSerializer<T> serializer =
+                
InternalSerializers.create(getOutputDataType().getLogicalType());
+        return DataTypes.RAW(
+                LinkedList.class, (TypeSerializer) new 
LinkedListSerializer<>(serializer));
+    }
+
+    @Override
+    public DataType getOutputDataType() {
+        return valueDataTypes[0];
+    }
+
+    // 
--------------------------------------------------------------------------------------------
+    // Runtime
+    // 
--------------------------------------------------------------------------------------------
+
+    public void accumulate(LagAcc<T> acc, T value) throws Exception {
+        acc.buffer.add(value);
+        while (acc.buffer.size() > acc.offset + 1) {
+            acc.buffer.removeFirst();
+        }
+    }
+
+    public void accumulate(LagAcc<T> acc, T value, int offset) throws 
Exception {
+        acc.offset = offset;
+        accumulate(acc, value);
+    }
+
+    public void accumulate(LagAcc<T> acc, T value, int offset, T defaultValue) 
throws Exception {
+        acc.defaultValue = defaultValue;
+        accumulate(acc, value, offset);
+    }
+
+    public void resetAccumulator(LagAcc<T> acc) throws Exception {
+        acc.offset = 1;
+        acc.defaultValue = null;
+        acc.buffer.clear();
+    }
+
+    @Override
+    public T getValue(LagAcc<T> acc) {
+        if (acc.buffer.size() < acc.offset + 1) {
+            return acc.defaultValue;
+        } else if (acc.buffer.size() == acc.offset + 1) {
+            return acc.buffer.getFirst();
+        } else {
+            throw new TableException("Too more elements: " + acc);
+        }
+    }
+
+    @Override
+    public LagAcc<T> createAccumulator() {
+        return new LagAcc<>();
+    }
+
+    /** Accumulator for LAG. */
+    public static class LagAcc<T> {
+        public int offset = 1;
+        public T defaultValue = null;
+        public LinkedList<T> buffer = new LinkedList<>();
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) {
+                return true;
+            }
+            if (o == null || getClass() != o.getClass()) {
+                return false;
+            }
+            LagAcc<?> lagAcc = (LagAcc<?>) o;
+            return offset == lagAcc.offset
+                    && Objects.equals(defaultValue, lagAcc.defaultValue)
+                    && Objects.equals(buffer, lagAcc.buffer);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(offset, defaultValue, buffer);
+        }
+    }
+}
diff --git 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java
 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java
index 8df6f2a..41ab7a2 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java
+++ 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java
@@ -145,14 +145,14 @@ public class StreamExecGlobalWindowAggregate extends 
StreamExecWindowAggregateBa
         final SliceAssigner sliceAssigner = createSliceAssigner(windowing, 
shiftTimeZone);
 
         final AggregateInfoList localAggInfoList =
-                AggregateUtil.deriveWindowAggregateInfoList(
+                AggregateUtil.deriveStreamWindowAggregateInfoList(
                         localAggInputRowType, // should use original input here
                         
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
                         windowing.getWindow(),
                         false); // isStateBackendDataViews
 
         final AggregateInfoList globalAggInfoList =
-                AggregateUtil.deriveWindowAggregateInfoList(
+                AggregateUtil.deriveStreamWindowAggregateInfoList(
                         localAggInputRowType, // should use original input here
                         
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
                         windowing.getWindow(),
diff --git 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java
 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java
index f333255..18f8a8d 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java
+++ 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java
@@ -122,7 +122,7 @@ public class StreamExecLocalWindowAggregate extends 
StreamExecWindowAggregateBas
         final SliceAssigner sliceAssigner = createSliceAssigner(windowing, 
shiftTimeZone);
 
         final AggregateInfoList aggInfoList =
-                AggregateUtil.deriveWindowAggregateInfoList(
+                AggregateUtil.deriveStreamWindowAggregateInfoList(
                         inputRowType,
                         
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
                         windowing.getWindow(),
diff --git 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java
 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java
index 913abee..3229441 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java
+++ 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java
@@ -143,7 +143,7 @@ public class StreamExecWindowAggregate extends 
StreamExecWindowAggregateBase {
         // Hopping window requires additional COUNT(*) to determine whether to 
register next timer
         // through whether the current fired window is empty, see 
SliceSharedWindowAggProcessor.
         final AggregateInfoList aggInfoList =
-                AggregateUtil.deriveWindowAggregateInfoList(
+                AggregateUtil.deriveStreamWindowAggregateInfoList(
                         inputRowType,
                         
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
                         windowing.getWindow(),
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
index 23bd99c..f7c4641 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
@@ -562,9 +562,10 @@ class FlinkRelMdColumnInterval private extends 
MetadataHandler[ColumnInterval] {
       def getAggCallFromLocalAgg(
           index: Int,
           aggCalls: Seq[AggregateCall],
-          inputType: RelDataType): AggregateCall = {
+          inputType: RelDataType,
+          isBounded: Boolean): AggregateCall = {
         val outputIndexToAggCallIndexMap = 
AggregateUtil.getOutputIndexToAggCallIndexMap(
-          aggCalls, inputType)
+          aggCalls, inputType, isBounded)
         if (outputIndexToAggCallIndexMap.containsKey(index)) {
           val realIndex = outputIndexToAggCallIndexMap.get(index)
           aggCalls(realIndex)
@@ -576,9 +577,10 @@ class FlinkRelMdColumnInterval private extends 
MetadataHandler[ColumnInterval] {
       def getAggCallIndexInLocalAgg(
           index: Int,
           globalAggCalls: Seq[AggregateCall],
-          inputRowType: RelDataType): Integer = {
+          inputRowType: RelDataType,
+          isBounded: Boolean): Integer = {
         val outputIndexToAggCallIndexMap = 
AggregateUtil.getOutputIndexToAggCallIndexMap(
-          globalAggCalls, inputRowType)
+          globalAggCalls, inputRowType, isBounded)
 
         outputIndexToAggCallIndexMap.foreach {
           case (k, v) => if (v == index) {
@@ -600,34 +602,37 @@ class FlinkRelMdColumnInterval private extends 
MetadataHandler[ColumnInterval] {
           case agg: StreamPhysicalGlobalGroupAggregate
             if agg.aggCalls.length > aggCallIndex =>
             val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
-              aggCallIndex, agg.aggCalls, agg.localAggInputRowType)
+              aggCallIndex, agg.aggCalls, agg.localAggInputRowType, isBounded 
= false)
             if (aggCallIndexInLocalAgg != null) {
               return fmq.getColumnInterval(agg.getInput, groupSet.length + 
aggCallIndexInLocalAgg)
             } else {
               null
             }
           case agg: StreamPhysicalLocalGroupAggregate =>
-            getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls, 
agg.getInput.getRowType)
+            getAggCallFromLocalAgg(
+              aggCallIndex, agg.aggCalls, agg.getInput.getRowType, isBounded = 
false)
           case agg: StreamPhysicalIncrementalGroupAggregate
             if agg.partialAggCalls.length > aggCallIndex =>
             agg.partialAggCalls(aggCallIndex)
           case agg: StreamPhysicalGroupWindowAggregate if agg.aggCalls.length 
> aggCallIndex =>
             agg.aggCalls(aggCallIndex)
           case agg: BatchPhysicalLocalHashAggregate =>
-            getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, 
agg.getInput.getRowType)
+            getAggCallFromLocalAgg(
+              aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, 
isBounded = true)
           case agg: BatchPhysicalHashAggregate if agg.isMerge =>
             val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
-              aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
+              aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded 
= true)
             if (aggCallIndexInLocalAgg != null) {
               return fmq.getColumnInterval(agg.getInput, groupSet.length + 
aggCallIndexInLocalAgg)
             } else {
               null
             }
           case agg: BatchPhysicalLocalSortAggregate =>
-            getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, 
agg.getInput.getRowType)
+            getAggCallFromLocalAgg(
+              aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, 
isBounded = true)
           case agg: BatchPhysicalSortAggregate if agg.isMerge =>
             val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
-              aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
+              aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded 
= true)
             if (aggCallIndexInLocalAgg != null) {
               return fmq.getColumnInterval(agg.getInput, groupSet.length + 
aggCallIndexInLocalAgg)
             } else {
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala
index bef2589..bdace61 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala
@@ -63,7 +63,7 @@ class StreamPhysicalGlobalWindowAggregate(
   extends SingleRel(cluster, traitSet, inputRel)
   with StreamPhysicalRel {
 
-  private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
+  private lazy val aggInfoList = 
AggregateUtil.deriveStreamWindowAggregateInfoList(
     FlinkTypeFactory.toLogicalRowType(inputRowTypeOfLocalAgg),
     aggCalls,
     windowing.getWindow,
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala
index 518ccda..2823aab 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala
@@ -56,7 +56,7 @@ class StreamPhysicalLocalWindowAggregate(
   extends SingleRel(cluster, traitSet, inputRel)
   with StreamPhysicalRel {
 
-  private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
+  private lazy val aggInfoList = 
AggregateUtil.deriveStreamWindowAggregateInfoList(
     FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
     aggCalls,
     windowing.getWindow,
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala
index 21a1f50..eaa70e2 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala
@@ -56,7 +56,7 @@ class StreamPhysicalWindowAggregate(
   extends SingleRel(cluster, traitSet, inputRel)
   with StreamPhysicalRel {
 
-  lazy val aggInfoList: AggregateInfoList = 
AggregateUtil.deriveWindowAggregateInfoList(
+  lazy val aggInfoList: AggregateInfoList = 
AggregateUtil.deriveStreamWindowAggregateInfoList(
     FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
     aggCalls,
     windowing.getWindow,
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
index e271a74..4f8021e 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
@@ -52,7 +52,8 @@ import scala.collection.JavaConversions._
 class AggFunctionFactory(
     inputRowType: RowType,
     orderKeyIndexes: Array[Int],
-    aggCallNeedRetractions: Array[Boolean]) {
+    aggCallNeedRetractions: Array[Boolean],
+    isBounded: Boolean) {
 
   /**
     * The entry point to create an aggregate function from the given 
[[AggregateCall]].
@@ -94,8 +95,12 @@ class AggFunctionFactory(
       case a: SqlRankFunction if a.getKind == SqlKind.DENSE_RANK =>
         createDenseRankAggFunction(argTypes)
 
-      case _: SqlLeadLagAggFunction =>
-        createLeadLagAggFunction(argTypes, index)
+      case func: SqlLeadLagAggFunction =>
+        if (isBounded) {
+          createBatchLeadLagAggFunction(argTypes, index)
+        } else {
+          createStreamLeadLagAggFunction(func, argTypes, index)
+        }
 
       case _: SqlSingleValueAggFunction =>
         createSingleValueAggFunction(argTypes)
@@ -328,7 +333,22 @@ class AggFunctionFactory(
     }
   }
 
-  private def createLeadLagAggFunction(
+  private def createStreamLeadLagAggFunction(
+      func: SqlLeadLagAggFunction,
+      argTypes: Array[LogicalType],
+      index: Int): UserDefinedFunction = {
+    if (func.getKind == SqlKind.LEAD) {
+      throw new TableException("LEAD Function is not supported in stream 
mode.")
+    }
+
+    if (aggCallNeedRetractions(index)) {
+      throw new TableException("LAG Function with retraction is not supported 
in stream mode.")
+    }
+
+    new LagAggFunction(argTypes)
+  }
+
+  private def createBatchLeadLagAggFunction(
       argTypes: Array[LogicalType], index: Int): UserDefinedFunction = {
     argTypes(0).getTypeRoot match {
       case TINYINT =>
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
index 9bfcdeb..3125238 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
@@ -153,6 +153,7 @@ object AggregateUtil extends Enumeration {
   def getOutputIndexToAggCallIndexMap(
       aggregateCalls: Seq[AggregateCall],
       inputType: RelDataType,
+      isBounded: Boolean,
       orderKeyIndexes: Array[Int] = null): util.Map[Integer, Integer] = {
     val aggInfos = transformToAggregateInfoList(
       FlinkTypeFactory.toLogicalRowType(inputType),
@@ -161,7 +162,8 @@ object AggregateUtil extends Enumeration {
       orderKeyIndexes,
       needInputCount = false,
       isStateBackedDataViews = false,
-      needDistinctInfo = false).aggInfos
+      needDistinctInfo = false,
+      isBounded).aggInfos
 
     val map = new util.HashMap[Integer, Integer]()
     var outputIndex = 0
@@ -248,7 +250,7 @@ object AggregateUtil extends Enumeration {
       isStateBackendDataViews = true)
   }
 
-  def deriveWindowAggregateInfoList(
+  def deriveStreamWindowAggregateInfoList(
       inputRowType: RowType,
       aggCalls: Seq[AggregateCall],
       windowSpec: WindowSpec,
@@ -271,7 +273,8 @@ object AggregateUtil extends Enumeration {
       orderKeyIndexes = null,
       needInputCount,
       isStateBackendDataViews,
-      needDistinctInfo = true)
+      needDistinctInfo = true,
+      isBounded = false)
   }
 
   def transformToBatchAggregateFunctions(
@@ -287,7 +290,8 @@ object AggregateUtil extends Enumeration {
       orderKeyIndexes,
       needInputCount = false,
       isStateBackedDataViews = false,
-      needDistinctInfo = false).aggInfos
+      needDistinctInfo = false,
+      isBounded = true).aggInfos
 
     val aggFields = aggInfos.map(_.argIndexes)
     val bufferTypes = aggInfos.map(_.externalAccTypes)
@@ -315,7 +319,8 @@ object AggregateUtil extends Enumeration {
       orderKeyIndexes,
       needInputCount = false,
       isStateBackedDataViews = false,
-      needDistinctInfo = false)
+      needDistinctInfo = false,
+      isBounded = true)
   }
 
   def transformToStreamAggregateInfoList(
@@ -332,7 +337,8 @@ object AggregateUtil extends Enumeration {
       orderKeyIndexes = null,
       needInputCount,
       isStateBackendDataViews,
-      needDistinctInfo)
+      needDistinctInfo,
+      isBounded = false)
   }
 
   /**
@@ -355,7 +361,8 @@ object AggregateUtil extends Enumeration {
       orderKeyIndexes: Array[Int],
       needInputCount: Boolean,
       isStateBackedDataViews: Boolean,
-      needDistinctInfo: Boolean): AggregateInfoList = {
+      needDistinctInfo: Boolean,
+      isBounded: Boolean): AggregateInfoList = {
 
     // Step-1:
     // if need inputCount, find count1 in the existed aggregate calls first,
@@ -375,7 +382,11 @@ object AggregateUtil extends Enumeration {
 
     // Step-3:
     // create aggregate information
-    val factory = new AggFunctionFactory(inputRowType, orderKeyIndexes, 
aggCallNeedRetractions)
+    val factory = new AggFunctionFactory(
+      inputRowType,
+      orderKeyIndexes,
+      aggCallNeedRetractions,
+      isBounded)
     val aggInfos = newAggCalls
       .zipWithIndex
       .map { case (call, index) =>
diff --git 
a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java
 
b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java
new file mode 100644
index 0000000..e3553d8
--- /dev/null
+++ 
b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.aggfunctions;
+
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.types.logical.CharType;
+import org.apache.flink.table.types.logical.IntType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.table.data.StringData.fromString;
+
+/** Test for {@link LagAggFunction}. */
+public class LagAggFunctionTest
+        extends AggFunctionTestBase<StringData, 
LagAggFunction.LagAcc<StringData>> {
+
+    @Override
+    protected List<List<StringData>> getInputValueSets() {
+        return Arrays.asList(
+                Collections.singletonList(fromString("1")),
+                Arrays.asList(fromString("1"), null),
+                Arrays.asList(null, null),
+                Arrays.asList(null, fromString("10")));
+    }
+
+    @Override
+    protected List<StringData> getExpectedResults() {
+        return Arrays.asList(null, fromString("1"), null, null);
+    }
+
+    @Override
+    protected AggregateFunction<StringData, LagAggFunction.LagAcc<StringData>> 
getAggregator() {
+        return new LagAggFunction<>(
+                new LogicalType[] {new VarCharType(), new IntType(), new 
CharType()});
+    }
+
+    @Override
+    protected Class<?> getAccClass() {
+        return LagAggFunction.LagAcc.class;
+    }
+}
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
index 69a2d18..595eb41 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
@@ -949,7 +949,8 @@ class FlinkRelMdHandlerTestBase {
     val aggFunctionFactory = new AggFunctionFactory(
       FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
       Array.empty[Int],
-      Array.fill(aggCalls.size())(false))
+      Array.fill(aggCalls.size())(false),
+      false)
     val aggCallToAggFunction = aggCalls.zipWithIndex.map {
       case (call, index) => (call, aggFunctionFactory.createAggFunction(call, 
index))
     }
@@ -1157,7 +1158,8 @@ class FlinkRelMdHandlerTestBase {
     val aggFunctionFactory = new AggFunctionFactory(
       FlinkTypeFactory.toLogicalRowType(calcOnStudentScan.getRowType),
       Array.empty[Int],
-      Array.fill(aggCalls.size())(false))
+      Array.fill(aggCalls.size())(false),
+      false)
     val aggCallToAggFunction = aggCalls.zipWithIndex.map {
       case (call, index) => (call, aggFunctionFactory.createAggFunction(call, 
index))
     }
@@ -1324,7 +1326,8 @@ class FlinkRelMdHandlerTestBase {
     val aggFunctionFactory = new AggFunctionFactory(
       FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
       Array.empty[Int],
-      Array.fill(aggCalls.size())(false))
+      Array.fill(aggCalls.size())(false),
+      false)
     val aggCallToAggFunction = aggCalls.zipWithIndex.map {
       case (call, index) => (call, aggFunctionFactory.createAggFunction(call, 
index))
     }
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
index 6deb647..e9a8bb3 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
@@ -56,6 +56,74 @@ class OverAggregateITCase(mode: StateBackendMode) extends 
StreamingWithStateTest
   }
 
   @Test
+  def testLagFunction(): Unit = {
+    val sqlQuery = "SELECT a, b, c, " +
+        "  LAG(b) OVER(PARTITION BY a ORDER BY rowtime)," +
+        "  LAG(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," +
+        "  LAG(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY 
rowtime)" +
+        "FROM T1"
+
+    val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq(
+      Left(14000001L, (1, 1L, "Hi")),
+      Left(14000005L, (1, 2L, "Hi")),
+      Left(14000002L, (1, 3L, "Hello")),
+      Left(14000003L, (1, 4L, "Hello")),
+      Left(14000003L, (1, 5L, "Hello")),
+      Right(14000020L),
+      Left(14000021L, (1, 6L, "Hello world")),
+      Left(14000022L, (1, 7L, "Hello world")),
+      Right(14000030L))
+
+    val source = failingDataSource(data)
+    val t1 = source.transform("TimeAssigner", new 
EventTimeProcessOperator[(Int, Long, String)])
+        .setParallelism(source.parallelism)
+        .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
+
+    tEnv.registerTable("T1", t1)
+
+    val sink = new TestingAppendSink
+    tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink)
+    env.execute()
+
+    val expected = List(
+      s"1,1,Hi,null,null,10086",
+      s"1,3,Hello,1,null,10086",
+      s"1,4,Hello,4,3,3",
+      s"1,5,Hello,4,3,3",
+      s"1,2,Hi,5,4,4",
+      s"1,6,Hello world,2,5,5",
+      s"1,7,Hello world,6,2,2")
+    assertEquals(expected.sorted, sink.getAppendResults.sorted)
+  }
+
+  @Test
+  def testLeadFunction(): Unit = {
+    expectedException.expectMessage("LEAD Function is not supported in stream 
mode")
+
+    val sqlQuery = "SELECT a, b, c, " +
+        "  LEAD(b) OVER(PARTITION BY a ORDER BY rowtime)," +
+        "  LEAD(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," +
+        "  LEAD(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY 
rowtime)" +
+        "FROM T1"
+
+    val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq(
+      Left(14000001L, (1, 1L, "Hi")),
+      Left(14000003L, (1, 5L, "Hello")),
+      Right(14000020L),
+      Left(14000021L, (1, 6L, "Hello world")),
+      Left(14000022L, (1, 7L, "Hello world")),
+      Right(14000030L))
+    val source = failingDataSource(data)
+    val t1 = source.transform("TimeAssigner", new 
EventTimeProcessOperator[(Int, Long, String)])
+        .setParallelism(source.parallelism)
+        .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
+    tEnv.registerTable("T1", t1)
+    val sink = new TestingAppendSink
+    tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink)
+    env.execute()
+  }
+
+  @Test
   def testRowNumberOnOver(): Unit = {
     val t = failingDataSource(TestData.tupleData5)
       .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
diff --git 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java
 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java
new file mode 100644
index 0000000..2025735
--- /dev/null
+++ 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.typeutils;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.ListSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+
+import java.io.IOException;
+import java.util.LinkedList;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * A serializer for {@link LinkedList}. The serializer relies on an element 
serializer for the
+ * serialization of the list's elements.
+ *
+ * @param <T> The type of element in the list.
+ */
+@Internal
+public final class LinkedListSerializer<T> extends 
TypeSerializer<LinkedList<T>> {
+
+    private static final long serialVersionUID = 1L;
+
+    /** The serializer for the elements of the list. */
+    private final TypeSerializer<T> elementSerializer;
+
+    /**
+     * Creates a list serializer that uses the given serializer to serialize 
the list's elements.
+     *
+     * @param elementSerializer The serializer for the elements of the list
+     */
+    public LinkedListSerializer(TypeSerializer<T> elementSerializer) {
+        this.elementSerializer = checkNotNull(elementSerializer);
+    }
+
+    // ------------------------------------------------------------------------
+    //  ListSerializer specific properties
+    // ------------------------------------------------------------------------
+
+    /**
+     * Gets the serializer for the elements of the list.
+     *
+     * @return The serializer for the elements of the list
+     */
+    public TypeSerializer<T> getElementSerializer() {
+        return elementSerializer;
+    }
+
+    // ------------------------------------------------------------------------
+    //  Type Serializer implementation
+    // ------------------------------------------------------------------------
+
+    @Override
+    public boolean isImmutableType() {
+        return false;
+    }
+
+    @Override
+    public TypeSerializer<LinkedList<T>> duplicate() {
+        TypeSerializer<T> duplicateElement = elementSerializer.duplicate();
+        return duplicateElement == elementSerializer
+                ? this
+                : new LinkedListSerializer<>(duplicateElement);
+    }
+
+    @Override
+    public LinkedList<T> createInstance() {
+        return new LinkedList<>();
+    }
+
+    @Override
+    public LinkedList<T> copy(LinkedList<T> from) {
+        LinkedList<T> newList = new LinkedList<>();
+
+        // We iterate here rather than accessing by index, because we cannot 
be sure that
+        // the given list supports RandomAccess.
+        // The Iterator should be stack allocated on new JVMs (due to escape 
analysis)
+        for (T element : from) {
+            newList.add(elementSerializer.copy(element));
+        }
+        return newList;
+    }
+
+    @Override
+    public LinkedList<T> copy(LinkedList<T> from, LinkedList<T> reuse) {
+        return copy(from);
+    }
+
+    @Override
+    public int getLength() {
+        return -1; // var length
+    }
+
+    @Override
+    public void serialize(LinkedList<T> list, DataOutputView target) throws 
IOException {
+        final int size = list.size();
+        target.writeInt(size);
+
+        // We iterate here rather than accessing by index, because we cannot 
be sure that
+        // the given list supports RandomAccess.
+        // The Iterator should be stack allocated on new JVMs (due to escape 
analysis)
+        for (T element : list) {
+            elementSerializer.serialize(element, target);
+        }
+    }
+
+    @Override
+    public LinkedList<T> deserialize(DataInputView source) throws IOException {
+        final int size = source.readInt();
+        final LinkedList<T> list = new LinkedList<>();
+        for (int i = 0; i < size; i++) {
+            list.add(elementSerializer.deserialize(source));
+        }
+        return list;
+    }
+
+    @Override
+    public LinkedList<T> deserialize(LinkedList<T> reuse, DataInputView 
source) throws IOException {
+        return deserialize(source);
+    }
+
+    @Override
+    public void copy(DataInputView source, DataOutputView target) throws 
IOException {
+        // copy number of elements
+        final int num = source.readInt();
+        target.writeInt(num);
+        for (int i = 0; i < num; i++) {
+            elementSerializer.copy(source, target);
+        }
+    }
+
+    // --------------------------------------------------------------------
+
+    @Override
+    public boolean equals(Object obj) {
+        return obj == this
+                || (obj != null
+                        && obj.getClass() == getClass()
+                        && elementSerializer.equals(
+                                ((LinkedListSerializer<?>) 
obj).elementSerializer));
+    }
+
+    @Override
+    public int hashCode() {
+        return elementSerializer.hashCode();
+    }
+
+    // 
--------------------------------------------------------------------------------------------
+    // Serializer configuration snapshot & compatibility
+    // 
--------------------------------------------------------------------------------------------
+
+    @Override
+    public TypeSerializerSnapshot<LinkedList<T>> snapshotConfiguration() {
+        return new LinkedListSerializerSnapshot<>(this);
+    }
+
+    /** Snapshot class for the {@link ListSerializer}. */
+    public static class LinkedListSerializerSnapshot<T>
+            extends CompositeTypeSerializerSnapshot<LinkedList<T>, 
LinkedListSerializer<T>> {
+
+        private static final int CURRENT_VERSION = 1;
+
+        /** Constructor for read instantiation. */
+        public LinkedListSerializerSnapshot() {
+            super(LinkedListSerializer.class);
+        }
+
+        /** Constructor to create the snapshot for writing. */
+        public LinkedListSerializerSnapshot(LinkedListSerializer<T> 
listSerializer) {
+            super(listSerializer);
+        }
+
+        @Override
+        public int getCurrentOuterSnapshotVersion() {
+            return CURRENT_VERSION;
+        }
+
+        @Override
+        protected LinkedListSerializer<T> 
createOuterSerializerWithNestedSerializers(
+                TypeSerializer<?>[] nestedSerializers) {
+            @SuppressWarnings("unchecked")
+            TypeSerializer<T> elementSerializer = (TypeSerializer<T>) 
nestedSerializers[0];
+            return new LinkedListSerializer<>(elementSerializer);
+        }
+
+        @Override
+        protected TypeSerializer<?>[] getNestedSerializers(
+                LinkedListSerializer<T> outerSerializer) {
+            return new TypeSerializer<?>[] 
{outerSerializer.getElementSerializer()};
+        }
+    }
+}

Reply via email to