From b104b8562adbd19c785a5c2d5140931f844b1a55 Mon Sep 17 00:00:00 2001
From: Denis Hirn <denis.hirn@uni-tuebingen.de>
Date: Tue, 23 Mar 2021 12:58:36 +0100
Subject: [PATCH] Add SQL-standard multiple linear self-references in WITH
 RECURSIVE

The recursive term of a WITH RECURSIVE query must contain exactly one
self-reference. Multiple such self-references are only allowed if the WITH query
features multiple recursive terms, all of which are joined by UNION [ALL].

Without using explicit parentheses to change the precedence of the parsing,
a left-deep UNION [ALL] tree is created. This allows for multiple non-recursive
terms, and exactly one recursive term. To allow multiple non-recursive terms,
as well as multiple recursive terms, a different grouping of terms is required.
To fix this, tree rotation is added to checkWellFormedRecursion.

Example:
A, B, and C are arbitrary SelectStmt nodes and can be deeper nested UNION nodes.

A is a non-recursive term in the WITH RECURSIVE query. B, and C both contain a
recursive self-reference. The planner expects the top UNION node to contain
the non-recursive term in the larg, and the recursive term in the rarg.
Therefore, the tree shape on the left is invalid and would result in an error
message at the parsing stage. However, by rotating the tree to the right, this
problem can be solved so that the valid tree shape on the right side is created.

      UNION   --->   UNION
     /     \        /     \
   UNION    C      A    UNION
  /     \              /     \
 A       B            B       C

Effectively this re-parenthesizes the expression:
(A UNION B) UNION C ---> A UNION (B UNION C)
---
 doc/src/sgml/queries.sgml                |  38 +++++
 src/backend/executor/nodeWorktablescan.c |  21 ++-
 src/backend/parser/parse_cte.c           |  84 +++++++++-
 src/include/nodes/execnodes.h            |   1 +
 src/test/regress/expected/with.out       | 205 ++++++++++++++++++++++-
 src/test/regress/sql/with.sql            | 121 ++++++++++++-
 6 files changed, 450 insertions(+), 20 deletions(-)

diff --git a/doc/src/sgml/queries.sgml b/doc/src/sgml/queries.sgml
index 834b83b509..fcc9131006 100644
--- a/doc/src/sgml/queries.sgml
+++ b/doc/src/sgml/queries.sgml
@@ -2172,6 +2172,44 @@ GROUP BY sub_part
 </programlisting>
   </para>
 
+ <sect3 id="queries-recursive-terms">
+  <title>Multiple Recursive Terms</title>
+
+  <para>
+   The <firstterm>recursive term</firstterm> of a recursive <literal>WITH</literal>
+   query may contain exactly one self-reference. Multiple such self-references
+   are only allowed if the <literal>WITH</literal> query features multiple recursive
+   terms, all of which are joined by <literal>UNION</literal> (or <literal>UNION ALL</literal>).
+   Consider the example query below which computes the transitive closure of travel connections
+   from <literal>'my_city'</literal>. Flight and train routes are found in separate tables
+   <literal>flights(source,dest,carrier)</literal> and <literal>trains(source,dest)</literal>.
+   Two recursive terms, joined by <literal>UNION ALL</literal>, use two recursive
+   self-references to <literal>connections</literal> to find viable onward jorneys,
+   in both the <literal>flights</literal> and <literal>trains</literal> tables.
+
+<programlisting>
+WITH RECURSIVE connections(source, dest, carrier) AS (
+     SELECT f.source, f.dest, f.carrier
+      FROM flights f
+      WHERE f.source = 'my_city'
+    UNION ALL
+      SELECT r.source, r.dest, 'Rail' AS carrier
+      FROM trains r
+      WHERE r.source = 'my_city'
+  UNION ALL -- two recursive terms below
+     SELECT c.source, f.dest, f.carrier
+      FROM <emphasis>connections c</emphasis>, flights f
+      WHERE c.dest = f.source
+    UNION ALL
+      SELECT c.source, r.dest, 'Rail' AS carrier
+      FROM <emphasis>connections c</emphasis>, trains r
+      WHERE c.dest = r.source
+)
+SELECT * FROM connections;
+</programlisting>
+  </para>
+ </sect3>
+
   <sect3 id="queries-with-search">
    <title>Search Order</title>
 
diff --git a/src/backend/executor/nodeWorktablescan.c b/src/backend/executor/nodeWorktablescan.c
index 91d3bf376b..c73d23d1de 100644
--- a/src/backend/executor/nodeWorktablescan.c
+++ b/src/backend/executor/nodeWorktablescan.c
@@ -42,10 +42,6 @@ WorkTableScanNext(WorkTableScanState *node)
 	 * worktable plan node, since it cannot appear high enough in the plan
 	 * tree of a scrollable cursor to be exposed to a backward-scan
 	 * requirement.  So it's not worth expending effort to support it.
-	 *
-	 * Note: we are also assuming that this node is the only reader of the
-	 * worktable.  Therefore, we don't need a private read pointer for the
-	 * tuplestore, nor do we need to tell tuplestore_gettupleslot to copy.
 	 */
 	Assert(ScanDirectionIsForward(node->ss.ps.state->es_direction));
 
@@ -55,6 +51,7 @@ WorkTableScanNext(WorkTableScanState *node)
 	 * Get the next tuple from tuplestore. Return NULL if no more tuples.
 	 */
 	slot = node->ss.ss_ScanTupleSlot;
+	tuplestore_select_read_pointer(tuplestorestate, node->readptr);
 	(void) tuplestore_gettupleslot(tuplestorestate, true, false, slot);
 	return slot;
 }
@@ -99,7 +96,14 @@ ExecWorkTableScan(PlanState *pstate)
 		Assert(!param->isnull);
 		node->rustate = castNode(RecursiveUnionState, DatumGetPointer(param->value));
 		Assert(node->rustate);
+		/*
+		 * Allocate a unique read pointer for each worktable scan.
+		 */
+		node->readptr = tuplestore_alloc_read_pointer(node->rustate->working_table, 0);
+		tuplestore_copy_read_pointer(node->rustate->working_table, 0, node->readptr);
+		tuplestore_rescan(node->rustate->working_table);
 
+		Assert(node->readptr != -1);
 		/*
 		 * The scan tuple type (ie, the rowtype we expect to find in the work
 		 * table) is the same as the result rowtype of the ancestor
@@ -147,6 +151,7 @@ ExecInitWorkTableScan(WorkTableScan *node, EState *estate, int eflags)
 	scanstate->ss.ps.plan = (Plan *) node;
 	scanstate->ss.ps.state = estate;
 	scanstate->ss.ps.ExecProcNode = ExecWorkTableScan;
+	scanstate->readptr = -1;	/* we'll set this later */
 	scanstate->rustate = NULL;	/* we'll set this later */
 
 	/*
@@ -219,5 +224,13 @@ ExecReScanWorkTableScan(WorkTableScanState *node)
 
 	/* No need (or way) to rescan if ExecWorkTableScan not called yet */
 	if (node->rustate)
+	{
+		/* Make sure to select the initial read pointer */
+		tuplestore_select_read_pointer(node->rustate->working_table, 0);
+		tuplestore_rescan(node->rustate->working_table);
+		/* Reallocate a unique read pointer. */
+		node->readptr = tuplestore_alloc_read_pointer(node->rustate->working_table, 0);
+		tuplestore_copy_read_pointer(node->rustate->working_table, 0, node->readptr);
 		tuplestore_rescan(node->rustate->working_table);
+	}
 }
diff --git a/src/backend/parser/parse_cte.c b/src/backend/parser/parse_cte.c
index f6ae96333a..b78a630367 100644
--- a/src/backend/parser/parse_cte.c
+++ b/src/backend/parser/parse_cte.c
@@ -80,6 +80,8 @@ typedef struct CteState
 	/* working state for checkWellFormedRecursion walk only: */
 	int			selfrefcount;	/* number of self-references detected */
 	RecursionContext context;	/* context to allow or disallow self-ref */
+	bool		root_is_union;	/* root of non-recursive term is SETOP_UNION */
+	bool		rotate;			/* self-reference in non-recursive term detected */
 } CteState;
 
 
@@ -853,11 +855,51 @@ checkWellFormedRecursion(CteState *cstate)
 					 parser_errposition(cstate->pstate, cte->location)));
 
 		/* The left-hand operand mustn't contain self-reference at all */
-		cstate->curitem = i;
-		cstate->innerwiths = NIL;
-		cstate->selfrefcount = 0;
-		cstate->context = RECURSION_NONRECURSIVETERM;
-		checkWellFormedRecursionWalker((Node *) stmt->larg, cstate);
+		do {
+			cstate->curitem = i;
+			cstate->innerwiths = NIL;
+			cstate->selfrefcount = 0;
+			cstate->context = RECURSION_NONRECURSIVETERM;
+			cstate->root_is_union = stmt->larg->op == SETOP_UNION;
+			cstate->rotate = false;
+			checkWellFormedRecursionWalker((Node *) stmt->larg, cstate);
+
+			/* The well formed recursion check failed and might have set the
+			 * rotate flag.
+			 */
+			if (cstate->root_is_union && cstate->rotate)
+			{
+				/* By default, the parser creates a left-deep UNION [ALL] tree.
+				 * Usage of multiple recursive terms requires a different
+				 * grouping of the terms. We have to perform a tree rotation to
+				 * the right to fix the structure. Example:
+				 *
+				 * A is a non-recursive term. B, and C both contain a recursive
+				 * reference.
+				 *
+				 *      UNION   --->   UNION
+				 *     /     \        /     \
+				 *   UNION    C      A    UNION
+				 *  /     \              /     \
+				 * A       B            B       C
+				 *
+				 * NOTE: A, B, and C are arbitrary SelectStmt nodes.
+				 */
+				SelectStmt *rarg = stmt->rarg;
+				bool all = stmt->larg->all;
+
+				stmt->rarg = stmt->larg;
+				stmt->larg = stmt->larg->larg;
+				stmt->rarg->larg = stmt->rarg->rarg;
+				stmt->rarg->rarg = rarg;
+				/*
+				 * Make sure that the UNION [ALL] flag is set correctly.
+				 */
+				stmt->rarg->all = stmt->all;
+				stmt->all = all;
+			}
+		} while (cstate->root_is_union && cstate->rotate);
+
 		Assert(cstate->innerwiths == NIL);
 
 		/* Right-hand operand should contain one reference in a valid place */
@@ -955,6 +997,12 @@ checkWellFormedRecursionWalker(Node *node, CteState *cstate)
 			mycte = cstate->items[cstate->curitem].cte;
 			if (strcmp(rv->relname, mycte->ctename) == 0)
 			{
+				/* Found a recursive reference, but it might be fixable */
+				if (cstate->context == RECURSION_NONRECURSIVETERM && cstate->root_is_union)
+				{
+					cstate->rotate = true;
+					return false;
+				}
 				/* Found a recursive reference to the active query */
 				if (cstate->context != RECURSION_OK)
 					ereport(ERROR,
@@ -1117,9 +1165,29 @@ checkWellFormedSelectStmt(SelectStmt *stmt, CteState *cstate)
 		{
 			case SETOP_NONE:
 			case SETOP_UNION:
-				raw_expression_tree_walker((Node *) stmt,
-										   checkWellFormedRecursionWalker,
-										   (void *) cstate);
+				/* check selfrefcount for each recursive member individually */
+				if (stmt->larg != NULL && stmt->rarg != NULL)
+				{
+					int curr_selfrefcount;
+					int selfrefcount = cstate->selfrefcount;
+
+					checkWellFormedRecursionWalker((Node *) stmt->larg, cstate);
+
+					/* Restore selfrefcount to allow multiple linear recursive references */
+					curr_selfrefcount = cstate->selfrefcount;
+					cstate->selfrefcount = selfrefcount;
+
+					checkWellFormedRecursionWalker((Node *) stmt->rarg, cstate);
+					/* Recursive anchors can contain recursive references, but don't have to. */
+					if (cstate->selfrefcount < curr_selfrefcount)
+						cstate->selfrefcount = curr_selfrefcount;
+				}
+				else
+				{
+					raw_expression_tree_walker((Node *) stmt,
+											   checkWellFormedRecursionWalker,
+											   (void *) cstate);
+				}
 				break;
 			case SETOP_INTERSECT:
 				if (stmt->all)
diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h
index 37cb4f3d59..b53844437c 100644
--- a/src/include/nodes/execnodes.h
+++ b/src/include/nodes/execnodes.h
@@ -1843,6 +1843,7 @@ typedef struct NamedTuplestoreScanState
 typedef struct WorkTableScanState
 {
 	ScanState	ss;				/* its first field is NodeTag */
+	int		readptr;			/* index of my tuplestore read pointer */
 	RecursiveUnionState *rustate;
 } WorkTableScanState;
 
diff --git a/src/test/regress/expected/with.out b/src/test/regress/expected/with.out
index 3523a7dcc1..008106b5ea 100644
--- a/src/test/regress/expected/with.out
+++ b/src/test/regress/expected/with.out
@@ -235,6 +235,199 @@ WITH RECURSIVE outermost(x) AS (
  7
 (7 rows)
 
+-- Test multiple self-references in different recursive anchors
+WITH RECURSIVE foo(i) AS
+    (values (1)
+    UNION ALL
+       (SELECT i+1 FROM foo WHERE i < 5
+          UNION ALL
+       SELECT i+1 FROM foo WHERE i < 2)
+) SELECT * FROM foo;
+ i 
+---
+ 1
+ 2
+ 2
+ 3
+ 3
+ 4
+ 4
+ 5
+ 5
+(9 rows)
+
+-- No explicit parentheses
+WITH RECURSIVE foo(i) AS
+    (values (1)
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 5
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+) SELECT * FROM foo;
+ i 
+---
+ 1
+ 2
+ 2
+ 3
+ 3
+ 4
+ 4
+ 5
+ 5
+(9 rows)
+
+WITH RECURSIVE foo(i) AS
+    (values (1)
+    UNION ALL
+	   SELECT * FROM
+       (SELECT i+1 FROM foo WHERE i < 5
+          UNION ALL
+       SELECT i+1 FROM foo WHERE i < 2) AS t
+) SELECT * FROM foo;
+ i 
+---
+ 1
+ 2
+ 2
+ 3
+ 3
+ 4
+ 4
+ 5
+ 5
+(9 rows)
+
+CREATE TEMP TABLE flights (
+  source  TEXT,
+  dest    TEXT,
+  carrier TEXT
+);
+INSERT INTO flights VALUES
+('A', 'B', 'C1'),
+('A', 'C', 'C2'),
+('A', 'D', 'C1'),
+('B', 'D', 'C3'),
+('C', 'E', 'C3')
+;
+CREATE TEMP TABLE trains (
+  source TEXT,
+  dest   TEXT
+);
+INSERT INTO trains VALUES
+('B', 'C'),
+('A', 'E'),
+('C', 'E')
+;
+WITH RECURSIVE connections(source, dest, carrier) AS (
+     SELECT f.source, f.dest, f.carrier
+      FROM flights f
+      WHERE f.source = 'A'
+    UNION ALL
+      SELECT r.source, r.dest, 'Rail' AS carrier
+      FROM trains r
+      WHERE r.source = 'A'
+  UNION ALL -- two recursive terms below
+     SELECT c.source, f.dest, f.carrier
+      FROM connections c, flights f
+      WHERE c.dest = f.source
+    UNION ALL
+      SELECT c.source, r.dest, 'Rail' AS carrier
+      FROM connections c, trains r
+      WHERE c.dest = r.source
+)
+SELECT * FROM connections;
+ source | dest | carrier 
+--------+------+---------
+ A      | B    | C1
+ A      | C    | C2
+ A      | D    | C1
+ A      | E    | Rail
+ A      | D    | C3
+ A      | E    | C3
+ A      | C    | Rail
+ A      | E    | Rail
+ A      | E    | C3
+ A      | E    | Rail
+(10 rows)
+
+-- Test mixed UNION and UNION ALL with multiple recursive anchors
+WITH RECURSIVE t(x) AS
+(
+  SELECT 2
+    UNION
+  SELECT 1
+    UNION ALL
+  SELECT x+1
+  FROM   t
+  WHERE  x < 4
+    UNION
+  SELECT x*2
+  FROM   t
+  WHERE  x >= 4 AND x < 8
+    UNION ALL
+  SELECT x+1
+  FROM   t
+  WHERE  x >= 4 AND x < 8
+) SELECT * FROM t;
+ x  
+----
+  1
+  2
+  2
+  3
+  4
+  3
+  5
+  4
+  8
+  8
+ 10
+  5
+  6
+  7
+ 10
+  6
+ 12
+  8
+  7
+ 12
+ 14
+ 14
+  8
+(23 rows)
+
+WITH RECURSIVE foo(i) AS
+    (values (1)
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+) SELECT * FROM foo;
+ i 
+---
+ 1
+ 2
+(2 rows)
+
 --
 -- Some examples with a tree
 --
@@ -1795,24 +1988,24 @@ LINE 2:   x (id) AS (SELECT 1 UNION ALL SELECT id+1 FROM y WHERE id ...
 WITH RECURSIVE foo(i) AS
     (values (1)
     UNION ALL
-       (SELECT i+1 FROM foo WHERE i < 10
+       (SELECT x.i+y.i FROM foo AS x, foo AS y WHERE x.i < 10
           UNION ALL
        SELECT i+1 FROM foo WHERE i < 5)
 ) SELECT * FROM foo;
 ERROR:  recursive reference to query "foo" must not appear more than once
-LINE 6:        SELECT i+1 FROM foo WHERE i < 5)
-                               ^
+LINE 4:        (SELECT x.i+y.i FROM foo AS x, foo AS y WHERE x.i < 1...
+                                              ^
 WITH RECURSIVE foo(i) AS
     (values (1)
     UNION ALL
 	   SELECT * FROM
-       (SELECT i+1 FROM foo WHERE i < 10
+       (SELECT x.i+y.i FROM foo AS x, foo as y WHERE x.i < 10
           UNION ALL
        SELECT i+1 FROM foo WHERE i < 5) AS t
 ) SELECT * FROM foo;
 ERROR:  recursive reference to query "foo" must not appear more than once
-LINE 7:        SELECT i+1 FROM foo WHERE i < 5) AS t
-                               ^
+LINE 5:        (SELECT x.i+y.i FROM foo AS x, foo as y WHERE x.i < 1...
+                                              ^
 WITH RECURSIVE foo(i) AS
     (values (1)
     UNION ALL
diff --git a/src/test/regress/sql/with.sql b/src/test/regress/sql/with.sql
index 8b213ee408..235726e15b 100644
--- a/src/test/regress/sql/with.sql
+++ b/src/test/regress/sql/with.sql
@@ -139,6 +139,123 @@ WITH RECURSIVE outermost(x) AS (
  )
  SELECT * FROM outermost ORDER BY 1;
 
+-- Test multiple self-references in different recursive anchors
+WITH RECURSIVE foo(i) AS
+    (values (1)
+    UNION ALL
+       (SELECT i+1 FROM foo WHERE i < 5
+          UNION ALL
+       SELECT i+1 FROM foo WHERE i < 2)
+) SELECT * FROM foo;
+
+-- No explicit parentheses
+WITH RECURSIVE foo(i) AS
+    (values (1)
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 5
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+) SELECT * FROM foo;
+
+WITH RECURSIVE foo(i) AS
+    (values (1)
+    UNION ALL
+	   SELECT * FROM
+       (SELECT i+1 FROM foo WHERE i < 5
+          UNION ALL
+       SELECT i+1 FROM foo WHERE i < 2) AS t
+) SELECT * FROM foo;
+
+CREATE TEMP TABLE flights (
+  source  TEXT,
+  dest    TEXT,
+  carrier TEXT
+);
+
+INSERT INTO flights VALUES
+('A', 'B', 'C1'),
+('A', 'C', 'C2'),
+('A', 'D', 'C1'),
+('B', 'D', 'C3'),
+('C', 'E', 'C3')
+;
+
+CREATE TEMP TABLE trains (
+  source TEXT,
+  dest   TEXT
+);
+
+INSERT INTO trains VALUES
+('B', 'C'),
+('A', 'E'),
+('C', 'E')
+;
+
+WITH RECURSIVE connections(source, dest, carrier) AS (
+     SELECT f.source, f.dest, f.carrier
+      FROM flights f
+      WHERE f.source = 'A'
+    UNION ALL
+      SELECT r.source, r.dest, 'Rail' AS carrier
+      FROM trains r
+      WHERE r.source = 'A'
+  UNION ALL -- two recursive terms below
+     SELECT c.source, f.dest, f.carrier
+      FROM connections c, flights f
+      WHERE c.dest = f.source
+    UNION ALL
+      SELECT c.source, r.dest, 'Rail' AS carrier
+      FROM connections c, trains r
+      WHERE c.dest = r.source
+)
+SELECT * FROM connections;
+
+-- Test mixed UNION and UNION ALL with multiple recursive anchors
+WITH RECURSIVE t(x) AS
+(
+  SELECT 2
+    UNION
+  SELECT 1
+    UNION ALL
+  SELECT x+1
+  FROM   t
+  WHERE  x < 4
+    UNION
+  SELECT x*2
+  FROM   t
+  WHERE  x >= 4 AND x < 8
+    UNION ALL
+  SELECT x+1
+  FROM   t
+  WHERE  x >= 4 AND x < 8
+) SELECT * FROM t;
+
+WITH RECURSIVE foo(i) AS
+    (values (1)
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION ALL
+    SELECT i+1 FROM foo WHERE i < 2
+      UNION
+    SELECT i+1 FROM foo WHERE i < 2
+) SELECT * FROM foo;
+
 --
 -- Some examples with a tree
 --
@@ -847,7 +964,7 @@ SELECT * FROM x;
 WITH RECURSIVE foo(i) AS
     (values (1)
     UNION ALL
-       (SELECT i+1 FROM foo WHERE i < 10
+       (SELECT x.i+y.i FROM foo AS x, foo AS y WHERE x.i < 10
           UNION ALL
        SELECT i+1 FROM foo WHERE i < 5)
 ) SELECT * FROM foo;
@@ -856,7 +973,7 @@ WITH RECURSIVE foo(i) AS
     (values (1)
     UNION ALL
 	   SELECT * FROM
-       (SELECT i+1 FROM foo WHERE i < 10
+       (SELECT x.i+y.i FROM foo AS x, foo as y WHERE x.i < 10
           UNION ALL
        SELECT i+1 FROM foo WHERE i < 5) AS t
 ) SELECT * FROM foo;
-- 
2.32.0

