This is an automated email from the ASF dual-hosted git repository.

godfrey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new b2e65a4  [FLINK-21923][table-planner-blink] Fix ClassCastException in 
SplitAggregateRule when a query contains both sum/count and avg function
b2e65a4 is described below

commit b2e65a41914766ab4b1f3495f7196611561fea4c
Author: Tartarus0zm <zhangma...@163.com>
AuthorDate: Tue Apr 6 16:41:56 2021 +0800

    [FLINK-21923][table-planner-blink] Fix ClassCastException in 
SplitAggregateRule when a query contains both sum/count and avg function
    
    This closes #15341
---
 .../plan/rules/logical/SplitAggregateRule.scala    | 32 ++++++++++++++--------
 .../plan/rules/logical/SplitAggregateRuleTest.xml  | 31 +++++++++++++++++++++
 .../rules/logical/SplitAggregateRuleTest.scala     | 19 +++++++++++++
 .../runtime/stream/sql/SplitAggregateITCase.scala  | 23 ++++++++++++++++
 4 files changed, 94 insertions(+), 11 deletions(-)

diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala
index be94ba1..31d1f25 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala
@@ -27,7 +27,7 @@ import 
org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery
 import org.apache.flink.table.planner.plan.nodes.FlinkRelNode
 import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate
 import 
org.apache.flink.table.planner.plan.utils.AggregateUtil.doAllAggSupportSplit
-import org.apache.flink.table.planner.plan.utils.{ExpandUtil, WindowUtil}
+import org.apache.flink.table.planner.plan.utils.{AggregateUtil, ExpandUtil, 
WindowUtil}
 
 import org.apache.calcite.plan.RelOptRule.{any, operand}
 import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
@@ -138,9 +138,11 @@ class SplitAggregateRule extends RelOptRule(
     val windowProps = fmq.getRelWindowProperties(agg.getInput)
     val isWindowAgg = 
WindowUtil.groupingContainsWindowStartEnd(agg.getGroupSet, windowProps)
     val isProctimeWindowAgg = isWindowAgg && !windowProps.isRowtime
+    // TableAggregate is not supported. see also FLINK-21923.
+    val isTableAgg = AggregateUtil.isTableAggregate(agg.getAggCallList)
 
     agg.partialFinalType == PartialFinalType.NONE && 
agg.containsDistinctCall() &&
-      splitDistinctAggEnabled && isAllAggSplittable && !isProctimeWindowAgg
+      splitDistinctAggEnabled && isAllAggSplittable && !isProctimeWindowAgg && 
!isTableAgg
   }
 
   override def onMatch(call: RelOptRuleCall): Unit = {
@@ -280,11 +282,16 @@ class SplitAggregateRule extends RelOptRule(
     }
 
     // STEP 2.3: construct partial aggregates
-    relBuilder.aggregate(
-      relBuilder.groupKey(fullGroupSet, 
ImmutableList.of[ImmutableBitSet](fullGroupSet)),
+    // Create aggregate node directly to avoid ClassCastException,
+    // Please see FLINK-21923 for more details.
+    // TODO reuse aggregate function, see FLINK-22412
+    val partialAggregate = FlinkLogicalAggregate.create(
+      relBuilder.build(),
+      fullGroupSet,
+      ImmutableList.of[ImmutableBitSet](fullGroupSet),
       newPartialAggCalls)
-    relBuilder.peek().asInstanceOf[FlinkLogicalAggregate]
-      .setPartialFinalType(PartialFinalType.PARTIAL)
+    partialAggregate.setPartialFinalType(PartialFinalType.PARTIAL)
+    relBuilder.push(partialAggregate)
 
     // STEP 3: construct final aggregates
     val finalAggInputOffset = fullGroupSet.cardinality
@@ -306,13 +313,16 @@ class SplitAggregateRule extends RelOptRule(
         needMergeFinalAggOutput = true
       }
     }
-    relBuilder.aggregate(
-      relBuilder.groupKey(
-        SplitAggregateRule.remap(fullGroupSet, originalAggregate.getGroupSet),
-        SplitAggregateRule.remap(fullGroupSet, 
Seq(originalAggregate.getGroupSet))),
+    // Create aggregate node directly to avoid ClassCastException,
+    // Please see FLINK-21923 for more details.
+    // TODO reuse aggregate function, see FLINK-22412
+    val finalAggregate = FlinkLogicalAggregate.create(
+      relBuilder.build(),
+      SplitAggregateRule.remap(fullGroupSet, originalAggregate.getGroupSet),
+      SplitAggregateRule.remap(fullGroupSet, 
Seq(originalAggregate.getGroupSet)),
       finalAggCalls)
-    val finalAggregate = relBuilder.peek().asInstanceOf[FlinkLogicalAggregate]
     finalAggregate.setPartialFinalType(PartialFinalType.FINAL)
+    relBuilder.push(finalAggregate)
 
     // STEP 4: convert final aggregation output to the original aggregation 
output.
     // For example, aggregate function AVG is transformed to SUM0 and COUNT, 
so the output of
diff --git 
a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.xml
 
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.xml
index 3895ee0..efe5bc6 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.xml
+++ 
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.xml
@@ -430,4 +430,35 @@ FlinkLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
agg#1=[$SUM0($2)])
 ]]>
     </Resource>
   </TestCase>
+  <TestCase name="testAggFilterClauseBothWithAvgAndCount">
+       <Resource name="sql">
+         <![CDATA[
+SELECT
+  a,
+  COUNT(DISTINCT b) FILTER (WHERE NOT b = 2),
+  SUM(b) FILTER (WHERE NOT b = 5),
+  COUNT(b),
+  AVG(b),
+  SUM(b)
+FROM MyTable
+GROUP BY a
+]]>
+       </Resource>
+       <Resource name="ast">
+         <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[COUNT(DISTINCT $1) FILTER $2], 
EXPR$2=[SUM($1) FILTER $3], EXPR$3=[COUNT($1)], EXPR$4=[AVG($1)], 
EXPR$5=[SUM($1)])
++- LogicalProject(a=[$0], b=[$1], $f2=[IS TRUE(<>($1, 2))], $f3=[IS 
TRUE(<>($1, 5))])
+   +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, 
source: [TestTableSource(a, b, c)]]])
+]]>
+       </Resource>
+       <Resource name="optimized rel plan">
+         <![CDATA[
+FlinkLogicalCalc(select=[a, $f1, $f2, $f3, CAST(IF(=($f5, 0:BIGINT), 
null:INTEGER, /($f4, $f5))) AS $f4, $f6])
++- FlinkLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[SUM($3)], 
agg#2=[$SUM0($4)], agg#3=[$SUM0($5)], agg#4=[$SUM0($6)], agg#5=[SUM($7)])
+   +- FlinkLogicalAggregate(group=[{0, 4}], agg#0=[COUNT(DISTINCT $1) FILTER 
$2], agg#1=[SUM($1) FILTER $3], agg#2=[COUNT($1)], agg#3=[$SUM0($1)], 
agg#4=[COUNT($1)], agg#5=[SUM($1)])
+      +- FlinkLogicalCalc(select=[a, b, IS TRUE(<>(b, 2)) AS $f2, IS 
TRUE(<>(b, 5)) AS $f3, MOD(HASH_CODE(b), 1024) AS $f4])
+         +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, 
default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, 
c])
+]]>
+       </Resource>
+  </TestCase>
 </Root>
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.scala
index 4dbce13..d809dc4 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRuleTest.scala
@@ -186,4 +186,23 @@ class SplitAggregateRuleTest extends TableTestBase {
          |""".stripMargin
     util.verifyRelPlan(sqlQuery)
   }
+
+  @Test
+  def testAggFilterClauseBothWithAvgAndCount(): Unit = {
+    util.tableEnv.getConfig.getConfiguration.setBoolean(
+      OptimizerConfigOptions.TABLE_OPTIMIZER_DISTINCT_AGG_SPLIT_ENABLED, true)
+    val sqlQuery =
+      s"""
+         |SELECT
+         |  a,
+         |  COUNT(DISTINCT b) FILTER (WHERE NOT b = 2),
+         |  SUM(b) FILTER (WHERE NOT b = 5),
+         |  COUNT(b),
+         |  AVG(b),
+         |  SUM(b)
+         |FROM MyTable
+         |GROUP BY a
+         |""".stripMargin
+    util.verifyRelPlan(sqlQuery)
+  }
 }
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/SplitAggregateITCase.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/SplitAggregateITCase.scala
index d799318..804c832 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/SplitAggregateITCase.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/SplitAggregateITCase.scala
@@ -412,6 +412,29 @@ class SplitAggregateITCase(
     val expected = List("1,2,1,2,1", "2,4,3,4,3", "3,1,1,null,5", "4,2,2,6,5")
     assertEquals(expected.sorted, sink.getRetractResults.sorted)
   }
+
+  @Test
+  def testAggFilterClauseBothWithAvgAndCount(): Unit = {
+    val t1 = tEnv.sqlQuery(
+      s"""
+         |SELECT
+         |  a,
+         |  COUNT(DISTINCT b) FILTER (WHERE NOT b = 2),
+         |  SUM(b) FILTER (WHERE NOT b = 5),
+         |  COUNT(b),
+         |  SUM(b),
+         |  AVG(b)
+         |FROM T
+         |GROUP BY a
+       """.stripMargin)
+
+    val sink = new TestingRetractSink
+    t1.toRetractStream[Row].addSink(sink)
+    env.execute()
+
+    val expected = List("1,1,3,2,3,1", "2,3,24,8,29,3", "3,1,null,2,10,5", 
"4,2,6,4,21,5")
+    assertEquals(expected.sorted, sink.getRetractResults.sorted)
+  }
 }
 
 object SplitAggregateITCase {

Reply via email to