From 688c60a4e29b426cfe713030226836bcd1cc37ff Mon Sep 17 00:00:00 2001
From: Alexey Grishchenko <agrishchenko@pivotal.io>
Date: Thu, 10 Mar 2016 12:56:06 +0000
Subject: [PATCH] Fix endless loop in plpython set-returning function

The issue occurs when the same set-returning function is called twice within a single
query. The issue is caused by the fact that result iterator is stored in the procedure
data structure, which is reused for different calls of the same procedure. But in fact,
when you execute the same SRF twice in the same query, the result iterator is being
continuously reinitialized by two different instances of the same function, the
execution never ends and process dies with OOM

Also this fix changes the PLy_function_delete_args, because with calling two SRF in a
single query the old implementation would attempt to delete the same input parameter
twice, which would lead to key error

The fix implementation allows to fix another issue with recursive SPI function calls

Regression test added for these cases
---
 src/pl/plpython/expected/plpython_setof.out |  12 +++
 src/pl/plpython/expected/plpython_spi.out   |  14 +++
 src/pl/plpython/plpy_exec.c                 | 139 ++++++++++++++++++++++++++--
 src/pl/plpython/plpy_procedure.c            |   3 +-
 src/pl/plpython/plpy_procedure.h            |   9 +-
 src/pl/plpython/sql/plpython_setof.sql      |   3 +
 src/pl/plpython/sql/plpython_spi.sql        |  10 +-
 7 files changed, 177 insertions(+), 13 deletions(-)

diff --git a/src/pl/plpython/expected/plpython_setof.out b/src/pl/plpython/expected/plpython_setof.out
index 62b8a45..d35f9c2 100644
--- a/src/pl/plpython/expected/plpython_setof.out
+++ b/src/pl/plpython/expected/plpython_setof.out
@@ -124,6 +124,18 @@ SELECT test_setof_spi_in_iterator();
  World
 (4 rows)
 
+-- Calling the same set-returning function twice in a single query
+select test_setof_as_list(2, 'list'), test_setof_as_list(3, 'list');
+ test_setof_as_list | test_setof_as_list 
+--------------------+--------------------
+ list               | list
+ list               | list
+ list               | list
+ list               | list
+ list               | list
+ list               | list
+(6 rows)
+
 -- returns set of named-composite-type tuples
 CREATE OR REPLACE FUNCTION get_user_records()
 RETURNS SETOF users
diff --git a/src/pl/plpython/expected/plpython_spi.out b/src/pl/plpython/expected/plpython_spi.out
index e715ee5..f0dd376 100644
--- a/src/pl/plpython/expected/plpython_spi.out
+++ b/src/pl/plpython/expected/plpython_spi.out
@@ -57,6 +57,14 @@ for r in rv:
 return seq
 '
 	LANGUAGE plpythonu;
+CREATE FUNCTION spi_recursive_sum(a int) RETURNS int
+	AS
+'r = 0
+if a > 1:
+    r = plpy.execute("SELECT spi_recursive_sum(%d) as a" % (a-1))[0]["a"]
+return a + r
+'
+	LANGUAGE plpythonu;
 -- spi and nested calls
 --
 select nested_call_one('pass this along');
@@ -112,6 +120,12 @@ SELECT join_sequences(sequences) FROM sequences
 ----------------
 (0 rows)
 
+SELECT spi_recursive_sum(10);
+ spi_recursive_sum 
+-------------------
+                55
+(1 row)
+
 --
 -- plan and result objects
 --
diff --git a/src/pl/plpython/plpy_exec.c b/src/pl/plpython/plpy_exec.c
index 24aed01..536046d 100644
--- a/src/pl/plpython/plpy_exec.c
+++ b/src/pl/plpython/plpy_exec.c
@@ -39,6 +39,8 @@ static void plpython_trigger_error_callback(void *arg);
 static PyObject *PLy_procedure_call(PLyProcedure *proc, char *kargs, PyObject *vargs);
 static void PLy_abort_open_subtransactions(int save_subxact_level);
 
+static void PLy_global_args_push(PLyProcedure *proc);
+static void PLy_global_args_pop(PLyProcedure *proc);
 
 /* function subhandler */
 Datum
@@ -47,17 +49,33 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc)
 	Datum		rv;
 	PyObject   *volatile plargs = NULL;
 	PyObject   *volatile plrv = NULL;
+	FuncCallContext	*volatile funcctx = NULL;
 	ErrorContextCallback plerrcontext;
 
 	PG_TRY();
 	{
-		if (!proc->is_setof || proc->setof == NULL)
+		if (proc->is_setof)
+		{
+			/* First Call setup */
+			if (SRF_IS_FIRSTCALL())
+				funcctx = SRF_FIRSTCALL_INIT();
+			/* Every call setup */
+			funcctx = SRF_PERCALL_SETUP();
+			Assert(funcctx != NULL);
+		}
+
+		if (!proc->is_setof || funcctx->user_fctx == NULL)
 		{
 			/*
 			 * Simple type returning function or first time for SETOF
 			 * function: actually execute the function.
 			 */
 			plargs = PLy_function_build_args(fcinfo, proc);
+			/*
+			 * In case of recursive call or SRF we might need to push old version
+			 * of arguments into the stack
+			 */
+			PLy_global_args_push(proc);
 			plrv = PLy_procedure_call(proc, "args", plargs);
 			if (!proc->is_setof)
 			{
@@ -80,7 +98,7 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc)
 			bool		has_error = false;
 			ReturnSetInfo *rsi = (ReturnSetInfo *) fcinfo->resultinfo;
 
-			if (proc->setof == NULL)
+			if (funcctx->user_fctx == NULL)
 			{
 				/* first time -- do checks and setup */
 				if (!rsi || !IsA(rsi, ReturnSetInfo) ||
@@ -94,11 +112,11 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc)
 				rsi->returnMode = SFRM_ValuePerCall;
 
 				/* Make iterator out of returned object */
-				proc->setof = PyObject_GetIter(plrv);
+				funcctx->user_fctx = (void*) PyObject_GetIter(plrv);
 				Py_DECREF(plrv);
 				plrv = NULL;
 
-				if (proc->setof == NULL)
+				if (funcctx->user_fctx == NULL)
 					ereport(ERROR,
 							(errcode(ERRCODE_DATATYPE_MISMATCH),
 							 errmsg("returned object cannot be iterated"),
@@ -106,7 +124,7 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc)
 			}
 
 			/* Fetch next from iterator */
-			plrv = PyIter_Next(proc->setof);
+			plrv = PyIter_Next(funcctx->user_fctx);
 			if (plrv)
 				rsi->isDone = ExprMultipleResult;
 			else
@@ -118,8 +136,8 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc)
 			if (rsi->isDone == ExprEndResult)
 			{
 				/* Iterator is exhausted or error happened */
-				Py_DECREF(proc->setof);
-				proc->setof = NULL;
+				Py_DECREF( (PyObject*) funcctx->user_fctx);
+				funcctx->user_fctx = NULL;
 
 				Py_XDECREF(plargs);
 				Py_XDECREF(plrv);
@@ -134,7 +152,7 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc)
 					elog(ERROR, "SPI_finish failed");
 
 				fcinfo->isnull = true;
-				return (Datum) NULL;
+				SRF_RETURN_DONE(funcctx);
 			}
 		}
 
@@ -213,8 +231,11 @@ PLy_exec_function(FunctionCallInfo fcinfo, PLyProcedure *proc)
 		 * yet. Set it to NULL so the next invocation of the function will
 		 * start the iteration again.
 		 */
-		Py_XDECREF(proc->setof);
-		proc->setof = NULL;
+		if (proc->is_setof && funcctx->user_fctx != NULL)
+		{
+			Py_XDECREF( (PyObject*) funcctx->user_fctx );
+			funcctx->user_fctx = NULL;
+		}
 
 		PG_RE_THROW();
 	}
@@ -442,6 +463,9 @@ PLy_function_delete_args(PLyProcedure *proc)
 	for (i = 0; i < proc->nargs; i++)
 		if (proc->argnames[i])
 			PyDict_DelItemString(proc->globals, proc->argnames[i]);
+
+	/* Pop global arguments from the stack if they were pushed there before */
+	PLy_global_args_pop(proc);
 }
 
 static void
@@ -855,3 +879,98 @@ PLy_abort_open_subtransactions(int save_subxact_level)
 		pfree(subtransactiondata);
 	}
 }
+
+static void
+PLy_global_args_push(PLyProcedure *proc)
+{
+	PLyArgsStack *node;
+	PyObject     *arglist;
+	MemoryContext oldcxt;
+
+	/*
+	 * Action is needed only if the same function was already called either
+	 * with SPI (recursive call) or with multiple set-returning functions in
+	 * a single query
+	 */
+	proc->calldepth += 1;
+	if (proc->calldepth > 1) {
+		/* Nothing to do if the function has no input parameters */
+		if (!proc->argnames)
+			return;
+
+		/* Fetch the argument list (Python list object) from the globals */
+		arglist = PyDict_GetItemString(proc->globals, "args");
+		if (arglist == NULL)
+		{
+			ereport(ERROR,
+					(errcode(ERRCODE_UNDEFINED_OBJECT),
+					 errmsg("\"args\" object is not defined prior to function call"),
+					 errdetail("PL/Python function does not allow removal of global \"args\" object")));
+		}
+		else
+		{
+			Py_INCREF(arglist);
+
+			/*
+			 * Push the function argument list into the stack in procedure
+			 * memory context
+			 */
+			oldcxt = MemoryContextSwitchTo(proc->mcxt);
+			node = (PLyArgsStack *) palloc0(sizeof(PLyArgsStack));
+			node->args = arglist;
+			node->next = proc->argstack;
+			proc->argstack = node;
+			MemoryContextSwitchTo(oldcxt);
+		}
+	}
+}
+
+static void
+PLy_global_args_pop(PLyProcedure *proc)
+{
+	int			  i;
+	PLyArgsStack *ptr;
+	PyObject     *arg;
+
+	/*
+	 * Action is needed only if the same function was already called either
+	 * with SPI (recursive call) or with multiple set-returning functions in
+	 * a single query
+	 */
+	if (proc->calldepth > 1) {
+		/* Nothing to do if the function has no input parameters */
+		if (!proc->argnames)
+			return;
+
+		/* If stack entry exist we do the pop */
+		if (proc->argstack != NULL) {
+			for (i = 0; i < proc->nargs; i++)
+				if (proc->argnames[i]) {
+
+					/* Get the argument object from saved argument list */
+					arg = PyList_GetItem(proc->argstack->args, i);
+					if (arg == NULL)
+						PLy_elog(ERROR, "PyList_GetItem() failed, while processing "
+										"arguments from the call cache");
+
+					/* Push the argument object back to the function globals */
+					if (PyDict_SetItemString(proc->globals, proc->argnames[i], arg) == -1)
+						PLy_elog(ERROR, "PyDict_SetItemString() failed, while setting up "
+										"arguments from the call cache");
+				}
+
+			/* Set the list of arguments back to globals */
+			if (PyDict_SetItemString(proc->globals, "args", proc->argstack->args) == -1)
+				PLy_elog(ERROR, "PyDict_SetItemString() failed, while setting up "
+								"arguments list object from the call cache");
+
+			/* Free stack entry */
+			ptr = proc->argstack->next;
+			Py_DECREF(proc->argstack->args);
+			pfree(proc->argstack);
+			proc->argstack = ptr;
+		}
+	}
+
+	proc->calldepth -= 1;
+}
\ No newline at end of file
diff --git a/src/pl/plpython/plpy_procedure.c b/src/pl/plpython/plpy_procedure.c
index a0d0792..948b976 100644
--- a/src/pl/plpython/plpy_procedure.c
+++ b/src/pl/plpython/plpy_procedure.c
@@ -203,9 +203,10 @@ PLy_procedure_create(HeapTuple procTup, Oid fn_oid, bool is_trigger)
 		proc->code = proc->statics = NULL;
 		proc->globals = NULL;
 		proc->is_setof = procStruct->proretset;
-		proc->setof = NULL;
 		proc->src = NULL;
 		proc->argnames = NULL;
+		proc->argstack = NULL;
+		proc->calldepth = 0;
 
 		/*
 		 * get information required for output conversion of the return value,
diff --git a/src/pl/plpython/plpy_procedure.h b/src/pl/plpython/plpy_procedure.h
index 9fc8db0..8a8b4cb 100644
--- a/src/pl/plpython/plpy_procedure.h
+++ b/src/pl/plpython/plpy_procedure.h
@@ -10,6 +10,12 @@
 
 extern void init_procedure_caches(void);
 
+/* stack of function call arguments */
+typedef struct PLyArgsStack
+{
+	PyObject			*args;	/* Python "list" object with call arguments */
+	struct PLyArgsStack *next;	/* pointer to the next stack element */
+} PLyArgsStack;
 
 /* cached procedure data */
 typedef struct PLyProcedure
@@ -24,7 +30,6 @@ typedef struct PLyProcedure
 	PLyTypeInfo result;			/* also used to store info for trigger tuple
 								 * type */
 	bool		is_setof;		/* true, if procedure returns result set */
-	PyObject   *setof;			/* contents of result set. */
 	char	   *src;			/* textual procedure code, after mangling */
 	char	  **argnames;		/* Argument names */
 	PLyTypeInfo args[FUNC_MAX_ARGS];
@@ -34,6 +39,8 @@ typedef struct PLyProcedure
 	PyObject   *code;			/* compiled procedure code */
 	PyObject   *statics;		/* data saved across calls, local scope */
 	PyObject   *globals;		/* data saved across calls, global scope */
+	int         calldepth;		/* depth of function recursive calls */
+	PLyArgsStack *argstack;		/* stack of function call arguments */
 } PLyProcedure;
 
 /* the procedure cache key */
diff --git a/src/pl/plpython/sql/plpython_setof.sql b/src/pl/plpython/sql/plpython_setof.sql
index fe034fb..63ca466 100644
--- a/src/pl/plpython/sql/plpython_setof.sql
+++ b/src/pl/plpython/sql/plpython_setof.sql
@@ -63,6 +63,9 @@ SELECT test_setof_as_iterator(2, null);
 
 SELECT test_setof_spi_in_iterator();
 
+-- Calling the same set-returning function twice in a single query
+select test_setof_as_list(2, 'list'), test_setof_as_list(3, 'list');
+
 
 -- returns set of named-composite-type tuples
 CREATE OR REPLACE FUNCTION get_user_records()
diff --git a/src/pl/plpython/sql/plpython_spi.sql b/src/pl/plpython/sql/plpython_spi.sql
index a882738..61fb614 100644
--- a/src/pl/plpython/sql/plpython_spi.sql
+++ b/src/pl/plpython/sql/plpython_spi.sql
@@ -69,7 +69,14 @@ return seq
 	LANGUAGE plpythonu;
 
 
-
+CREATE FUNCTION spi_recursive_sum(a int) RETURNS int
+	AS
+'r = 0
+if a > 1:
+    r = plpy.execute("SELECT spi_recursive_sum(%d) as a" % (a-1))[0]["a"]
+return a + r
+'
+	LANGUAGE plpythonu;
 
 
 -- spi and nested calls
@@ -88,6 +95,7 @@ SELECT join_sequences(sequences) FROM sequences
 SELECT join_sequences(sequences) FROM sequences
 	WHERE join_sequences(sequences) ~* '^B';
 
+SELECT spi_recursive_sum(10);
 
 --
 -- plan and result objects
-- 
1.9.5 (Apple Git-50.3)

