From d8f6753eeb70349ccb45d21dc741dbe6fbaaf2a5 Mon Sep 17 00:00:00 2001
From: Amit Langote <amitlan@postgresql.org>
Date: Fri, 1 Nov 2024 16:15:50 +0900
Subject: [PATCH v4 1/2] Disallow partitionwise grouping when collation doesn't
 match

Insist that the collation used for grouping matches exactly with the
collation used for partitioning to allow using either full or
partial partitionwise grouping.

Bug: #18568
Reported-by: Webbo Han <1105066510@qq.com>
Author: Webbo Han <1105066510@qq.com>
Reviewed-by: Tender Wang <tndrwang@gmail.com>
Reviewed-by: Aleksander Alekseev <aleksander@timescale.com>
Reviewed-by: Jian He <jian.universality@gmail.com>
Discussion: https://postgr.es/m/18568-2a9afb6b9f7e6ed3@postgresql.org
Discussion: https://postgr.es/m/tencent_9D9103CDA420C07768349CC1DFF88465F90A@qq.com
Discussion: https://postgr.es/m/CAHewXNno_HKiQ6PqyLYfuqDtwp7KKHZiH1J7Pqyz0nr+PS2Dwg@mail.gmail.com
---
 src/backend/optimizer/plan/planner.c          | 66 +++++++++++---
 .../regress/expected/collate.icu.utf8.out     | 86 +++++++++++++++++++
 .../regress/expected/partition_aggregate.out  | 11 ---
 src/test/regress/sql/collate.icu.utf8.sql     | 37 ++++++++
 src/test/regress/sql/partition_aggregate.sql  |  1 -
 5 files changed, 175 insertions(+), 26 deletions(-)

diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index 0f423e9684..4044b13071 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -253,7 +253,8 @@ static void create_partitionwise_grouping_paths(PlannerInfo *root,
 												GroupPathExtraData *extra);
 static bool group_by_has_partkey(RelOptInfo *input_rel,
 								 List *targetList,
-								 List *groupClause);
+								 List *groupClause,
+								 bool *collation_incompatible);
 static int	common_prefix_cmp(const void *a, const void *b);
 static List *generate_setop_child_grouplist(SetOperationStmt *op,
 											List *targetlist);
@@ -4090,23 +4091,30 @@ create_ordinary_grouping_paths(PlannerInfo *root, RelOptInfo *input_rel,
 	if (extra->patype != PARTITIONWISE_AGGREGATE_NONE &&
 		IS_PARTITIONED_REL(input_rel))
 	{
+		bool		collation_incompatible = false;
+		bool		group_by_contains_partkey =
+			group_by_has_partkey(input_rel, extra->targetList,
+								 root->parse->groupClause,
+								 &collation_incompatible);
+
 		/*
 		 * If this is the topmost relation or if the parent relation is doing
 		 * full partitionwise aggregation, then we can do full partitionwise
 		 * aggregation provided that the GROUP BY clause contains all of the
-		 * partitioning columns at this level.  Otherwise, we can do at most
-		 * partial partitionwise aggregation.  But if partial aggregation is
-		 * not supported in general then we can't use it for partitionwise
-		 * aggregation either.
+		 * partitioning columns at this level and the collation used by GROUP
+		 * BY matches the partitioning collation.  Otherwise, we can do at most
+		 * partial partitionwise aggregation, but again only if the collation
+		 * is compatible.  If partial aggregation is not supported in general
+		 * then we can't use it for partitionwise aggregation either.
 		 *
 		 * Check parse->groupClause not processed_groupClause, because it's
 		 * okay if some of the partitioning columns were proved redundant.
 		 */
 		if (extra->patype == PARTITIONWISE_AGGREGATE_FULL &&
-			group_by_has_partkey(input_rel, extra->targetList,
-								 root->parse->groupClause))
+			group_by_contains_partkey)
 			patype = PARTITIONWISE_AGGREGATE_FULL;
-		else if ((extra->flags & GROUPING_CAN_PARTIAL_AGG) != 0)
+		else if ((extra->flags & GROUPING_CAN_PARTIAL_AGG) != 0 &&
+				 !collation_incompatible)
 			patype = PARTITIONWISE_AGGREGATE_PARTIAL;
 		else
 			patype = PARTITIONWISE_AGGREGATE_NONE;
@@ -8105,13 +8113,18 @@ create_partitionwise_grouping_paths(PlannerInfo *root,
 /*
  * group_by_has_partkey
  *
- * Returns true, if all the partition keys of the given relation are part of
- * the GROUP BY clauses, false otherwise.
+ * Returns true if all the partition keys of the given relation are part of
+ * the GROUP BY clauses, including having matching collation, false otherwise.
+ *
+ * Returns false also if a collation mismatch is detected between a partition
+ * key and its corresponding expression in groupClause, in which case.
+ * *collation_incompatible is set to true.
  */
 static bool
 group_by_has_partkey(RelOptInfo *input_rel,
 					 List *targetList,
-					 List *groupClause)
+					 List *groupClause,
+					 bool *collation_incompatible)
 {
 	List	   *groupexprs = get_sortgrouplist_exprs(groupClause, targetList);
 	int			cnt = 0;
@@ -8134,13 +8147,38 @@ group_by_has_partkey(RelOptInfo *input_rel,
 
 		foreach(lc, partexprs)
 		{
+			ListCell   *lg;
 			Expr	   *partexpr = lfirst(lc);
+			Oid			partcoll = input_rel->part_scheme->partcollation[cnt];
 
-			if (list_member(groupexprs, partexpr))
+			foreach(lg, groupexprs)
 			{
-				found = true;
-				break;
+				Expr	   *groupexpr = lfirst(lg);
+				Oid			groupcoll = exprCollation((Node *) groupexpr);
+
+				if (IsA(groupexpr, RelabelType))
+					groupexpr = ((RelabelType *) groupexpr)->arg;
+
+				/*
+				 * Ensure that the grouping collation is compatible with the
+				 * partitioning collation.
+				 */
+				if (OidIsValid(partcoll) && OidIsValid(groupcoll) &&
+					partcoll != groupcoll)
+				{
+					*collation_incompatible = true;
+					return false;
+				}
+
+				if (equal(groupexpr, partexpr))
+				{
+					found = true;
+					break;
+				}
 			}
+
+			if (found)
+				break;
 		}
 
 		/*
diff --git a/src/test/regress/expected/collate.icu.utf8.out b/src/test/regress/expected/collate.icu.utf8.out
index faa376e060..737cf363a2 100644
--- a/src/test/regress/expected/collate.icu.utf8.out
+++ b/src/test/regress/expected/collate.icu.utf8.out
@@ -2054,6 +2054,92 @@ SELECT (SELECT count(*) FROM test33_0) <> (SELECT count(*) FROM test33_1);
  t
 (1 row)
 
+--
+-- Bug #18568
+--
+-- Partitionwise aggregate (full or partial) should not be used when a
+-- partition key's collation doesn't match that of the GROUP BY column it is
+-- matched with.
+SET max_parallel_workers_per_gather TO 0;
+SET enable_incremental_sort TO off;
+CREATE TABLE pagg_tab3 (c text collate case_insensitive) PARTITION BY LIST(c collate "C");
+CREATE TABLE pagg_tab3_p1 PARTITION OF pagg_tab3 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab3_p2 PARTITION OF pagg_tab3 FOR VALUES IN ('B', 'A');
+INSERT INTO pagg_tab3 SELECT substr('abAB', (i % 4) +1 , 1) FROM generate_series(0, 7) i;
+ANALYZE pagg_tab3;
+SET enable_partitionwise_aggregate TO false;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+                        QUERY PLAN                         
+-----------------------------------------------------------
+ Sort
+   Sort Key: (upper(pagg_tab3.c)) COLLATE case_insensitive
+   ->  HashAggregate
+         Group Key: pagg_tab3.c
+         ->  Append
+               ->  Seq Scan on pagg_tab3_p2 pagg_tab3_1
+               ->  Seq Scan on pagg_tab3_p1 pagg_tab3_2
+(7 rows)
+
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+ upper | count 
+-------+-------
+ A     |     4
+ B     |     4
+(2 rows)
+
+-- No partitionwise aggregation allowed!
+SET enable_partitionwise_aggregate TO true;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+                        QUERY PLAN                         
+-----------------------------------------------------------
+ Sort
+   Sort Key: (upper(pagg_tab3.c)) COLLATE case_insensitive
+   ->  HashAggregate
+         Group Key: pagg_tab3.c
+         ->  Append
+               ->  Seq Scan on pagg_tab3_p2 pagg_tab3_1
+               ->  Seq Scan on pagg_tab3_p1 pagg_tab3_2
+(7 rows)
+
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+ upper | count 
+-------+-------
+ A     |     4
+ B     |     4
+(2 rows)
+
+-- OK to use full partitionwise aggregate after changing the GROUP BY column's
+-- collation to be the same as that of the partition key.
+EXPLAIN (COSTS OFF)
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+                       QUERY PLAN                       
+--------------------------------------------------------
+ Sort
+   Sort Key: ((pagg_tab3.c)::text) COLLATE "C"
+   ->  Append
+         ->  HashAggregate
+               Group Key: (pagg_tab3.c)::text
+               ->  Seq Scan on pagg_tab3_p2 pagg_tab3
+         ->  HashAggregate
+               Group Key: (pagg_tab3_1.c)::text
+               ->  Seq Scan on pagg_tab3_p1 pagg_tab3_1
+(9 rows)
+
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+ c | count 
+---+-------
+ A |     2
+ B |     2
+ a |     2
+ b |     2
+(4 rows)
+
+DROP TABLE pagg_tab3;
+RESET enable_partitionwise_aggregate;
+RESET max_parallel_workers_per_gather;
+RESET enable_incremental_sort;
 -- cleanup
 RESET search_path;
 SET client_min_messages TO warning;
diff --git a/src/test/regress/expected/partition_aggregate.out b/src/test/regress/expected/partition_aggregate.out
index 5f2c0cf578..670eb98906 100644
--- a/src/test/regress/expected/partition_aggregate.out
+++ b/src/test/regress/expected/partition_aggregate.out
@@ -1507,14 +1507,3 @@ SELECT x, sum(y), avg(y), count(*) FROM pagg_tab_para GROUP BY x HAVING avg(y) <
                ->  Seq Scan on pagg_tab_para_p3 pagg_tab_para_2
 (15 rows)
 
-SELECT x, sum(y), avg(y), count(*) FROM pagg_tab_para GROUP BY x HAVING avg(y) < 7 ORDER BY 1, 2, 3;
- x  | sum  |        avg         | count 
-----+------+--------------------+-------
-  0 | 5000 | 5.0000000000000000 |  1000
-  1 | 6000 | 6.0000000000000000 |  1000
- 10 | 5000 | 5.0000000000000000 |  1000
- 11 | 6000 | 6.0000000000000000 |  1000
- 20 | 5000 | 5.0000000000000000 |  1000
- 21 | 6000 | 6.0000000000000000 |  1000
-(6 rows)
-
diff --git a/src/test/regress/sql/collate.icu.utf8.sql b/src/test/regress/sql/collate.icu.utf8.sql
index 80f28a97d7..ca1bb7c1f2 100644
--- a/src/test/regress/sql/collate.icu.utf8.sql
+++ b/src/test/regress/sql/collate.icu.utf8.sql
@@ -796,6 +796,43 @@ INSERT INTO test33 VALUES (2, 'DEF');
 -- they end up in the same partition (but it's platform-dependent which one)
 SELECT (SELECT count(*) FROM test33_0) <> (SELECT count(*) FROM test33_1);
 
+--
+-- Bug #18568
+--
+-- Partitionwise aggregate (full or partial) should not be used when a
+-- partition key's collation doesn't match that of the GROUP BY column it is
+-- matched with.
+SET max_parallel_workers_per_gather TO 0;
+SET enable_incremental_sort TO off;
+
+CREATE TABLE pagg_tab3 (c text collate case_insensitive) PARTITION BY LIST(c collate "C");
+CREATE TABLE pagg_tab3_p1 PARTITION OF pagg_tab3 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab3_p2 PARTITION OF pagg_tab3 FOR VALUES IN ('B', 'A');
+INSERT INTO pagg_tab3 SELECT substr('abAB', (i % 4) +1 , 1) FROM generate_series(0, 7) i;
+ANALYZE pagg_tab3;
+
+SET enable_partitionwise_aggregate TO false;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+
+-- No partitionwise aggregation allowed!
+SET enable_partitionwise_aggregate TO true;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+
+-- OK to use full partitionwise aggregate after changing the GROUP BY column's
+-- collation to be the same as that of the partition key.
+EXPLAIN (COSTS OFF)
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+
+DROP TABLE pagg_tab3;
+
+RESET enable_partitionwise_aggregate;
+RESET max_parallel_workers_per_gather;
+RESET enable_incremental_sort;
 
 -- cleanup
 RESET search_path;
diff --git a/src/test/regress/sql/partition_aggregate.sql b/src/test/regress/sql/partition_aggregate.sql
index ab070fee24..1e263c1caf 100644
--- a/src/test/regress/sql/partition_aggregate.sql
+++ b/src/test/regress/sql/partition_aggregate.sql
@@ -333,4 +333,3 @@ RESET parallel_setup_cost;
 
 EXPLAIN (COSTS OFF)
 SELECT x, sum(y), avg(y), count(*) FROM pagg_tab_para GROUP BY x HAVING avg(y) < 7 ORDER BY 1, 2, 3;
-SELECT x, sum(y), avg(y), count(*) FROM pagg_tab_para GROUP BY x HAVING avg(y) < 7 ORDER BY 1, 2, 3;
-- 
2.43.0

