From 52972c8933772115cb72933e0dc24e441c4534eb Mon Sep 17 00:00:00 2001
From: Richard Guo <guofenglinux@gmail.com>
Date: Wed, 17 Jan 2024 14:01:27 +0800
Subject: [PATCH v10] Avoid reparameterizing Paths when it's not suitable

When creating a nestloop join path, if the inner path is parameterized,
it is parameterized by the topmost parent of the outer rel, not the
outer rel itself.  Therefore, we need to translate the parameterization
so that the inner path is parameterized by the given outer rel itself.
Currently, this is done in reparameterize_path_by_child() during
generating join paths.

However, reparameterize_path_by_child() does not perform the translation
for sample scan paths' sampling infos, or scan paths' restriction
clauses.  This omission can lead to executor crashes, wrong results, or
planning errors, as we have already observed.

Please note that the sampling infos are contained in RangeTblEntries (as
TableSampleClause objects), and the scan restriction clauses are
contained in RelOptInfos or IndexOptInfos (as lists of RestrictInfo
objects).  We cannot just modify them on the fly during generating join
paths.  Doing so would break things if we end up using a
non-partitionwise join.

So this commit chooses not to reparameterize the path if there are
lateral references to the other relation in the sampling infos or
restriction clauses associated with the path.
---
 src/backend/optimizer/util/pathnode.c        | 108 +++++++++++-
 src/test/regress/expected/partition_join.out | 167 +++++++++++++++++++
 src/test/regress/sql/partition_join.sql      |  48 ++++++
 3 files changed, 322 insertions(+), 1 deletion(-)

diff --git a/src/backend/optimizer/util/pathnode.c b/src/backend/optimizer/util/pathnode.c
index 2185fc35a3..f11bcd1f10 100644
--- a/src/backend/optimizer/util/pathnode.c
+++ b/src/backend/optimizer/util/pathnode.c
@@ -56,6 +56,8 @@ static int	append_startup_cost_compare(const ListCell *a, const ListCell *b);
 static List *reparameterize_pathlist_by_child(PlannerInfo *root,
 											  List *pathlist,
 											  RelOptInfo *child_rel);
+static bool contain_references_to(PlannerInfo *root, List *restrictinfo_list,
+								  Relids relids);
 
 
 /*****************************************************************************
@@ -4103,13 +4105,60 @@ do { \
 	switch (nodeTag(path))
 	{
 		case T_Path:
-			FLAT_COPY_PATH(new_path, path, Path);
+			{
+				/*
+				 * If the path's restriction clauses contain lateral references
+				 * to the other relation, we can't reparameterize, because we
+				 * must not change the RelOptInfo's contents here.  (Doing so
+				 * would break things if we end up using a non-partitionwise
+				 * join.)
+				 */
+				if (contain_references_to(root, path->parent->baserestrictinfo,
+										  child_rel->top_parent_relids))
+					return NULL;
+
+				/*
+				 * If it's a SampleScan with tablesample parameters referencing
+				 * the other relation, we can't reparameterize, because we must
+				 * not change the RTE's contents here.  (Doing so would break
+				 * things if we end up using a non-partitionwise join.)
+				 */
+				if (path->pathtype == T_SampleScan)
+				{
+					Index		scan_relid = path->parent->relid;
+					RangeTblEntry *rte;
+
+					/* it should be a base rel with a tablesample clause... */
+					Assert(scan_relid > 0);
+					rte = planner_rt_fetch(scan_relid, root);
+					Assert(rte->rtekind == RTE_RELATION);
+					Assert(rte->tablesample != NULL);
+
+					if (bms_overlap(pull_varnos(root, (Node *) rte->tablesample),
+									child_rel->top_parent_relids))
+						return NULL;
+				}
+
+				FLAT_COPY_PATH(new_path, path, Path);
+			}
 			break;
 
 		case T_IndexPath:
 			{
 				IndexPath  *ipath;
 
+				/*
+				 * If the path's restriction clauses contain lateral references
+				 * to the other relation, we can't reparameterize, because we
+				 * must not change the IndexOptInfo's contents here.  (Doing so
+				 * would break things if we end up using a non-partitionwise
+				 * join.)
+				 */
+				if (contain_references_to(root,
+										  ((IndexPath *) path)->indexinfo->indrestrictinfo,
+										  child_rel->top_parent_relids))
+					return NULL;
+
 				FLAT_COPY_PATH(ipath, path, IndexPath);
 				ADJUST_CHILD_ATTRS(ipath->indexclauses);
 				new_path = (Path *) ipath;
@@ -4120,6 +4169,17 @@ do { \
 			{
 				BitmapHeapPath *bhpath;
 
+				/*
+				 * If the path's restriction clauses contain lateral references
+				 * to the other relation, we can't reparameterize, because we
+				 * must not change the RelOptInfo's contents here.  (Doing so
+				 * would break things if we end up using a non-partitionwise
+				 * join.)
+				 */
+				if (contain_references_to(root, path->parent->baserestrictinfo,
+										  child_rel->top_parent_relids))
+					return NULL;
+
 				FLAT_COPY_PATH(bhpath, path, BitmapHeapPath);
 				REPARAMETERIZE_CHILD_PATH(bhpath->bitmapqual);
 				new_path = (Path *) bhpath;
@@ -4151,6 +4211,17 @@ do { \
 				ForeignPath *fpath;
 				ReparameterizeForeignPathByChild_function rfpc_func;
 
+				/*
+				 * If the path's restriction clauses contain lateral references
+				 * to the other relation, we can't reparameterize, because we
+				 * must not change the RelOptInfo's contents here.  (Doing so
+				 * would break things if we end up using a non-partitionwise
+				 * join.)
+				 */
+				if (contain_references_to(root, path->parent->baserestrictinfo,
+										  child_rel->top_parent_relids))
+					return NULL;
+
 				FLAT_COPY_PATH(fpath, path, ForeignPath);
 				if (fpath->fdw_outerpath)
 					REPARAMETERIZE_CHILD_PATH(fpath->fdw_outerpath);
@@ -4169,6 +4240,17 @@ do { \
 			{
 				CustomPath *cpath;
 
+				/*
+				 * If the path's restriction clauses contain lateral references
+				 * to the other relation, we can't reparameterize, because we
+				 * must not change the RelOptInfo's contents here.  (Doing so
+				 * would break things if we end up using a non-partitionwise
+				 * join.)
+				 */
+				if (contain_references_to(root, path->parent->baserestrictinfo,
+										  child_rel->top_parent_relids))
+					return NULL;
+
 				FLAT_COPY_PATH(cpath, path, CustomPath);
 				REPARAMETERIZE_CHILD_PATH_LIST(cpath->custom_paths);
 				if (cpath->methods &&
@@ -4358,3 +4440,27 @@ reparameterize_pathlist_by_child(PlannerInfo *root,
 
 	return result;
 }
+
+/*
+ * contain_references_to
+ *		Detect whether any Vars in the given 'restrictinfo_list' contain
+ *		references to the given 'relids'.
+ */
+static bool
+contain_references_to(PlannerInfo *root, List *restrictinfo_list,
+					  Relids relids)
+{
+	List	   *scan_clauses;
+	List	   *vars;
+	bool		ret;
+
+	scan_clauses = extract_actual_clauses(restrictinfo_list, false);
+	vars = pull_var_clause((Node *) scan_clauses,
+						   PVC_RECURSE_AGGREGATES |
+						   PVC_RECURSE_WINDOWFUNCS |
+						   PVC_RECURSE_PLACEHOLDERS);
+	ret = bms_overlap(pull_varnos(root, (Node *) vars), relids);
+	list_free(vars);
+
+	return ret;
+}
diff --git a/src/test/regress/expected/partition_join.out b/src/test/regress/expected/partition_join.out
index 6560fe2416..72f286b893 100644
--- a/src/test/regress/expected/partition_join.out
+++ b/src/test/regress/expected/partition_join.out
@@ -505,6 +505,99 @@ SELECT t1.a, ss.t2a, ss.t2c FROM prt1 t1 LEFT JOIN LATERAL
  550 |     | 
 (12 rows)
 
+SET max_parallel_workers_per_gather = 0;
+-- If there are lateral references to the other relation in sample scan, we'd
+-- fail to generate a partitionwise join.
+EXPLAIN (COSTS OFF)
+SELECT * FROM prt1 t1 JOIN LATERAL
+			  (SELECT * FROM prt1 t2 TABLESAMPLE SYSTEM (t1.a) REPEATABLE(t1.b)) s
+			  ON t1.a = s.a;
+                       QUERY PLAN                        
+---------------------------------------------------------
+ Nested Loop
+   ->  Append
+         ->  Seq Scan on prt1_p1 t1_1
+         ->  Seq Scan on prt1_p2 t1_2
+         ->  Seq Scan on prt1_p3 t1_3
+   ->  Append
+         ->  Sample Scan on prt1_p1 t2_1
+               Sampling: system (t1.a) REPEATABLE (t1.b)
+               Filter: (t1.a = a)
+         ->  Sample Scan on prt1_p2 t2_2
+               Sampling: system (t1.a) REPEATABLE (t1.b)
+               Filter: (t1.a = a)
+         ->  Sample Scan on prt1_p3 t2_3
+               Sampling: system (t1.a) REPEATABLE (t1.b)
+               Filter: (t1.a = a)
+(15 rows)
+
+-- If there are lateral references to the other relation in scan's restriction
+-- clauses, we'd fail to generate a partitionwise join.
+EXPLAIN (COSTS OFF)
+SELECT count(*) FROM prt1 t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2 t2) s
+			  ON t1.a = s.b WHERE s.t1b = s.a;
+                          QUERY PLAN                           
+---------------------------------------------------------------
+ Aggregate
+   ->  Nested Loop
+         ->  Append
+               ->  Seq Scan on prt1_p1 t1_1
+               ->  Seq Scan on prt1_p2 t1_2
+               ->  Seq Scan on prt1_p3 t1_3
+         ->  Append
+               ->  Index Scan using iprt2_p1_b on prt2_p1 t2_1
+                     Index Cond: (b = t1.a)
+                     Filter: (t1.b = a)
+               ->  Index Scan using iprt2_p2_b on prt2_p2 t2_2
+                     Index Cond: (b = t1.a)
+                     Filter: (t1.b = a)
+               ->  Index Scan using iprt2_p3_b on prt2_p3 t2_3
+                     Index Cond: (b = t1.a)
+                     Filter: (t1.b = a)
+(16 rows)
+
+SELECT count(*) FROM prt1 t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2 t2) s
+			  ON t1.a = s.b WHERE s.t1b = s.a;
+ count 
+-------
+   100
+(1 row)
+
+EXPLAIN (COSTS OFF)
+SELECT count(*) FROM prt1 t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2 t2) s
+			  ON t1.a = s.b WHERE s.t1b = s.b;
+                             QUERY PLAN                             
+--------------------------------------------------------------------
+ Aggregate
+   ->  Nested Loop
+         ->  Append
+               ->  Seq Scan on prt1_p1 t1_1
+               ->  Seq Scan on prt1_p2 t1_2
+               ->  Seq Scan on prt1_p3 t1_3
+         ->  Append
+               ->  Index Only Scan using iprt2_p1_b on prt2_p1 t2_1
+                     Index Cond: (b = t1.a)
+                     Filter: (b = t1.b)
+               ->  Index Only Scan using iprt2_p2_b on prt2_p2 t2_2
+                     Index Cond: (b = t1.a)
+                     Filter: (b = t1.b)
+               ->  Index Only Scan using iprt2_p3_b on prt2_p3 t2_3
+                     Index Cond: (b = t1.a)
+                     Filter: (b = t1.b)
+(16 rows)
+
+SELECT count(*) FROM prt1 t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2 t2) s
+			  ON t1.a = s.b WHERE s.t1b = s.b;
+ count 
+-------
+     5
+(1 row)
+
+RESET max_parallel_workers_per_gather;
 -- bug with inadequate sort key representation
 SET enable_partitionwise_aggregate TO true;
 SET enable_hashjoin TO false;
@@ -1944,6 +2037,80 @@ SELECT * FROM prt1_l t1 LEFT JOIN LATERAL
  550 | 0 | 0002 |     |      |     |     |      
 (12 rows)
 
+SET max_parallel_workers_per_gather = 0;
+-- If there are lateral references to the other relation in sample scan, we'd
+-- fail to generate a partitionwise join.
+EXPLAIN (COSTS OFF)
+SELECT * FROM prt1_l t1 JOIN LATERAL
+			  (SELECT * FROM prt1_l t2 TABLESAMPLE SYSTEM (t1.a) REPEATABLE(t1.b)) s
+			  ON t1.a = s.a AND t1.b = s.b AND t1.c = s.c;
+                                    QUERY PLAN                                    
+----------------------------------------------------------------------------------
+ Nested Loop
+   ->  Append
+         ->  Seq Scan on prt1_l_p1 t1_1
+         ->  Seq Scan on prt1_l_p2_p1 t1_2
+         ->  Seq Scan on prt1_l_p2_p2 t1_3
+         ->  Seq Scan on prt1_l_p3_p1 t1_4
+         ->  Seq Scan on prt1_l_p3_p2 t1_5
+   ->  Append
+         ->  Sample Scan on prt1_l_p1 t2_1
+               Sampling: system (t1.a) REPEATABLE (t1.b)
+               Filter: ((t1.a = a) AND (t1.b = b) AND ((t1.c)::text = (c)::text))
+         ->  Sample Scan on prt1_l_p2_p1 t2_2
+               Sampling: system (t1.a) REPEATABLE (t1.b)
+               Filter: ((t1.a = a) AND (t1.b = b) AND ((t1.c)::text = (c)::text))
+         ->  Sample Scan on prt1_l_p2_p2 t2_3
+               Sampling: system (t1.a) REPEATABLE (t1.b)
+               Filter: ((t1.a = a) AND (t1.b = b) AND ((t1.c)::text = (c)::text))
+         ->  Sample Scan on prt1_l_p3_p1 t2_4
+               Sampling: system (t1.a) REPEATABLE (t1.b)
+               Filter: ((t1.a = a) AND (t1.b = b) AND ((t1.c)::text = (c)::text))
+         ->  Sample Scan on prt1_l_p3_p2 t2_5
+               Sampling: system (t1.a) REPEATABLE (t1.b)
+               Filter: ((t1.a = a) AND (t1.b = b) AND ((t1.c)::text = (c)::text))
+(23 rows)
+
+-- If there are lateral references to the other relation in scan's restriction
+-- clauses, we'd fail to generate a partitionwise join.
+EXPLAIN (COSTS OFF)
+SELECT COUNT(*) FROM prt1_l t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2_l t2) s
+			  ON t1.a = s.b AND t1.b = s.a AND t1.c = s.c
+			  WHERE s.t1b = s.a;
+                                              QUERY PLAN                                               
+-------------------------------------------------------------------------------------------------------
+ Aggregate
+   ->  Nested Loop
+         ->  Append
+               ->  Seq Scan on prt1_l_p1 t1_1
+               ->  Seq Scan on prt1_l_p2_p1 t1_2
+               ->  Seq Scan on prt1_l_p2_p2 t1_3
+               ->  Seq Scan on prt1_l_p3_p1 t1_4
+               ->  Seq Scan on prt1_l_p3_p2 t1_5
+         ->  Append
+               ->  Seq Scan on prt2_l_p1 t2_1
+                     Filter: ((a = t1.b) AND (t1.a = b) AND (t1.b = a) AND ((t1.c)::text = (c)::text))
+               ->  Seq Scan on prt2_l_p2_p1 t2_2
+                     Filter: ((a = t1.b) AND (t1.a = b) AND (t1.b = a) AND ((t1.c)::text = (c)::text))
+               ->  Seq Scan on prt2_l_p2_p2 t2_3
+                     Filter: ((a = t1.b) AND (t1.a = b) AND (t1.b = a) AND ((t1.c)::text = (c)::text))
+               ->  Seq Scan on prt2_l_p3_p1 t2_4
+                     Filter: ((a = t1.b) AND (t1.a = b) AND (t1.b = a) AND ((t1.c)::text = (c)::text))
+               ->  Seq Scan on prt2_l_p3_p2 t2_5
+                     Filter: ((a = t1.b) AND (t1.a = b) AND (t1.b = a) AND ((t1.c)::text = (c)::text))
+(19 rows)
+
+SELECT COUNT(*) FROM prt1_l t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2_l t2) s
+			  ON t1.a = s.b AND t1.b = s.a AND t1.c = s.c
+			  WHERE s.t1b = s.a;
+ count 
+-------
+   100
+(1 row)
+
+RESET max_parallel_workers_per_gather;
 -- join with one side empty
 EXPLAIN (COSTS OFF)
 SELECT t1.a, t1.c, t2.b, t2.c FROM (SELECT * FROM prt1_l WHERE a = 1 AND a = 2) t1 RIGHT JOIN prt2_l t2 ON t1.a = t2.b AND t1.b = t2.a AND t1.c = t2.c;
diff --git a/src/test/regress/sql/partition_join.sql b/src/test/regress/sql/partition_join.sql
index 48daf3aee3..0554a6568b 100644
--- a/src/test/regress/sql/partition_join.sql
+++ b/src/test/regress/sql/partition_join.sql
@@ -100,6 +100,33 @@ SELECT t1.a, ss.t2a, ss.t2c FROM prt1 t1 LEFT JOIN LATERAL
 			  (SELECT t2.a AS t2a, t3.a AS t3a, t2.b t2b, t2.c t2c, least(t1.a,t2.a,t3.a) FROM prt1 t2 JOIN prt2 t3 ON (t2.a = t3.b)) ss
 			  ON t1.c = ss.t2c WHERE (t1.b + coalesce(ss.t2b, 0)) = 0 ORDER BY t1.a;
 
+SET max_parallel_workers_per_gather = 0;
+-- If there are lateral references to the other relation in sample scan, we'd
+-- fail to generate a partitionwise join.
+EXPLAIN (COSTS OFF)
+SELECT * FROM prt1 t1 JOIN LATERAL
+			  (SELECT * FROM prt1 t2 TABLESAMPLE SYSTEM (t1.a) REPEATABLE(t1.b)) s
+			  ON t1.a = s.a;
+
+-- If there are lateral references to the other relation in scan's restriction
+-- clauses, we'd fail to generate a partitionwise join.
+EXPLAIN (COSTS OFF)
+SELECT count(*) FROM prt1 t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2 t2) s
+			  ON t1.a = s.b WHERE s.t1b = s.a;
+SELECT count(*) FROM prt1 t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2 t2) s
+			  ON t1.a = s.b WHERE s.t1b = s.a;
+
+EXPLAIN (COSTS OFF)
+SELECT count(*) FROM prt1 t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2 t2) s
+			  ON t1.a = s.b WHERE s.t1b = s.b;
+SELECT count(*) FROM prt1 t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2 t2) s
+			  ON t1.a = s.b WHERE s.t1b = s.b;
+RESET max_parallel_workers_per_gather;
+
 -- bug with inadequate sort key representation
 SET enable_partitionwise_aggregate TO true;
 SET enable_hashjoin TO false;
@@ -387,6 +414,27 @@ SELECT * FROM prt1_l t1 LEFT JOIN LATERAL
 			  (SELECT t2.a AS t2a, t2.c AS t2c, t2.b AS t2b, t3.b AS t3b, least(t1.a,t2.a,t3.b) FROM prt1_l t2 JOIN prt2_l t3 ON (t2.a = t3.b AND t2.c = t3.c)) ss
 			  ON t1.a = ss.t2a AND t1.c = ss.t2c WHERE t1.b = 0 ORDER BY t1.a;
 
+SET max_parallel_workers_per_gather = 0;
+-- If there are lateral references to the other relation in sample scan, we'd
+-- fail to generate a partitionwise join.
+EXPLAIN (COSTS OFF)
+SELECT * FROM prt1_l t1 JOIN LATERAL
+			  (SELECT * FROM prt1_l t2 TABLESAMPLE SYSTEM (t1.a) REPEATABLE(t1.b)) s
+			  ON t1.a = s.a AND t1.b = s.b AND t1.c = s.c;
+
+-- If there are lateral references to the other relation in scan's restriction
+-- clauses, we'd fail to generate a partitionwise join.
+EXPLAIN (COSTS OFF)
+SELECT COUNT(*) FROM prt1_l t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2_l t2) s
+			  ON t1.a = s.b AND t1.b = s.a AND t1.c = s.c
+			  WHERE s.t1b = s.a;
+SELECT COUNT(*) FROM prt1_l t1 LEFT JOIN LATERAL
+			  (SELECT t1.b AS t1b, t2.* FROM prt2_l t2) s
+			  ON t1.a = s.b AND t1.b = s.a AND t1.c = s.c
+			  WHERE s.t1b = s.a;
+RESET max_parallel_workers_per_gather;
+
 -- join with one side empty
 EXPLAIN (COSTS OFF)
 SELECT t1.a, t1.c, t2.b, t2.c FROM (SELECT * FROM prt1_l WHERE a = 1 AND a = 2) t1 RIGHT JOIN prt2_l t2 ON t1.a = t2.b AND t1.b = t2.a AND t1.c = t2.c;
-- 
2.31.0

