From 68c8838061213da8abdd54a602da40c7ed5c98a8 Mon Sep 17 00:00:00 2001
From: Richard Guo <guofenglinux@gmail.com>
Date: Mon, 17 Jun 2024 12:32:00 +0900
Subject: [PATCH v7] Check lateral refs within PHVs for memoize cache keys

If we intend to generate a memoize node on top of a path, we need cache
keys of some sort.  Currently we search for the cache keys in the
parameterized clauses of the path as well as the lateral_vars of its
parent.  However, it turns out that this is not sufficient because their
might be lateral references derived from PlaceHolderVars, which we fail
to take into consideration.

This oversight can cause us to miss opportunities to utilize the Memoize
node.  Moreover, in some plans, the failure to recognize all the cache
keys could result in performance regressions.  This is because without
identifying all the cache keys, we would need to purge the entire cache
every time we get a new outer tuple during execution.

This patch fixes this issue by extracting lateral Vars from within
PlaceHolderVars and subsequently including them in the cache keys.

This patch also includes a comment clarifying that Memoize nodes are
currently not added on top of join relation paths.  This explains why
this patch only considers PHVs that are due to be evaluated at baserels.
---
 .../postgres_fdw/expected/postgres_fdw.out    | 12 ++-
 src/backend/optimizer/path/joinpath.c         | 98 ++++++++++++++++++-
 src/test/regress/expected/memoize.out         | 63 ++++++++++++
 src/test/regress/sql/memoize.sql              | 24 +++++
 4 files changed, 189 insertions(+), 8 deletions(-)

diff --git a/contrib/postgres_fdw/expected/postgres_fdw.out b/contrib/postgres_fdw/expected/postgres_fdw.out
index ea566d5034..8cf222caca 100644
--- a/contrib/postgres_fdw/expected/postgres_fdw.out
+++ b/contrib/postgres_fdw/expected/postgres_fdw.out
@@ -3774,15 +3774,19 @@ ORDER BY ref_0."C 1";
          ->  Index Scan using t1_pkey on "S 1"."T 1" ref_0
                Output: ref_0."C 1", ref_0.c2, ref_0.c3, ref_0.c4, ref_0.c5, ref_0.c6, ref_0.c7, ref_0.c8
                Index Cond: (ref_0."C 1" < 10)
-         ->  Foreign Scan on public.ft1 ref_1
-               Output: ref_1.c3, ref_0.c2
-               Remote SQL: SELECT c3 FROM "S 1"."T 1" WHERE ((c3 = '00001'))
+         ->  Memoize
+               Output: ref_1.c3, (ref_0.c2)
+               Cache Key: ref_0.c2
+               Cache Mode: binary
+               ->  Foreign Scan on public.ft1 ref_1
+                     Output: ref_1.c3, ref_0.c2
+                     Remote SQL: SELECT c3 FROM "S 1"."T 1" WHERE ((c3 = '00001'))
    ->  Materialize
          Output: ref_3.c3
          ->  Foreign Scan on public.ft2 ref_3
                Output: ref_3.c3
                Remote SQL: SELECT c3 FROM "S 1"."T 1" WHERE ((c3 = '00001'))
-(15 rows)
+(19 rows)
 
 SELECT ref_0.c2, subq_1.*
 FROM
diff --git a/src/backend/optimizer/path/joinpath.c b/src/backend/optimizer/path/joinpath.c
index 5be8da9e09..eb61bb096d 100644
--- a/src/backend/optimizer/path/joinpath.c
+++ b/src/backend/optimizer/path/joinpath.c
@@ -23,6 +23,7 @@
 #include "optimizer/optimizer.h"
 #include "optimizer/pathnode.h"
 #include "optimizer/paths.h"
+#include "optimizer/placeholder.h"
 #include "optimizer/planmain.h"
 #include "utils/typcache.h"
 
@@ -438,10 +439,11 @@ have_unsafe_outer_join_ref(PlannerInfo *root,
 static bool
 paraminfo_get_equal_hashops(PlannerInfo *root, ParamPathInfo *param_info,
 							RelOptInfo *outerrel, RelOptInfo *innerrel,
-							List **param_exprs, List **operators,
-							bool *binary_mode)
+							List *ph_lateral_vars, List **param_exprs,
+							List **operators, bool *binary_mode)
 
 {
+	List	   *lateral_vars;
 	ListCell   *lc;
 
 	*param_exprs = NIL;
@@ -521,7 +523,8 @@ paraminfo_get_equal_hashops(PlannerInfo *root, ParamPathInfo *param_info,
 	}
 
 	/* Now add any lateral vars to the cache key too */
-	foreach(lc, innerrel->lateral_vars)
+	lateral_vars = list_concat(ph_lateral_vars, innerrel->lateral_vars);
+	foreach(lc, lateral_vars)
 	{
 		Node	   *expr = (Node *) lfirst(lc);
 		TypeCacheEntry *typentry;
@@ -572,10 +575,88 @@ paraminfo_get_equal_hashops(PlannerInfo *root, ParamPathInfo *param_info,
 	return true;
 }
 
+/*
+ * extract_lateral_vars_from_PHVs
+ *	  Extract lateral references within PlaceHolderVars that are due to be
+ *	  evaluated at 'innerrelids'.
+ */
+static List *
+extract_lateral_vars_from_PHVs(PlannerInfo *root, Relids innerrelids)
+{
+	List	   *ph_lateral_vars = NIL;
+	ListCell   *lc;
+
+	/* Nothing would be found if the query contains no LATERAL RTEs */
+	if (!root->hasLateralRTEs)
+		return NIL;
+
+	/*
+	 * No need to consider PHVs that are due to be evaluated at joinrels, since
+	 * we do not add Memoize nodes on top of joinrel paths.
+	 */
+	if (bms_membership(innerrelids) == BMS_MULTIPLE)
+		return NIL;
+
+	foreach(lc, root->placeholder_list)
+	{
+		PlaceHolderInfo *phinfo = (PlaceHolderInfo *) lfirst(lc);
+		List	   *vars;
+		ListCell   *cell;
+
+		/* PHV is uninteresting if no lateral refs */
+		if (phinfo->ph_lateral == NULL)
+			continue;
+
+		/* PHV is uninteresting if not due to be evaluated at innerrelids */
+		if (!bms_equal(phinfo->ph_eval_at, innerrelids))
+			continue;
+
+		/* Fetch Vars and PHVs of lateral references within PlaceHolderVars */
+		vars = pull_vars_of_level((Node *) phinfo->ph_var->phexpr, 0);
+		foreach(cell, vars)
+		{
+			Node	   *node = (Node *) lfirst(cell);
+
+			node = copyObject(node);
+			if (IsA(node, Var))
+			{
+				Var		   *var = (Var *) node;
+
+				Assert(var->varlevelsup == 0);
+
+				if (bms_is_member(var->varno, phinfo->ph_lateral))
+					ph_lateral_vars = lappend(ph_lateral_vars, node);
+			}
+			else if (IsA(node, PlaceHolderVar))
+			{
+				PlaceHolderVar *phv = (PlaceHolderVar *) node;
+
+				Assert(phv->phlevelsup == 0);
+
+				if (bms_is_subset(find_placeholder_info(root, phv)->ph_eval_at,
+								  phinfo->ph_lateral))
+					ph_lateral_vars = lappend(ph_lateral_vars, node);
+			}
+			else
+				Assert(false);
+		}
+
+		list_free(vars);
+	}
+
+	return ph_lateral_vars;
+}
+
 /*
  * get_memoize_path
  *		If possible, make and return a Memoize path atop of 'inner_path'.
  *		Otherwise return NULL.
+ *
+ * Note that currently we do not add Memoize nodes on top of join relation
+ * paths.  This is because the ParamPathInfos for join relation paths do not
+ * maintain ppi_clauses, as the set of relevant clauses varies depending on how
+ * the join is formed.  In addition, joinrels do not maintain lateral_vars.  So
+ * we do not have a way to extract cache keys from joinrels.
  */
 static Path *
 get_memoize_path(PlannerInfo *root, RelOptInfo *innerrel,
@@ -587,6 +668,7 @@ get_memoize_path(PlannerInfo *root, RelOptInfo *innerrel,
 	List	   *hash_operators;
 	ListCell   *lc;
 	bool		binary_mode;
+	List	   *ph_lateral_vars;
 
 	/* Obviously not if it's disabled */
 	if (!enable_memoize)
@@ -601,6 +683,12 @@ get_memoize_path(PlannerInfo *root, RelOptInfo *innerrel,
 	if (outer_path->parent->rows < 2)
 		return NULL;
 
+	/*
+	 * Extract lateral Vars within PlaceHolderVars that are due to be evaluated
+	 * at innerrel.  These lateral Vars could be used as memoize cache keys.
+	 */
+	ph_lateral_vars = extract_lateral_vars_from_PHVs(root, innerrel->relids);
+
 	/*
 	 * We can only have a memoize node when there's some kind of cache key,
 	 * either parameterized path clauses or lateral Vars.  No cache key sounds
@@ -608,7 +696,8 @@ get_memoize_path(PlannerInfo *root, RelOptInfo *innerrel,
 	 */
 	if ((inner_path->param_info == NULL ||
 		 inner_path->param_info->ppi_clauses == NIL) &&
-		innerrel->lateral_vars == NIL)
+		innerrel->lateral_vars == NIL &&
+		ph_lateral_vars == NIL)
 		return NULL;
 
 	/*
@@ -695,6 +784,7 @@ get_memoize_path(PlannerInfo *root, RelOptInfo *innerrel,
 									outerrel->top_parent ?
 									outerrel->top_parent : outerrel,
 									innerrel,
+									ph_lateral_vars,
 									&param_exprs,
 									&hash_operators,
 									&binary_mode))
diff --git a/src/test/regress/expected/memoize.out b/src/test/regress/expected/memoize.out
index 0fd103c06b..4c2a37bf97 100644
--- a/src/test/regress/expected/memoize.out
+++ b/src/test/regress/expected/memoize.out
@@ -129,6 +129,69 @@ WHERE t1.unique1 < 10;
     20 | 0.50000000000000000000
 (1 row)
 
+-- Try with LATERAL references within PlaceHolderVars at a baserel
+SELECT explain_memoize('
+SELECT COUNT(*), AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty AS c1, t2.unique1 AS c2 FROM tenk1 t2) s ON TRUE
+WHERE s.c1 = s.c2 AND t1.unique1 < 1000;', false);
+                                      explain_memoize                                      
+-------------------------------------------------------------------------------------------
+ Aggregate (actual rows=1 loops=N)
+   ->  Nested Loop (actual rows=1000 loops=N)
+         ->  Seq Scan on tenk1 t1 (actual rows=1000 loops=N)
+               Filter: (unique1 < 1000)
+               Rows Removed by Filter: 9000
+         ->  Memoize (actual rows=1 loops=N)
+               Cache Key: t1.twenty
+               Cache Mode: binary
+               Hits: 980  Misses: 20  Evictions: Zero  Overflows: 0  Memory Usage: NkB
+               ->  Index Only Scan using tenk1_unique1 on tenk1 t2 (actual rows=1 loops=N)
+                     Filter: (t1.twenty = unique1)
+                     Rows Removed by Filter: 9999
+                     Heap Fetches: N
+(13 rows)
+
+-- And check we get the expected results.
+SELECT COUNT(*), AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty AS c1, t2.unique1 AS c2 FROM tenk1 t2) s ON TRUE
+WHERE s.c1 = s.c2 AND t1.unique1 < 1000;
+ count |        avg         
+-------+--------------------
+  1000 | 9.5000000000000000
+(1 row)
+
+-- Ensure we do not omit the cache keys from PlaceHolderVars
+SELECT explain_memoize('
+SELECT COUNT(*), AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty AS c1, t2.unique1 AS c2, t2.two FROM tenk1 t2) s
+ON t1.two = s.two
+WHERE s.c1 = s.c2 AND t1.unique1 < 1000;', false);
+                                    explain_memoize                                    
+---------------------------------------------------------------------------------------
+ Aggregate (actual rows=1 loops=N)
+   ->  Nested Loop (actual rows=1000 loops=N)
+         ->  Seq Scan on tenk1 t1 (actual rows=1000 loops=N)
+               Filter: (unique1 < 1000)
+               Rows Removed by Filter: 9000
+         ->  Memoize (actual rows=1 loops=N)
+               Cache Key: t1.two, t1.twenty
+               Cache Mode: binary
+               Hits: 980  Misses: 20  Evictions: Zero  Overflows: 0  Memory Usage: NkB
+               ->  Seq Scan on tenk1 t2 (actual rows=1 loops=N)
+                     Filter: ((t1.twenty = unique1) AND (t1.two = two))
+                     Rows Removed by Filter: 9999
+(12 rows)
+
+-- And check we get the expected results.
+SELECT COUNT(*), AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty AS c1, t2.unique1 AS c2, t2.two FROM tenk1 t2) s
+ON t1.two = s.two
+WHERE s.c1 = s.c2 AND t1.unique1 < 1000;
+ count |        avg         
+-------+--------------------
+  1000 | 9.5000000000000000
+(1 row)
+
 SET enable_mergejoin TO off;
 -- Test for varlena datatype with expr evaluation
 CREATE TABLE expr_key (x numeric, t text);
diff --git a/src/test/regress/sql/memoize.sql b/src/test/regress/sql/memoize.sql
index e00e1a94a8..ff1d771bd2 100644
--- a/src/test/regress/sql/memoize.sql
+++ b/src/test/regress/sql/memoize.sql
@@ -74,6 +74,30 @@ LATERAL (
 ON t1.two = t2.two
 WHERE t1.unique1 < 10;
 
+-- Try with LATERAL references within PlaceHolderVars at a baserel
+SELECT explain_memoize('
+SELECT COUNT(*), AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty AS c1, t2.unique1 AS c2 FROM tenk1 t2) s ON TRUE
+WHERE s.c1 = s.c2 AND t1.unique1 < 1000;', false);
+
+-- And check we get the expected results.
+SELECT COUNT(*), AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty AS c1, t2.unique1 AS c2 FROM tenk1 t2) s ON TRUE
+WHERE s.c1 = s.c2 AND t1.unique1 < 1000;
+
+-- Ensure we do not omit the cache keys from PlaceHolderVars
+SELECT explain_memoize('
+SELECT COUNT(*), AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty AS c1, t2.unique1 AS c2, t2.two FROM tenk1 t2) s
+ON t1.two = s.two
+WHERE s.c1 = s.c2 AND t1.unique1 < 1000;', false);
+
+-- And check we get the expected results.
+SELECT COUNT(*), AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty AS c1, t2.unique1 AS c2, t2.two FROM tenk1 t2) s
+ON t1.two = s.two
+WHERE s.c1 = s.c2 AND t1.unique1 < 1000;
+
 SET enable_mergejoin TO off;
 
 -- Test for varlena datatype with expr evaluation
-- 
2.43.0

