From 07784159aea4de7b5614fd7a39bb6eeafe07cb22 Mon Sep 17 00:00:00 2001
From: Amit Langote <amitlan@postgresql.org>
Date: Sat, 15 Feb 2025 16:39:54 +0900
Subject: [PATCH] Fix an oversight in cbc127917 for MERGE handling

ExecInitModifyTable() should also trim MERGE-related lists to exclude
result relations pruned during initial pruning.

Reported-by: Alexander Lakhin <exclusion@gmail.com> (via sqlsmith)
Discussion: https://postgr.es/m/e72c94d9-e5f9-4753-9bc1-69d72bd54b8a@gmail.com
---
 src/backend/executor/nodeModifyTable.c        | 24 ++++++++++---
 src/include/nodes/execnodes.h                 |  7 ++--
 src/test/regress/expected/partition_prune.out | 34 +++++++++++++++++++
 src/test/regress/sql/partition_prune.sql      | 13 +++++++
 4 files changed, 72 insertions(+), 6 deletions(-)

diff --git a/src/backend/executor/nodeModifyTable.c b/src/backend/executor/nodeModifyTable.c
index a15e7863b0d..e0f859ba966 100644
--- a/src/backend/executor/nodeModifyTable.c
+++ b/src/backend/executor/nodeModifyTable.c
@@ -3667,14 +3667,14 @@ ExecInitMerge(ModifyTableState *mtstate, EState *estate)
 	 * anything here, do so there too.
 	 */
 	i = 0;
-	foreach(lc, node->mergeActionLists)
+	foreach(lc, mtstate->mt_mergeActionLists)
 	{
 		List	   *mergeActionList = lfirst(lc);
 		Node	   *joinCondition;
 		TupleDesc	relationDesc;
 		ListCell   *l;
 
-		joinCondition = (Node *) list_nth(node->mergeJoinConditions, i);
+		joinCondition = (Node *) list_nth(mtstate->mt_mergeJoinConditions, i);
 		resultRelInfo = mtstate->resultRelInfo + i;
 		i++;
 		relationDesc = RelationGetDescr(resultRelInfo->ri_RelationDesc);
@@ -4475,6 +4475,8 @@ ExecInitModifyTable(ModifyTable *node, EState *estate, int eflags)
 	List	   *withCheckOptionLists = NIL;
 	List	   *returningLists = NIL;
 	List	   *updateColnosLists = NIL;
+	List	   *mergeActionLists = NIL;
+	List	   *mergeJoinConditions = NIL;
 	ResultRelInfo *resultRelInfo;
 	List	   *arowmarks;
 	ListCell   *l;
@@ -4518,6 +4520,18 @@ ExecInitModifyTable(ModifyTable *node, EState *estate, int eflags)
 
 				updateColnosLists = lappend(updateColnosLists, updateColnosList);
 			}
+			if (node->mergeActionLists)
+			{
+				List	   *mergeActionList = list_nth(node->mergeActionLists, i);
+
+				mergeActionLists = lappend(mergeActionLists, mergeActionList);
+			}
+			if (node->mergeJoinConditions)
+			{
+				List	   *mergeJoinCondition = list_nth(node->mergeJoinConditions, i);
+
+				mergeJoinConditions = lappend(mergeJoinConditions, mergeJoinCondition);
+			}
 		}
 		i++;
 	}
@@ -4544,6 +4558,8 @@ ExecInitModifyTable(ModifyTable *node, EState *estate, int eflags)
 	mtstate->mt_merge_updated = 0;
 	mtstate->mt_merge_deleted = 0;
 	mtstate->mt_updateColnosLists = updateColnosLists;
+	mtstate->mt_mergeActionLists = mergeActionLists;
+	mtstate->mt_mergeJoinConditions = mergeJoinConditions;
 
 	/*----------
 	 * Resolve the target relation. This is the same as:
@@ -4599,8 +4615,8 @@ ExecInitModifyTable(ModifyTable *node, EState *estate, int eflags)
 		Index		resultRelation = lfirst_int(l);
 		List	   *mergeActions = NIL;
 
-		if (node->mergeActionLists)
-			mergeActions = list_nth(node->mergeActionLists, i);
+		if (mergeActionLists)
+			mergeActions = list_nth(mergeActionLists, i);
 
 		if (resultRelInfo != mtstate->rootResultRelInfo)
 		{
diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h
index e2d1dc1e067..66fa6133343 100644
--- a/src/include/nodes/execnodes.h
+++ b/src/include/nodes/execnodes.h
@@ -1448,10 +1448,13 @@ typedef struct ModifyTableState
 	double		mt_merge_deleted;
 
 	/*
-	 * List of valid updateColnosLists.  Contains only those belonging to
-	 * unpruned relations from ModifyTable.updateColnosLists.
+	 * Lists of valid updateColnosListsm, mergeActionLists, and
+	 * mergeJoinConditions.  These contain only those belonging to unpruned
+	 * relations from the respective Lists in the ModifyTable.
 	 */
 	List	   *mt_updateColnosLists;
+	List	   *mt_mergeActionLists;
+	List	   *mt_mergeJoinConditions;
 } ModifyTableState;
 
 /* ----------------
diff --git a/src/test/regress/expected/partition_prune.out b/src/test/regress/expected/partition_prune.out
index e667503c961..3261da28219 100644
--- a/src/test/regress/expected/partition_prune.out
+++ b/src/test/regress/expected/partition_prune.out
@@ -4513,5 +4513,39 @@ execute update_part_abc_view (2, 'a');
 ERROR:  new row violates check option for view "part_abc_view"
 DETAIL:  Failing row contains (2, a, t).
 deallocate update_part_abc_view;
+-- Runtime pruning on MERGE using a stable function
+create function stable_one() returns int as $$ begin return 1; end; $$ language plpgsql stable;
+explain (costs off)
+merge into part_abc_view pt
+using (select stable_one() as pid) as q join part_abc_1 pt1 on (q.pid = pt1.a) on pt.a = pt1.a
+when matched then delete returning pt.a;
+                              QUERY PLAN                               
+-----------------------------------------------------------------------
+ Merge on part_abc
+   Merge on part_abc_1
+   ->  Nested Loop
+         ->  Append
+               Subplans Removed: 1
+               ->  Seq Scan on part_abc_1
+                     Filter: ((b <> 'a'::text) AND (a = stable_one()))
+         ->  Materialize
+               ->  Seq Scan on part_abc_1 pt1
+                     Filter: (a = stable_one())
+(10 rows)
+
+merge into part_abc_view pt
+using (select stable_one() as pid) as q join part_abc_1 pt1 on (q.pid = pt1.a) on pt.a = pt1.a
+when matched then delete returning pt.a;
+ a 
+---
+ 1
+(1 row)
+
+table part_abc_view;
+ a | b | c 
+---+---+---
+ 2 | c | t
+(1 row)
+
 drop view part_abc_view;
 drop table part_abc;
diff --git a/src/test/regress/sql/partition_prune.sql b/src/test/regress/sql/partition_prune.sql
index 730545e86a7..b27f3ace73c 100644
--- a/src/test/regress/sql/partition_prune.sql
+++ b/src/test/regress/sql/partition_prune.sql
@@ -1372,5 +1372,18 @@ execute update_part_abc_view (1, 'd');
 explain (costs off) execute update_part_abc_view (2, 'a');
 execute update_part_abc_view (2, 'a');
 deallocate update_part_abc_view;
+
+-- Runtime pruning on MERGE using a stable function
+create function stable_one() returns int as $$ begin return 1; end; $$ language plpgsql stable;
+explain (costs off)
+merge into part_abc_view pt
+using (select stable_one() as pid) as q join part_abc_1 pt1 on (q.pid = pt1.a) on pt.a = pt1.a
+when matched then delete returning pt.a;
+
+merge into part_abc_view pt
+using (select stable_one() as pid) as q join part_abc_1 pt1 on (q.pid = pt1.a) on pt.a = pt1.a
+when matched then delete returning pt.a;
+table part_abc_view;
+
 drop view part_abc_view;
 drop table part_abc;
-- 
2.43.0

