This is an automated email from the ASF dual-hosted git repository. jchan pushed a commit to branch release-1.18 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.18 by this push: new d7e9abe73a2 [FLINK-31788][table] TableAggregateFunction supports emitUpdateWithRetract d7e9abe73a2 is described below commit d7e9abe73a27edc2a27182b55307ff15d88f1042 Author: Jane Chan <qingyue....@gmail.com> AuthorDate: Sun Jan 14 13:10:21 2024 +0800 [FLINK-31788][table] TableAggregateFunction supports emitUpdateWithRetract This closes #24074 (cherry picked from commit 01569644aedb56f792c7f7e04f84612d405b0bdf) --- .../exec/stream/StreamExecGroupTableAggregate.java | 1 + .../codegen/agg/AggsHandlerCodeGenerator.scala | 52 +++++++++- .../planner/codegen/agg/ImperativeAggCodeGen.scala | 21 +++- .../utils/JavaUserDefinedTableAggFunctions.java | 114 +++++++++++++++++++++ .../stream/table/TableAggregateITCase.scala | 80 ++++++++++++++- .../operators/aggregate/GroupTableAggFunction.java | 9 +- 6 files changed, 268 insertions(+), 9 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java index 0f4f80f5c94..1d9e454bd11 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java @@ -156,6 +156,7 @@ public class StreamExecGroupTableAggregate extends ExecNodeBase<RowData> accTypes, inputCountIndex, generateUpdateBefore, + generator.isIncrementalUpdate(), config.getStateRetentionTime()); final OneInputStreamOperator<RowData, RowData> operator = new KeyedProcessOperator<>(aggFunction); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala index 583e49bd035..84dd8d83858 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala @@ -21,13 +21,15 @@ import org.apache.flink.api.common.typeutils.TypeSerializer import org.apache.flink.table.api.{DataTypes, TableException} import org.apache.flink.table.data.GenericRowData import org.apache.flink.table.expressions._ -import org.apache.flink.table.functions.{DeclarativeAggregateFunction, ImperativeAggregateFunction} +import org.apache.flink.table.functions.{DeclarativeAggregateFunction, ImperativeAggregateFunction, TableAggregateFunction, UserDefinedFunctionHelper} +import org.apache.flink.table.functions.TableAggregateFunction.RetractableCollector import org.apache.flink.table.planner.JLong import org.apache.flink.table.planner.codegen._ import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.Indenter.toISC import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator._ import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver.toRexInputRef +import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils import org.apache.flink.table.planner.plan.utils.AggregateInfoList import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala import org.apache.flink.table.runtime.dataview.{DataViewSpec, ListViewSpec, MapViewSpec, StateListView, StateMapView} @@ -86,6 +88,7 @@ class AggsHandlerCodeGenerator( private var isRetractNeeded = false private var isMergeNeeded = false private var isWindowSizeNeeded = false + private var isIncrementalUpdateNeeded = false var valueType: RowType = _ @@ -166,6 +169,14 @@ class AggsHandlerCodeGenerator( this } + /** + * Whether to update acc result incrementally. The value is true only for TableAggregateFunction + * with emitUpdateWithRetract method implemented. + */ + def isIncrementalUpdate: Boolean = { + isIncrementalUpdateNeeded + } + /** * Tells the generator to generate `merge(..)` method with the merged accumulator information for * the [[AggsHandleFunction]] and [[NamespaceAggsHandleFunction]]. Default not generate @@ -234,6 +245,20 @@ class AggsHandlerCodeGenerator( constants, relBuilder) case _: ImperativeAggregateFunction[_, _] => + aggInfo.function match { + case tableAggFunc: TableAggregateFunction[_, _] => + // If the user implements both the emitValue and emitUpdateWithRetract methods, + // the emitUpdateWithRetract method will be called with priority. + if ( + UserDefinedFunctionUtils.ifMethodExistInFunction( + UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT, + tableAggFunc) + ) { + this.isIncrementalUpdateNeeded = true + } + + case _ => + } new ImperativeAggCodeGen( ctx, aggInfo, @@ -247,7 +272,8 @@ class AggsHandlerCodeGenerator( hasNamespace, mergedAccOnHeap, mergedAccExternalTypes(aggBufferOffset), - copyInputField) + copyInputField, + isIncrementalUpdateNeeded) } aggBufferOffset = aggBufferOffset + aggInfo.externalAccTypes.length codegen @@ -447,6 +473,23 @@ class AggsHandlerCodeGenerator( val recordInputName = newName("recordInput") val recordToRowDataCode = genRecordToRowData(aggExternalType, recordInputName) + // for emitUpdateWithRetract, the collector needs to implement RetractableCollector + // and override retract method + val (collectorClassName, collectorRetractCode) = + if (isIncrementalUpdateNeeded) + ( + RETRACTABLE_COLLECTOR, + s""" + |@Override + |public void retract(Object $recordInputName) throws Exception { + | $ROW_DATA tempRowData = convertToRowData($recordInputName); + | result.replace(key, tempRowData); + | result.setRowKind($ROW_KIND.DELETE); + | $COLLECTOR_TERM.collect(result); + |} + |""".stripMargin) + else (COLLECTOR, "") + val functionName = newName(name) val functionCode = j""" @@ -527,7 +570,7 @@ class AggsHandlerCodeGenerator( ${ctx.reuseCloseCode()} } - private class $CONVERT_COLLECTOR_TYPE_TERM implements $COLLECTOR { + private class $CONVERT_COLLECTOR_TYPE_TERM implements $collectorClassName { private $COLLECTOR<$ROW_DATA> $COLLECTOR_TERM; private $ROW_DATA key; private $JOINED_ROW result; @@ -562,6 +605,8 @@ class AggsHandlerCodeGenerator( $COLLECTOR_TERM.collect(result); } + $collectorRetractCode + @Override public void close() { $COLLECTOR_TERM.close(); @@ -1255,6 +1300,7 @@ object AggsHandlerCodeGenerator { val STORE_TERM = "store" val COLLECTOR: String = className[Collector[_]] + val RETRACTABLE_COLLECTOR: String = className[RetractableCollector[_]] val COLLECTOR_TERM = "out" val MEMBER_COLLECTOR_TERM = "convertCollector" val CONVERT_COLLECTOR_TYPE_TERM = "ConvertCollector" diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala index 533c956d3a3..6add23ac9a8 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.planner.codegen.agg import org.apache.flink.table.data.{GenericRowData, RowData, UpdatableRowData} import org.apache.flink.table.expressions.Expression import org.apache.flink.table.functions.{FunctionContext, ImperativeAggregateFunction, UserDefinedFunctionHelper} +import org.apache.flink.table.functions.TableAggregateFunction.RetractableCollector import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, ExprCodeGenerator, GeneratedExpression} import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.GenerateUtils.generateFieldAccess @@ -69,6 +70,9 @@ import scala.collection.mutable.ArrayBuffer * whether the accumulators state has namespace * @param inputFieldCopy * copy input field element if true (only mutable type will be copied) + * @param isIncrementalUpdateNeeded + * whether the agg supports emitting incremental update, true for TableAggregateFunction if + * user-defined function implements emitUpdateWithRetract, otherwise false. */ class ImperativeAggCodeGen( ctx: CodeGeneratorContext, @@ -83,7 +87,8 @@ class ImperativeAggCodeGen( hasNamespace: Boolean, mergedAccOnHeap: Boolean, mergedAccExternalType: DataType, - inputFieldCopy: Boolean) + inputFieldCopy: Boolean, + isIncrementalUpdateNeeded: Boolean) extends AggCodeGen { private val SINGLE_ITERABLE = className[SingleElementIterator[_]] @@ -488,10 +493,14 @@ class ImperativeAggCodeGen( } if (needEmitValue) { + val (emitMethod, collectorClass) = + if (isIncrementalUpdateNeeded) + (UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT, classOf[RetractableCollector[_]]) + else (UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT, classOf[Collector[_]]) UserDefinedFunctionHelper.validateClassForRuntime( function.getClass, - UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT, - accumulatorClass ++ Array(classOf[Collector[_]]), + emitMethod, + accumulatorClass ++ Array(collectorClass), classOf[Unit], functionName ) @@ -500,7 +509,11 @@ class ImperativeAggCodeGen( def emitValue: String = { val accTerm = if (isAccTypeInternal) accInternalTerm else accExternalTerm - s"$functionTerm.emitValue($accTerm, $MEMBER_COLLECTOR_TERM);" + val finalEmitMethodName = + if (isIncrementalUpdateNeeded) UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT + else UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT + + s"$functionTerm.$finalEmitMethodName($accTerm, $MEMBER_COLLECTOR_TERM);" } override def setWindowSize(generator: ExprCodeGenerator): String = { diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java new file mode 100644 index 00000000000..3f02991b0ce --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.runtime.utils; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.functions.TableAggregateFunction; +import org.apache.flink.util.Collector; + +/** Test table aggregate table functions. */ +public class JavaUserDefinedTableAggFunctions { + + /** Mutable accumulator of structured type for the table aggregate function. */ + public static class Top2Accumulator { + + public Integer first; + public Integer second; + public Integer previousFirst; + public Integer previousSecond; + } + + /** + * Function that takes (value INT), stores intermediate results in a structured type of {@link + * Top2Accumulator}, and returns the result as a structured type of {@link Tuple2} for value and + * rank. + */ + public static class Top2 + extends TableAggregateFunction<Tuple2<Integer, Integer>, Top2Accumulator> { + + @Override + public Top2Accumulator createAccumulator() { + Top2Accumulator acc = new Top2Accumulator(); + acc.first = Integer.MIN_VALUE; + acc.second = Integer.MIN_VALUE; + return acc; + } + + public void accumulate(Top2Accumulator acc, Integer value) { + if (value > acc.first) { + acc.second = acc.first; + acc.first = value; + } else if (value > acc.second) { + acc.second = value; + } + } + + public void merge(Top2Accumulator acc, Iterable<Top2Accumulator> it) { + for (Top2Accumulator otherAcc : it) { + accumulate(acc, otherAcc.first); + accumulate(acc, otherAcc.second); + } + } + + public void emitValue(Top2Accumulator acc, Collector<Tuple2<Integer, Integer>> out) { + // emit the value and rank + if (acc.first != Integer.MIN_VALUE) { + out.collect(Tuple2.of(acc.first, 1)); + } + if (acc.second != Integer.MIN_VALUE) { + out.collect(Tuple2.of(acc.second, 2)); + } + } + } + + /** Subclass of {@link Top2} to support emit incremental changes. */ + public static class IncrementalTop2 extends Top2 { + @Override + public Top2Accumulator createAccumulator() { + Top2Accumulator acc = super.createAccumulator(); + acc.previousFirst = Integer.MIN_VALUE; + acc.previousSecond = Integer.MIN_VALUE; + return acc; + } + + @Override + public void accumulate(Top2Accumulator acc, Integer value) { + acc.previousFirst = acc.first; + acc.previousSecond = acc.second; + super.accumulate(acc, value); + } + + public void emitUpdateWithRetract( + Top2Accumulator acc, RetractableCollector<Tuple2<Integer, Integer>> out) { + // emit the value and rank only if they're changed + if (!acc.first.equals(acc.previousFirst)) { + if (!acc.previousFirst.equals(Integer.MIN_VALUE)) { + out.retract(Tuple2.of(acc.previousFirst, 1)); + } + out.collect(Tuple2.of(acc.first, 1)); + } + if (!acc.second.equals(acc.previousSecond)) { + if (!acc.previousSecond.equals(Integer.MIN_VALUE)) { + out.retract(Tuple2.of(acc.previousSecond, 2)); + } + out.collect(Tuple2.of(acc.second, 2)); + } + } + } +} diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala index 4076739004b..79f95056a29 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala @@ -21,7 +21,7 @@ import org.apache.flink.api.common.time.Time import org.apache.flink.api.scala._ import org.apache.flink.table.api._ import org.apache.flink.table.api.bridge.scala._ -import org.apache.flink.table.planner.runtime.utils.{StreamingWithStateTestBase, TestingRetractSink} +import org.apache.flink.table.planner.runtime.utils.{JavaUserDefinedTableAggFunctions, StreamingWithStateTestBase, TestingRetractSink} import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.OverloadedDoubleMaxFunction import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode import org.apache.flink.table.planner.runtime.utils.TestData.tupleData3 @@ -43,6 +43,84 @@ class TableAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTes tEnv.getConfig.setIdleStateRetentionTime(Time.hours(1), Time.hours(2)) } + @Test + def testFlagAggregateWithOrWithoutIncrementalUpdate(): Unit = { + // Create a Table from the array of Rows + val table = tEnv.fromValues( + DataTypes.ROW( + DataTypes.FIELD("id", DataTypes.INT), + DataTypes.FIELD("name", DataTypes.STRING), + DataTypes.FIELD("price", DataTypes.INT)), + row(1, "Latte", 6: java.lang.Integer), + row(2, "Milk", 3: java.lang.Integer), + row(3, "Breve", 5: java.lang.Integer), + row(4, "Mocha", 8: java.lang.Integer), + row(5, "Tea", 4: java.lang.Integer) + ) + + // Register the table aggregate function + tEnv.createTemporarySystemFunction("top2", new JavaUserDefinedTableAggFunctions.Top2) + tEnv.createTemporarySystemFunction( + "incrementalTop2", + new JavaUserDefinedTableAggFunctions.IncrementalTop2) + + checkRank( + "top2", + List( + // output triggered by (1, "Latte", 6) + "(true,6,1)", + // output triggered by (2, "Milk", 3) + "(false,6,1)", + "(true,6,1)", + "(true,3,2)", + // output triggered by (3, "Breve", 5) + "(false,6,1)", + "(false,3,2)", + "(true,6,1)", + "(true,5,2)", + // output triggered by (4, "Mocha", 8) + "(false,6,1)", + "(false,5,2)", + "(true,8,1)", + "(true,6,2)", + // output triggered by (5, "Tea", 4) + "(false,8,1)", + "(false,6,2)", + "(true,8,1)", + "(true,6,2)" + ) + ) + checkRank( + "incrementalTop2", + List( + // output triggered by (1, "Latte", 6) + "(true,6,1)", + // output triggered by (2, "Milk", 3) + "(true,3,2)", + // output triggered by (3, "Breve", 5) + "(false,3,2)", + "(true,5,2)", + // output triggered by (4, "Mocha", 8) + "(false,6,1)", + "(true,8,1)", + "(false,5,2)", + "(true,6,2)" + ) + ) + + def checkRank(func: String, expectedResult: List[String]): Unit = { + val resultTable = + table + .flatAggregate(call(func, $("price")).as("top_price", "rank")) + .select($("top_price"), $("rank")) + + val sink = new TestingRetractSink() + resultTable.toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + assertEquals(expectedResult, sink.getRawResults) + } + } + @Test def testGroupByFlatAggregate(): Unit = { val top3 = new Top3 diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java index e951e2f4432..fb411d5d8ac 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java @@ -51,6 +51,8 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData /** Whether this operator will generate UPDATE_BEFORE messages. */ private final boolean generateUpdateBefore; + private final boolean incrementalUpdate; + /** State idle retention time which unit is MILLISECONDS. */ private final long stateRetentionTime; @@ -69,6 +71,7 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData * contain COUNT(*), i.e. doesn't contain retraction messages. We make sure there is a * COUNT(*) if input stream contains retraction. * @param generateUpdateBefore Whether this operator will generate UPDATE_BEFORE messages. + * @param incrementalUpdate Whether to update acc result incrementally. * @param stateRetentionTime state idle retention time which unit is MILLISECONDS. */ public GroupTableAggFunction( @@ -76,11 +79,13 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData LogicalType[] accTypes, int indexOfCountStar, boolean generateUpdateBefore, + boolean incrementalUpdate, long stateRetentionTime) { this.genAggsHandler = genAggsHandler; this.accTypes = accTypes; this.recordCounter = RecordCounter.of(indexOfCountStar); this.generateUpdateBefore = generateUpdateBefore; + this.incrementalUpdate = incrementalUpdate; this.stateRetentionTime = stateRetentionTime; } @@ -117,7 +122,9 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData // set accumulators to handler first function.setAccumulators(accumulators); - if (!firstRow && generateUpdateBefore) { + // when incrementalUpdate is required, there is no need to retract + // previous sent data which is not changed + if (!firstRow && !incrementalUpdate && generateUpdateBefore) { function.emitValue(out, currentKey, true); }