From 3476856ac3989ff596ba70066a20f8ea46762295 Mon Sep 17 00:00:00 2001
From: Richard Guo <guofenglinux@gmail.com>
Date: Mon, 5 Dec 2022 17:34:40 +0800
Subject: [PATCH v1] Fix MemoizePath for partitionwise join

---
 src/backend/optimizer/path/joinpath.c |  3 ++-
 src/backend/optimizer/util/pathnode.c |  1 +
 src/test/regress/expected/memoize.out | 38 +++++++++++++++++++++++++++
 src/test/regress/sql/memoize.sql      | 17 ++++++++++++
 4 files changed, 58 insertions(+), 1 deletion(-)

diff --git a/src/backend/optimizer/path/joinpath.c b/src/backend/optimizer/path/joinpath.c
index 2a3f0ab7bf..4d09881259 100644
--- a/src/backend/optimizer/path/joinpath.c
+++ b/src/backend/optimizer/path/joinpath.c
@@ -597,7 +597,8 @@ get_memoize_path(PlannerInfo *root, RelOptInfo *innerrel,
 	/* Check if we have hash ops for each parameter to the path */
 	if (paraminfo_get_equal_hashops(root,
 									inner_path->param_info,
-									outerrel,
+									outerrel->top_parent ?
+									outerrel->top_parent : outerrel,
 									innerrel,
 									&param_exprs,
 									&hash_operators,
diff --git a/src/backend/optimizer/util/pathnode.c b/src/backend/optimizer/util/pathnode.c
index 5379c087a1..55deee555a 100644
--- a/src/backend/optimizer/util/pathnode.c
+++ b/src/backend/optimizer/util/pathnode.c
@@ -4246,6 +4246,7 @@ do { \
 
 				FLAT_COPY_PATH(mpath, path, MemoizePath);
 				REPARAMETERIZE_CHILD_PATH(mpath->subpath);
+				ADJUST_CHILD_ATTRS(mpath->param_exprs);
 				new_path = (Path *) mpath;
 			}
 			break;
diff --git a/src/test/regress/expected/memoize.out b/src/test/regress/expected/memoize.out
index 00438eb1ea..84957d6587 100644
--- a/src/test/regress/expected/memoize.out
+++ b/src/test/regress/expected/memoize.out
@@ -157,6 +157,44 @@ SELECT * FROM flt f1 INNER JOIN flt f2 ON f1.f >= f2.f;', false);
 (10 rows)
 
 DROP TABLE flt;
+CREATE TABLE prt (a int) PARTITION BY RANGE(a);
+CREATE TABLE prt_p1 PARTITION OF prt FOR VALUES FROM (0) TO (10);
+CREATE TABLE prt_p2 PARTITION OF prt FOR VALUES FROM (10) TO (20);
+INSERT INTO prt VALUES (0), (0), (0), (0);
+INSERT INTO prt VALUES (10), (10), (10), (10);
+CREATE INDEX iprt_p1_a ON prt_p1 (a);
+CREATE INDEX iprt_p2_a ON prt_p2 (a);
+ANALYZE prt;
+SET enable_partitionwise_join TO on;
+-- Ensure memoize works for partitionwise join
+SELECT explain_memoize('
+SELECT * FROM prt t1 INNER JOIN prt t2 ON t1.a = t2.a;', false);
+                                     explain_memoize                                      
+------------------------------------------------------------------------------------------
+ Append (actual rows=32 loops=N)
+   ->  Nested Loop (actual rows=16 loops=N)
+         ->  Index Only Scan using iprt_p1_a on prt_p1 t1_1 (actual rows=4 loops=N)
+               Heap Fetches: N
+         ->  Memoize (actual rows=4 loops=N)
+               Cache Key: t1_1.a
+               Cache Mode: logical
+               Hits: 3  Misses: 1  Evictions: Zero  Overflows: 0  Memory Usage: NkB
+               ->  Index Only Scan using iprt_p1_a on prt_p1 t2_1 (actual rows=4 loops=N)
+                     Index Cond: (a = t1_1.a)
+                     Heap Fetches: N
+   ->  Nested Loop (actual rows=16 loops=N)
+         ->  Index Only Scan using iprt_p2_a on prt_p2 t1_2 (actual rows=4 loops=N)
+               Heap Fetches: N
+         ->  Memoize (actual rows=4 loops=N)
+               Cache Key: t1_2.a
+               Cache Mode: logical
+               Hits: 3  Misses: 1  Evictions: Zero  Overflows: 0  Memory Usage: NkB
+               ->  Index Only Scan using iprt_p2_a on prt_p2 t2_2 (actual rows=4 loops=N)
+                     Index Cond: (a = t1_2.a)
+                     Heap Fetches: N
+(21 rows)
+
+DROP TABLE prt;
 -- Exercise Memoize in binary mode with a large fixed width type and a
 -- varlena type.
 CREATE TABLE strtest (n name, t text);
diff --git a/src/test/regress/sql/memoize.sql b/src/test/regress/sql/memoize.sql
index 0979bcdf76..6d11d18eda 100644
--- a/src/test/regress/sql/memoize.sql
+++ b/src/test/regress/sql/memoize.sql
@@ -84,6 +84,23 @@ SELECT * FROM flt f1 INNER JOIN flt f2 ON f1.f >= f2.f;', false);
 
 DROP TABLE flt;
 
+CREATE TABLE prt (a int) PARTITION BY RANGE(a);
+CREATE TABLE prt_p1 PARTITION OF prt FOR VALUES FROM (0) TO (10);
+CREATE TABLE prt_p2 PARTITION OF prt FOR VALUES FROM (10) TO (20);
+INSERT INTO prt VALUES (0), (0), (0), (0);
+INSERT INTO prt VALUES (10), (10), (10), (10);
+CREATE INDEX iprt_p1_a ON prt_p1 (a);
+CREATE INDEX iprt_p2_a ON prt_p2 (a);
+ANALYZE prt;
+
+SET enable_partitionwise_join TO on;
+
+-- Ensure memoize works for partitionwise join
+SELECT explain_memoize('
+SELECT * FROM prt t1 INNER JOIN prt t2 ON t1.a = t2.a;', false);
+
+DROP TABLE prt;
+
 -- Exercise Memoize in binary mode with a large fixed width type and a
 -- varlena type.
 CREATE TABLE strtest (n name, t text);
-- 
2.31.0

