From 429e75f1ebc5a88f81c8c3702b43432ec4489c26 Mon Sep 17 00:00:00 2001
From: Richard Guo <guofenglinux@gmail.com>
Date: Fri, 30 Dec 2022 10:35:36 +0800
Subject: [PATCH v2] Check lateral refs within PHVs for memoize cache keys

---
 .../postgres_fdw/expected/postgres_fdw.out    | 12 +++--
 src/backend/optimizer/plan/initsplan.c        | 51 +++++++++++++++++++
 src/test/regress/expected/memoize.out         | 31 +++++++++++
 src/test/regress/sql/memoize.sql              | 11 ++++
 4 files changed, 101 insertions(+), 4 deletions(-)

diff --git a/contrib/postgres_fdw/expected/postgres_fdw.out b/contrib/postgres_fdw/expected/postgres_fdw.out
index 1a2c2a665c..145ae37b1c 100644
--- a/contrib/postgres_fdw/expected/postgres_fdw.out
+++ b/contrib/postgres_fdw/expected/postgres_fdw.out
@@ -3650,15 +3650,19 @@ ORDER BY ref_0."C 1";
          ->  Index Scan using t1_pkey on "S 1"."T 1" ref_0
                Output: ref_0."C 1", ref_0.c2, ref_0.c3, ref_0.c4, ref_0.c5, ref_0.c6, ref_0.c7, ref_0.c8
                Index Cond: (ref_0."C 1" < 10)
-         ->  Foreign Scan on public.ft1 ref_1
-               Output: ref_1.c3, ref_0.c2
-               Remote SQL: SELECT c3 FROM "S 1"."T 1" WHERE ((c3 = '00001'))
+         ->  Memoize
+               Output: ref_1.c3, (ref_0.c2)
+               Cache Key: ref_0.c2
+               Cache Mode: binary
+               ->  Foreign Scan on public.ft1 ref_1
+                     Output: ref_1.c3, ref_0.c2
+                     Remote SQL: SELECT c3 FROM "S 1"."T 1" WHERE ((c3 = '00001'))
    ->  Materialize
          Output: ref_3.c3
          ->  Foreign Scan on public.ft2 ref_3
                Output: ref_3.c3
                Remote SQL: SELECT c3 FROM "S 1"."T 1" WHERE ((c3 = '00001'))
-(15 rows)
+(19 rows)
 
 SELECT ref_0.c2, subq_1.*
 FROM
diff --git a/src/backend/optimizer/plan/initsplan.c b/src/backend/optimizer/plan/initsplan.c
index fd8cbb1dc7..68ff0975d1 100644
--- a/src/backend/optimizer/plan/initsplan.c
+++ b/src/backend/optimizer/plan/initsplan.c
@@ -523,12 +523,49 @@ create_lateral_join_info(PlannerInfo *root)
 		PlaceHolderInfo *phinfo = (PlaceHolderInfo *) lfirst(lc);
 		Relids		eval_at = phinfo->ph_eval_at;
 		int			varno;
+		List	   *vars;
+		List	   *ph_lateral_vars;
+		ListCell   *cell;
 
 		if (phinfo->ph_lateral == NULL)
 			continue;			/* PHV is uninteresting if no lateral refs */
 
 		found_laterals = true;
 
+		/* Fetch Vars and PHVs of lateral references within PlaceHolderVars */
+		vars = pull_vars_of_level((Node *) phinfo->ph_var->phexpr, 0);
+
+		ph_lateral_vars = NIL;
+		foreach(cell, vars)
+		{
+			Node	   *node = (Node *) lfirst(cell);
+
+			node = copyObject(node);
+			if (IsA(node, Var))
+			{
+				Var		   *var = (Var *) node;
+
+				Assert(var->varlevelsup == 0);
+
+				if (bms_is_member(var->varno, phinfo->ph_lateral))
+					ph_lateral_vars = lappend(ph_lateral_vars, node);
+			}
+			else if (IsA(node, PlaceHolderVar))
+			{
+				PlaceHolderVar *phv = (PlaceHolderVar *) node;
+
+				Assert(phv->phlevelsup == 0);
+
+				if (bms_is_subset(find_placeholder_info(root, phv)->ph_eval_at,
+								  phinfo->ph_lateral))
+					ph_lateral_vars = lappend(ph_lateral_vars, node);
+			}
+			else
+				Assert(false);
+		}
+
+		list_free(vars);
+
 		if (bms_get_singleton_member(eval_at, &varno))
 		{
 			/* Evaluation site is a baserel */
@@ -540,6 +577,13 @@ create_lateral_join_info(PlannerInfo *root)
 			brel->lateral_relids =
 				bms_add_members(brel->lateral_relids,
 								phinfo->ph_lateral);
+
+			/*
+			 * Remember the lateral references. We need this info for searching
+			 * memoize cache keys.
+			 */
+			brel->lateral_vars =
+				list_concat(brel->lateral_vars, ph_lateral_vars);
 		}
 		else
 		{
@@ -551,6 +595,13 @@ create_lateral_join_info(PlannerInfo *root)
 
 				brel->lateral_relids = bms_add_members(brel->lateral_relids,
 													   phinfo->ph_lateral);
+
+				/*
+				 * Remember the lateral references. We need this info for
+				 * searching memoize cache keys.
+				 */
+				brel->lateral_vars = list_concat(brel->lateral_vars,
+												 ph_lateral_vars);
 			}
 		}
 	}
diff --git a/src/test/regress/expected/memoize.out b/src/test/regress/expected/memoize.out
index de43afa76e..cf238661d2 100644
--- a/src/test/regress/expected/memoize.out
+++ b/src/test/regress/expected/memoize.out
@@ -90,6 +90,37 @@ WHERE t1.unique1 < 1000;
   1000 | 9.5000000000000000
 (1 row)
 
+-- Try with LATERAL references within PlaceHolderVars
+SELECT explain_memoize('
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;', false);
+                                      explain_memoize                                      
+-------------------------------------------------------------------------------------------
+ Aggregate (actual rows=1 loops=N)
+   ->  Nested Loop (actual rows=1000 loops=N)
+         ->  Seq Scan on tenk1 t1 (actual rows=1000 loops=N)
+               Filter: (unique1 < 1000)
+               Rows Removed by Filter: 9000
+         ->  Memoize (actual rows=1 loops=N)
+               Cache Key: t1.twenty
+               Cache Mode: binary
+               Hits: 980  Misses: 20  Evictions: Zero  Overflows: 0  Memory Usage: NkB
+               ->  Index Only Scan using tenk1_unique1 on tenk1 t2 (actual rows=1 loops=N)
+                     Filter: (t1.twenty = unique1)
+                     Rows Removed by Filter: 9999
+                     Heap Fetches: N
+(13 rows)
+
+-- And check we get the expected results.
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;
+ count |        avg         
+-------+--------------------
+  1000 | 9.5000000000000000
+(1 row)
+
 -- Reduce work_mem and hash_mem_multiplier so that we see some cache evictions
 SET work_mem TO '64kB';
 SET hash_mem_multiplier TO 1.0;
diff --git a/src/test/regress/sql/memoize.sql b/src/test/regress/sql/memoize.sql
index 17c5b4bfab..512c6dfeba 100644
--- a/src/test/regress/sql/memoize.sql
+++ b/src/test/regress/sql/memoize.sql
@@ -55,6 +55,17 @@ SELECT COUNT(*),AVG(t2.unique1) FROM tenk1 t1,
 LATERAL (SELECT t2.unique1 FROM tenk1 t2 WHERE t1.twenty = t2.unique1) t2
 WHERE t1.unique1 < 1000;
 
+-- Try with LATERAL references within PlaceHolderVars
+SELECT explain_memoize('
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;', false);
+
+-- And check we get the expected results.
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;
+
 -- Reduce work_mem and hash_mem_multiplier so that we see some cache evictions
 SET work_mem TO '64kB';
 SET hash_mem_multiplier TO 1.0;
-- 
2.31.0

