This is an automated email from the ASF dual-hosted git repository.

shengkai pushed a commit to branch release-1.18
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.18 by this push:
     new f2584a1df36 [FLINK-33063][table-runtime] Fix udaf with complex user 
defined pojo object throw error while generate record equaliser (#23388)
f2584a1df36 is described below

commit f2584a1df364a14ff50b5a52fe7cf5e38d4cdc9a
Author: yunhong <337361...@qq.com>
AuthorDate: Tue Sep 12 11:08:48 2023 +0800

    [FLINK-33063][table-runtime] Fix udaf with complex user defined pojo object 
throw error while generate record equaliser (#23388)
    
    Co-authored-by: zhengyunhong.zyh <zhengyunhong....@alibaba-inc.com>
---
 .../plan/utils/JavaUserDefinedAggFunctions.java    | 80 ++++++++++++++++++++++
 .../runtime/stream/sql/AggregateITCase.scala       | 30 +++++++-
 .../table/runtime/typeutils/TypeCheckUtils.java    | 12 +++-
 3 files changed, 120 insertions(+), 2 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/utils/JavaUserDefinedAggFunctions.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/utils/JavaUserDefinedAggFunctions.java
index 64cb9bbd2bd..4e722be8ee3 100644
--- 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/utils/JavaUserDefinedAggFunctions.java
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/utils/JavaUserDefinedAggFunctions.java
@@ -25,7 +25,11 @@ import org.apache.flink.table.api.dataview.ListView;
 import org.apache.flink.table.api.dataview.MapView;
 import org.apache.flink.table.functions.AggregateFunction;
 
+import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
 
 /** Test aggregator functions. */
 public class JavaUserDefinedAggFunctions {
@@ -421,4 +425,80 @@ public class JavaUserDefinedAggFunctions {
             return Tuple1.of(0L);
         }
     }
+
+    /** User defined pojo object. */
+    public static class TestObject {
+        private final String a;
+
+        public TestObject(String a) {
+            this.a = a;
+        }
+
+        public String getA() {
+            return a;
+        }
+    }
+
+    /** User defined object. */
+    public static class UserDefinedObject {
+        // List with user defined pojo object.
+        public List<TestObject> testObjectList = new ArrayList<>();
+        // Map with user defined pojo object.
+        public Map<String, TestObject> testObjectMap = new HashMap<>();
+    }
+
+    /** User defined UDAF whose value and acc is user defined complex pojo 
object. */
+    public static class UserDefinedObjectUDAF
+            extends AggregateFunction<UserDefinedObject, UserDefinedObject> {
+        private static final String KEY = "key";
+
+        @Override
+        public UserDefinedObject getValue(UserDefinedObject accumulator) {
+            return accumulator;
+        }
+
+        @Override
+        public UserDefinedObject createAccumulator() {
+            return new UserDefinedObject();
+        }
+
+        public void accumulate(UserDefinedObject acc, String a) {
+            if (a != null) {
+                acc.testObjectList.add(new TestObject(a));
+                acc.testObjectMap.put(KEY, new TestObject(a));
+            }
+        }
+
+        public void retract(UserDefinedObject acc, UserDefinedObject a) {
+            // do nothing.
+        }
+    }
+
+    /** User defined UDAF whose value and acc is user defined complex pojo 
object. */
+    public static class UserDefinedObjectUDAF2
+            extends AggregateFunction<String, UserDefinedObject> {
+        private static final String KEY = "key";
+
+        @Override
+        public String getValue(UserDefinedObject accumulator) {
+            if (accumulator.testObjectMap.containsKey(KEY)) {
+                return accumulator.testObjectMap.get(KEY).getA();
+            }
+            return null;
+        }
+
+        @Override
+        public UserDefinedObject createAccumulator() {
+            return new UserDefinedObject();
+        }
+
+        public void accumulate(UserDefinedObject acc, UserDefinedObject a) {
+            acc.testObjectList = a.testObjectList;
+            acc.testObjectMap = a.testObjectMap;
+        }
+
+        public void retract(UserDefinedObject acc, UserDefinedObject a) {
+            // do nothing
+        }
+    }
 }
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
index 152c8d02be4..4df20e1615d 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
@@ -28,7 +28,7 @@ import org.apache.flink.table.api.bridge.scala._
 import org.apache.flink.table.api.internal.TableEnvironmentInternal
 import org.apache.flink.table.planner.factories.TestValuesTableFactory
 import 
org.apache.flink.table.planner.factories.TestValuesTableFactory.{changelogRow, 
registerData}
-import 
org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions.VarSumAggFunction
+import 
org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions.{UserDefinedObjectUDAF,
 UserDefinedObjectUDAF2, VarSumAggFunction}
 import 
org.apache.flink.table.planner.runtime.batch.sql.agg.{MyPojoAggFunction, 
VarArgsAggFunction}
 import org.apache.flink.table.planner.runtime.utils._
 import 
org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.OverloadedMaxFunction
@@ -1359,6 +1359,34 @@ class AggregateITCase(aggMode: AggMode, miniBatch: 
MiniBatchMode, backend: State
     assertEquals(expected.sorted, sink.getRetractResults.sorted)
   }
 
+  @Test
+  def testUserDefinedObjectAgg(): Unit = {
+    tEnv.createTemporaryFunction("user_define_object", new 
UserDefinedObjectUDAF)
+    tEnv.createTemporaryFunction("user_define_object2", new 
UserDefinedObjectUDAF2)
+    val sqlQuery =
+      s"""
+         |select t1.a, user_define_object2(t1.d) from 
+         |(SELECT a, user_define_object(b) as d
+         |FROM MyTable GROUP BY a) t1
+         |group by t1.a
+         |""".stripMargin
+    val data = new mutable.MutableList[(Int, String)]
+    data.+=((1, "Sam"))
+    data.+=((1, "Jerry"))
+    data.+=((2, "Ali"))
+    data.+=((3, "Grace"))
+    data.+=((3, "Lucas"))
+
+    val t = failingDataSource(data).toTable(tEnv, 'a, 'b)
+    tEnv.createTemporaryView("MyTable", t)
+
+    val sink = new TestingRetractSink
+    tEnv.sqlQuery(sqlQuery).toRetractStream[Row].addSink(sink)
+    env.execute()
+    val expected = List("1,Jerry", "2,Ali", "3,Lucas")
+    assertEquals(expected.sorted, sink.getRetractResults.sorted)
+  }
+
   @Test
   def testSTDDEV(): Unit = {
     val sqlQuery = "SELECT STDDEV_SAMP(a), STDDEV_POP(a) FROM MyTable GROUP BY 
c"
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java
index 1f90d48947b..c938932c389 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java
@@ -34,6 +34,7 @@ import static 
org.apache.flink.table.types.logical.LogicalTypeRoot.MAP;
 import static org.apache.flink.table.types.logical.LogicalTypeRoot.MULTISET;
 import static org.apache.flink.table.types.logical.LogicalTypeRoot.RAW;
 import static org.apache.flink.table.types.logical.LogicalTypeRoot.ROW;
+import static 
org.apache.flink.table.types.logical.LogicalTypeRoot.STRUCTURED_TYPE;
 import static 
org.apache.flink.table.types.logical.LogicalTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE;
 import static 
org.apache.flink.table.types.logical.LogicalTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
 import static 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.isRowtimeAttribute;
@@ -126,8 +127,17 @@ public class TypeCheckUtils {
         return type.getTypeRoot() == ROW;
     }
 
+    public static boolean isStructuredType(LogicalType type) {
+        return type.getTypeRoot() == STRUCTURED_TYPE;
+    }
+
     public static boolean isComparable(LogicalType type) {
-        return !isRaw(type) && !isMap(type) && !isMultiset(type) && 
!isRow(type) && !isArray(type);
+        return !isRaw(type)
+                && !isMap(type)
+                && !isMultiset(type)
+                && !isRow(type)
+                && !isArray(type)
+                && !isStructuredType(type);
     }
 
     public static boolean isMutable(LogicalType type) {

Reply via email to