From 6fa7144a91ca21b0ca6bd6fcc34a57ab517359d8 Mon Sep 17 00:00:00 2001
From: Amit Langote <amitlan@postgresql.org>
Date: Fri, 1 Nov 2024 16:15:50 +0900
Subject: [PATCH v6 1/2] Disallow partitionwise grouping when collation doesn't
 match

Insist that the collation used for grouping matches exactly with the
collation used for partitioning to allow using either full or
partial partitionwise grouping.

Bug: #18568
Reported-by: Webbo Han <1105066510@qq.com>
Author: Webbo Han <1105066510@qq.com>
Reviewed-by: Tender Wang <tndrwang@gmail.com>
Reviewed-by: Aleksander Alekseev <aleksander@timescale.com>
Reviewed-by: Jian He <jian.universality@gmail.com>
Discussion: https://postgr.es/m/18568-2a9afb6b9f7e6ed3@postgresql.org
Discussion: https://postgr.es/m/tencent_9D9103CDA420C07768349CC1DFF88465F90A@qq.com
Discussion: https://postgr.es/m/CAHewXNno_HKiQ6PqyLYfuqDtwp7KKHZiH1J7Pqyz0nr+PS2Dwg@mail.gmail.com
---
 src/backend/optimizer/plan/planner.c          |  67 +++++++---
 .../regress/expected/collate.icu.utf8.out     | 114 ++++++++++++++++++
 src/test/regress/sql/collate.icu.utf8.sql     |  43 +++++++
 3 files changed, 210 insertions(+), 14 deletions(-)

diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index 0f423e9684..24a3ff087e 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -253,7 +253,8 @@ static void create_partitionwise_grouping_paths(PlannerInfo *root,
 												GroupPathExtraData *extra);
 static bool group_by_has_partkey(RelOptInfo *input_rel,
 								 List *targetList,
-								 List *groupClause);
+								 List *groupClause,
+								 bool *collation_incompatible);
 static int	common_prefix_cmp(const void *a, const void *b);
 static List *generate_setop_child_grouplist(SetOperationStmt *op,
 											List *targetlist);
@@ -4090,23 +4091,30 @@ create_ordinary_grouping_paths(PlannerInfo *root, RelOptInfo *input_rel,
 	if (extra->patype != PARTITIONWISE_AGGREGATE_NONE &&
 		IS_PARTITIONED_REL(input_rel))
 	{
+		bool		collation_incompatible = false;
+		bool		group_by_contains_partkey =
+			group_by_has_partkey(input_rel, extra->targetList,
+								 root->parse->groupClause,
+								 &collation_incompatible);
+
 		/*
 		 * If this is the topmost relation or if the parent relation is doing
 		 * full partitionwise aggregation, then we can do full partitionwise
 		 * aggregation provided that the GROUP BY clause contains all of the
-		 * partitioning columns at this level.  Otherwise, we can do at most
-		 * partial partitionwise aggregation.  But if partial aggregation is
-		 * not supported in general then we can't use it for partitionwise
-		 * aggregation either.
+		 * partitioning columns at this level and the collation used by GROUP
+		 * BY matches the partitioning collation.  Otherwise, we can do at most
+		 * partial partitionwise aggregation, but again only if the collation
+		 * is compatible.  If partial aggregation is not supported in general
+		 * then we can't use it for partitionwise aggregation either.
 		 *
 		 * Check parse->groupClause not processed_groupClause, because it's
 		 * okay if some of the partitioning columns were proved redundant.
 		 */
 		if (extra->patype == PARTITIONWISE_AGGREGATE_FULL &&
-			group_by_has_partkey(input_rel, extra->targetList,
-								 root->parse->groupClause))
+			group_by_contains_partkey)
 			patype = PARTITIONWISE_AGGREGATE_FULL;
-		else if ((extra->flags & GROUPING_CAN_PARTIAL_AGG) != 0)
+		else if ((extra->flags & GROUPING_CAN_PARTIAL_AGG) != 0 &&
+				 !collation_incompatible)
 			patype = PARTITIONWISE_AGGREGATE_PARTIAL;
 		else
 			patype = PARTITIONWISE_AGGREGATE_NONE;
@@ -8105,13 +8113,18 @@ create_partitionwise_grouping_paths(PlannerInfo *root,
 /*
  * group_by_has_partkey
  *
- * Returns true, if all the partition keys of the given relation are part of
- * the GROUP BY clauses, false otherwise.
+ * Returns true if all the partition keys of the given relation are part of
+ * the GROUP BY clauses, including having matching collation, false otherwise.
+ *
+ * Returns false also if a collation mismatch is detected between a partition
+ * key and its corresponding expression in groupClause, in which case.
+ * *collation_incompatible is set to true.
  */
 static bool
 group_by_has_partkey(RelOptInfo *input_rel,
 					 List *targetList,
-					 List *groupClause)
+					 List *groupClause,
+					 bool *collation_incompatible)
 {
 	List	   *groupexprs = get_sortgrouplist_exprs(groupClause, targetList);
 	int			cnt = 0;
@@ -8134,13 +8147,39 @@ group_by_has_partkey(RelOptInfo *input_rel,
 
 		foreach(lc, partexprs)
 		{
+			ListCell   *lg;
 			Expr	   *partexpr = lfirst(lc);
+			Oid			partcoll = input_rel->part_scheme->partcollation[cnt];
 
-			if (list_member(groupexprs, partexpr))
+			foreach(lg, groupexprs)
 			{
-				found = true;
-				break;
+				Expr	   *groupexpr = lfirst(lg);
+				Oid			groupcoll = exprCollation((Node *) groupexpr);
+
+				if (IsA(groupexpr, RelabelType))
+					groupexpr = ((RelabelType *) groupexpr)->arg;
+
+				if (equal(groupexpr, partexpr))
+				{
+
+					/*
+					 * Reject a match if the grouping collation does not match
+					 * the partitioning collation.
+					 */
+					if (OidIsValid(partcoll) && OidIsValid(groupcoll) &&
+						partcoll != groupcoll)
+					{
+						*collation_incompatible = true;
+						return false;
+					}
+
+					found = true;
+					break;
+				}
 			}
+
+			if (found)
+				break;
 		}
 
 		/*
diff --git a/src/test/regress/expected/collate.icu.utf8.out b/src/test/regress/expected/collate.icu.utf8.out
index faa376e060..a4cd9ea085 100644
--- a/src/test/regress/expected/collate.icu.utf8.out
+++ b/src/test/regress/expected/collate.icu.utf8.out
@@ -2054,6 +2054,120 @@ SELECT (SELECT count(*) FROM test33_0) <> (SELECT count(*) FROM test33_1);
  t
 (1 row)
 
+--
+-- Bug #18568
+--
+-- Partitionwise aggregate (full or partial) should not be used when a
+-- partition key's collation doesn't match that of the GROUP BY column it is
+-- matched with.
+SET max_parallel_workers_per_gather TO 0;
+SET enable_incremental_sort TO off;
+CREATE TABLE pagg_tab3 (a text, c text collate case_insensitive) PARTITION BY LIST(c collate "C");
+CREATE TABLE pagg_tab3_p1 PARTITION OF pagg_tab3 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab3_p2 PARTITION OF pagg_tab3 FOR VALUES IN ('B', 'A');
+INSERT INTO pagg_tab3 SELECT i % 4 + 1, substr('abAB', (i % 4) + 1 , 1) FROM generate_series(0, 11) i;
+ANALYZE pagg_tab3;
+SET enable_partitionwise_aggregate TO false;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+                        QUERY PLAN                         
+-----------------------------------------------------------
+ Sort
+   Sort Key: (upper(pagg_tab3.c)) COLLATE case_insensitive
+   ->  HashAggregate
+         Group Key: pagg_tab3.c
+         ->  Append
+               ->  Seq Scan on pagg_tab3_p2 pagg_tab3_1
+               ->  Seq Scan on pagg_tab3_p1 pagg_tab3_2
+(7 rows)
+
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+ upper | count 
+-------+-------
+ A     |     6
+ B     |     6
+(2 rows)
+
+-- No partitionwise aggregation allowed!
+SET enable_partitionwise_aggregate TO true;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+                        QUERY PLAN                         
+-----------------------------------------------------------
+ Sort
+   Sort Key: (upper(pagg_tab3.c)) COLLATE case_insensitive
+   ->  HashAggregate
+         Group Key: pagg_tab3.c
+         ->  Append
+               ->  Seq Scan on pagg_tab3_p2 pagg_tab3_1
+               ->  Seq Scan on pagg_tab3_p1 pagg_tab3_2
+(7 rows)
+
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+ upper | count 
+-------+-------
+ A     |     6
+ B     |     6
+(2 rows)
+
+-- OK to use full partitionwise aggregate after changing the GROUP BY column's
+-- collation to be the same as that of the partition key.
+EXPLAIN (COSTS OFF)
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+                       QUERY PLAN                       
+--------------------------------------------------------
+ Sort
+   Sort Key: ((pagg_tab3.c)::text) COLLATE "C"
+   ->  Append
+         ->  HashAggregate
+               Group Key: (pagg_tab3.c)::text
+               ->  Seq Scan on pagg_tab3_p2 pagg_tab3
+         ->  HashAggregate
+               Group Key: (pagg_tab3_1.c)::text
+               ->  Seq Scan on pagg_tab3_p1 pagg_tab3_1
+(9 rows)
+
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+ c | count 
+---+-------
+ A |     3
+ B |     3
+ a |     3
+ b |     3
+(4 rows)
+
+-- Also OK to use partial partitionwise aggregate when grouping columns do not
+-- include partition key columns
+EXPLAIN (COSTS OFF)
+SELECT a collate "C", count(c collate "C") FROM pagg_tab3 GROUP BY a collate "C" ORDER BY 1;
+                          QUERY PLAN                          
+--------------------------------------------------------------
+ Finalize GroupAggregate
+   Group Key: ((pagg_tab3.a)::text)
+   ->  Sort
+         Sort Key: ((pagg_tab3.a)::text) COLLATE "C"
+         ->  Append
+               ->  Partial HashAggregate
+                     Group Key: (pagg_tab3.a)::text
+                     ->  Seq Scan on pagg_tab3_p2 pagg_tab3
+               ->  Partial HashAggregate
+                     Group Key: (pagg_tab3_1.a)::text
+                     ->  Seq Scan on pagg_tab3_p1 pagg_tab3_1
+(11 rows)
+
+SELECT a collate "C", count(c collate "C") FROM pagg_tab3 GROUP BY a collate "C" ORDER BY 1;
+ a | count 
+---+-------
+ 1 |     3
+ 2 |     3
+ 3 |     3
+ 4 |     3
+(4 rows)
+
+DROP TABLE pagg_tab3;
+RESET enable_partitionwise_aggregate;
+RESET max_parallel_workers_per_gather;
+RESET enable_incremental_sort;
 -- cleanup
 RESET search_path;
 SET client_min_messages TO warning;
diff --git a/src/test/regress/sql/collate.icu.utf8.sql b/src/test/regress/sql/collate.icu.utf8.sql
index 80f28a97d7..f523dd73be 100644
--- a/src/test/regress/sql/collate.icu.utf8.sql
+++ b/src/test/regress/sql/collate.icu.utf8.sql
@@ -796,6 +796,49 @@ INSERT INTO test33 VALUES (2, 'DEF');
 -- they end up in the same partition (but it's platform-dependent which one)
 SELECT (SELECT count(*) FROM test33_0) <> (SELECT count(*) FROM test33_1);
 
+--
+-- Bug #18568
+--
+-- Partitionwise aggregate (full or partial) should not be used when a
+-- partition key's collation doesn't match that of the GROUP BY column it is
+-- matched with.
+SET max_parallel_workers_per_gather TO 0;
+SET enable_incremental_sort TO off;
+
+CREATE TABLE pagg_tab3 (a text, c text collate case_insensitive) PARTITION BY LIST(c collate "C");
+CREATE TABLE pagg_tab3_p1 PARTITION OF pagg_tab3 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab3_p2 PARTITION OF pagg_tab3 FOR VALUES IN ('B', 'A');
+INSERT INTO pagg_tab3 SELECT i % 4 + 1, substr('abAB', (i % 4) + 1 , 1) FROM generate_series(0, 11) i;
+ANALYZE pagg_tab3;
+
+SET enable_partitionwise_aggregate TO false;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+
+-- No partitionwise aggregation allowed!
+SET enable_partitionwise_aggregate TO true;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+
+-- OK to use full partitionwise aggregate after changing the GROUP BY column's
+-- collation to be the same as that of the partition key.
+EXPLAIN (COSTS OFF)
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+
+-- Also OK to use partial partitionwise aggregate when grouping columns do not
+-- include partition key columns
+EXPLAIN (COSTS OFF)
+SELECT a collate "C", count(c collate "C") FROM pagg_tab3 GROUP BY a collate "C" ORDER BY 1;
+SELECT a collate "C", count(c collate "C") FROM pagg_tab3 GROUP BY a collate "C" ORDER BY 1;
+
+DROP TABLE pagg_tab3;
+
+RESET enable_partitionwise_aggregate;
+RESET max_parallel_workers_per_gather;
+RESET enable_incremental_sort;
 
 -- cleanup
 RESET search_path;
-- 
2.43.0

