From 36274d54a2367ef2c249ac4c4d2762387d36175e Mon Sep 17 00:00:00 2001
From: Amit Langote <amitlan@postgresql.org>
Date: Fri, 1 Nov 2024 15:32:33 +0900
Subject: [PATCH v6 2/2] Disallow partitionwise join when collation doesn't
 match

Insist that the collation used for joining matches exactly with the
collation used for partitioning to allow using partitionwise join.

Reported-by: Tender Wang <tndrwang@gmail.com>
Author: Jian He <jian.universality@gmail.com>
Author: Amit Langote <amitlangote09@gmail.com>
Reviewed-by: Tender Wang <tndrwang@gmail.com>
Reviewed-by: Junwang Zhao <zhjwpku@gmail.com>
Discussion: https://postgr.es/m/18568-2a9afb6b9f7e6ed3@postgresql.org
Discussion: https://postgr.es/m/tencent_9D9103CDA420C07768349CC1DFF88465F90A@qq.com
Discussion: https://postgr.es/m/CAHewXNno_HKiQ6PqyLYfuqDtwp7KKHZiH1J7Pqyz0nr+PS2Dwg@mail.gmail.com
---
 src/backend/optimizer/util/relnode.c          |  28 ++-
 .../regress/expected/collate.icu.utf8.out     | 209 ++++++++++++++++++
 src/test/regress/sql/collate.icu.utf8.sql     |  58 +++++
 3 files changed, 293 insertions(+), 2 deletions(-)

diff --git a/src/backend/optimizer/util/relnode.c b/src/backend/optimizer/util/relnode.c
index d7266e4cdb..489fe6c26f 100644
--- a/src/backend/optimizer/util/relnode.c
+++ b/src/backend/optimizer/util/relnode.c
@@ -2185,6 +2185,10 @@ have_partkey_equi_join(PlannerInfo *root, RelOptInfo *joinrel,
 		if (pk_known_equal[ipk1])
 			continue;
 
+		/* Reject if the partition key collation differs from the clause's. */
+		if (rel1->part_scheme->partcollation[ipk1] != opexpr->inputcollid)
+			return false;
+
 		/*
 		 * The clause allows partitionwise join only if it uses the same
 		 * operator family as that specified by the partition key.
@@ -2258,6 +2262,8 @@ have_partkey_equi_join(PlannerInfo *root, RelOptInfo *joinrel,
 		{
 			Node	   *expr1 = (Node *) lfirst(lc);
 			ListCell   *lc2;
+			Oid			partcoll1 = rel1->part_scheme->partcollation[ipk];
+			Oid			exprcoll1 = exprCollation(expr1);
 
 			foreach(lc2, rel2->partexprs[ipk])
 			{
@@ -2265,8 +2271,26 @@ have_partkey_equi_join(PlannerInfo *root, RelOptInfo *joinrel,
 
 				if (exprs_known_equal(root, expr1, expr2, btree_opfamily))
 				{
-					pk_known_equal[ipk] = true;
-					break;
+					/*
+					 * Ensure that the collation of the expression matches
+					 * that of the partition key. Checking just one collation
+					 * (partcoll1 and exprcoll1) suffices because partcoll1
+					 * and partcoll2, as well as exprcoll1 and exprcoll2,
+					 * should be identical. This holds because both rel1 and
+					 * rel2 use the same PartitionScheme and expr1 and expr2
+					 * are equal.
+					 */
+					if (partcoll1 == exprcoll1)
+					{
+						Oid			partcoll2 PG_USED_FOR_ASSERTS_ONLY =
+							rel1->part_scheme->partcollation[ipk];
+						Oid			exprcoll2 PG_USED_FOR_ASSERTS_ONLY =
+							exprCollation(expr2);
+
+						Assert(partcoll2 == exprcoll2);
+						pk_known_equal[ipk] = true;
+						break;
+					}
 				}
 			}
 			if (pk_known_equal[ipk])
diff --git a/src/test/regress/expected/collate.icu.utf8.out b/src/test/regress/expected/collate.icu.utf8.out
index a4cd9ea085..e025c2518c 100644
--- a/src/test/regress/expected/collate.icu.utf8.out
+++ b/src/test/regress/expected/collate.icu.utf8.out
@@ -2164,7 +2164,216 @@ SELECT a collate "C", count(c collate "C") FROM pagg_tab3 GROUP BY a collate "C"
  4 |     3
 (4 rows)
 
+-- Partitionwise join should not be allowed too when the collation used by the
+-- join keys doesn't match the partition key collation.
+SET enable_partitionwise_join TO false;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C";
+                         QUERY PLAN                          
+-------------------------------------------------------------
+ Sort
+   Sort Key: t1.c COLLATE "C"
+   ->  HashAggregate
+         Group Key: t1.c
+         ->  Hash Join
+               Hash Cond: (t1.c = t2.c)
+               ->  Append
+                     ->  Seq Scan on pagg_tab3_p2 t1_1
+                     ->  Seq Scan on pagg_tab3_p1 t1_2
+               ->  Hash
+                     ->  Append
+                           ->  Seq Scan on pagg_tab3_p2 t2_1
+                           ->  Seq Scan on pagg_tab3_p1 t2_2
+(13 rows)
+
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C";
+ c | count 
+---+-------
+ A |    36
+ B |    36
+(2 rows)
+
+SET enable_partitionwise_join TO true;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C";
+                         QUERY PLAN                          
+-------------------------------------------------------------
+ Sort
+   Sort Key: t1.c COLLATE "C"
+   ->  HashAggregate
+         Group Key: t1.c
+         ->  Hash Join
+               Hash Cond: (t1.c = t2.c)
+               ->  Append
+                     ->  Seq Scan on pagg_tab3_p2 t1_1
+                     ->  Seq Scan on pagg_tab3_p1 t1_2
+               ->  Hash
+                     ->  Append
+                           ->  Seq Scan on pagg_tab3_p2 t2_1
+                           ->  Seq Scan on pagg_tab3_p1 t2_2
+(13 rows)
+
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C";
+ c | count 
+---+-------
+ A |    36
+ B |    36
+(2 rows)
+
+-- OK when the join clause uses the same collation as the partition key.
+EXPLAIN (COSTS OFF)
+SELECT t1.c COLLATE "C", count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c COLLATE "C" GROUP BY t1.c COLLATE "C" ORDER BY t1.c COLLATE "C";
+                            QUERY PLAN                            
+------------------------------------------------------------------
+ Sort
+   Sort Key: ((t1.c)::text) COLLATE "C"
+   ->  Append
+         ->  HashAggregate
+               Group Key: (t1.c)::text
+               ->  Hash Join
+                     Hash Cond: ((t1.c)::text = (t2.c)::text)
+                     ->  Seq Scan on pagg_tab3_p2 t1
+                     ->  Hash
+                           ->  Seq Scan on pagg_tab3_p2 t2
+         ->  HashAggregate
+               Group Key: (t1_1.c)::text
+               ->  Hash Join
+                     Hash Cond: ((t1_1.c)::text = (t2_1.c)::text)
+                     ->  Seq Scan on pagg_tab3_p1 t1_1
+                     ->  Hash
+                           ->  Seq Scan on pagg_tab3_p1 t2_1
+(17 rows)
+
+SELECT t1.c COLLATE "C", count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c COLLATE "C" GROUP BY t1.c COLLATE "C" ORDER BY t1.c COLLATE "C";
+ c | count 
+---+-------
+ A |     9
+ B |     9
+ a |     9
+ b |     9
+(4 rows)
+
+-- Few other cases where the joined partition keys are matched via equivalence
+-- class, not a join restriction clause.
+-- Collations of joined columns match, but the partition keys collation is different
+CREATE TABLE pagg_tab4 (c text collate case_insensitive, b text collate case_insensitive) PARTITION BY LIST (b collate "C");
+CREATE TABLE pagg_tab4_p1 PARTITION OF pagg_tab4 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab4_p2 PARTITION OF pagg_tab4 FOR VALUES IN ('B', 'A');
+INSERT INTO pagg_tab4 (b, c) SELECT substr('abAB', (i % 4) + 1 , 1), substr('abAB', (i % 2) + 1 , 1) FROM generate_series(0, 11) i;
+ANALYZE pagg_tab4;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab4 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+                         QUERY PLAN                          
+-------------------------------------------------------------
+ Sort
+   Sort Key: t1.c COLLATE "C"
+   ->  HashAggregate
+         Group Key: t1.c
+         ->  Hash Join
+               Hash Cond: (t1.c = t2.c)
+               ->  Append
+                     ->  Seq Scan on pagg_tab3_p2 t1_1
+                     ->  Seq Scan on pagg_tab3_p1 t1_2
+               ->  Hash
+                     ->  Append
+                           ->  Seq Scan on pagg_tab4_p2 t2_1
+                                 Filter: (c = b)
+                           ->  Seq Scan on pagg_tab4_p1 t2_2
+                                 Filter: (c = b)
+(15 rows)
+
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab4 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+ c | count 
+---+-------
+ A |    36
+ B |    36
+(2 rows)
+
+-- OK when the partition key collation is same as that of the join columns
+CREATE TABLE pagg_tab5 (c text collate case_insensitive, b text collate case_insensitive) PARTITION BY LIST (c collate case_insensitive);
+CREATE TABLE pagg_tab5_p1 PARTITION OF pagg_tab5 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab5_p2 PARTITION OF pagg_tab5 FOR VALUES IN ('c', 'd');
+INSERT INTO pagg_tab5 (b, c) SELECT substr('abAB', (i % 4) + 1 , 1), substr('abAB', (i % 2) + 1 , 1) FROM generate_series(0, 5) i;
+INSERT INTO pagg_tab5 (b, c) SELECT substr('cdCD', (i % 4) + 1 , 1), substr('cdCD', (i % 2) + 1 , 1) FROM generate_series(0, 5) i;
+ANALYZE pagg_tab5;
+CREATE TABLE pagg_tab6 (c text collate case_insensitive, b text collate case_insensitive) PARTITION BY LIST (b collate case_insensitive);
+CREATE TABLE pagg_tab6_p1 PARTITION OF pagg_tab6 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab6_p2 PARTITION OF pagg_tab6 FOR VALUES IN ('c', 'd');
+INSERT INTO pagg_tab6 (b, c) SELECT substr('abAB', (i % 4) + 1 , 1), substr('abAB', (i % 2) + 1 , 1) FROM generate_series(0, 5) i;
+INSERT INTO pagg_tab6 (b, c) SELECT substr('cdCD', (i % 4) + 1 , 1), substr('cdCD', (i % 2) + 1 , 1) FROM generate_series(0, 5) i;
+ANALYZE pagg_tab5;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab5 t1 JOIN pagg_tab6 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+                            QUERY PLAN                             
+-------------------------------------------------------------------
+ Sort
+   Sort Key: t1.c COLLATE "C"
+   ->  GroupAggregate
+         Group Key: t1.c
+         ->  Sort
+               Sort Key: t1.c COLLATE case_insensitive
+               ->  Append
+                     ->  Hash Join
+                           Hash Cond: (t2_1.c = t1_1.c)
+                           ->  Seq Scan on pagg_tab6_p1 t2_1
+                                 Filter: (c = b)
+                           ->  Hash
+                                 ->  Seq Scan on pagg_tab5_p1 t1_1
+                     ->  Hash Join
+                           Hash Cond: (t2_2.c = t1_2.c)
+                           ->  Seq Scan on pagg_tab6_p2 t2_2
+                                 Filter: (c = b)
+                           ->  Hash
+                                 ->  Seq Scan on pagg_tab5_p2 t1_2
+(19 rows)
+
+SELECT t1.c, count(t2.c) FROM pagg_tab5 t1 JOIN pagg_tab6 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+ c | count 
+---+-------
+ a |     9
+ b |     9
+ c |     9
+ d |     9
+(4 rows)
+
+SET enable_partitionwise_join TO false;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab5 t1 JOIN pagg_tab6 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+                         QUERY PLAN                          
+-------------------------------------------------------------
+ Sort
+   Sort Key: t1.c COLLATE "C"
+   ->  GroupAggregate
+         Group Key: t1.c
+         ->  Merge Join
+               Merge Cond: (t2.c = t1.c)
+               ->  Sort
+                     Sort Key: t2.c COLLATE case_insensitive
+                     ->  Append
+                           ->  Seq Scan on pagg_tab6_p1 t2_1
+                                 Filter: (c = b)
+                           ->  Seq Scan on pagg_tab6_p2 t2_2
+                                 Filter: (c = b)
+               ->  Sort
+                     Sort Key: t1.c COLLATE case_insensitive
+                     ->  Append
+                           ->  Seq Scan on pagg_tab5_p1 t1_1
+                           ->  Seq Scan on pagg_tab5_p2 t1_2
+(18 rows)
+
+SELECT t1.c, count(t2.c) FROM pagg_tab5 t1 JOIN pagg_tab6 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+ c | count 
+---+-------
+ a |     9
+ b |     9
+ c |     9
+ d |     9
+(4 rows)
+
 DROP TABLE pagg_tab3;
+DROP TABLE pagg_tab4;
+DROP TABLE pagg_tab5;
+DROP TABLE pagg_tab6;
 RESET enable_partitionwise_aggregate;
 RESET max_parallel_workers_per_gather;
 RESET enable_incremental_sort;
diff --git a/src/test/regress/sql/collate.icu.utf8.sql b/src/test/regress/sql/collate.icu.utf8.sql
index f523dd73be..bad92b2ab7 100644
--- a/src/test/regress/sql/collate.icu.utf8.sql
+++ b/src/test/regress/sql/collate.icu.utf8.sql
@@ -834,7 +834,65 @@ EXPLAIN (COSTS OFF)
 SELECT a collate "C", count(c collate "C") FROM pagg_tab3 GROUP BY a collate "C" ORDER BY 1;
 SELECT a collate "C", count(c collate "C") FROM pagg_tab3 GROUP BY a collate "C" ORDER BY 1;
 
+-- Partitionwise join should not be allowed too when the collation used by the
+-- join keys doesn't match the partition key collation.
+SET enable_partitionwise_join TO false;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C";
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C";
+
+SET enable_partitionwise_join TO true;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C";
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C";
+
+-- OK when the join clause uses the same collation as the partition key.
+EXPLAIN (COSTS OFF)
+SELECT t1.c COLLATE "C", count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c COLLATE "C" GROUP BY t1.c COLLATE "C" ORDER BY t1.c COLLATE "C";
+SELECT t1.c COLLATE "C", count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c COLLATE "C" GROUP BY t1.c COLLATE "C" ORDER BY t1.c COLLATE "C";
+
+-- Few other cases where the joined partition keys are matched via equivalence
+-- class, not a join restriction clause.
+
+-- Collations of joined columns match, but the partition keys collation is different
+CREATE TABLE pagg_tab4 (c text collate case_insensitive, b text collate case_insensitive) PARTITION BY LIST (b collate "C");
+CREATE TABLE pagg_tab4_p1 PARTITION OF pagg_tab4 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab4_p2 PARTITION OF pagg_tab4 FOR VALUES IN ('B', 'A');
+INSERT INTO pagg_tab4 (b, c) SELECT substr('abAB', (i % 4) + 1 , 1), substr('abAB', (i % 2) + 1 , 1) FROM generate_series(0, 11) i;
+ANALYZE pagg_tab4;
+
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab4 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab4 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+
+-- OK when the partition key collation is same as that of the join columns
+CREATE TABLE pagg_tab5 (c text collate case_insensitive, b text collate case_insensitive) PARTITION BY LIST (c collate case_insensitive);
+CREATE TABLE pagg_tab5_p1 PARTITION OF pagg_tab5 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab5_p2 PARTITION OF pagg_tab5 FOR VALUES IN ('c', 'd');
+INSERT INTO pagg_tab5 (b, c) SELECT substr('abAB', (i % 4) + 1 , 1), substr('abAB', (i % 2) + 1 , 1) FROM generate_series(0, 5) i;
+INSERT INTO pagg_tab5 (b, c) SELECT substr('cdCD', (i % 4) + 1 , 1), substr('cdCD', (i % 2) + 1 , 1) FROM generate_series(0, 5) i;
+ANALYZE pagg_tab5;
+
+CREATE TABLE pagg_tab6 (c text collate case_insensitive, b text collate case_insensitive) PARTITION BY LIST (b collate case_insensitive);
+CREATE TABLE pagg_tab6_p1 PARTITION OF pagg_tab6 FOR VALUES IN ('a', 'b');
+CREATE TABLE pagg_tab6_p2 PARTITION OF pagg_tab6 FOR VALUES IN ('c', 'd');
+INSERT INTO pagg_tab6 (b, c) SELECT substr('abAB', (i % 4) + 1 , 1), substr('abAB', (i % 2) + 1 , 1) FROM generate_series(0, 5) i;
+INSERT INTO pagg_tab6 (b, c) SELECT substr('cdCD', (i % 4) + 1 , 1), substr('cdCD', (i % 2) + 1 , 1) FROM generate_series(0, 5) i;
+ANALYZE pagg_tab5;
+
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab5 t1 JOIN pagg_tab6 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+SELECT t1.c, count(t2.c) FROM pagg_tab5 t1 JOIN pagg_tab6 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+
+SET enable_partitionwise_join TO false;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab5 t1 JOIN pagg_tab6 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+SELECT t1.c, count(t2.c) FROM pagg_tab5 t1 JOIN pagg_tab6 t2 ON t1.c = t2.c AND t1.c = t2.b GROUP BY 1 ORDER BY t1.c COLLATE "C";
+
 DROP TABLE pagg_tab3;
+DROP TABLE pagg_tab4;
+DROP TABLE pagg_tab5;
+DROP TABLE pagg_tab6;
 
 RESET enable_partitionwise_aggregate;
 RESET max_parallel_workers_per_gather;
-- 
2.43.0

