This is an automated email from the ASF dual-hosted git repository.

jchan pushed a commit to branch release-1.18
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.18 by this push:
     new d7e9abe73a2 [FLINK-31788][table] TableAggregateFunction supports 
emitUpdateWithRetract
d7e9abe73a2 is described below

commit d7e9abe73a27edc2a27182b55307ff15d88f1042
Author: Jane Chan <qingyue....@gmail.com>
AuthorDate: Sun Jan 14 13:10:21 2024 +0800

    [FLINK-31788][table] TableAggregateFunction supports emitUpdateWithRetract
    
    This closes #24074
    
    (cherry picked from commit 01569644aedb56f792c7f7e04f84612d405b0bdf)
---
 .../exec/stream/StreamExecGroupTableAggregate.java |   1 +
 .../codegen/agg/AggsHandlerCodeGenerator.scala     |  52 +++++++++-
 .../planner/codegen/agg/ImperativeAggCodeGen.scala |  21 +++-
 .../utils/JavaUserDefinedTableAggFunctions.java    | 114 +++++++++++++++++++++
 .../stream/table/TableAggregateITCase.scala        |  80 ++++++++++++++-
 .../operators/aggregate/GroupTableAggFunction.java |   9 +-
 6 files changed, 268 insertions(+), 9 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java
index 0f4f80f5c94..1d9e454bd11 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java
@@ -156,6 +156,7 @@ public class StreamExecGroupTableAggregate extends 
ExecNodeBase<RowData>
                         accTypes,
                         inputCountIndex,
                         generateUpdateBefore,
+                        generator.isIncrementalUpdate(),
                         config.getStateRetentionTime());
         final OneInputStreamOperator<RowData, RowData> operator =
                 new KeyedProcessOperator<>(aggFunction);
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
index 583e49bd035..84dd8d83858 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
@@ -21,13 +21,15 @@ import org.apache.flink.api.common.typeutils.TypeSerializer
 import org.apache.flink.table.api.{DataTypes, TableException}
 import org.apache.flink.table.data.GenericRowData
 import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.{DeclarativeAggregateFunction, 
ImperativeAggregateFunction}
+import org.apache.flink.table.functions.{DeclarativeAggregateFunction, 
ImperativeAggregateFunction, TableAggregateFunction, UserDefinedFunctionHelper}
+import 
org.apache.flink.table.functions.TableAggregateFunction.RetractableCollector
 import org.apache.flink.table.planner.JLong
 import org.apache.flink.table.planner.codegen._
 import org.apache.flink.table.planner.codegen.CodeGenUtils._
 import org.apache.flink.table.planner.codegen.Indenter.toISC
 import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator._
 import 
org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver.toRexInputRef
+import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils
 import org.apache.flink.table.planner.plan.utils.AggregateInfoList
 import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
 import org.apache.flink.table.runtime.dataview.{DataViewSpec, ListViewSpec, 
MapViewSpec, StateListView, StateMapView}
@@ -86,6 +88,7 @@ class AggsHandlerCodeGenerator(
   private var isRetractNeeded = false
   private var isMergeNeeded = false
   private var isWindowSizeNeeded = false
+  private var isIncrementalUpdateNeeded = false
 
   var valueType: RowType = _
 
@@ -166,6 +169,14 @@ class AggsHandlerCodeGenerator(
     this
   }
 
+  /**
+   * Whether to update acc result incrementally. The value is true only for 
TableAggregateFunction
+   * with emitUpdateWithRetract method implemented.
+   */
+  def isIncrementalUpdate: Boolean = {
+    isIncrementalUpdateNeeded
+  }
+
   /**
    * Tells the generator to generate `merge(..)` method with the merged 
accumulator information for
    * the [[AggsHandleFunction]] and [[NamespaceAggsHandleFunction]]. Default 
not generate
@@ -234,6 +245,20 @@ class AggsHandlerCodeGenerator(
               constants,
               relBuilder)
           case _: ImperativeAggregateFunction[_, _] =>
+            aggInfo.function match {
+              case tableAggFunc: TableAggregateFunction[_, _] =>
+                // If the user implements both the emitValue and 
emitUpdateWithRetract methods,
+                // the emitUpdateWithRetract method will be called with 
priority.
+                if (
+                  UserDefinedFunctionUtils.ifMethodExistInFunction(
+                    UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT,
+                    tableAggFunc)
+                ) {
+                  this.isIncrementalUpdateNeeded = true
+                }
+
+              case _ =>
+            }
             new ImperativeAggCodeGen(
               ctx,
               aggInfo,
@@ -247,7 +272,8 @@ class AggsHandlerCodeGenerator(
               hasNamespace,
               mergedAccOnHeap,
               mergedAccExternalTypes(aggBufferOffset),
-              copyInputField)
+              copyInputField,
+              isIncrementalUpdateNeeded)
         }
         aggBufferOffset = aggBufferOffset + aggInfo.externalAccTypes.length
         codegen
@@ -447,6 +473,23 @@ class AggsHandlerCodeGenerator(
     val recordInputName = newName("recordInput")
     val recordToRowDataCode = genRecordToRowData(aggExternalType, 
recordInputName)
 
+    // for emitUpdateWithRetract, the collector needs to implement 
RetractableCollector
+    // and override retract method
+    val (collectorClassName, collectorRetractCode) =
+      if (isIncrementalUpdateNeeded)
+        (
+          RETRACTABLE_COLLECTOR,
+          s"""
+             |@Override
+             |public void retract(Object $recordInputName) throws Exception {
+             |  $ROW_DATA tempRowData = convertToRowData($recordInputName);
+             |  result.replace(key, tempRowData);
+             |  result.setRowKind($ROW_KIND.DELETE);
+             |  $COLLECTOR_TERM.collect(result);
+             |}
+             |""".stripMargin)
+      else (COLLECTOR, "")
+
     val functionName = newName(name)
     val functionCode =
       j"""
@@ -527,7 +570,7 @@ class AggsHandlerCodeGenerator(
             ${ctx.reuseCloseCode()}
           }
 
-          private class $CONVERT_COLLECTOR_TYPE_TERM implements $COLLECTOR {
+          private class $CONVERT_COLLECTOR_TYPE_TERM implements 
$collectorClassName {
             private $COLLECTOR<$ROW_DATA> $COLLECTOR_TERM;
             private $ROW_DATA key;
             private $JOINED_ROW result;
@@ -562,6 +605,8 @@ class AggsHandlerCodeGenerator(
               $COLLECTOR_TERM.collect(result);
             }
 
+            $collectorRetractCode
+
             @Override
             public void close() {
               $COLLECTOR_TERM.close();
@@ -1255,6 +1300,7 @@ object AggsHandlerCodeGenerator {
   val STORE_TERM = "store"
 
   val COLLECTOR: String = className[Collector[_]]
+  val RETRACTABLE_COLLECTOR: String = className[RetractableCollector[_]]
   val COLLECTOR_TERM = "out"
   val MEMBER_COLLECTOR_TERM = "convertCollector"
   val CONVERT_COLLECTOR_TYPE_TERM = "ConvertCollector"
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
index 533c956d3a3..6add23ac9a8 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
@@ -20,6 +20,7 @@ package org.apache.flink.table.planner.codegen.agg
 import org.apache.flink.table.data.{GenericRowData, RowData, UpdatableRowData}
 import org.apache.flink.table.expressions.Expression
 import org.apache.flink.table.functions.{FunctionContext, 
ImperativeAggregateFunction, UserDefinedFunctionHelper}
+import 
org.apache.flink.table.functions.TableAggregateFunction.RetractableCollector
 import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, 
ExprCodeGenerator, GeneratedExpression}
 import org.apache.flink.table.planner.codegen.CodeGenUtils._
 import org.apache.flink.table.planner.codegen.GenerateUtils.generateFieldAccess
@@ -69,6 +70,9 @@ import scala.collection.mutable.ArrayBuffer
  *   whether the accumulators state has namespace
  * @param inputFieldCopy
  *   copy input field element if true (only mutable type will be copied)
+ * @param isIncrementalUpdateNeeded
+ *   whether the agg supports emitting incremental update, true for 
TableAggregateFunction if
+ *   user-defined function implements emitUpdateWithRetract, otherwise false.
  */
 class ImperativeAggCodeGen(
     ctx: CodeGeneratorContext,
@@ -83,7 +87,8 @@ class ImperativeAggCodeGen(
     hasNamespace: Boolean,
     mergedAccOnHeap: Boolean,
     mergedAccExternalType: DataType,
-    inputFieldCopy: Boolean)
+    inputFieldCopy: Boolean,
+    isIncrementalUpdateNeeded: Boolean)
   extends AggCodeGen {
 
   private val SINGLE_ITERABLE = className[SingleElementIterator[_]]
@@ -488,10 +493,14 @@ class ImperativeAggCodeGen(
     }
 
     if (needEmitValue) {
+      val (emitMethod, collectorClass) =
+        if (isIncrementalUpdateNeeded)
+          (UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT, 
classOf[RetractableCollector[_]])
+        else (UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT, 
classOf[Collector[_]])
       UserDefinedFunctionHelper.validateClassForRuntime(
         function.getClass,
-        UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT,
-        accumulatorClass ++ Array(classOf[Collector[_]]),
+        emitMethod,
+        accumulatorClass ++ Array(collectorClass),
         classOf[Unit],
         functionName
       )
@@ -500,7 +509,11 @@ class ImperativeAggCodeGen(
 
   def emitValue: String = {
     val accTerm = if (isAccTypeInternal) accInternalTerm else accExternalTerm
-    s"$functionTerm.emitValue($accTerm, $MEMBER_COLLECTOR_TERM);"
+    val finalEmitMethodName =
+      if (isIncrementalUpdateNeeded) 
UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT
+      else UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT
+
+    s"$functionTerm.$finalEmitMethodName($accTerm, $MEMBER_COLLECTOR_TERM);"
   }
 
   override def setWindowSize(generator: ExprCodeGenerator): String = {
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java
new file mode 100644
index 00000000000..3f02991b0ce
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.runtime.utils;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.table.functions.TableAggregateFunction;
+import org.apache.flink.util.Collector;
+
+/** Test table aggregate table functions. */
+public class JavaUserDefinedTableAggFunctions {
+
+    /** Mutable accumulator of structured type for the table aggregate 
function. */
+    public static class Top2Accumulator {
+
+        public Integer first;
+        public Integer second;
+        public Integer previousFirst;
+        public Integer previousSecond;
+    }
+
+    /**
+     * Function that takes (value INT), stores intermediate results in a 
structured type of {@link
+     * Top2Accumulator}, and returns the result as a structured type of {@link 
Tuple2} for value and
+     * rank.
+     */
+    public static class Top2
+            extends TableAggregateFunction<Tuple2<Integer, Integer>, 
Top2Accumulator> {
+
+        @Override
+        public Top2Accumulator createAccumulator() {
+            Top2Accumulator acc = new Top2Accumulator();
+            acc.first = Integer.MIN_VALUE;
+            acc.second = Integer.MIN_VALUE;
+            return acc;
+        }
+
+        public void accumulate(Top2Accumulator acc, Integer value) {
+            if (value > acc.first) {
+                acc.second = acc.first;
+                acc.first = value;
+            } else if (value > acc.second) {
+                acc.second = value;
+            }
+        }
+
+        public void merge(Top2Accumulator acc, Iterable<Top2Accumulator> it) {
+            for (Top2Accumulator otherAcc : it) {
+                accumulate(acc, otherAcc.first);
+                accumulate(acc, otherAcc.second);
+            }
+        }
+
+        public void emitValue(Top2Accumulator acc, Collector<Tuple2<Integer, 
Integer>> out) {
+            // emit the value and rank
+            if (acc.first != Integer.MIN_VALUE) {
+                out.collect(Tuple2.of(acc.first, 1));
+            }
+            if (acc.second != Integer.MIN_VALUE) {
+                out.collect(Tuple2.of(acc.second, 2));
+            }
+        }
+    }
+
+    /** Subclass of {@link Top2} to support emit incremental changes. */
+    public static class IncrementalTop2 extends Top2 {
+        @Override
+        public Top2Accumulator createAccumulator() {
+            Top2Accumulator acc = super.createAccumulator();
+            acc.previousFirst = Integer.MIN_VALUE;
+            acc.previousSecond = Integer.MIN_VALUE;
+            return acc;
+        }
+
+        @Override
+        public void accumulate(Top2Accumulator acc, Integer value) {
+            acc.previousFirst = acc.first;
+            acc.previousSecond = acc.second;
+            super.accumulate(acc, value);
+        }
+
+        public void emitUpdateWithRetract(
+                Top2Accumulator acc, RetractableCollector<Tuple2<Integer, 
Integer>> out) {
+            // emit the value and rank only if they're changed
+            if (!acc.first.equals(acc.previousFirst)) {
+                if (!acc.previousFirst.equals(Integer.MIN_VALUE)) {
+                    out.retract(Tuple2.of(acc.previousFirst, 1));
+                }
+                out.collect(Tuple2.of(acc.first, 1));
+            }
+            if (!acc.second.equals(acc.previousSecond)) {
+                if (!acc.previousSecond.equals(Integer.MIN_VALUE)) {
+                    out.retract(Tuple2.of(acc.previousSecond, 2));
+                }
+                out.collect(Tuple2.of(acc.second, 2));
+            }
+        }
+    }
+}
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
index 4076739004b..79f95056a29 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
@@ -21,7 +21,7 @@ import org.apache.flink.api.common.time.Time
 import org.apache.flink.api.scala._
 import org.apache.flink.table.api._
 import org.apache.flink.table.api.bridge.scala._
-import 
org.apache.flink.table.planner.runtime.utils.{StreamingWithStateTestBase, 
TestingRetractSink}
+import 
org.apache.flink.table.planner.runtime.utils.{JavaUserDefinedTableAggFunctions, 
StreamingWithStateTestBase, TestingRetractSink}
 import 
org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.OverloadedDoubleMaxFunction
 import 
org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
 import org.apache.flink.table.planner.runtime.utils.TestData.tupleData3
@@ -43,6 +43,84 @@ class TableAggregateITCase(mode: StateBackendMode) extends 
StreamingWithStateTes
     tEnv.getConfig.setIdleStateRetentionTime(Time.hours(1), Time.hours(2))
   }
 
+  @Test
+  def testFlagAggregateWithOrWithoutIncrementalUpdate(): Unit = {
+    // Create a Table from the array of Rows
+    val table = tEnv.fromValues(
+      DataTypes.ROW(
+        DataTypes.FIELD("id", DataTypes.INT),
+        DataTypes.FIELD("name", DataTypes.STRING),
+        DataTypes.FIELD("price", DataTypes.INT)),
+      row(1, "Latte", 6: java.lang.Integer),
+      row(2, "Milk", 3: java.lang.Integer),
+      row(3, "Breve", 5: java.lang.Integer),
+      row(4, "Mocha", 8: java.lang.Integer),
+      row(5, "Tea", 4: java.lang.Integer)
+    )
+
+    // Register the table aggregate function
+    tEnv.createTemporarySystemFunction("top2", new 
JavaUserDefinedTableAggFunctions.Top2)
+    tEnv.createTemporarySystemFunction(
+      "incrementalTop2",
+      new JavaUserDefinedTableAggFunctions.IncrementalTop2)
+
+    checkRank(
+      "top2",
+      List(
+        // output triggered by (1, "Latte", 6)
+        "(true,6,1)",
+        // output triggered by (2, "Milk", 3)
+        "(false,6,1)",
+        "(true,6,1)",
+        "(true,3,2)",
+        // output triggered by (3, "Breve", 5)
+        "(false,6,1)",
+        "(false,3,2)",
+        "(true,6,1)",
+        "(true,5,2)",
+        // output triggered by (4, "Mocha", 8)
+        "(false,6,1)",
+        "(false,5,2)",
+        "(true,8,1)",
+        "(true,6,2)",
+        // output triggered by (5, "Tea", 4)
+        "(false,8,1)",
+        "(false,6,2)",
+        "(true,8,1)",
+        "(true,6,2)"
+      )
+    )
+    checkRank(
+      "incrementalTop2",
+      List(
+        // output triggered by (1, "Latte", 6)
+        "(true,6,1)",
+        // output triggered by (2, "Milk", 3)
+        "(true,3,2)",
+        // output triggered by (3, "Breve", 5)
+        "(false,3,2)",
+        "(true,5,2)",
+        // output triggered by (4, "Mocha", 8)
+        "(false,6,1)",
+        "(true,8,1)",
+        "(false,5,2)",
+        "(true,6,2)"
+      )
+    )
+
+    def checkRank(func: String, expectedResult: List[String]): Unit = {
+      val resultTable =
+        table
+          .flatAggregate(call(func, $("price")).as("top_price", "rank"))
+          .select($("top_price"), $("rank"))
+
+      val sink = new TestingRetractSink()
+      resultTable.toRetractStream[Row].addSink(sink).setParallelism(1)
+      env.execute()
+      assertEquals(expectedResult, sink.getRawResults)
+    }
+  }
+
   @Test
   def testGroupByFlatAggregate(): Unit = {
     val top3 = new Top3
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
index e951e2f4432..fb411d5d8ac 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
@@ -51,6 +51,8 @@ public class GroupTableAggFunction extends 
KeyedProcessFunction<RowData, RowData
     /** Whether this operator will generate UPDATE_BEFORE messages. */
     private final boolean generateUpdateBefore;
 
+    private final boolean incrementalUpdate;
+
     /** State idle retention time which unit is MILLISECONDS. */
     private final long stateRetentionTime;
 
@@ -69,6 +71,7 @@ public class GroupTableAggFunction extends 
KeyedProcessFunction<RowData, RowData
      *     contain COUNT(*), i.e. doesn't contain retraction messages. We make 
sure there is a
      *     COUNT(*) if input stream contains retraction.
      * @param generateUpdateBefore Whether this operator will generate 
UPDATE_BEFORE messages.
+     * @param incrementalUpdate Whether to update acc result incrementally.
      * @param stateRetentionTime state idle retention time which unit is 
MILLISECONDS.
      */
     public GroupTableAggFunction(
@@ -76,11 +79,13 @@ public class GroupTableAggFunction extends 
KeyedProcessFunction<RowData, RowData
             LogicalType[] accTypes,
             int indexOfCountStar,
             boolean generateUpdateBefore,
+            boolean incrementalUpdate,
             long stateRetentionTime) {
         this.genAggsHandler = genAggsHandler;
         this.accTypes = accTypes;
         this.recordCounter = RecordCounter.of(indexOfCountStar);
         this.generateUpdateBefore = generateUpdateBefore;
+        this.incrementalUpdate = incrementalUpdate;
         this.stateRetentionTime = stateRetentionTime;
     }
 
@@ -117,7 +122,9 @@ public class GroupTableAggFunction extends 
KeyedProcessFunction<RowData, RowData
         // set accumulators to handler first
         function.setAccumulators(accumulators);
 
-        if (!firstRow && generateUpdateBefore) {
+        // when incrementalUpdate is required, there is no need to retract
+        // previous sent data which is not changed
+        if (!firstRow && !incrementalUpdate && generateUpdateBefore) {
             function.emitValue(out, currentKey, true);
         }
 

Reply via email to