From 03cb00c55fef0f194c3c585f43f004d0a74542f5 Mon Sep 17 00:00:00 2001
From: Amit Langote <amitlan@postgresql.org>
Date: Thu, 31 Oct 2024 21:59:28 +0900
Subject: [PATCH v3 1/2] Disallow partitionwise grouping when collation doesn't
 match

Insist that the collation used for grouping matches exactly with the
collation used for partitioning to allow using either full or
partial partitionwise grouping.

Bug: #18568
Reported-by: Webbo Han <1105066510@qq.com>
Author: Webbo Han <1105066510@qq.com>
Reviewed-by: Tender Wang <tndrwang@gmail.com>
Reviewed-by: Aleksander Alekseev <aleksander@timescale.com>
Reviewed-by: Jian He <jian.universality@gmail.com>
Discussion: https://postgr.es/m/18568-2a9afb6b9f7e6ed3@postgresql.org
Discussion: https://postgr.es/m/tencent_9D9103CDA420C07768349CC1DFF88465F90A@qq.com
Discussion: https://postgr.es/m/CAHewXNno_HKiQ6PqyLYfuqDtwp7KKHZiH1J7Pqyz0nr+PS2Dwg@mail.gmail.com
---
 src/backend/optimizer/plan/planner.c          | 53 +++++++++---
 .../regress/expected/collate.icu.utf8.out     | 85 +++++++++++++++++++
 .../regress/expected/partition_aggregate.out  | 11 ---
 src/test/regress/sql/collate.icu.utf8.sql     | 35 ++++++++
 src/test/regress/sql/partition_aggregate.sql  |  1 -
 5 files changed, 163 insertions(+), 22 deletions(-)

diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index 0f423e9684..60bb86a99c 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -253,7 +253,8 @@ static void create_partitionwise_grouping_paths(PlannerInfo *root,
 												GroupPathExtraData *extra);
 static bool group_by_has_partkey(RelOptInfo *input_rel,
 								 List *targetList,
-								 List *groupClause);
+								 List *groupClause,
+								 bool *collation_incompatible);
 static int	common_prefix_cmp(const void *a, const void *b);
 static List *generate_setop_child_grouplist(SetOperationStmt *op,
 											List *targetlist);
@@ -4090,6 +4091,12 @@ create_ordinary_grouping_paths(PlannerInfo *root, RelOptInfo *input_rel,
 	if (extra->patype != PARTITIONWISE_AGGREGATE_NONE &&
 		IS_PARTITIONED_REL(input_rel))
 	{
+		bool	collation_incompatible = false;
+		bool	group_by_contains_partkey =
+			group_by_has_partkey(input_rel, extra->targetList,
+								 root->parse->groupClause,
+								 &collation_incompatible);
+
 		/*
 		 * If this is the topmost relation or if the parent relation is doing
 		 * full partitionwise aggregation, then we can do full partitionwise
@@ -4103,10 +4110,10 @@ create_ordinary_grouping_paths(PlannerInfo *root, RelOptInfo *input_rel,
 		 * okay if some of the partitioning columns were proved redundant.
 		 */
 		if (extra->patype == PARTITIONWISE_AGGREGATE_FULL &&
-			group_by_has_partkey(input_rel, extra->targetList,
-								 root->parse->groupClause))
+			group_by_contains_partkey)
 			patype = PARTITIONWISE_AGGREGATE_FULL;
-		else if ((extra->flags & GROUPING_CAN_PARTIAL_AGG) != 0)
+		else if ((extra->flags & GROUPING_CAN_PARTIAL_AGG) != 0 &&
+				 !collation_incompatible)
 			patype = PARTITIONWISE_AGGREGATE_PARTIAL;
 		else
 			patype = PARTITIONWISE_AGGREGATE_NONE;
@@ -8105,13 +8112,14 @@ create_partitionwise_grouping_paths(PlannerInfo *root,
 /*
  * group_by_has_partkey
  *
- * Returns true, if all the partition keys of the given relation are part of
- * the GROUP BY clauses, false otherwise.
+ * Returns true if all the partition keys of the given relation are part of
+ * the GROUP BY clauses, including having matching collation, false otherwise.
  */
 static bool
 group_by_has_partkey(RelOptInfo *input_rel,
 					 List *targetList,
-					 List *groupClause)
+					 List *groupClause,
+					 bool *collation_incompatible)
 {
 	List	   *groupexprs = get_sortgrouplist_exprs(groupClause, targetList);
 	int			cnt = 0;
@@ -8134,13 +8142,38 @@ group_by_has_partkey(RelOptInfo *input_rel,
 
 		foreach(lc, partexprs)
 		{
+			ListCell   *lg;
 			Expr	   *partexpr = lfirst(lc);
+			Oid			partcoll = input_rel->part_scheme->partcollation[cnt];
 
-			if (list_member(groupexprs, partexpr))
+			foreach (lg, groupexprs)
 			{
-				found = true;
-				break;
+				Expr *groupexpr = lfirst(lg);
+				Oid   groupcoll = exprCollation((Node *) groupexpr);
+
+				if (IsA(groupexpr, RelabelType))
+					groupexpr = ((RelabelType *) groupexpr)->arg;
+
+				if (OidIsValid(partcoll) && OidIsValid(groupcoll) &&
+					partcoll != groupcoll)
+				{
+					*collation_incompatible = true;
+					return false;
+				}
+
+				/*
+				 * Ensure that the grouping collation is compatible with the
+				 * partitioning collation.
+				 */
+				if (equal(groupexpr, partexpr))
+				{
+					found = true;
+					break;
+				}
 			}
+
+			if (found)
+				break;
 		}
 
 		/*
diff --git a/src/test/regress/expected/collate.icu.utf8.out b/src/test/regress/expected/collate.icu.utf8.out
index faa376e060..3d6d8f9a20 100644
--- a/src/test/regress/expected/collate.icu.utf8.out
+++ b/src/test/regress/expected/collate.icu.utf8.out
@@ -2054,6 +2054,91 @@ SELECT (SELECT count(*) FROM test33_0) <> (SELECT count(*) FROM test33_1);
  t
 (1 row)
 
+--
+-- Bug #18568
+--
+-- Partitionwise aggregate (full or partial) should not be used when a
+-- partition key's collation doesn't match that of the GROUP BY column it is
+-- matched with.
+SET max_parallel_workers_per_gather TO 0;
+SET enable_incremental_sort TO off;
+CREATE TABLE pagg_tab3 (c text collate case_insensitive) PARTITION BY LIST(c collate "C");
+CREATE TABLE pagg_tab3_p1 PARTITION OF pagg_tab3 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab3_p2 PARTITION OF pagg_tab3 FOR VALUES IN ('B', 'A');
+INSERT INTO pagg_tab3 SELECT substr('abAB', (i % 4) +1 , 1) FROM generate_series(0, 7) i;
+ANALYZE pagg_tab3;
+SET enable_partitionwise_aggregate TO false;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+                        QUERY PLAN                         
+-----------------------------------------------------------
+ Sort
+   Sort Key: (upper(pagg_tab3.c)) COLLATE case_insensitive
+   ->  HashAggregate
+         Group Key: pagg_tab3.c
+         ->  Append
+               ->  Seq Scan on pagg_tab3_p2 pagg_tab3_1
+               ->  Seq Scan on pagg_tab3_p1 pagg_tab3_2
+(7 rows)
+
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+ upper | count 
+-------+-------
+ A     |     4
+ B     |     4
+(2 rows)
+
+-- No partitionwise aggregation allowed!
+SET enable_partitionwise_aggregate TO true;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+                        QUERY PLAN                         
+-----------------------------------------------------------
+ Sort
+   Sort Key: (upper(pagg_tab3.c)) COLLATE case_insensitive
+   ->  HashAggregate
+         Group Key: pagg_tab3.c
+         ->  Append
+               ->  Seq Scan on pagg_tab3_p2 pagg_tab3_1
+               ->  Seq Scan on pagg_tab3_p1 pagg_tab3_2
+(7 rows)
+
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+ upper | count 
+-------+-------
+ A     |     4
+ B     |     4
+(2 rows)
+
+-- OK to use full partitionwise aggregate after changing the GROUP BY column's
+-- collation to be the same as that of the partition key.
+EXPLAIN (COSTS OFF)
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+                       QUERY PLAN                       
+--------------------------------------------------------
+ Sort
+   Sort Key: ((pagg_tab3.c)::text) COLLATE "C"
+   ->  Append
+         ->  HashAggregate
+               Group Key: (pagg_tab3.c)::text
+               ->  Seq Scan on pagg_tab3_p2 pagg_tab3
+         ->  HashAggregate
+               Group Key: (pagg_tab3_1.c)::text
+               ->  Seq Scan on pagg_tab3_p1 pagg_tab3_1
+(9 rows)
+
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+ c | count 
+---+-------
+ A |     2
+ B |     2
+ a |     2
+ b |     2
+(4 rows)
+
+RESET enable_partitionwise_aggregate;
+RESET max_parallel_workers_per_gather;
+RESET enable_incremental_sort;
 -- cleanup
 RESET search_path;
 SET client_min_messages TO warning;
diff --git a/src/test/regress/expected/partition_aggregate.out b/src/test/regress/expected/partition_aggregate.out
index 5f2c0cf578..670eb98906 100644
--- a/src/test/regress/expected/partition_aggregate.out
+++ b/src/test/regress/expected/partition_aggregate.out
@@ -1507,14 +1507,3 @@ SELECT x, sum(y), avg(y), count(*) FROM pagg_tab_para GROUP BY x HAVING avg(y) <
                ->  Seq Scan on pagg_tab_para_p3 pagg_tab_para_2
 (15 rows)
 
-SELECT x, sum(y), avg(y), count(*) FROM pagg_tab_para GROUP BY x HAVING avg(y) < 7 ORDER BY 1, 2, 3;
- x  | sum  |        avg         | count 
-----+------+--------------------+-------
-  0 | 5000 | 5.0000000000000000 |  1000
-  1 | 6000 | 6.0000000000000000 |  1000
- 10 | 5000 | 5.0000000000000000 |  1000
- 11 | 6000 | 6.0000000000000000 |  1000
- 20 | 5000 | 5.0000000000000000 |  1000
- 21 | 6000 | 6.0000000000000000 |  1000
-(6 rows)
-
diff --git a/src/test/regress/sql/collate.icu.utf8.sql b/src/test/regress/sql/collate.icu.utf8.sql
index 80f28a97d7..eaf9a99be7 100644
--- a/src/test/regress/sql/collate.icu.utf8.sql
+++ b/src/test/regress/sql/collate.icu.utf8.sql
@@ -796,6 +796,41 @@ INSERT INTO test33 VALUES (2, 'DEF');
 -- they end up in the same partition (but it's platform-dependent which one)
 SELECT (SELECT count(*) FROM test33_0) <> (SELECT count(*) FROM test33_1);
 
+--
+-- Bug #18568
+--
+-- Partitionwise aggregate (full or partial) should not be used when a
+-- partition key's collation doesn't match that of the GROUP BY column it is
+-- matched with.
+SET max_parallel_workers_per_gather TO 0;
+SET enable_incremental_sort TO off;
+
+CREATE TABLE pagg_tab3 (c text collate case_insensitive) PARTITION BY LIST(c collate "C");
+CREATE TABLE pagg_tab3_p1 PARTITION OF pagg_tab3 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab3_p2 PARTITION OF pagg_tab3 FOR VALUES IN ('B', 'A');
+INSERT INTO pagg_tab3 SELECT substr('abAB', (i % 4) +1 , 1) FROM generate_series(0, 7) i;
+ANALYZE pagg_tab3;
+
+SET enable_partitionwise_aggregate TO false;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+
+-- No partitionwise aggregation allowed!
+SET enable_partitionwise_aggregate TO true;
+EXPLAIN (COSTS OFF)
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+SELECT upper(c collate case_insensitive), count(c) FROM pagg_tab3 GROUP BY c collate case_insensitive ORDER BY 1;
+
+-- OK to use full partitionwise aggregate after changing the GROUP BY column's
+-- collation to be the same as that of the partition key.
+EXPLAIN (COSTS OFF)
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
+
+RESET enable_partitionwise_aggregate;
+RESET max_parallel_workers_per_gather;
+RESET enable_incremental_sort;
 
 -- cleanup
 RESET search_path;
diff --git a/src/test/regress/sql/partition_aggregate.sql b/src/test/regress/sql/partition_aggregate.sql
index ab070fee24..1e263c1caf 100644
--- a/src/test/regress/sql/partition_aggregate.sql
+++ b/src/test/regress/sql/partition_aggregate.sql
@@ -333,4 +333,3 @@ RESET parallel_setup_cost;
 
 EXPLAIN (COSTS OFF)
 SELECT x, sum(y), avg(y), count(*) FROM pagg_tab_para GROUP BY x HAVING avg(y) < 7 ORDER BY 1, 2, 3;
-SELECT x, sum(y), avg(y), count(*) FROM pagg_tab_para GROUP BY x HAVING avg(y) < 7 ORDER BY 1, 2, 3;
-- 
2.43.0

