This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new b2e65a4 [FLINK-21923][table-planner-blink] Fix ClassCastException in SplitAggregateRule when a query contains both sum/count and avg function b2e65a4 is described below commit b2e65a41914766ab4b1f3495f7196611561fea4c Author: Tartarus0zm <zhangma...@163.com> AuthorDate: Tue Apr 6 16:41:56 2021 +0800 [FLINK-21923][table-planner-blink] Fix ClassCastException in SplitAggregateRule when a query contains both sum/count and avg function This closes #15341 --- .../plan/rules/logical/SplitAggregateRule.scala | 32 ++++++++++++++-------- .../plan/rules/logical/SplitAggregateRuleTest.xml | 31 +++++++++++++++++++++ .../rules/logical/SplitAggregateRuleTest.scala | 19 +++++++++++++ .../runtime/stream/sql/SplitAggregateITCase.scala | 23 ++++++++++++++++ 4 files changed, 94 insertions(+), 11 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala index be94ba1..31d1f25 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala @@ -27,7 +27,7 @@ import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery import org.apache.flink.table.planner.plan.nodes.FlinkRelNode import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate import org.apache.flink.table.planner.plan.utils.AggregateUtil.doAllAggSupportSplit -import org.apache.flink.table.planner.plan.utils.{ExpandUtil, WindowUtil} +import org.apache.flink.table.planner.plan.utils.{AggregateUtil, ExpandUtil, WindowUtil} import org.apache.calcite.plan.RelOptRule.{any, operand} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} @@ -138,9 +138,11 @@ class SplitAggregateRule extends RelOptRule( val windowProps = fmq.getRelWindowProperties(agg.getInput) val isWindowAgg = WindowUtil.groupingContainsWindowStartEnd(agg.getGroupSet, windowProps) val isProctimeWindowAgg = isWindowAgg && !windowProps.isRowtime + // TableAggregate is not supported. see also FLINK-21923. + val isTableAgg = AggregateUtil.isTableAggregate(agg.getAggCallList) agg.partialFinalType == PartialFinalType.NONE && agg.containsDistinctCall() && - splitDistinctAggEnabled && isAllAggSplittable && !isProctimeWindowAgg + splitDistinctAggEnabled && isAllAggSplittable && !isProctimeWindowAgg && !isTableAgg } override def onMatch(call: RelOptRuleCall): Unit = { @@ -280,11 +282,16 @@ class SplitAggregateRule extends RelOptRule( } // STEP 2.3: construct partial aggregates - relBuilder.aggregate( - relBuilder.groupKey(fullGroupSet, ImmutableList.of[ImmutableBitSet](fullGroupSet)), + // Create aggregate node directly to avoid ClassCastException, + // Please see FLINK-21923 for more details. + // TODO reuse aggregate function, see FLINK-22412 + val partialAggregate = FlinkLogicalAggregate.create( + relBuilder.build(), + fullGroupSet, + ImmutableList.of[ImmutableBitSet](fullGroupSet), newPartialAggCalls) - relBuilder.peek().asInstanceOf[FlinkLogicalAggregate] - .setPartialFinalType(PartialFinalType.PARTIAL) + partialAggregate.setPartialFinalType(PartialFinalType.PARTIAL) + relBuilder.push(partialAggregate) // STEP 3: construct final aggregates val finalAggInputOffset = fullGroupSet.cardinality @@ -306,13 +313,16 @@ class SplitAggregateRule extends RelOptRule( needMergeFinalAggOutput = true } } - relBuilder.aggregate( - relBuilder.groupKey( - SplitAggregateRule.remap(fullGroupSet, originalAggregate.getGroupSet), - SplitAggregateRule.remap(fullGroupSet, Seq(originalAggregate.getGroupSet))), + // Create aggregate node directly to avoid ClassCastException, + // Please see FLINK-21923 for more details. + // TODO reuse aggregate function, see FLINK-22412 + val finalAggregate = FlinkLogicalAggregate.create( + relBuilder.build(), + SplitAggregateRule.remap(fullGroupSet, originalAggregate.getGroupSet), + SplitAggregateRule.remap(fullGroupSet, Seq(originalAggregate.getGroupSet)), finalAggCalls) - val finalAggregate = relBuilder.peek().asInstanceOf[FlinkLogicalAggregate] finalAggregate.setPartialFinalType(PartialFinalType.FINAL) + relBuilder.push(finalAggregate) // STEP 4: convert final aggregation output to the original aggregation output. // For example, aggregate function AVG is transformed to SUM0 and COUNT, so the output of diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.xml index 3895ee0..efe5bc6 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.xml @@ -430,4 +430,35 @@ FlinkLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[$SUM0($2)]) ]]> </Resource> </TestCase> + <TestCase name="testAggFilterClauseBothWithAvgAndCount"> + <Resource name="sql"> + <![CDATA[ +SELECT + a, + COUNT(DISTINCT b) FILTER (WHERE NOT b = 2), + SUM(b) FILTER (WHERE NOT b = 5), + COUNT(b), + AVG(b), + SUM(b) +FROM MyTable +GROUP BY a +]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalAggregate(group=[{0}], EXPR$1=[COUNT(DISTINCT $1) FILTER $2], EXPR$2=[SUM($1) FILTER $3], EXPR$3=[COUNT($1)], EXPR$4=[AVG($1)], EXPR$5=[SUM($1)]) ++- LogicalProject(a=[$0], b=[$1], $f2=[IS TRUE(<>($1, 2))], $f3=[IS TRUE(<>($1, 5))]) + +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]]) +]]> + </Resource> + <Resource name="optimized rel plan"> + <![CDATA[ +FlinkLogicalCalc(select=[a, $f1, $f2, $f3, CAST(IF(=($f5, 0:BIGINT), null:INTEGER, /($f4, $f5))) AS $f4, $f6]) ++- FlinkLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[SUM($3)], agg#2=[$SUM0($4)], agg#3=[$SUM0($5)], agg#4=[$SUM0($6)], agg#5=[SUM($7)]) + +- FlinkLogicalAggregate(group=[{0, 4}], agg#0=[COUNT(DISTINCT $1) FILTER $2], agg#1=[SUM($1) FILTER $3], agg#2=[COUNT($1)], agg#3=[$SUM0($1)], agg#4=[COUNT($1)], agg#5=[SUM($1)]) + +- FlinkLogicalCalc(select=[a, b, IS TRUE(<>(b, 2)) AS $f2, IS TRUE(<>(b, 5)) AS $f3, MOD(HASH_CODE(b), 1024) AS $f4]) + +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + </Resource> + </TestCase> </Root> diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.scala index 4dbce13..d809dc4 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.scala @@ -186,4 +186,23 @@ class SplitAggregateRuleTest extends TableTestBase { |""".stripMargin util.verifyRelPlan(sqlQuery) } + + @Test + def testAggFilterClauseBothWithAvgAndCount(): Unit = { + util.tableEnv.getConfig.getConfiguration.setBoolean( + OptimizerConfigOptions.TABLE_OPTIMIZER_DISTINCT_AGG_SPLIT_ENABLED, true) + val sqlQuery = + s""" + |SELECT + | a, + | COUNT(DISTINCT b) FILTER (WHERE NOT b = 2), + | SUM(b) FILTER (WHERE NOT b = 5), + | COUNT(b), + | AVG(b), + | SUM(b) + |FROM MyTable + |GROUP BY a + |""".stripMargin + util.verifyRelPlan(sqlQuery) + } } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/SplitAggregateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/SplitAggregateITCase.scala index d799318..804c832 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/SplitAggregateITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/SplitAggregateITCase.scala @@ -412,6 +412,29 @@ class SplitAggregateITCase( val expected = List("1,2,1,2,1", "2,4,3,4,3", "3,1,1,null,5", "4,2,2,6,5") assertEquals(expected.sorted, sink.getRetractResults.sorted) } + + @Test + def testAggFilterClauseBothWithAvgAndCount(): Unit = { + val t1 = tEnv.sqlQuery( + s""" + |SELECT + | a, + | COUNT(DISTINCT b) FILTER (WHERE NOT b = 2), + | SUM(b) FILTER (WHERE NOT b = 5), + | COUNT(b), + | SUM(b), + | AVG(b) + |FROM T + |GROUP BY a + """.stripMargin) + + val sink = new TestingRetractSink + t1.toRetractStream[Row].addSink(sink) + env.execute() + + val expected = List("1,1,3,2,3,1", "2,3,24,8,29,3", "3,1,null,2,10,5", "4,2,6,4,21,5") + assertEquals(expected.sorted, sink.getRetractResults.sorted) + } } object SplitAggregateITCase {