This is an automated email from the ASF dual-hosted git repository. lzljs3620320 pushed a commit to branch FLINK-19449 in repository https://gitbox.apache.org/repos/asf/flink.git
commit d9012f8732101a8f5e341367653e10ce9d492605 Author: JingsongLi <lzljs3620...@aliyun.com> AuthorDate: Fri Apr 23 17:02:56 2021 +0800 [FLINK-19449][table-planner] LEAD/LAG cannot work correctly in streaming mode --- docs/data/sql_functions.yml | 6 +- .../functions/aggfunctions/LagAggFunction.java | 159 +++++++++++++++ .../stream/StreamExecGlobalWindowAggregate.java | 4 +- .../stream/StreamExecLocalWindowAggregate.java | 2 +- .../exec/stream/StreamExecWindowAggregate.java | 2 +- .../plan/metadata/FlinkRelMdColumnInterval.scala | 25 ++- .../StreamPhysicalGlobalWindowAggregate.scala | 2 +- .../StreamPhysicalLocalWindowAggregate.scala | 2 +- .../stream/StreamPhysicalWindowAggregate.scala | 2 +- .../planner/plan/utils/AggFunctionFactory.scala | 28 ++- .../table/planner/plan/utils/AggregateUtil.scala | 27 ++- .../functions/aggfunctions/LagAggFunctionTest.java | 62 ++++++ .../plan/metadata/FlinkRelMdHandlerTestBase.scala | 9 +- .../runtime/stream/sql/OverAggregateITCase.scala | 68 +++++++ .../runtime/typeutils/LinkedListSerializer.java | 213 +++++++++++++++++++++ 15 files changed, 576 insertions(+), 35 deletions(-) diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml index 6b7caa15..51df9d1 100644 --- a/docs/data/sql_functions.yml +++ b/docs/data/sql_functions.yml @@ -674,10 +674,10 @@ aggregate: - sql: ROW_NUMER() description: Assigns a unique, sequential number to each row, starting with one, according to the ordering of rows within the window partition. ROW_NUMBER and RANK are similar. ROW_NUMBER numbers all rows sequentially (for example 1, 2, 3, 4, 5). RANK provides the same numeric value for ties (for example 1, 2, 2, 4, 5). - sql: LEAD(expression [, offset] [, default]) - description: Returns the value of expression at the offsetth row before the current row in the window. The default value of offset is 1 and the default value of default is NULL. - - sql: LAG(expression [, offset] [, default]) description: Returns the value of expression at the offsetth row after the current row in the window. The default value of offset is 1 and the default value of default is NULL. - - sql: FIRST_VALUE(expression) + - sql: LAG(expression [, offset] [, default]) + description: Returns the value of expression at the offsetth row before the current row in the window. The default value of offset is 1 and the default value of default is NULL. + - sql: FIRST_VALUE(expression) description: Returns the first value in an ordered set of values. - sql: LAST_VALUE(expression) description: Returns the last value in an ordered set of values. diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java new file mode 100644 index 0000000..3333865 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.aggfunctions; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction; +import org.apache.flink.table.runtime.typeutils.InternalSerializers; +import org.apache.flink.table.runtime.typeutils.LinkedListSerializer; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.utils.DataTypeUtils; + +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; + +/** Lag {@link AggregateFunction}. */ +public class LagAggFunction<T> extends BuiltInAggregateFunction<T, LagAggFunction.LagAcc<T>> { + + private final transient DataType[] valueDataTypes; + + @SuppressWarnings("unchecked") + public LagAggFunction(LogicalType[] valueTypes) { + this.valueDataTypes = + Arrays.stream(valueTypes) + .map(DataTypeUtils::toInternalDataType) + .toArray(DataType[]::new); + if (valueDataTypes.length == 3 + && valueDataTypes[2].getLogicalType().getTypeRoot() != LogicalTypeRoot.NULL) { + if (valueDataTypes[0].getConversionClass() != valueDataTypes[2].getConversionClass()) { + throw new TableException( + String.format( + "Please explicitly cast default value %s to %s.", + valueDataTypes[2], valueDataTypes[1])); + } + } + } + + // -------------------------------------------------------------------------------------------- + // Planning + // -------------------------------------------------------------------------------------------- + + @Override + public List<DataType> getArgumentDataTypes() { + return Arrays.asList(valueDataTypes); + } + + @Override + public DataType getAccumulatorDataType() { + return DataTypes.STRUCTURED( + LagAcc.class, + DataTypes.FIELD("offset", DataTypes.INT()), + DataTypes.FIELD("defaultValue", valueDataTypes[0]), + DataTypes.FIELD("buffer", getLinkedListType())); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private DataType getLinkedListType() { + TypeSerializer<T> serializer = + InternalSerializers.create(getOutputDataType().getLogicalType()); + return DataTypes.RAW( + LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer)); + } + + @Override + public DataType getOutputDataType() { + return valueDataTypes[0]; + } + + // -------------------------------------------------------------------------------------------- + // Runtime + // -------------------------------------------------------------------------------------------- + + public void accumulate(LagAcc<T> acc, T value) throws Exception { + acc.buffer.add(value); + while (acc.buffer.size() > acc.offset + 1) { + acc.buffer.removeFirst(); + } + } + + public void accumulate(LagAcc<T> acc, T value, int offset) throws Exception { + acc.offset = offset; + accumulate(acc, value); + } + + public void accumulate(LagAcc<T> acc, T value, int offset, T defaultValue) throws Exception { + acc.defaultValue = defaultValue; + accumulate(acc, value, offset); + } + + public void resetAccumulator(LagAcc<T> acc) throws Exception { + acc.offset = 1; + acc.defaultValue = null; + acc.buffer.clear(); + } + + @Override + public T getValue(LagAcc<T> acc) { + if (acc.buffer.size() < acc.offset + 1) { + return acc.defaultValue; + } else if (acc.buffer.size() == acc.offset + 1) { + return acc.buffer.getFirst(); + } else { + throw new TableException("Too more elements: " + acc); + } + } + + @Override + public LagAcc<T> createAccumulator() { + return new LagAcc<>(); + } + + /** Accumulator for LAG. */ + public static class LagAcc<T> { + public int offset = 1; + public T defaultValue = null; + public LinkedList<T> buffer = new LinkedList<>(); + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + LagAcc<?> lagAcc = (LagAcc<?>) o; + return offset == lagAcc.offset + && Objects.equals(defaultValue, lagAcc.defaultValue) + && Objects.equals(buffer, lagAcc.buffer); + } + + @Override + public int hashCode() { + return Objects.hash(offset, defaultValue, buffer); + } + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java index 8df6f2a..41ab7a2 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java @@ -145,14 +145,14 @@ public class StreamExecGlobalWindowAggregate extends StreamExecWindowAggregateBa final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone); final AggregateInfoList localAggInfoList = - AggregateUtil.deriveWindowAggregateInfoList( + AggregateUtil.deriveStreamWindowAggregateInfoList( localAggInputRowType, // should use original input here JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), windowing.getWindow(), false); // isStateBackendDataViews final AggregateInfoList globalAggInfoList = - AggregateUtil.deriveWindowAggregateInfoList( + AggregateUtil.deriveStreamWindowAggregateInfoList( localAggInputRowType, // should use original input here JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), windowing.getWindow(), diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java index f333255..18f8a8d 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java @@ -122,7 +122,7 @@ public class StreamExecLocalWindowAggregate extends StreamExecWindowAggregateBas final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone); final AggregateInfoList aggInfoList = - AggregateUtil.deriveWindowAggregateInfoList( + AggregateUtil.deriveStreamWindowAggregateInfoList( inputRowType, JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), windowing.getWindow(), diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java index 913abee..3229441 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java @@ -143,7 +143,7 @@ public class StreamExecWindowAggregate extends StreamExecWindowAggregateBase { // Hopping window requires additional COUNT(*) to determine whether to register next timer // through whether the current fired window is empty, see SliceSharedWindowAggProcessor. final AggregateInfoList aggInfoList = - AggregateUtil.deriveWindowAggregateInfoList( + AggregateUtil.deriveStreamWindowAggregateInfoList( inputRowType, JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), windowing.getWindow(), diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala index 23bd99c..f7c4641 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala @@ -562,9 +562,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { def getAggCallFromLocalAgg( index: Int, aggCalls: Seq[AggregateCall], - inputType: RelDataType): AggregateCall = { + inputType: RelDataType, + isBounded: Boolean): AggregateCall = { val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap( - aggCalls, inputType) + aggCalls, inputType, isBounded) if (outputIndexToAggCallIndexMap.containsKey(index)) { val realIndex = outputIndexToAggCallIndexMap.get(index) aggCalls(realIndex) @@ -576,9 +577,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { def getAggCallIndexInLocalAgg( index: Int, globalAggCalls: Seq[AggregateCall], - inputRowType: RelDataType): Integer = { + inputRowType: RelDataType, + isBounded: Boolean): Integer = { val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap( - globalAggCalls, inputRowType) + globalAggCalls, inputRowType, isBounded) outputIndexToAggCallIndexMap.foreach { case (k, v) => if (v == index) { @@ -600,34 +602,37 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { case agg: StreamPhysicalGlobalGroupAggregate if agg.aggCalls.length > aggCallIndex => val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg( - aggCallIndex, agg.aggCalls, agg.localAggInputRowType) + aggCallIndex, agg.aggCalls, agg.localAggInputRowType, isBounded = false) if (aggCallIndexInLocalAgg != null) { return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg) } else { null } case agg: StreamPhysicalLocalGroupAggregate => - getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls, agg.getInput.getRowType) + getAggCallFromLocalAgg( + aggCallIndex, agg.aggCalls, agg.getInput.getRowType, isBounded = false) case agg: StreamPhysicalIncrementalGroupAggregate if agg.partialAggCalls.length > aggCallIndex => agg.partialAggCalls(aggCallIndex) case agg: StreamPhysicalGroupWindowAggregate if agg.aggCalls.length > aggCallIndex => agg.aggCalls(aggCallIndex) case agg: BatchPhysicalLocalHashAggregate => - getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType) + getAggCallFromLocalAgg( + aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true) case agg: BatchPhysicalHashAggregate if agg.isMerge => val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg( - aggCallIndex, agg.getAggCallList, agg.aggInputRowType) + aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true) if (aggCallIndexInLocalAgg != null) { return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg) } else { null } case agg: BatchPhysicalLocalSortAggregate => - getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType) + getAggCallFromLocalAgg( + aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true) case agg: BatchPhysicalSortAggregate if agg.isMerge => val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg( - aggCallIndex, agg.getAggCallList, agg.aggInputRowType) + aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true) if (aggCallIndexInLocalAgg != null) { return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg) } else { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala index bef2589..bdace61 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala @@ -63,7 +63,7 @@ class StreamPhysicalGlobalWindowAggregate( extends SingleRel(cluster, traitSet, inputRel) with StreamPhysicalRel { - private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList( + private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList( FlinkTypeFactory.toLogicalRowType(inputRowTypeOfLocalAgg), aggCalls, windowing.getWindow, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala index 518ccda..2823aab 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala @@ -56,7 +56,7 @@ class StreamPhysicalLocalWindowAggregate( extends SingleRel(cluster, traitSet, inputRel) with StreamPhysicalRel { - private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList( + private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList( FlinkTypeFactory.toLogicalRowType(inputRel.getRowType), aggCalls, windowing.getWindow, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala index 21a1f50..eaa70e2 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala @@ -56,7 +56,7 @@ class StreamPhysicalWindowAggregate( extends SingleRel(cluster, traitSet, inputRel) with StreamPhysicalRel { - lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveWindowAggregateInfoList( + lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList( FlinkTypeFactory.toLogicalRowType(inputRel.getRowType), aggCalls, windowing.getWindow, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index e271a74..4f8021e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -52,7 +52,8 @@ import scala.collection.JavaConversions._ class AggFunctionFactory( inputRowType: RowType, orderKeyIndexes: Array[Int], - aggCallNeedRetractions: Array[Boolean]) { + aggCallNeedRetractions: Array[Boolean], + isBounded: Boolean) { /** * The entry point to create an aggregate function from the given [[AggregateCall]]. @@ -94,8 +95,12 @@ class AggFunctionFactory( case a: SqlRankFunction if a.getKind == SqlKind.DENSE_RANK => createDenseRankAggFunction(argTypes) - case _: SqlLeadLagAggFunction => - createLeadLagAggFunction(argTypes, index) + case func: SqlLeadLagAggFunction => + if (isBounded) { + createBatchLeadLagAggFunction(argTypes, index) + } else { + createStreamLeadLagAggFunction(func, argTypes, index) + } case _: SqlSingleValueAggFunction => createSingleValueAggFunction(argTypes) @@ -328,7 +333,22 @@ class AggFunctionFactory( } } - private def createLeadLagAggFunction( + private def createStreamLeadLagAggFunction( + func: SqlLeadLagAggFunction, + argTypes: Array[LogicalType], + index: Int): UserDefinedFunction = { + if (func.getKind == SqlKind.LEAD) { + throw new TableException("LEAD Function is not supported in stream mode.") + } + + if (aggCallNeedRetractions(index)) { + throw new TableException("LAG Function with retraction is not supported in stream mode.") + } + + new LagAggFunction(argTypes) + } + + private def createBatchLeadLagAggFunction( argTypes: Array[LogicalType], index: Int): UserDefinedFunction = { argTypes(0).getTypeRoot match { case TINYINT => diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index 9bfcdeb..3125238 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -153,6 +153,7 @@ object AggregateUtil extends Enumeration { def getOutputIndexToAggCallIndexMap( aggregateCalls: Seq[AggregateCall], inputType: RelDataType, + isBounded: Boolean, orderKeyIndexes: Array[Int] = null): util.Map[Integer, Integer] = { val aggInfos = transformToAggregateInfoList( FlinkTypeFactory.toLogicalRowType(inputType), @@ -161,7 +162,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes, needInputCount = false, isStateBackedDataViews = false, - needDistinctInfo = false).aggInfos + needDistinctInfo = false, + isBounded).aggInfos val map = new util.HashMap[Integer, Integer]() var outputIndex = 0 @@ -248,7 +250,7 @@ object AggregateUtil extends Enumeration { isStateBackendDataViews = true) } - def deriveWindowAggregateInfoList( + def deriveStreamWindowAggregateInfoList( inputRowType: RowType, aggCalls: Seq[AggregateCall], windowSpec: WindowSpec, @@ -271,7 +273,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes = null, needInputCount, isStateBackendDataViews, - needDistinctInfo = true) + needDistinctInfo = true, + isBounded = false) } def transformToBatchAggregateFunctions( @@ -287,7 +290,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes, needInputCount = false, isStateBackedDataViews = false, - needDistinctInfo = false).aggInfos + needDistinctInfo = false, + isBounded = true).aggInfos val aggFields = aggInfos.map(_.argIndexes) val bufferTypes = aggInfos.map(_.externalAccTypes) @@ -315,7 +319,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes, needInputCount = false, isStateBackedDataViews = false, - needDistinctInfo = false) + needDistinctInfo = false, + isBounded = true) } def transformToStreamAggregateInfoList( @@ -332,7 +337,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes = null, needInputCount, isStateBackendDataViews, - needDistinctInfo) + needDistinctInfo, + isBounded = false) } /** @@ -355,7 +361,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes: Array[Int], needInputCount: Boolean, isStateBackedDataViews: Boolean, - needDistinctInfo: Boolean): AggregateInfoList = { + needDistinctInfo: Boolean, + isBounded: Boolean): AggregateInfoList = { // Step-1: // if need inputCount, find count1 in the existed aggregate calls first, @@ -375,7 +382,11 @@ object AggregateUtil extends Enumeration { // Step-3: // create aggregate information - val factory = new AggFunctionFactory(inputRowType, orderKeyIndexes, aggCallNeedRetractions) + val factory = new AggFunctionFactory( + inputRowType, + orderKeyIndexes, + aggCallNeedRetractions, + isBounded) val aggInfos = newAggCalls .zipWithIndex .map { case (call, index) => diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java new file mode 100644 index 0000000..e3553d8 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.aggfunctions; + +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.table.types.logical.CharType; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.VarCharType; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.apache.flink.table.data.StringData.fromString; + +/** Test for {@link LagAggFunction}. */ +public class LagAggFunctionTest + extends AggFunctionTestBase<StringData, LagAggFunction.LagAcc<StringData>> { + + @Override + protected List<List<StringData>> getInputValueSets() { + return Arrays.asList( + Collections.singletonList(fromString("1")), + Arrays.asList(fromString("1"), null), + Arrays.asList(null, null), + Arrays.asList(null, fromString("10"))); + } + + @Override + protected List<StringData> getExpectedResults() { + return Arrays.asList(null, fromString("1"), null, null); + } + + @Override + protected AggregateFunction<StringData, LagAggFunction.LagAcc<StringData>> getAggregator() { + return new LagAggFunction<>( + new LogicalType[] {new VarCharType(), new IntType(), new CharType()}); + } + + @Override + protected Class<?> getAccClass() { + return LagAggFunction.LagAcc.class; + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala index 69a2d18..595eb41 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala @@ -949,7 +949,8 @@ class FlinkRelMdHandlerTestBase { val aggFunctionFactory = new AggFunctionFactory( FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType), Array.empty[Int], - Array.fill(aggCalls.size())(false)) + Array.fill(aggCalls.size())(false), + false) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) } @@ -1157,7 +1158,8 @@ class FlinkRelMdHandlerTestBase { val aggFunctionFactory = new AggFunctionFactory( FlinkTypeFactory.toLogicalRowType(calcOnStudentScan.getRowType), Array.empty[Int], - Array.fill(aggCalls.size())(false)) + Array.fill(aggCalls.size())(false), + false) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) } @@ -1324,7 +1326,8 @@ class FlinkRelMdHandlerTestBase { val aggFunctionFactory = new AggFunctionFactory( FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType), Array.empty[Int], - Array.fill(aggCalls.size())(false)) + Array.fill(aggCalls.size())(false), + false) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala index 6deb647..e9a8bb3 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala @@ -56,6 +56,74 @@ class OverAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTest } @Test + def testLagFunction(): Unit = { + val sqlQuery = "SELECT a, b, c, " + + " LAG(b) OVER(PARTITION BY a ORDER BY rowtime)," + + " LAG(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," + + " LAG(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY rowtime)" + + "FROM T1" + + val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq( + Left(14000001L, (1, 1L, "Hi")), + Left(14000005L, (1, 2L, "Hi")), + Left(14000002L, (1, 3L, "Hello")), + Left(14000003L, (1, 4L, "Hello")), + Left(14000003L, (1, 5L, "Hello")), + Right(14000020L), + Left(14000021L, (1, 6L, "Hello world")), + Left(14000022L, (1, 7L, "Hello world")), + Right(14000030L)) + + val source = failingDataSource(data) + val t1 = source.transform("TimeAssigner", new EventTimeProcessOperator[(Int, Long, String)]) + .setParallelism(source.parallelism) + .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) + + tEnv.registerTable("T1", t1) + + val sink = new TestingAppendSink + tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = List( + s"1,1,Hi,null,null,10086", + s"1,3,Hello,1,null,10086", + s"1,4,Hello,4,3,3", + s"1,5,Hello,4,3,3", + s"1,2,Hi,5,4,4", + s"1,6,Hello world,2,5,5", + s"1,7,Hello world,6,2,2") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + } + + @Test + def testLeadFunction(): Unit = { + expectedException.expectMessage("LEAD Function is not supported in stream mode") + + val sqlQuery = "SELECT a, b, c, " + + " LEAD(b) OVER(PARTITION BY a ORDER BY rowtime)," + + " LEAD(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," + + " LEAD(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY rowtime)" + + "FROM T1" + + val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq( + Left(14000001L, (1, 1L, "Hi")), + Left(14000003L, (1, 5L, "Hello")), + Right(14000020L), + Left(14000021L, (1, 6L, "Hello world")), + Left(14000022L, (1, 7L, "Hello world")), + Right(14000030L)) + val source = failingDataSource(data) + val t1 = source.transform("TimeAssigner", new EventTimeProcessOperator[(Int, Long, String)]) + .setParallelism(source.parallelism) + .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) + tEnv.registerTable("T1", t1) + val sink = new TestingAppendSink + tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink) + env.execute() + } + + @Test def testRowNumberOnOver(): Unit = { val t = failingDataSource(TestData.tupleData5) .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime) diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java new file mode 100644 index 0000000..2025735 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.typeutils; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import java.io.IOException; +import java.util.LinkedList; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A serializer for {@link LinkedList}. The serializer relies on an element serializer for the + * serialization of the list's elements. + * + * @param <T> The type of element in the list. + */ +@Internal +public final class LinkedListSerializer<T> extends TypeSerializer<LinkedList<T>> { + + private static final long serialVersionUID = 1L; + + /** The serializer for the elements of the list. */ + private final TypeSerializer<T> elementSerializer; + + /** + * Creates a list serializer that uses the given serializer to serialize the list's elements. + * + * @param elementSerializer The serializer for the elements of the list + */ + public LinkedListSerializer(TypeSerializer<T> elementSerializer) { + this.elementSerializer = checkNotNull(elementSerializer); + } + + // ------------------------------------------------------------------------ + // ListSerializer specific properties + // ------------------------------------------------------------------------ + + /** + * Gets the serializer for the elements of the list. + * + * @return The serializer for the elements of the list + */ + public TypeSerializer<T> getElementSerializer() { + return elementSerializer; + } + + // ------------------------------------------------------------------------ + // Type Serializer implementation + // ------------------------------------------------------------------------ + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer<LinkedList<T>> duplicate() { + TypeSerializer<T> duplicateElement = elementSerializer.duplicate(); + return duplicateElement == elementSerializer + ? this + : new LinkedListSerializer<>(duplicateElement); + } + + @Override + public LinkedList<T> createInstance() { + return new LinkedList<>(); + } + + @Override + public LinkedList<T> copy(LinkedList<T> from) { + LinkedList<T> newList = new LinkedList<>(); + + // We iterate here rather than accessing by index, because we cannot be sure that + // the given list supports RandomAccess. + // The Iterator should be stack allocated on new JVMs (due to escape analysis) + for (T element : from) { + newList.add(elementSerializer.copy(element)); + } + return newList; + } + + @Override + public LinkedList<T> copy(LinkedList<T> from, LinkedList<T> reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; // var length + } + + @Override + public void serialize(LinkedList<T> list, DataOutputView target) throws IOException { + final int size = list.size(); + target.writeInt(size); + + // We iterate here rather than accessing by index, because we cannot be sure that + // the given list supports RandomAccess. + // The Iterator should be stack allocated on new JVMs (due to escape analysis) + for (T element : list) { + elementSerializer.serialize(element, target); + } + } + + @Override + public LinkedList<T> deserialize(DataInputView source) throws IOException { + final int size = source.readInt(); + final LinkedList<T> list = new LinkedList<>(); + for (int i = 0; i < size; i++) { + list.add(elementSerializer.deserialize(source)); + } + return list; + } + + @Override + public LinkedList<T> deserialize(LinkedList<T> reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + // copy number of elements + final int num = source.readInt(); + target.writeInt(num); + for (int i = 0; i < num; i++) { + elementSerializer.copy(source, target); + } + } + + // -------------------------------------------------------------------- + + @Override + public boolean equals(Object obj) { + return obj == this + || (obj != null + && obj.getClass() == getClass() + && elementSerializer.equals( + ((LinkedListSerializer<?>) obj).elementSerializer)); + } + + @Override + public int hashCode() { + return elementSerializer.hashCode(); + } + + // -------------------------------------------------------------------------------------------- + // Serializer configuration snapshot & compatibility + // -------------------------------------------------------------------------------------------- + + @Override + public TypeSerializerSnapshot<LinkedList<T>> snapshotConfiguration() { + return new LinkedListSerializerSnapshot<>(this); + } + + /** Snapshot class for the {@link ListSerializer}. */ + public static class LinkedListSerializerSnapshot<T> + extends CompositeTypeSerializerSnapshot<LinkedList<T>, LinkedListSerializer<T>> { + + private static final int CURRENT_VERSION = 1; + + /** Constructor for read instantiation. */ + public LinkedListSerializerSnapshot() { + super(LinkedListSerializer.class); + } + + /** Constructor to create the snapshot for writing. */ + public LinkedListSerializerSnapshot(LinkedListSerializer<T> listSerializer) { + super(listSerializer); + } + + @Override + public int getCurrentOuterSnapshotVersion() { + return CURRENT_VERSION; + } + + @Override + protected LinkedListSerializer<T> createOuterSerializerWithNestedSerializers( + TypeSerializer<?>[] nestedSerializers) { + @SuppressWarnings("unchecked") + TypeSerializer<T> elementSerializer = (TypeSerializer<T>) nestedSerializers[0]; + return new LinkedListSerializer<>(elementSerializer); + } + + @Override + protected TypeSerializer<?>[] getNestedSerializers( + LinkedListSerializer<T> outerSerializer) { + return new TypeSerializer<?>[] {outerSerializer.getElementSerializer()}; + } + } +}