This is an automated email from the ASF dual-hosted git repository. fpaul pushed a commit to branch release-2.1 in repository https://gitbox.apache.org/repos/asf/flink.git
commit 74404f8e8bc8e711114b01b351cbd0f81e08e6ce Author: Gustavo de Morais <[email protected]> AuthorDate: Fri Oct 24 11:57:15 2025 +0200 [FLINK-38554][table] Fix rowCount cost for FlinkLogicalMultiJoin --- .../plan/nodes/logical/FlinkLogicalMultiJoin.java | 2 +- .../planner/plan/stream/sql/MultiJoinTest.java | 39 ++++++++++++++ .../planner/plan/stream/sql/MultiJoinTest.xml | 62 ++++++++++++++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalMultiJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalMultiJoin.java index 5429629f991..21bb14c73a0 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalMultiJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/logical/FlinkLogicalMultiJoin.java @@ -153,7 +153,7 @@ public class FlinkLogicalMultiJoin extends AbstractRelNode implements FlinkLogic final Double averageRowSize = mq.getAverageRowSize(input); final double dAverageRowSize = averageRowSize == null ? 100.0 : averageRowSize; - rowCount *= inputRowCount; + rowCount += inputRowCount; cpu += inputRowCount; io += inputRowCount * dAverageRowSize; } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.java index a6a488d3db3..8121d763d84 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.java @@ -171,6 +171,45 @@ public class MultiJoinTest extends TableTestBase { + "LEFT JOIN Payments p ON u.user_id_0 = p.user_id_2"); } + @Test + void testTwoWayJoinWithUnion() { + util.tableEnv() + .executeSql( + "CREATE TABLE Orders2 (" + + " order_id STRING PRIMARY KEY NOT ENFORCED," + + " user_id_1 STRING," + + " product STRING" + + ") WITH ('connector' = 'values', 'changelog-mode' = 'I,D')"); + + util.verifyRelPlan( + "WITH OrdersUnion as (" + + "SELECT * FROM Orders " + + "UNION ALL " + + "SELECT * FROM Orders2" + + ") " + + "SELECT * FROM OrdersUnion o " + + "LEFT JOIN Users u " + + "ON o.user_id_1 = u.user_id_0"); + } + + @Test + void testTwoWayJoinWithRank() { + util.getTableEnv() + .getConfig() + .set(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED, true); + + util.verifyRelPlan( + "WITH JoinedEvents as (" + + "SELECT e1.id as id, e1.val, e1.rowtime as `rowtime`, e2.price " + + "FROM EventTable1 e1 " + + "JOIN EventTable2 e2 ON e1.id = e2.id) " + + "SELECT id, val, `rowtime` FROM (" + + "SELECT *, " + + "ROW_NUMBER() OVER (PARTITION BY id ORDER BY `rowtime` DESC) as ts " + + "FROM JoinedEvents) " + + "WHERE ts = 1"); + } + @Test void testFourWayComplexJoinRelPlan() { util.verifyRelPlan( diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.xml index ba05d7b6778..7df8d36563a 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MultiJoinTest.xml @@ -703,6 +703,68 @@ Calc(select=[user_id_0, CAST('Gus' AS VARCHAR(2147483647)) AS name, order_id, CA +- Exchange(distribution=[hash[user_id_2]]) +- Calc(select=[payment_id, price, user_id_2], where=[(price > 10)]) +- TableSourceScan(table=[[default_catalog, default_database, Payments, filter=[]]], fields=[payment_id, price, user_id_2]) +]]> + </Resource> + </TestCase> + <TestCase name="testTwoWayJoinWithUnion"> + <Resource name="sql"> + <![CDATA[WITH OrdersUnion as (SELECT * FROM Orders UNION ALL SELECT * FROM Orders2) SELECT * FROM OrdersUnion o LEFT JOIN Users u ON o.user_id_1 = u.user_id_0]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(order_id=[$0], user_id_1=[$1], product=[$2], user_id_0=[$3], name=[$4], cash=[$5]) ++- LogicalJoin(condition=[=($1, $3)], joinType=[left]) + :- LogicalUnion(all=[true]) + : :- LogicalProject(order_id=[$0], user_id_1=[$1], product=[$2]) + : : +- LogicalTableScan(table=[[default_catalog, default_database, Orders]]) + : +- LogicalProject(order_id=[$0], user_id_1=[$1], product=[$2]) + : +- LogicalTableScan(table=[[default_catalog, default_database, Orders2]]) + +- LogicalTableScan(table=[[default_catalog, default_database, Users]]) +]]> + </Resource> + <Resource name="optimized rel plan"> + <![CDATA[ +MultiJoin(commonJoinKey=[user_id_1], joinTypes=[LEFT], inputUniqueKeys=[noUniqueKey, (user_id_0)], joinConditions=[=(user_id_1, user_id_0)], select=[order_id,user_id_1,product,user_id_0,name,cash], rowType=[RecordType(VARCHAR(2147483647) order_id, VARCHAR(2147483647) user_id_1, VARCHAR(2147483647) product, VARCHAR(2147483647) user_id_0, VARCHAR(2147483647) name, INTEGER cash)]) +:- Exchange(distribution=[hash[user_id_1]]) +: +- Union(all=[true], union=[order_id, user_id_1, product]) +: :- TableSourceScan(table=[[default_catalog, default_database, Orders]], fields=[order_id, user_id_1, product]) +: +- TableSourceScan(table=[[default_catalog, default_database, Orders2]], fields=[order_id, user_id_1, product]) ++- Exchange(distribution=[hash[user_id_0]]) + +- TableSourceScan(table=[[default_catalog, default_database, Users]], fields=[user_id_0, name, cash]) +]]> + </Resource> + </TestCase> + <TestCase name="testTwoWayJoinWithRank"> + <Resource name="sql"> + <![CDATA[WITH JoinedEvents as (SELECT e1.id as id, e1.val, e1.rowtime as `rowtime`, e2.price FROM EventTable1 e1 JOIN EventTable2 e2 ON e1.id = e2.id) SELECT id, val, `rowtime` FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY `rowtime` DESC) as ts FROM JoinedEvents) WHERE ts = 1]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(id=[$0], val=[$1], rowtime=[$2]) ++- LogicalFilter(condition=[=($4, 1)]) + +- LogicalProject(id=[$0], val=[$1], rowtime=[$2], price=[$3], ts=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY $2 DESC NULLS LAST)]) + +- LogicalProject(id=[$0], val=[$1], rowtime=[$2], price=[$4]) + +- LogicalJoin(condition=[=($0, $3)], joinType=[inner]) + :- LogicalWatermarkAssigner(rowtime=[rowtime], watermark=[-($2, 5000:INTERVAL SECOND)]) + : +- LogicalTableScan(table=[[default_catalog, default_database, EventTable1]]) + +- LogicalWatermarkAssigner(rowtime=[rowtime], watermark=[-($2, 5000:INTERVAL SECOND)]) + +- LogicalTableScan(table=[[default_catalog, default_database, EventTable2]]) +]]> + </Resource> + <Resource name="optimized rel plan"> + <![CDATA[ +Rank(strategy=[AppendFastStrategy], rankType=[ROW_NUMBER], rankRange=[rankStart=1, rankEnd=1], partitionBy=[id], orderBy=[rowtime DESC], select=[id, val, rowtime]) ++- Exchange(distribution=[hash[id]]) + +- Calc(select=[id, val, rowtime]) + +- MultiJoin(commonJoinKey=[id], joinTypes=[INNER], inputUniqueKeys=[noUniqueKey, noUniqueKey], joinConditions=[=(id, id0)], select=[id,val,rowtime,id0,price,rowtime0], rowType=[RecordType(VARCHAR(2147483647) id, INTEGER val, TIMESTAMP(3) rowtime, VARCHAR(2147483647) id0, DOUBLE price, TIMESTAMP(3) rowtime0)]) + :- Exchange(distribution=[hash[id]]) + : +- Calc(select=[id, val, CAST(rowtime AS TIMESTAMP(3)) AS rowtime]) + : +- WatermarkAssigner(rowtime=[rowtime], watermark=[-(rowtime, 5000:INTERVAL SECOND)]) + : +- TableSourceScan(table=[[default_catalog, default_database, EventTable1]], fields=[id, val, rowtime]) + +- Exchange(distribution=[hash[id]]) + +- Calc(select=[id, price, CAST(rowtime AS TIMESTAMP(3)) AS rowtime]) + +- WatermarkAssigner(rowtime=[rowtime], watermark=[-(rowtime, 5000:INTERVAL SECOND)]) + +- TableSourceScan(table=[[default_catalog, default_database, EventTable2]], fields=[id, price, rowtime]) ]]> </Resource> </TestCase>
