This is an automated email from the ASF dual-hosted git repository. shengkai pushed a commit to branch release-1.18 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.18 by this push: new f2584a1df36 [FLINK-33063][table-runtime] Fix udaf with complex user defined pojo object throw error while generate record equaliser (#23388) f2584a1df36 is described below commit f2584a1df364a14ff50b5a52fe7cf5e38d4cdc9a Author: yunhong <337361...@qq.com> AuthorDate: Tue Sep 12 11:08:48 2023 +0800 [FLINK-33063][table-runtime] Fix udaf with complex user defined pojo object throw error while generate record equaliser (#23388) Co-authored-by: zhengyunhong.zyh <zhengyunhong....@alibaba-inc.com> --- .../plan/utils/JavaUserDefinedAggFunctions.java | 80 ++++++++++++++++++++++ .../runtime/stream/sql/AggregateITCase.scala | 30 +++++++- .../table/runtime/typeutils/TypeCheckUtils.java | 12 +++- 3 files changed, 120 insertions(+), 2 deletions(-) diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/utils/JavaUserDefinedAggFunctions.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/utils/JavaUserDefinedAggFunctions.java index 64cb9bbd2bd..4e722be8ee3 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/utils/JavaUserDefinedAggFunctions.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/utils/JavaUserDefinedAggFunctions.java @@ -25,7 +25,11 @@ import org.apache.flink.table.api.dataview.ListView; import org.apache.flink.table.api.dataview.MapView; import org.apache.flink.table.functions.AggregateFunction; +import java.util.ArrayList; +import java.util.HashMap; import java.util.Iterator; +import java.util.List; +import java.util.Map; /** Test aggregator functions. */ public class JavaUserDefinedAggFunctions { @@ -421,4 +425,80 @@ public class JavaUserDefinedAggFunctions { return Tuple1.of(0L); } } + + /** User defined pojo object. */ + public static class TestObject { + private final String a; + + public TestObject(String a) { + this.a = a; + } + + public String getA() { + return a; + } + } + + /** User defined object. */ + public static class UserDefinedObject { + // List with user defined pojo object. + public List<TestObject> testObjectList = new ArrayList<>(); + // Map with user defined pojo object. + public Map<String, TestObject> testObjectMap = new HashMap<>(); + } + + /** User defined UDAF whose value and acc is user defined complex pojo object. */ + public static class UserDefinedObjectUDAF + extends AggregateFunction<UserDefinedObject, UserDefinedObject> { + private static final String KEY = "key"; + + @Override + public UserDefinedObject getValue(UserDefinedObject accumulator) { + return accumulator; + } + + @Override + public UserDefinedObject createAccumulator() { + return new UserDefinedObject(); + } + + public void accumulate(UserDefinedObject acc, String a) { + if (a != null) { + acc.testObjectList.add(new TestObject(a)); + acc.testObjectMap.put(KEY, new TestObject(a)); + } + } + + public void retract(UserDefinedObject acc, UserDefinedObject a) { + // do nothing. + } + } + + /** User defined UDAF whose value and acc is user defined complex pojo object. */ + public static class UserDefinedObjectUDAF2 + extends AggregateFunction<String, UserDefinedObject> { + private static final String KEY = "key"; + + @Override + public String getValue(UserDefinedObject accumulator) { + if (accumulator.testObjectMap.containsKey(KEY)) { + return accumulator.testObjectMap.get(KEY).getA(); + } + return null; + } + + @Override + public UserDefinedObject createAccumulator() { + return new UserDefinedObject(); + } + + public void accumulate(UserDefinedObject acc, UserDefinedObject a) { + acc.testObjectList = a.testObjectList; + acc.testObjectMap = a.testObjectMap; + } + + public void retract(UserDefinedObject acc, UserDefinedObject a) { + // do nothing + } + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala index 152c8d02be4..4df20e1615d 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala @@ -28,7 +28,7 @@ import org.apache.flink.table.api.bridge.scala._ import org.apache.flink.table.api.internal.TableEnvironmentInternal import org.apache.flink.table.planner.factories.TestValuesTableFactory import org.apache.flink.table.planner.factories.TestValuesTableFactory.{changelogRow, registerData} -import org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions.VarSumAggFunction +import org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions.{UserDefinedObjectUDAF, UserDefinedObjectUDAF2, VarSumAggFunction} import org.apache.flink.table.planner.runtime.batch.sql.agg.{MyPojoAggFunction, VarArgsAggFunction} import org.apache.flink.table.planner.runtime.utils._ import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.OverloadedMaxFunction @@ -1359,6 +1359,34 @@ class AggregateITCase(aggMode: AggMode, miniBatch: MiniBatchMode, backend: State assertEquals(expected.sorted, sink.getRetractResults.sorted) } + @Test + def testUserDefinedObjectAgg(): Unit = { + tEnv.createTemporaryFunction("user_define_object", new UserDefinedObjectUDAF) + tEnv.createTemporaryFunction("user_define_object2", new UserDefinedObjectUDAF2) + val sqlQuery = + s""" + |select t1.a, user_define_object2(t1.d) from + |(SELECT a, user_define_object(b) as d + |FROM MyTable GROUP BY a) t1 + |group by t1.a + |""".stripMargin + val data = new mutable.MutableList[(Int, String)] + data.+=((1, "Sam")) + data.+=((1, "Jerry")) + data.+=((2, "Ali")) + data.+=((3, "Grace")) + data.+=((3, "Lucas")) + + val t = failingDataSource(data).toTable(tEnv, 'a, 'b) + tEnv.createTemporaryView("MyTable", t) + + val sink = new TestingRetractSink + tEnv.sqlQuery(sqlQuery).toRetractStream[Row].addSink(sink) + env.execute() + val expected = List("1,Jerry", "2,Ali", "3,Lucas") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } + @Test def testSTDDEV(): Unit = { val sqlQuery = "SELECT STDDEV_SAMP(a), STDDEV_POP(a) FROM MyTable GROUP BY c" diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java index 1f90d48947b..c938932c389 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/TypeCheckUtils.java @@ -34,6 +34,7 @@ import static org.apache.flink.table.types.logical.LogicalTypeRoot.MAP; import static org.apache.flink.table.types.logical.LogicalTypeRoot.MULTISET; import static org.apache.flink.table.types.logical.LogicalTypeRoot.RAW; import static org.apache.flink.table.types.logical.LogicalTypeRoot.ROW; +import static org.apache.flink.table.types.logical.LogicalTypeRoot.STRUCTURED_TYPE; import static org.apache.flink.table.types.logical.LogicalTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE; import static org.apache.flink.table.types.logical.LogicalTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE; import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.isRowtimeAttribute; @@ -126,8 +127,17 @@ public class TypeCheckUtils { return type.getTypeRoot() == ROW; } + public static boolean isStructuredType(LogicalType type) { + return type.getTypeRoot() == STRUCTURED_TYPE; + } + public static boolean isComparable(LogicalType type) { - return !isRaw(type) && !isMap(type) && !isMultiset(type) && !isRow(type) && !isArray(type); + return !isRaw(type) + && !isMap(type) + && !isMultiset(type) + && !isRow(type) + && !isArray(type) + && !isStructuredType(type); } public static boolean isMutable(LogicalType type) {