From 0927c2c5df986b406a6ba954018429e951b1405a Mon Sep 17 00:00:00 2001
From: Ashutosh Sharma <asharma@microsoft.com>
Date: Wed, 5 Jun 2024 08:59:31 +0000
Subject: [PATCH] Ensure security definer functions created by extensions use
 trusted schemas.

This commit addresses an issue where security definer functions
created by extensions did not have their search_path set to trusted
schemas, potentially exposing them to security risks. The provided
patch rectifies this by configuring the search_path for such functions
to include the schema where both the extension and the function are
created. This ensures that security definer functions operate within
trusted environments, enhancing the overall system security.
---
 src/backend/commands/extension.c    |  3 ++
 src/backend/commands/functioncmds.c | 62 ++++++++++++++++++++++++++++-
 src/include/commands/extension.h    |  1 +
 3 files changed, 64 insertions(+), 2 deletions(-)

diff --git a/src/backend/commands/extension.c b/src/backend/commands/extension.c
index 1643c8c69a..bb8d1f7f8d 100644
--- a/src/backend/commands/extension.c
+++ b/src/backend/commands/extension.c
@@ -70,6 +70,7 @@
 /* Globally visible state variables */
 bool		creating_extension = false;
 Oid			CurrentExtensionObject = InvalidOid;
+char	   *create_extension_search_path = NULL;
 
 /*
  * Internal data structure to hold the results of parsing a control file
@@ -992,6 +993,7 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control,
 	 */
 	creating_extension = true;
 	CurrentExtensionObject = extensionOid;
+	create_extension_search_path = pstrdup(namespace_search_path);
 	PG_TRY();
 	{
 		char	   *c_sql = read_extension_script_file(control, filename);
@@ -1116,6 +1118,7 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control,
 	{
 		creating_extension = false;
 		CurrentExtensionObject = InvalidOid;
+		pfree(create_extension_search_path);
 	}
 	PG_END_TRY();
 
diff --git a/src/backend/commands/functioncmds.c b/src/backend/commands/functioncmds.c
index 6593fd7d81..886e325cd4 100644
--- a/src/backend/commands/functioncmds.c
+++ b/src/backend/commands/functioncmds.c
@@ -52,6 +52,7 @@
 #include "executor/functions.h"
 #include "funcapi.h"
 #include "miscadmin.h"
+#include "nodes/makefuncs.h"
 #include "nodes/nodeFuncs.h"
 #include "optimizer/optimizer.h"
 #include "parser/analyze.h"
@@ -71,6 +72,7 @@
 #include "utils/snapmgr.h"
 #include "utils/syscache.h"
 #include "utils/typcache.h"
+#include "utils/varlena.h"
 
 /*
  *	 Examine the RETURNS clause of the CREATE FUNCTION statement
@@ -705,6 +707,22 @@ interpret_func_support(DefElem *defel)
 	return procOid;
 }
 
+static bool
+IsSearchPathSet(List *set_items)
+{
+	ListCell   *l;
+
+	foreach(l, set_items)
+	{
+		VariableSetStmt *sstmt = lfirst_node(VariableSetStmt, l);
+
+		if (pg_strcasecmp(sstmt->name, "search_path") == 0 &&
+			sstmt->kind == VAR_SET_VALUE)
+			return true;
+	}
+
+	return false;
+}
 
 /*
  * Dissect the list of options assembled in gram.y into function
@@ -726,7 +744,8 @@ compute_function_attributes(ParseState *pstate,
 							float4 *procost,
 							float4 *prorows,
 							Oid *prosupport,
-							char *parallel_p)
+							char *parallel_p,
+							Oid namespaceId)
 {
 	ListCell   *option;
 	DefElem    *as_item = NULL;
@@ -813,6 +832,45 @@ compute_function_attributes(ParseState *pstate,
 		*security_definer = boolVal(security_item->arg);
 	if (leakproof_item)
 		*leakproof_p = boolVal(leakproof_item->arg);
+
+	/*
+	 * Ensure that security definer functions created by an extension have the
+	 * search_path set to trusted schemas, which includes the schema where the
+	 * function is being created and the search_path set by the extension. See
+	 * execute_extension_script() for search_path set by the extension.
+	 */
+	if (creating_extension && *security_definer)
+	{
+		/* If the search_path is already set, there is nothing to do. */
+		if (!set_items || !IsSearchPathSet(set_items))
+		{
+			VariableSetStmt *sp = makeNode(VariableSetStmt);
+			char    *rawstring;
+			List    *schemalist;
+			ListCell    *lc;
+
+			sp->kind = VAR_SET_VALUE;
+			sp->name = "search_path";
+			/* Start with the schema where the function is getting created. */
+			sp->args = lappend(sp->args, makeStringConst(get_namespace_name(namespaceId), -1));
+
+			/* Append the schema(s) set by the extension in search_path */
+			rawstring = pstrdup(create_extension_search_path);
+			(void) SplitIdentifierString(rawstring, ',', &schemalist);
+
+			foreach(lc, schemalist)
+			{
+				char *schema_name = lfirst(lc);
+
+				sp->args = lappend(sp->args,
+								   makeStringConst(pstrdup(schema_name), -1));
+			}
+
+			set_items = lappend(set_items, sp);
+			pfree(rawstring);
+		}
+	}
+
 	if (set_items)
 		*proconfig = update_proconfig_value(NULL, set_items);
 	if (cost_item)
@@ -1079,7 +1137,7 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt)
 								&isWindowFunc, &volatility,
 								&isStrict, &security, &isLeakProof,
 								&proconfig, &procost, &prorows,
-								&prosupport, &parallel);
+								&prosupport, &parallel, namespaceId);
 
 	if (!language)
 	{
diff --git a/src/include/commands/extension.h b/src/include/commands/extension.h
index c6f3f867eb..e9c77f8491 100644
--- a/src/include/commands/extension.h
+++ b/src/include/commands/extension.h
@@ -29,6 +29,7 @@
  */
 extern PGDLLIMPORT bool creating_extension;
 extern PGDLLIMPORT Oid CurrentExtensionObject;
+extern PGDLLIMPORT char *create_extension_search_path;
 
 
 extern ObjectAddress CreateExtension(ParseState *pstate, CreateExtensionStmt *stmt);
-- 
2.17.1

