From ccf9f6cfe2a8078961ce2250ec56884ac12b6e01 Mon Sep 17 00:00:00 2001
From: Amit Langote <amitlan@postgresql.org>
Date: Thu, 31 Oct 2024 21:58:30 +0900
Subject: [PATCH v3 2/2] Disallow partitionwise join when collation doesn't
 match

Insist that the collation used for joining matches exactly with the
collation used for partitioning to allow using partitionwise join.

Reported-by: Tender Wang <tndrwang@gmail.com>
Author: Jian He <jian.universality@gmail.com>
Author: Amit Langote <amitlangote09@gmail.com>
Reviewed-by: Tender Wang <tndrwang@gmail.com>
Reviewed-by: Junwang Zhao <zhjwpku@gmail.com>
Discussion: https://postgr.es/m/18568-2a9afb6b9f7e6ed3@postgresql.org
Discussion: https://postgr.es/m/tencent_9D9103CDA420C07768349CC1DFF88465F90A@qq.com
Discussion: https://postgr.es/m/CAHewXNno_HKiQ6PqyLYfuqDtwp7KKHZiH1J7Pqyz0nr+PS2Dwg@mail.gmail.com
---
 src/backend/optimizer/util/relnode.c          | 26 +++++-
 .../regress/expected/collate.icu.utf8.out     | 89 +++++++++++++++++++
 src/test/regress/sql/collate.icu.utf8.sql     | 17 ++++
 3 files changed, 130 insertions(+), 2 deletions(-)

diff --git a/src/backend/optimizer/util/relnode.c b/src/backend/optimizer/util/relnode.c
index d7266e4cdb..0dfcd3a556 100644
--- a/src/backend/optimizer/util/relnode.c
+++ b/src/backend/optimizer/util/relnode.c
@@ -2258,6 +2258,8 @@ have_partkey_equi_join(PlannerInfo *root, RelOptInfo *joinrel,
 		{
 			Node	   *expr1 = (Node *) lfirst(lc);
 			ListCell   *lc2;
+			Oid			partcoll1 = rel1->part_scheme->partcollation[ipk];
+			Oid			exprcoll1 = exprCollation(expr1);
 
 			foreach(lc2, rel2->partexprs[ipk])
 			{
@@ -2265,8 +2267,18 @@ have_partkey_equi_join(PlannerInfo *root, RelOptInfo *joinrel,
 
 				if (exprs_known_equal(root, expr1, expr2, btree_opfamily))
 				{
-					pk_known_equal[ipk] = true;
-					break;
+					Oid		partcoll2 = rel1->part_scheme->partcollation[ipk];
+					Oid		exprcoll2 = exprCollation(expr2);
+
+					/*
+					 * Ensure that the collations match those of the partition
+					 * keys.
+					 */
+					if (partcoll1 == exprcoll1 && partcoll2 == exprcoll2)
+					{
+						pk_known_equal[ipk] = true;
+						break;
+					}
 				}
 			}
 			if (pk_known_equal[ipk])
@@ -2301,6 +2313,7 @@ static int
 match_expr_to_partition_keys(Expr *expr, RelOptInfo *rel, bool strict_op)
 {
 	int			cnt;
+	Oid			exprcoll = exprCollation((Node *) expr);
 
 	/* This function should be called only for partitioned relations. */
 	Assert(rel->part_scheme);
@@ -2314,10 +2327,15 @@ match_expr_to_partition_keys(Expr *expr, RelOptInfo *rel, bool strict_op)
 	for (cnt = 0; cnt < rel->part_scheme->partnatts; cnt++)
 	{
 		ListCell   *lc;
+		Oid			partcoll = rel->part_scheme->partcollation[cnt];
 
 		/* We can always match to the non-nullable partition keys. */
 		foreach(lc, rel->partexprs[cnt])
 		{
+			if (OidIsValid(partcoll) && OidIsValid(exprcoll) &&
+				partcoll != exprcoll)
+				return -1;
+
 			if (equal(lfirst(lc), expr))
 				return cnt;
 		}
@@ -2334,6 +2352,10 @@ match_expr_to_partition_keys(Expr *expr, RelOptInfo *rel, bool strict_op)
 		 */
 		foreach(lc, rel->nullable_partexprs[cnt])
 		{
+			if (OidIsValid(partcoll) && OidIsValid(exprcoll) &&
+				partcoll != exprcoll)
+				return -1;
+
 			if (equal(lfirst(lc), expr))
 				return cnt;
 		}
diff --git a/src/test/regress/expected/collate.icu.utf8.out b/src/test/regress/expected/collate.icu.utf8.out
index 3d6d8f9a20..56239737cd 100644
--- a/src/test/regress/expected/collate.icu.utf8.out
+++ b/src/test/regress/expected/collate.icu.utf8.out
@@ -2136,6 +2136,95 @@ SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
  b |     2
 (4 rows)
 
+-- Partitionwise join should not be allowed too when the collation used by the
+-- join keys doesn't match the partition key
+SET enable_partitionwise_join TO false;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY 1;
+                         QUERY PLAN                          
+-------------------------------------------------------------
+ Sort
+   Sort Key: t1.c COLLATE case_insensitive
+   ->  HashAggregate
+         Group Key: t1.c
+         ->  Hash Join
+               Hash Cond: (t1.c = t2.c)
+               ->  Append
+                     ->  Seq Scan on pagg_tab3_p2 t1_1
+                     ->  Seq Scan on pagg_tab3_p1 t1_2
+               ->  Hash
+                     ->  Append
+                           ->  Seq Scan on pagg_tab3_p2 t2_1
+                           ->  Seq Scan on pagg_tab3_p1 t2_2
+(13 rows)
+
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY 1;
+ c | count 
+---+-------
+ A |    16
+ B |    16
+(2 rows)
+
+SET enable_partitionwise_join TO true;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY 1;
+                         QUERY PLAN                          
+-------------------------------------------------------------
+ Sort
+   Sort Key: t1.c COLLATE case_insensitive
+   ->  HashAggregate
+         Group Key: t1.c
+         ->  Hash Join
+               Hash Cond: (t1.c = t2.c)
+               ->  Append
+                     ->  Seq Scan on pagg_tab3_p2 t1_1
+                     ->  Seq Scan on pagg_tab3_p1 t1_2
+               ->  Hash
+                     ->  Append
+                           ->  Seq Scan on pagg_tab3_p2 t2_1
+                           ->  Seq Scan on pagg_tab3_p1 t2_2
+(13 rows)
+
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY 1;
+ c | count 
+---+-------
+ A |    16
+ B |    16
+(2 rows)
+
+-- OK when the join key uses the same collation.
+EXPLAIN (COSTS OFF)
+SELECT t1.c COLLATE "C", count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c COLLATE "C" GROUP BY t1.c COLLATE "C" ORDER BY 1;
+                            QUERY PLAN                            
+------------------------------------------------------------------
+ Sort
+   Sort Key: ((t1.c)::text) COLLATE "C"
+   ->  Append
+         ->  HashAggregate
+               Group Key: (t1.c)::text
+               ->  Hash Join
+                     Hash Cond: ((t1.c)::text = (t2.c)::text)
+                     ->  Seq Scan on pagg_tab3_p2 t1
+                     ->  Hash
+                           ->  Seq Scan on pagg_tab3_p2 t2
+         ->  HashAggregate
+               Group Key: (t1_1.c)::text
+               ->  Hash Join
+                     Hash Cond: ((t1_1.c)::text = (t2_1.c)::text)
+                     ->  Seq Scan on pagg_tab3_p1 t1_1
+                     ->  Hash
+                           ->  Seq Scan on pagg_tab3_p1 t2_1
+(17 rows)
+
+SELECT t1.c COLLATE "C", count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c COLLATE "C" GROUP BY t1.c COLLATE "C" ORDER BY 1;
+ c | count 
+---+-------
+ A |     4
+ B |     4
+ a |     4
+ b |     4
+(4 rows)
+
 RESET enable_partitionwise_aggregate;
 RESET max_parallel_workers_per_gather;
 RESET enable_incremental_sort;
diff --git a/src/test/regress/sql/collate.icu.utf8.sql b/src/test/regress/sql/collate.icu.utf8.sql
index eaf9a99be7..b5f872742b 100644
--- a/src/test/regress/sql/collate.icu.utf8.sql
+++ b/src/test/regress/sql/collate.icu.utf8.sql
@@ -828,6 +828,23 @@ EXPLAIN (COSTS OFF)
 SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
 SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1;
 
+-- Partitionwise join should not be allowed too when the collation used by the
+-- join keys doesn't match the partition key
+SET enable_partitionwise_join TO false;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY 1;
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY 1;
+
+SET enable_partitionwise_join TO true;
+EXPLAIN (COSTS OFF)
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY 1;
+SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY 1;
+
+-- OK when the join key uses the same collation.
+EXPLAIN (COSTS OFF)
+SELECT t1.c COLLATE "C", count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c COLLATE "C" GROUP BY t1.c COLLATE "C" ORDER BY 1;
+SELECT t1.c COLLATE "C", count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c COLLATE "C" GROUP BY t1.c COLLATE "C" ORDER BY 1;
+
 RESET enable_partitionwise_aggregate;
 RESET max_parallel_workers_per_gather;
 RESET enable_incremental_sort;
-- 
2.43.0

