From a46d0b0d4dbf4f8474cbb8a5047ff955ceca5759 Mon Sep 17 00:00:00 2001
From: amit <amitlangote09@gmail.com>
Date: Thu, 18 Jul 2019 10:33:20 +0900
Subject: [PATCH 3/4] Fix partitionwise join to handle FULL JOINs correctly

---
 src/backend/optimizer/util/relnode.c         | 104 +++++++++++++++++----
 src/test/regress/expected/partition_join.out | 129 +++++++++++++++++++++++++++
 src/test/regress/sql/partition_join.sql      |  24 +++++
 3 files changed, 241 insertions(+), 16 deletions(-)

diff --git a/src/backend/optimizer/util/relnode.c b/src/backend/optimizer/util/relnode.c
index 07ece2f870..ac34aed0e0 100644
--- a/src/backend/optimizer/util/relnode.c
+++ b/src/backend/optimizer/util/relnode.c
@@ -72,6 +72,7 @@ static bool have_partkey_equi_join(RelOptInfo *joinrel,
 								   JoinType jointype, List *restrictlist);
 static int match_expr_to_partition_keys(Expr *expr, RelOptInfo *rel,
 										bool strict_op);
+static List *extract_coalesce_args(Expr *expr);
 
 
 /*
@@ -1964,6 +1965,8 @@ static int
 match_expr_to_partition_keys(Expr *expr, RelOptInfo *rel, bool strict_op)
 {
 	int			cnt;
+	int			matched = -1;
+	List	   *nullable_exprs;
 
 	/* This function should be called only for partitioned relations. */
 	Assert(rel->part_scheme);
@@ -1972,34 +1975,103 @@ match_expr_to_partition_keys(Expr *expr, RelOptInfo *rel, bool strict_op)
 	while (IsA(expr, RelabelType))
 		expr = (Expr *) (castNode(RelabelType, expr))->arg;
 
+	/* For PlaceHolderVars, refer to contained expression. */
+	if (IsA(expr, PlaceHolderVar))
+		expr = (castNode(PlaceHolderVar, expr))->phexpr;
+
+	/*
+	 * Extract the arguments from possibly nested COALESCE expressions.  Each
+	 * of these arguments could be null when joining, so these expressions are
+	 * called as such and are to be matched only with the nullable partition
+	 * keys.
+	 */
+	if (IsA(expr, CoalesceExpr))
+		nullable_exprs = extract_coalesce_args(expr);
+	else
+		/*
+		 * expr may or may not be nullable but add to the list anyway to
+		 * simplify the coding below.
+		 */
+		nullable_exprs = list_make1(expr);
+
 	for (cnt = 0; cnt < rel->part_scheme->partnatts; cnt++)
 	{
-		ListCell   *lc;
-
 		Assert(rel->partexprs);
-		foreach(lc, rel->partexprs[cnt])
+
+		/* Is the expression one of the non-nullable partition keys? */
+		if (list_member(rel->partexprs[cnt], expr))
 		{
-			if (equal(lfirst(lc), expr))
-				return cnt;
+			matched = cnt;
+			break;
 		}
 
+		/*
+		 * Nope, so check if it is one of the nullable keys.  Allowing
+		 * nullable keys won't work if the join operator is not strict,
+		 * because null partition keys may then join with rows from other
+		 * partitions.  XXX - would that ever be true if the operator is
+		 * already determined to be mergejoin- and hashjoin-able?
+		 */
 		if (!strict_op)
 			continue;
 
-		/*
-		 * If it's a strict equi-join a NULL partition key on one side will
-		 * not join a NULL partition key on the other side. So, rows with NULL
-		 * partition key from a partition on one side can not join with those
-		 * from a non-matching partition on the other side. So, search the
-		 * nullable partition keys as well.
-		 */
+		/* OK to match with nullable keys. */
 		Assert(rel->nullable_partexprs);
-		foreach(lc, rel->nullable_partexprs[cnt])
+
+		/* First rule out nullable_exprs containing non-key expressions. */
+		if (list_difference(nullable_exprs,
+							rel->nullable_partexprs[cnt]) != NIL)
+			continue;
+
+		if (list_intersection(rel->nullable_partexprs[cnt],
+							  nullable_exprs) != NIL)
 		{
-			if (equal(lfirst(lc), expr))
-				return cnt;
+			matched = cnt;
+			break;
 		}
 	}
 
-	return -1;
+	Assert(list_length(nullable_exprs) >= 1);
+	list_free(nullable_exprs);
+
+	return matched;
+}
+
+/*
+ * extract_coalesce_args
+ *		Extract all arguments from arbitrarily nested CoalesceExpr's
+ *
+ * Note: caller should free the List structure when done using it.
+ */
+static List *
+extract_coalesce_args(Expr *expr)
+{
+	List   *coalesce_args = NIL;
+
+	while (expr && IsA(expr, CoalesceExpr))
+	{
+		CoalesceExpr *cexpr = (CoalesceExpr *) expr;
+		ListCell *lc;
+
+		expr = NULL;
+		foreach(lc, cexpr->args)
+		{
+			Expr   *expr = lfirst(lc);
+
+			/* Remove any relabel decorations. */
+			while (IsA(expr, RelabelType))
+				expr = (Expr *) (castNode(RelabelType, expr))->arg;
+
+			/* For PlaceHolderVars, refer to contained expression. */
+			if (IsA(expr, PlaceHolderVar))
+				expr = (castNode(PlaceHolderVar, expr))->phexpr;
+
+			if (!IsA(expr, CoalesceExpr))
+				coalesce_args = lappend(coalesce_args, expr);
+		}
+
+		Assert(expr == NULL || IsA(expr, CoalesceExpr));
+	}
+
+	return coalesce_args;
 }
diff --git a/src/test/regress/expected/partition_join.out b/src/test/regress/expected/partition_join.out
index 975bf6765c..e8388eedb6 100644
--- a/src/test/regress/expected/partition_join.out
+++ b/src/test/regress/expected/partition_join.out
@@ -750,6 +750,135 @@ SELECT t1.a, t1.c, t2.b, t2.c, t3.a + t3.b, t3.c FROM (prt1 t1 LEFT JOIN prt2 t2
  550 | 0550 |     |      |     1100 | 0
 (12 rows)
 
+-- FULL JOIN with COALESCE expression
+SET enable_partitionwise_aggregate TO true;
+EXPLAIN (COSTS OFF)
+SELECT a, b FROM prt1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) USING (a, b)
+  WHERE a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+                                                                      QUERY PLAN                                                                       
+-------------------------------------------------------------------------------------------------------------------------------------------------------
+ Group
+   Group Key: (COALESCE(COALESCE(prt1_p1.a, p2.a), p3.a)), (COALESCE(COALESCE(prt1_p1.b, p2.b), p3.b))
+   ->  Merge Append
+         Sort Key: (COALESCE(COALESCE(prt1_p1.a, p2.a), p3.a)), (COALESCE(COALESCE(prt1_p1.b, p2.b), p3.b))
+         ->  Group
+               Group Key: (COALESCE(COALESCE(prt1_p1.a, p2.a), p3.a)), (COALESCE(COALESCE(prt1_p1.b, p2.b), p3.b))
+               ->  Sort
+                     Sort Key: (COALESCE(COALESCE(prt1_p1.a, p2.a), p3.a)), (COALESCE(COALESCE(prt1_p1.b, p2.b), p3.b))
+                     ->  Hash Full Join
+                           Hash Cond: ((COALESCE(prt1_p1.a, p2.a) = p3.a) AND (COALESCE(prt1_p1.b, p2.b) = p3.b))
+                           Filter: ((COALESCE(COALESCE(prt1_p1.a, p2.a), p3.a) >= 490) AND (COALESCE(COALESCE(prt1_p1.a, p2.a), p3.a) <= 510))
+                           ->  Hash Full Join
+                                 Hash Cond: ((prt1_p1.a = p2.a) AND (prt1_p1.b = p2.b))
+                                 ->  Seq Scan on prt1_p1
+                                 ->  Hash
+                                       ->  Seq Scan on prt2_p1 p2
+                           ->  Hash
+                                 ->  Seq Scan on prt2_p1 p3
+         ->  Group
+               Group Key: (COALESCE(COALESCE(prt1_p2.a, p2_1.a), p3_1.a)), (COALESCE(COALESCE(prt1_p2.b, p2_1.b), p3_1.b))
+               ->  Sort
+                     Sort Key: (COALESCE(COALESCE(prt1_p2.a, p2_1.a), p3_1.a)), (COALESCE(COALESCE(prt1_p2.b, p2_1.b), p3_1.b))
+                     ->  Hash Full Join
+                           Hash Cond: ((COALESCE(prt1_p2.a, p2_1.a) = p3_1.a) AND (COALESCE(prt1_p2.b, p2_1.b) = p3_1.b))
+                           Filter: ((COALESCE(COALESCE(prt1_p2.a, p2_1.a), p3_1.a) >= 490) AND (COALESCE(COALESCE(prt1_p2.a, p2_1.a), p3_1.a) <= 510))
+                           ->  Hash Full Join
+                                 Hash Cond: ((prt1_p2.a = p2_1.a) AND (prt1_p2.b = p2_1.b))
+                                 ->  Seq Scan on prt1_p2
+                                 ->  Hash
+                                       ->  Seq Scan on prt2_p2 p2_1
+                           ->  Hash
+                                 ->  Seq Scan on prt2_p2 p3_1
+         ->  Group
+               Group Key: (COALESCE(COALESCE(prt1_p3.a, p2_2.a), p3_2.a)), (COALESCE(COALESCE(prt1_p3.b, p2_2.b), p3_2.b))
+               ->  Sort
+                     Sort Key: (COALESCE(COALESCE(prt1_p3.a, p2_2.a), p3_2.a)), (COALESCE(COALESCE(prt1_p3.b, p2_2.b), p3_2.b))
+                     ->  Hash Full Join
+                           Hash Cond: ((COALESCE(prt1_p3.a, p2_2.a) = p3_2.a) AND (COALESCE(prt1_p3.b, p2_2.b) = p3_2.b))
+                           Filter: ((COALESCE(COALESCE(prt1_p3.a, p2_2.a), p3_2.a) >= 490) AND (COALESCE(COALESCE(prt1_p3.a, p2_2.a), p3_2.a) <= 510))
+                           ->  Hash Full Join
+                                 Hash Cond: ((prt1_p3.a = p2_2.a) AND (prt1_p3.b = p2_2.b))
+                                 ->  Seq Scan on prt1_p3
+                                 ->  Hash
+                                       ->  Seq Scan on prt2_p3 p2_2
+                           ->  Hash
+                                 ->  Seq Scan on prt2_p3 p3_2
+(46 rows)
+
+SELECT a, b FROM prt1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) USING (a, b)
+  WHERE a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+  a  | b  
+-----+----
+ 490 | 15
+ 492 | 17
+ 494 | 19
+ 495 | 20
+ 496 | 21
+ 498 | 23
+ 500 |  0
+ 501 |  1
+ 502 |  2
+ 504 |  4
+ 506 |  6
+ 507 |  7
+ 508 |  8
+ 510 | 10
+(14 rows)
+
+-- Manually written COALESCE expression containing non-key expression
+EXPLAIN (COSTS OFF)
+SELECT p1.a, p1.b FROM prt1 p1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) ON COALESCE(p2.b, p3.a) = p3.a
+  WHERE p1.a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+                                   QUERY PLAN                                   
+--------------------------------------------------------------------------------
+ Group
+   Group Key: p1.a, p1.b
+   ->  Sort
+         Sort Key: p1.a, p1.b
+         ->  Nested Loop Left Join
+               Join Filter: (COALESCE(p2.b, p3.a) = p3.a)
+               ->  Append
+                     ->  Hash Right Join
+                           Hash Cond: ((p2.a = p1.a) AND (p2.b = p1.b))
+                           ->  Seq Scan on prt2_p2 p2
+                           ->  Hash
+                                 ->  Seq Scan on prt1_p2 p1
+                                       Filter: ((a >= 490) AND (a <= 510))
+                     ->  Hash Right Join
+                           Hash Cond: ((p2_1.a = p1_1.a) AND (p2_1.b = p1_1.b))
+                           ->  Seq Scan on prt2_p3 p2_1
+                           ->  Hash
+                                 ->  Seq Scan on prt1_p3 p1_1
+                                       Filter: ((a >= 490) AND (a <= 510))
+               ->  Materialize
+                     ->  Append
+                           ->  Seq Scan on prt2_p1 p3
+                           ->  Seq Scan on prt2_p2 p3_1
+                           ->  Seq Scan on prt2_p3 p3_2
+(24 rows)
+
+SELECT p1.a, p1.b FROM prt1 p1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) ON COALESCE(p2.b, p3.a) = p3.a
+  WHERE p1.a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+  a  | b  
+-----+----
+ 490 | 15
+ 492 | 17
+ 494 | 19
+ 496 | 21
+ 498 | 23
+ 500 |  0
+ 502 |  2
+ 504 |  4
+ 506 |  6
+ 508 |  8
+ 510 | 10
+(11 rows)
+
+RESET enable_partitionwise_aggregate;
 -- Cases with non-nullable expressions in subquery results;
 -- make sure these go to null as expected
 EXPLAIN (COSTS OFF)
diff --git a/src/test/regress/sql/partition_join.sql b/src/test/regress/sql/partition_join.sql
index 92994b479b..9f68c5074c 100644
--- a/src/test/regress/sql/partition_join.sql
+++ b/src/test/regress/sql/partition_join.sql
@@ -145,6 +145,29 @@ EXPLAIN (COSTS OFF)
 SELECT t1.a, t1.c, t2.b, t2.c, t3.a + t3.b, t3.c FROM (prt1 t1 LEFT JOIN prt2 t2 ON t1.a = t2.b) RIGHT JOIN prt1_e t3 ON (t1.a = (t3.a + t3.b)/2) WHERE t3.c = 0 ORDER BY t1.a, t2.b, t3.a + t3.b;
 SELECT t1.a, t1.c, t2.b, t2.c, t3.a + t3.b, t3.c FROM (prt1 t1 LEFT JOIN prt2 t2 ON t1.a = t2.b) RIGHT JOIN prt1_e t3 ON (t1.a = (t3.a + t3.b)/2) WHERE t3.c = 0 ORDER BY t1.a, t2.b, t3.a + t3.b;
 
+-- FULL JOIN with COALESCE expression
+
+SET enable_partitionwise_aggregate TO true;
+
+EXPLAIN (COSTS OFF)
+SELECT a, b FROM prt1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) USING (a, b)
+  WHERE a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+SELECT a, b FROM prt1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) USING (a, b)
+  WHERE a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+
+-- Manually written COALESCE expression containing non-key expression
+EXPLAIN (COSTS OFF)
+SELECT p1.a, p1.b FROM prt1 p1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) ON COALESCE(p2.b, p3.a) = p3.a
+  WHERE p1.a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+SELECT p1.a, p1.b FROM prt1 p1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) ON COALESCE(p2.b, p3.a) = p3.a
+  WHERE p1.a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+
+RESET enable_partitionwise_aggregate;
+
 -- Cases with non-nullable expressions in subquery results;
 -- make sure these go to null as expected
 EXPLAIN (COSTS OFF)
@@ -285,6 +308,7 @@ EXPLAIN (COSTS OFF)
 SELECT avg(t1.a), avg(t2.b), avg(t3.a + t3.b), t1.c, t2.c, t3.c FROM pht1 t1, pht2 t2, pht1_e t3 WHERE t1.b = t2.b AND t1.c = t2.c AND ltrim(t3.c, 'A') = t1.c GROUP BY t1.c, t2.c, t3.c ORDER BY t1.c, t2.c, t3.c;
 SELECT avg(t1.a), avg(t2.b), avg(t3.a + t3.b), t1.c, t2.c, t3.c FROM pht1 t1, pht2 t2, pht1_e t3 WHERE t1.b = t2.b AND t1.c = t2.c AND ltrim(t3.c, 'A') = t1.c GROUP BY t1.c, t2.c, t3.c ORDER BY t1.c, t2.c, t3.c;
 
+
 -- test default partition behavior for range
 ALTER TABLE prt1 DETACH PARTITION prt1_p3;
 ALTER TABLE prt1 ATTACH PARTITION prt1_p3 DEFAULT;
-- 
2.11.0

