This is an automated email from the ASF dual-hosted git repository.

lincoln pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 475f45ba0fb [FLINK-27519][table-planner] Fix column name conflicts in 
StreamPhysicalOverAggregate
475f45ba0fb is described below

commit 475f45ba0fb78f81adaa627ed2e8fbdcd71b83f6
Author: lincoln lee <[email protected]>
AuthorDate: Tue Aug 6 19:30:20 2024 +0800

    [FLINK-27519][table-planner] Fix column name conflicts in 
StreamPhysicalOverAggregate
    
    This closes #25152
---
 .../batch/BatchPhysicalOverAggregateRule.scala     | 23 ++------------
 .../stream/StreamPhysicalOverAggregateRule.scala   | 12 +++++--
 .../planner/plan/utils/OverAggregateUtil.scala     | 23 +++++++++++++-
 .../plan/batch/sql/agg/OverAggregateTest.xml       | 37 ++++++++++++++++++++++
 .../plan/stream/sql/agg/OverAggregateTest.xml      | 35 ++++++++++++++++++++
 .../plan/batch/sql/agg/OverAggregateTest.scala     | 28 ++++++++++++++++
 .../plan/stream/sql/agg/OverAggregateTest.scala    | 27 ++++++++++++++++
 7 files changed, 162 insertions(+), 23 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
index 6d20ce229ec..a428d03f95f 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
@@ -26,14 +26,13 @@ import 
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalOverAggrega
 import 
org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalOverAggregate,
 BatchPhysicalOverAggregateBase, BatchPhysicalPythonOverAggregate}
 import org.apache.flink.table.planner.plan.utils.{AggregateUtil, 
OverAggregateUtil, SortUtil}
 import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate
-import org.apache.flink.table.planner.typeutils.RowTypeUtils
 import org.apache.flink.table.planner.utils.ShortcutUtils
 
-import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall, 
RelOptUtil}
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
 import org.apache.calcite.plan.RelOptRule._
 import org.apache.calcite.rel._
 import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rel.core.{AggregateCall, Window}
+import org.apache.calcite.rel.core.Window
 import org.apache.calcite.rel.core.Window.Group
 import org.apache.calcite.rex.{RexInputRef, RexNode, RexShuttle}
 import org.apache.calcite.sql.SqlAggFunction
@@ -107,7 +106,7 @@ class BatchPhysicalOverAggregateRule
           (group, aggCallToAggFunction)
       }
 
-      val outputRowType = inferOutputRowType(
+      val outputRowType = OverAggregateUtil.inferOutputRowType(
         logicWindow.getCluster,
         inputRowType,
         groupToAggCallToAggFunction.flatMap(_._2).map(_._1))
@@ -198,22 +197,6 @@ class BatchPhysicalOverAggregateRule
     isSatisfied
   }
 
-  private def inferOutputRowType(
-      cluster: RelOptCluster,
-      inputType: RelDataType,
-      aggCalls: Seq[AggregateCall]): RelDataType = {
-
-    val inputNameList = inputType.getFieldNames
-    val inputTypeList = inputType.getFieldList.asScala.map(field => 
field.getType)
-
-    // we should avoid duplicated names with input column names
-    val aggNames = RowTypeUtils.getUniqueName(aggCalls.map(_.getName), 
inputNameList)
-    val aggTypes = aggCalls.map(_.getType)
-
-    val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
-    typeFactory.createStructType(inputTypeList ++ aggTypes, inputNameList ++ 
aggNames)
-  }
-
   private def adjustGroup(
       groupBuffer: ArrayBuffer[Window.Group],
       groupIdx: Int,
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalOverAggregateRule.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalOverAggregateRule.scala
index 60fdaceb93c..7004bcdf5e7 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalOverAggregateRule.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalOverAggregateRule.scala
@@ -22,6 +22,7 @@ import 
org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
 import org.apache.flink.table.planner.plan.nodes.FlinkConventions
 import 
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalOverAggregate
 import 
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalOverAggregate
+import org.apache.flink.table.planner.plan.utils.OverAggregateUtil
 import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate
 
 import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
@@ -66,13 +67,20 @@ class StreamPhysicalOverAggregateRule(config: Config) 
extends ConverterRule(conf
       .replace(FlinkConventions.STREAM_PHYSICAL)
       .replace(requiredDistribution)
     val providedTraitSet = 
rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
-    val newInput = RelOptRule.convert(logicWindow.getInput, requiredTraitSet)
+    val input = logicWindow.getInput
+    val newInput = RelOptRule.convert(input, requiredTraitSet)
+
+    val outputRowType = OverAggregateUtil.inferOutputRowType(
+      logicWindow.getCluster,
+      input.getRowType,
+      // only supports one group now
+      logicWindow.groups.get(0).getAggregateCalls(logicWindow).asScala)
 
     new StreamPhysicalOverAggregate(
       rel.getCluster,
       providedTraitSet,
       newInput,
-      rel.getRowType,
+      outputRowType,
       logicWindow)
   }
 }
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/OverAggregateUtil.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/OverAggregateUtil.scala
index c68d6abe100..b054d2af886 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/OverAggregateUtil.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/OverAggregateUtil.scala
@@ -19,16 +19,21 @@ package org.apache.flink.table.planner.plan.utils
 
 import org.apache.flink.table.api.TableException
 import org.apache.flink.table.planner.JArrayList
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory
 import org.apache.flink.table.planner.plan.nodes.exec.spec.{OverSpec, 
PartitionSpec}
 import org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec.GroupSpec
+import org.apache.flink.table.planner.typeutils.RowTypeUtils
 
+import org.apache.calcite.plan.RelOptCluster
+import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.rel.{RelCollation, RelCollations, RelFieldCollation}
 import org.apache.calcite.rel.RelFieldCollation.{Direction, NullDirection}
-import org.apache.calcite.rel.core.Window
+import org.apache.calcite.rel.core.{AggregateCall, Window}
 import org.apache.calcite.rex.{RexInputRef, RexLiteral, RexWindowBound}
 import org.apache.calcite.sql.`type`.SqlTypeName
 
 import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
 object OverAggregateUtil {
@@ -219,4 +224,20 @@ object OverAggregateUtil {
       }
     }
   }
+
+  def inferOutputRowType(
+      cluster: RelOptCluster,
+      inputType: RelDataType,
+      aggCalls: Seq[AggregateCall]): RelDataType = {
+
+    val inputNameList = inputType.getFieldNames
+    val inputTypeList = inputType.getFieldList.asScala.map(_.getType)
+
+    // we should avoid duplicated names with input column names
+    val aggNames = RowTypeUtils.getUniqueName(aggCalls.map(_.getName), 
inputNameList)
+    val aggTypes = aggCalls.map(_.getType)
+
+    val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+    typeFactory.createStructType(inputTypeList ++ aggTypes, inputNameList ++ 
aggNames)
+  }
 }
diff --git 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
index 0ca5ec28442..ed6b45f01fc 100644
--- 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
+++ 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
@@ -735,6 +735,43 @@ Calc(select=[a, w0$o0 AS $1, w1$o0 AS $2])
             +- Sort(orderBy=[b ASC, c ASC, a DESC])
                +- Exchange(distribution=[hash[b]])
                   +- LegacyTableSourceScan(table=[[default_catalog, 
default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, 
c])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testNestedOverAgg">
+    <Resource name="sql">
+      <![CDATA[
+SELECT *
+FROM (
+ SELECT
+    *, count(*) OVER (PARTITION BY a ORDER BY ts) AS c2
+  FROM (
+    SELECT
+      *, count(*) OVER (PARTITION BY a,b ORDER BY ts) AS c1
+    FROM src
+  )
+)
+]]>
+    </Resource>
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[$4])
++- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[COUNT() OVER 
(PARTITION BY $0 ORDER BY $2 NULLS FIRST)])
+   +- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[COUNT() OVER (PARTITION BY 
$0, $1 ORDER BY $2 NULLS FIRST)])
+      +- LogicalTableScan(table=[[default_catalog, default_database, src]])
+]]>
+    </Resource>
+    <Resource name="optimized exec plan">
+      <![CDATA[
+OverAggregate(partitionBy=[a], orderBy=[ts ASC], window#0=[COUNT(*) AS w0$o0_0 
RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, ts, w0$o0, 
w0$o0_0])
++- Exchange(distribution=[forward])
+   +- Sort(orderBy=[a ASC, ts ASC])
+      +- Exchange(distribution=[hash[a]])
+         +- OverAggregate(partitionBy=[a, b], orderBy=[ts ASC], 
window#0=[COUNT(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], 
select=[a, b, ts, w0$o0])
+            +- Exchange(distribution=[forward])
+               +- Sort(orderBy=[a ASC, b ASC, ts ASC])
+                  +- Exchange(distribution=[hash[a, b]])
+                     +- TableSourceScan(table=[[default_catalog, 
default_database, src]], fields=[a, b, ts])
 ]]>
     </Resource>
   </TestCase>
diff --git 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml
 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml
index bab37227900..72e3bffb228 100644
--- 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml
+++ 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml
@@ -16,6 +16,41 @@ See the License for the specific language governing 
permissions and
 limitations under the License.
 -->
 <Root>
+  <TestCase name="testNestedOverAgg">
+    <Resource name="sql">
+      <![CDATA[
+SELECT *
+FROM (
+ SELECT
+    *, count(*) OVER (PARTITION BY a ORDER BY ts) AS c2
+  FROM (
+    SELECT
+      *, count(*) OVER (PARTITION BY a,b ORDER BY ts) AS c1
+    FROM src
+  )
+)
+]]>
+    </Resource>
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[$4])
++- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[COUNT() OVER 
(PARTITION BY $0 ORDER BY $2 NULLS FIRST)])
+   +- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[COUNT() OVER (PARTITION BY 
$0, $1 ORDER BY $2 NULLS FIRST)])
+      +- LogicalWatermarkAssigner(rowtime=[ts], watermark=[$2])
+         +- LogicalTableScan(table=[[default_catalog, default_database, src]])
+]]>
+    </Resource>
+    <Resource name="optimized exec plan">
+      <![CDATA[
+OverAggregate(partitionBy=[a], orderBy=[ts ASC], window=[ RANG BETWEEN 
UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, ts, w0$o0, COUNT(*) AS 
w0$o0_0])
++- Exchange(distribution=[hash[a]])
+   +- OverAggregate(partitionBy=[a, b], orderBy=[ts ASC], window=[ RANG 
BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, ts, COUNT(*) AS 
w0$o0])
+      +- Exchange(distribution=[hash[a, b]])
+         +- WatermarkAssigner(rowtime=[ts], watermark=[ts])
+            +- TableSourceScan(table=[[default_catalog, default_database, 
src]], fields=[a, b, ts])
+]]>
+    </Resource>
+  </TestCase>
   <TestCase name="testProctimeBoundedDistinctPartitionedRowOver">
     <Resource name="sql">
       <![CDATA[
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
index f71325beb57..fb95adbf319 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
@@ -367,4 +367,32 @@ class OverAggregateTest extends TableTestBase {
         () =>
           util.verifyExecPlan("SELECT overAgg(b, a) FROM T GROUP BY TUMBLE(ts, 
INTERVAL '2' HOUR)"))
   }
+
+  @Test
+  def testNestedOverAgg(): Unit = {
+    util.addTable(s"""
+                     |CREATE TEMPORARY TABLE src (
+                     |  a STRING,
+                     |  b STRING,
+                     |  ts TIMESTAMP_LTZ(3),
+                     |  watermark FOR ts as ts
+                     |) WITH (
+                     |  'connector' = 'values'
+                     |  ,'bounded' = 'true'
+                     |)
+                     |""".stripMargin)
+
+    util.verifyExecPlan(s"""
+                           |SELECT *
+                           |FROM (
+                           | SELECT
+                           |    *, count(*) OVER (PARTITION BY a ORDER BY ts) 
AS c2
+                           |  FROM (
+                           |    SELECT
+                           |      *, count(*) OVER (PARTITION BY a,b ORDER BY 
ts) AS c1
+                           |    FROM src
+                           |  )
+                           |)
+                           |""".stripMargin)
+  }
 }
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.scala
index 65e6fb40eb9..e290bbce225 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.scala
@@ -441,4 +441,31 @@ class OverAggregateTest extends TableTestBase {
 
     util.verifyExecPlan(sqlQuery)
   }
+
+  @Test
+  def testNestedOverAgg(): Unit = {
+    util.addTable(s"""
+                     |CREATE TEMPORARY TABLE src (
+                     |  a STRING,
+                     |  b STRING,
+                     |  ts TIMESTAMP_LTZ(3),
+                     |  watermark FOR ts as ts
+                     |) WITH (
+                     |  'connector' = 'values'
+                     |)
+                     |""".stripMargin)
+
+    util.verifyExecPlan(s"""
+                           |SELECT *
+                           |FROM (
+                           | SELECT
+                           |    *, count(*) OVER (PARTITION BY a ORDER BY ts) 
AS c2
+                           |  FROM (
+                           |    SELECT
+                           |      *, count(*) OVER (PARTITION BY a,b ORDER BY 
ts) AS c1
+                           |    FROM src
+                           |  )
+                           |)
+                           |""".stripMargin)
+  }
 }

Reply via email to