From 090ddaeaa87313ea90c22da43cbe34ac0bee4a12 Mon Sep 17 00:00:00 2001
From: Denis Hirn <denis.hirn@uni-tuebingen.de>
Date: Tue, 23 Mar 2021 12:58:36 +0100
Subject: [PATCH] Allow multiple linear recursive self-references

---
 doc/src/sgml/queries.sgml                | 38 +++++++++++++++
 src/backend/executor/nodeWorktablescan.c | 21 +++++++--
 src/backend/parser/parse_cte.c           | 60 ++++++++++++++++++++----
 src/include/nodes/execnodes.h            |  1 +
 src/test/regress/expected/with.out       | 54 ++++++++++++++++++---
 src/test/regress/sql/with.sql            | 22 ++++++++-
 6 files changed, 176 insertions(+), 20 deletions(-)

diff --git a/doc/src/sgml/queries.sgml b/doc/src/sgml/queries.sgml
index 834b83b509..7b97604c85 100644
--- a/doc/src/sgml/queries.sgml
+++ b/doc/src/sgml/queries.sgml
@@ -2172,6 +2172,44 @@ GROUP BY sub_part
 </programlisting>
   </para>
 
+ <sect3 id="queries-recursive-terms">
+  <title>Multiple Recursive Terms</title>
+
+  <para>
+   The <firstterm>recursive term</firstterm> of a recursive <literal>WITH</literal>
+   query may contain exactly one self-reference. Multiple such self-references
+   are only allowed if the <literal>WITH</literal> query features multiple recursive
+   terms, all of which are joined by <literal>UNION</literal> (or <literal>UNION ALL</literal>).
+   Consider the example query below which computes the transitive closure of travel connections
+   from <literal>'my_city'</literal>. Flight and train routes are found in separate tables
+   <literal>flights(source,dest,carrier)</literal> and <literal>trains(source,dest)</literal>.
+   Two recursive terms, joined by <literal>UNION ALL</literal>, use two recursive
+   self-references to <literal>connections</literal> to find viable onward jorneys,
+   in both the <literal>flights</literal> and <literal>trains</literal> tables.
+
+<programlisting>
+WITH RECURSIVE connections(source, dest, carrier) AS (
+     SELECT f.source, f.dest, f.carrier
+      FROM flights f
+      WHERE f.source = 'my_city'
+    UNION ALL
+      SELECT r.source, r.dest, 'Rail' AS carrier
+      FROM trains r
+      WHERE r.source = 'my_city'
+  UNION ALL -- two recursive terms below
+     SELECT c.source, f.dest, f.carrier
+      FROM <emphasis>connections c</emphasis>, flights f
+      WHERE c.dest = f.source
+    UNION ALL
+      SELECT c.source, r.dest, 'Rail' AS carrier
+      FROM <emphasis>connections c</emphasis>, trains r
+      WHERE c.dest = r.source
+)
+SELECT *  FROM tc;
+</programlisting>
+  </para>
+ </sect3>
+
   <sect3 id="queries-with-search">
    <title>Search Order</title>
 
diff --git a/src/backend/executor/nodeWorktablescan.c b/src/backend/executor/nodeWorktablescan.c
index 91d3bf376b..c73d23d1de 100644
--- a/src/backend/executor/nodeWorktablescan.c
+++ b/src/backend/executor/nodeWorktablescan.c
@@ -42,10 +42,6 @@ WorkTableScanNext(WorkTableScanState *node)
 	 * worktable plan node, since it cannot appear high enough in the plan
 	 * tree of a scrollable cursor to be exposed to a backward-scan
 	 * requirement.  So it's not worth expending effort to support it.
-	 *
-	 * Note: we are also assuming that this node is the only reader of the
-	 * worktable.  Therefore, we don't need a private read pointer for the
-	 * tuplestore, nor do we need to tell tuplestore_gettupleslot to copy.
 	 */
 	Assert(ScanDirectionIsForward(node->ss.ps.state->es_direction));
 
@@ -55,6 +51,7 @@ WorkTableScanNext(WorkTableScanState *node)
 	 * Get the next tuple from tuplestore. Return NULL if no more tuples.
 	 */
 	slot = node->ss.ss_ScanTupleSlot;
+	tuplestore_select_read_pointer(tuplestorestate, node->readptr);
 	(void) tuplestore_gettupleslot(tuplestorestate, true, false, slot);
 	return slot;
 }
@@ -99,7 +96,14 @@ ExecWorkTableScan(PlanState *pstate)
 		Assert(!param->isnull);
 		node->rustate = castNode(RecursiveUnionState, DatumGetPointer(param->value));
 		Assert(node->rustate);
+		/*
+		 * Allocate a unique read pointer for each worktable scan.
+		 */
+		node->readptr = tuplestore_alloc_read_pointer(node->rustate->working_table, 0);
+		tuplestore_copy_read_pointer(node->rustate->working_table, 0, node->readptr);
+		tuplestore_rescan(node->rustate->working_table);
 
+		Assert(node->readptr != -1);
 		/*
 		 * The scan tuple type (ie, the rowtype we expect to find in the work
 		 * table) is the same as the result rowtype of the ancestor
@@ -147,6 +151,7 @@ ExecInitWorkTableScan(WorkTableScan *node, EState *estate, int eflags)
 	scanstate->ss.ps.plan = (Plan *) node;
 	scanstate->ss.ps.state = estate;
 	scanstate->ss.ps.ExecProcNode = ExecWorkTableScan;
+	scanstate->readptr = -1;	/* we'll set this later */
 	scanstate->rustate = NULL;	/* we'll set this later */
 
 	/*
@@ -219,5 +224,13 @@ ExecReScanWorkTableScan(WorkTableScanState *node)
 
 	/* No need (or way) to rescan if ExecWorkTableScan not called yet */
 	if (node->rustate)
+	{
+		/* Make sure to select the initial read pointer */
+		tuplestore_select_read_pointer(node->rustate->working_table, 0);
+		tuplestore_rescan(node->rustate->working_table);
+		/* Reallocate a unique read pointer. */
+		node->readptr = tuplestore_alloc_read_pointer(node->rustate->working_table, 0);
+		tuplestore_copy_read_pointer(node->rustate->working_table, 0, node->readptr);
 		tuplestore_rescan(node->rustate->working_table);
+	}
 }
diff --git a/src/backend/parser/parse_cte.c b/src/backend/parser/parse_cte.c
index f6ae96333a..a2a496472a 100644
--- a/src/backend/parser/parse_cte.c
+++ b/src/backend/parser/parse_cte.c
@@ -80,6 +80,8 @@ typedef struct CteState
 	/* working state for checkWellFormedRecursion walk only: */
 	int			selfrefcount;	/* number of self-references detected */
 	RecursionContext context;	/* context to allow or disallow self-ref */
+	bool		root_is_union;	/* root of non-recursive term is SETOP_UNION */
+	bool		rotate;			/* self-reference in non-recursive term detected */
 } CteState;
 
 
@@ -853,11 +855,27 @@ checkWellFormedRecursion(CteState *cstate)
 					 parser_errposition(cstate->pstate, cte->location)));
 
 		/* The left-hand operand mustn't contain self-reference at all */
-		cstate->curitem = i;
-		cstate->innerwiths = NIL;
-		cstate->selfrefcount = 0;
-		cstate->context = RECURSION_NONRECURSIVETERM;
-		checkWellFormedRecursionWalker((Node *) stmt->larg, cstate);
+		do {
+			cstate->curitem = i;
+			cstate->innerwiths = NIL;
+			cstate->selfrefcount = 0;
+			cstate->context = RECURSION_NONRECURSIVETERM;
+			cstate->root_is_union = stmt->larg->op == SETOP_UNION;
+			cstate->rotate = false;
+			checkWellFormedRecursionWalker((Node *) stmt->larg, cstate);
+
+			/* Check if rotate flag has been set. */
+			if (cstate->root_is_union && cstate->rotate)
+			{
+				/* Rotate stmt UNION tree to the right. */
+				SelectStmt *rarg = stmt->rarg;
+				stmt->rarg = stmt->larg;
+				stmt->larg = stmt->larg->larg;
+				stmt->rarg->larg = stmt->rarg->rarg;
+				stmt->rarg->rarg = rarg;
+			}
+		} while (cstate->root_is_union && cstate->rotate);
+
 		Assert(cstate->innerwiths == NIL);
 
 		/* Right-hand operand should contain one reference in a valid place */
@@ -955,6 +973,12 @@ checkWellFormedRecursionWalker(Node *node, CteState *cstate)
 			mycte = cstate->items[cstate->curitem].cte;
 			if (strcmp(rv->relname, mycte->ctename) == 0)
 			{
+				/* Found a recursive reference, but it might be fixable */
+				if (cstate->context == RECURSION_NONRECURSIVETERM && cstate->root_is_union)
+				{
+					cstate->rotate = true;
+					return false;
+				}
 				/* Found a recursive reference to the active query */
 				if (cstate->context != RECURSION_OK)
 					ereport(ERROR,
@@ -1117,9 +1141,29 @@ checkWellFormedSelectStmt(SelectStmt *stmt, CteState *cstate)
 		{
 			case SETOP_NONE:
 			case SETOP_UNION:
-				raw_expression_tree_walker((Node *) stmt,
-										   checkWellFormedRecursionWalker,
-										   (void *) cstate);
+				/* check selfrefcount for each recursive member individually */
+				if (stmt->larg != NULL && stmt->rarg != NULL)
+				{
+					int curr_selfrefcount;
+					int selfrefcount = cstate->selfrefcount;
+
+					checkWellFormedRecursionWalker((Node *) stmt->larg, cstate);
+
+					/* Restore selfrefcount to allow multiple linear recursive references */
+					curr_selfrefcount = cstate->selfrefcount;
+					cstate->selfrefcount = selfrefcount;
+
+					checkWellFormedRecursionWalker((Node *) stmt->rarg, cstate);
+					/* Recursive anchors can contain recursive references, but don't have to. */
+					if (cstate->selfrefcount < curr_selfrefcount)
+						cstate->selfrefcount = curr_selfrefcount;
+				}
+				else
+				{
+					raw_expression_tree_walker((Node *) stmt,
+											   checkWellFormedRecursionWalker,
+											   (void *) cstate);
+				}
 				break;
 			case SETOP_INTERSECT:
 				if (stmt->all)
diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h
index 37cb4f3d59..b53844437c 100644
--- a/src/include/nodes/execnodes.h
+++ b/src/include/nodes/execnodes.h
@@ -1843,6 +1843,7 @@ typedef struct NamedTuplestoreScanState
 typedef struct WorkTableScanState
 {
 	ScanState	ss;				/* its first field is NodeTag */
+	int		readptr;			/* index of my tuplestore read pointer */
 	RecursiveUnionState *rustate;
 } WorkTableScanState;
 
diff --git a/src/test/regress/expected/with.out b/src/test/regress/expected/with.out
index 3523a7dcc1..8f83582d7b 100644
--- a/src/test/regress/expected/with.out
+++ b/src/test/regress/expected/with.out
@@ -235,6 +235,48 @@ WITH RECURSIVE outermost(x) AS (
  7
 (7 rows)
 
+-- Test multiple self-references in different recursive anchors
+WITH RECURSIVE foo(i) AS
+    (values (1)
+    UNION ALL
+       (SELECT i+1 FROM foo WHERE i < 5
+          UNION ALL
+       SELECT i+1 FROM foo WHERE i < 2)
+) SELECT * FROM foo;
+ i 
+---
+ 1
+ 2
+ 2
+ 3
+ 3
+ 4
+ 4
+ 5
+ 5
+(9 rows)
+
+WITH RECURSIVE foo(i) AS
+    (values (1)
+    UNION ALL
+	   SELECT * FROM
+       (SELECT i+1 FROM foo WHERE i < 5
+          UNION ALL
+       SELECT i+1 FROM foo WHERE i < 2) AS t
+) SELECT * FROM foo;
+ i 
+---
+ 1
+ 2
+ 2
+ 3
+ 3
+ 4
+ 4
+ 5
+ 5
+(9 rows)
+
 --
 -- Some examples with a tree
 --
@@ -1795,24 +1837,24 @@ LINE 2:   x (id) AS (SELECT 1 UNION ALL SELECT id+1 FROM y WHERE id ...
 WITH RECURSIVE foo(i) AS
     (values (1)
     UNION ALL
-       (SELECT i+1 FROM foo WHERE i < 10
+       (SELECT x.i+y.i FROM foo AS x, foo AS y WHERE x.i < 10
           UNION ALL
        SELECT i+1 FROM foo WHERE i < 5)
 ) SELECT * FROM foo;
 ERROR:  recursive reference to query "foo" must not appear more than once
-LINE 6:        SELECT i+1 FROM foo WHERE i < 5)
-                               ^
+LINE 4:        (SELECT x.i+y.i FROM foo AS x, foo AS y WHERE x.i < 1...
+                                              ^
 WITH RECURSIVE foo(i) AS
     (values (1)
     UNION ALL
 	   SELECT * FROM
-       (SELECT i+1 FROM foo WHERE i < 10
+       (SELECT x.i+y.i FROM foo AS x, foo as y WHERE x.i < 10
           UNION ALL
        SELECT i+1 FROM foo WHERE i < 5) AS t
 ) SELECT * FROM foo;
 ERROR:  recursive reference to query "foo" must not appear more than once
-LINE 7:        SELECT i+1 FROM foo WHERE i < 5) AS t
-                               ^
+LINE 5:        (SELECT x.i+y.i FROM foo AS x, foo as y WHERE x.i < 1...
+                                              ^
 WITH RECURSIVE foo(i) AS
     (values (1)
     UNION ALL
diff --git a/src/test/regress/sql/with.sql b/src/test/regress/sql/with.sql
index 8b213ee408..fab1933e3d 100644
--- a/src/test/regress/sql/with.sql
+++ b/src/test/regress/sql/with.sql
@@ -139,6 +139,24 @@ WITH RECURSIVE outermost(x) AS (
  )
  SELECT * FROM outermost ORDER BY 1;
 
+-- Test multiple self-references in different recursive anchors
+WITH RECURSIVE foo(i) AS
+    (values (1)
+    UNION ALL
+       (SELECT i+1 FROM foo WHERE i < 5
+          UNION ALL
+       SELECT i+1 FROM foo WHERE i < 2)
+) SELECT * FROM foo;
+
+WITH RECURSIVE foo(i) AS
+    (values (1)
+    UNION ALL
+	   SELECT * FROM
+       (SELECT i+1 FROM foo WHERE i < 5
+          UNION ALL
+       SELECT i+1 FROM foo WHERE i < 2) AS t
+) SELECT * FROM foo;
+
 --
 -- Some examples with a tree
 --
@@ -847,7 +865,7 @@ SELECT * FROM x;
 WITH RECURSIVE foo(i) AS
     (values (1)
     UNION ALL
-       (SELECT i+1 FROM foo WHERE i < 10
+       (SELECT x.i+y.i FROM foo AS x, foo AS y WHERE x.i < 10
           UNION ALL
        SELECT i+1 FROM foo WHERE i < 5)
 ) SELECT * FROM foo;
@@ -856,7 +874,7 @@ WITH RECURSIVE foo(i) AS
     (values (1)
     UNION ALL
 	   SELECT * FROM
-       (SELECT i+1 FROM foo WHERE i < 10
+       (SELECT x.i+y.i FROM foo AS x, foo as y WHERE x.i < 10
           UNION ALL
        SELECT i+1 FROM foo WHERE i < 5) AS t
 ) SELECT * FROM foo;
-- 
2.32.0

