From a742037f22b190bdc95a6fffb1c9de80655d5549 Mon Sep 17 00:00:00 2001
From: Ashutosh Sharma <asharma@microsoft.com>
Date: Tue, 11 Jun 2024 09:11:52 +0000
Subject: [PATCH] Implement implicit search_path assignment for
 extension-created functions upon request.

It primarily implements the following enhancements to enable this:

- Extends the CREATE EXTENSION command to support a new option, SET
  SEARCH_PATH.

- If the SET SEARCH_PATH option is specified with the CREATE EXTENSION
  command, the implicit search_path for functions created by an
  extension is set, if not already configured.

- Upon execution of ALTER EXTENSION SET SCHEMA command, if the
  function's search_path contains the old schema of the extension, it
  is updated with the new schema.
---
 src/backend/commands/extension.c    |  59 ++++++---
 src/backend/commands/functioncmds.c | 181 +++++++++++++++++++++++++++-
 src/backend/parser/gram.y           |  13 ++
 src/include/commands/defrem.h       |   2 +
 src/include/commands/extension.h    |   2 +
 5 files changed, 241 insertions(+), 16 deletions(-)

diff --git a/src/backend/commands/extension.c b/src/backend/commands/extension.c
index 1643c8c69a..5134deb49b 100644
--- a/src/backend/commands/extension.c
+++ b/src/backend/commands/extension.c
@@ -45,6 +45,7 @@
 #include "catalog/pg_depend.h"
 #include "catalog/pg_extension.h"
 #include "catalog/pg_namespace.h"
+#include "catalog/pg_proc.h"
 #include "catalog/pg_type.h"
 #include "commands/alter.h"
 #include "commands/comment.h"
@@ -70,6 +71,8 @@
 /* Globally visible state variables */
 bool		creating_extension = false;
 Oid			CurrentExtensionObject = InvalidOid;
+char	   *create_extension_search_path = NULL;
+bool		create_extension_set_search_path = false;
 
 /*
  * Internal data structure to hold the results of parsing a control file
@@ -117,7 +120,8 @@ static Oid	get_required_extension(char *reqExtensionName,
 								   char *origSchemaName,
 								   bool cascade,
 								   List *parents,
-								   bool is_create);
+								   bool is_create,
+								   bool set_search_path);
 static void get_available_versions_for_extension(ExtensionControlFile *pcontrol,
 												 Tuplestorestate *tupstore,
 												 TupleDesc tupdesc);
@@ -128,7 +132,8 @@ static void ApplyExtensionUpdates(Oid extensionOid,
 								  List *updateVersions,
 								  char *origSchemaName,
 								  bool cascade,
-								  bool is_create);
+								  bool is_create,
+								  bool set_search_path);
 static void ExecAlterExtensionContentsRecurse(AlterExtensionContentsStmt *stmt,
 											  ObjectAddress extension,
 											  ObjectAddress object);
@@ -871,7 +876,7 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control,
 						 const char *from_version,
 						 const char *version,
 						 List *requiredSchemas,
-						 const char *schemaName, Oid schemaOid)
+						 const char *schemaName, Oid schemaOid, bool set_search_path)
 {
 	bool		switch_to_superuser = false;
 	char	   *filename;
@@ -992,6 +997,8 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control,
 	 */
 	creating_extension = true;
 	CurrentExtensionObject = extensionOid;
+	create_extension_search_path = pstrdup(namespace_search_path);
+	create_extension_set_search_path = set_search_path;
 	PG_TRY();
 	{
 		char	   *c_sql = read_extension_script_file(control, filename);
@@ -1116,6 +1123,8 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control,
 	{
 		creating_extension = false;
 		CurrentExtensionObject = InvalidOid;
+		pfree(create_extension_search_path);
+		create_extension_set_search_path = false;
 	}
 	PG_END_TRY();
 
@@ -1460,7 +1469,8 @@ CreateExtensionInternal(char *extensionName,
 						const char *versionName,
 						bool cascade,
 						List *parents,
-						bool is_create)
+						bool is_create,
+						bool set_search_path)
 {
 	char	   *origSchemaName = schemaName;
 	Oid			schemaOid = InvalidOid;
@@ -1648,7 +1658,8 @@ CreateExtensionInternal(char *extensionName,
 										origSchemaName,
 										cascade,
 										parents,
-										is_create);
+										is_create,
+										set_search_path);
 		reqschema = get_extension_schema(reqext);
 		requiredExtensions = lappend_oid(requiredExtensions, reqext);
 		requiredSchemas = lappend_oid(requiredSchemas, reqschema);
@@ -1677,7 +1688,7 @@ CreateExtensionInternal(char *extensionName,
 	execute_extension_script(extensionOid, control,
 							 NULL, versionName,
 							 requiredSchemas,
-							 schemaName, schemaOid);
+							 schemaName, schemaOid, set_search_path);
 
 	/*
 	 * If additional update scripts have to be executed, apply the updates as
@@ -1685,7 +1696,7 @@ CreateExtensionInternal(char *extensionName,
 	 */
 	ApplyExtensionUpdates(extensionOid, pcontrol,
 						  versionName, updateVersions,
-						  origSchemaName, cascade, is_create);
+						  origSchemaName, cascade, is_create, set_search_path);
 
 	return address;
 }
@@ -1699,7 +1710,8 @@ get_required_extension(char *reqExtensionName,
 					   char *origSchemaName,
 					   bool cascade,
 					   List *parents,
-					   bool is_create)
+					   bool is_create,
+					   bool set_search_path)
 {
 	Oid			reqExtensionOid;
 
@@ -1744,7 +1756,8 @@ get_required_extension(char *reqExtensionName,
 										   NULL,
 										   cascade,
 										   cascade_parents,
-										   is_create);
+										   is_create,
+										   set_search_path);
 
 			/* Get its newly-assigned OID. */
 			reqExtensionOid = addr.objectId;
@@ -1770,9 +1783,11 @@ CreateExtension(ParseState *pstate, CreateExtensionStmt *stmt)
 	DefElem    *d_schema = NULL;
 	DefElem    *d_new_version = NULL;
 	DefElem    *d_cascade = NULL;
+	DefElem    *d_search_path = NULL;
 	char	   *schemaName = NULL;
 	char	   *versionName = NULL;
 	bool		cascade = false;
+	bool		set_search_path = false;
 	ListCell   *lc;
 
 	/* Check extension name validity before any filesystem access */
@@ -1836,6 +1851,13 @@ CreateExtension(ParseState *pstate, CreateExtensionStmt *stmt)
 			d_cascade = defel;
 			cascade = defGetBoolean(d_cascade);
 		}
+		else if (strcmp(defel->defname, "search_path") == 0)
+		{
+			if (d_search_path)
+				errorConflictingDefElem(defel, pstate);
+			d_search_path = defel;
+			set_search_path = defGetBoolean(d_search_path);
+		}
 		else
 			elog(ERROR, "unrecognized option: %s", defel->defname);
 	}
@@ -1846,7 +1868,8 @@ CreateExtension(ParseState *pstate, CreateExtensionStmt *stmt)
 								   versionName,
 								   cascade,
 								   NIL,
-								   true);
+								   true,
+								   set_search_path);
 }
 
 /*
@@ -2950,6 +2973,12 @@ AlterExtensionNamespace(const char *extensionName, const char *newschema, Oid *o
 					 errdetail("%s is not in the extension's schema \"%s\"",
 							   getObjectDescription(&dep, false),
 							   get_namespace_name(oldNspOid))));
+
+		/*
+		 * If the function has search_path set in its proconfig, update it, if needed.
+		 */
+		if (dep.classId == ProcedureRelationId)
+			AlterProcSearchPathIfNeeded(dep.objectId, dep_oldNspOid, nspOid);
 	}
 
 	/* report old schema, if caller wants it */
@@ -3115,7 +3144,7 @@ ExecAlterExtensionStmt(ParseState *pstate, AlterExtensionStmt *stmt)
 	 */
 	ApplyExtensionUpdates(extensionOid, control,
 						  oldVersionName, updateVersions,
-						  NULL, false, false);
+						  NULL, false, false, false);
 
 	ObjectAddressSet(address, ExtensionRelationId, extensionOid);
 
@@ -3137,7 +3166,8 @@ ApplyExtensionUpdates(Oid extensionOid,
 					  List *updateVersions,
 					  char *origSchemaName,
 					  bool cascade,
-					  bool is_create)
+					  bool is_create,
+					  bool set_search_path)
 {
 	const char *oldVersionName = initialVersion;
 	ListCell   *lcv;
@@ -3232,7 +3262,8 @@ ApplyExtensionUpdates(Oid extensionOid,
 											origSchemaName,
 											cascade,
 											NIL,
-											is_create);
+											is_create,
+											set_search_path);
 			reqschema = get_extension_schema(reqext);
 			requiredExtensions = lappend_oid(requiredExtensions, reqext);
 			requiredSchemas = lappend_oid(requiredSchemas, reqschema);
@@ -3269,7 +3300,7 @@ ApplyExtensionUpdates(Oid extensionOid,
 		execute_extension_script(extensionOid, control,
 								 oldVersionName, versionName,
 								 requiredSchemas,
-								 schemaName, schemaOid);
+								 schemaName, schemaOid, set_search_path);
 
 		/*
 		 * Update prior-version name and loop around.  Since
diff --git a/src/backend/commands/functioncmds.c b/src/backend/commands/functioncmds.c
index 6593fd7d81..aab39676c0 100644
--- a/src/backend/commands/functioncmds.c
+++ b/src/backend/commands/functioncmds.c
@@ -34,6 +34,7 @@
 
 #include "access/htup_details.h"
 #include "access/table.h"
+#include "access/xact.h"
 #include "catalog/catalog.h"
 #include "catalog/dependency.h"
 #include "catalog/indexing.h"
@@ -52,6 +53,7 @@
 #include "executor/functions.h"
 #include "funcapi.h"
 #include "miscadmin.h"
+#include "nodes/makefuncs.h"
 #include "nodes/nodeFuncs.h"
 #include "optimizer/optimizer.h"
 #include "parser/analyze.h"
@@ -71,6 +73,7 @@
 #include "utils/snapmgr.h"
 #include "utils/syscache.h"
 #include "utils/typcache.h"
+#include "utils/varlena.h"
 
 /*
  *	 Examine the RETURNS clause of the CREATE FUNCTION statement
@@ -705,6 +708,69 @@ interpret_func_support(DefElem *defel)
 	return procOid;
 }
 
+static bool
+IsSearchPathSet(List *set_items)
+{
+	ListCell   *l;
+
+	foreach(l, set_items)
+	{
+		VariableSetStmt *sstmt = lfirst_node(VariableSetStmt, l);
+
+		if (pg_strcasecmp(sstmt->name, "search_path") == 0 &&
+			sstmt->kind == VAR_SET_VALUE)
+			return true;
+	}
+
+	return false;
+}
+
+/*
+ * Prepare set_items from searchPath, replacing old schema with new schema
+ * if needed.
+ *
+ * Set need_update flag to true, if old schema got replaced with a
+ * new one.
+ */
+static List *
+prepare_sp_set_items(char *searchPath, Oid old_nspOid, Oid new_nspOid,
+					 bool *need_update)
+{
+	VariableSetStmt *sp = makeNode(VariableSetStmt);
+	List	*sp_set_items = NIL;
+	List    *schemaList;
+	ListCell    *lc;
+	bool	replace_schema = false;
+
+	sp->kind = VAR_SET_VALUE;
+	sp->name = "search_path";
+
+	(void) SplitIdentifierString(searchPath, ',', &schemaList);
+
+	foreach(lc, schemaList)
+	{
+		char *schema_name = lfirst(lc);
+
+		if (OidIsValid(old_nspOid) && OidIsValid(new_nspOid) &&
+			old_nspOid == LookupNamespaceNoError(schema_name))
+		{
+			replace_schema = true;
+			if (need_update)
+				*need_update = true;
+		}
+
+		sp->args = list_append_unique(sp->args,
+						   makeStringConst(replace_schema ?
+										   get_namespace_name(new_nspOid) :
+										   pstrdup(schema_name), -1));
+
+		replace_schema = false;
+	}
+
+	sp_set_items = lappend(sp_set_items, sp);
+
+	return sp_set_items;
+}
 
 /*
  * Dissect the list of options assembled in gram.y into function
@@ -726,7 +792,8 @@ compute_function_attributes(ParseState *pstate,
 							float4 *procost,
 							float4 *prorows,
 							Oid *prosupport,
-							char *parallel_p)
+							char *parallel_p,
+							Oid namespaceId)
 {
 	ListCell   *option;
 	DefElem    *as_item = NULL;
@@ -813,6 +880,32 @@ compute_function_attributes(ParseState *pstate,
 		*security_definer = boolVal(security_item->arg);
 	if (leakproof_item)
 		*leakproof_p = boolVal(leakproof_item->arg);
+
+	/*
+	 * If user has specified SET SEARCH_PATH option with CREATE EXTENSION command,
+	 * we ensure that the function(s) created by an extension have search_path set
+	 * to trusted schemas, which includes the schema where the function is being
+	 * created and the search_path set by the extension. See
+	 * execute_extension_script() for search_path set by the extension.
+	 */
+	if (creating_extension && create_extension_set_search_path)
+	{
+		/* If the search_path is already set, there is nothing to do. */
+		if (!set_items || !IsSearchPathSet(set_items))
+		{
+			StringInfoData sp;
+
+			initStringInfo(&sp);
+			appendStringInfoString(&sp, get_namespace_name(namespaceId));
+			appendStringInfoString(&sp, ", ");
+			appendStringInfoString(&sp, create_extension_search_path);
+
+			set_items = prepare_sp_set_items(sp.data, InvalidOid, InvalidOid,
+											 NULL);
+			pfree(sp.data);
+		}
+	}
+
 	if (set_items)
 		*proconfig = update_proconfig_value(NULL, set_items);
 	if (cost_item)
@@ -1079,7 +1172,7 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt)
 								&isWindowFunc, &volatility,
 								&isStrict, &security, &isLeakProof,
 								&proconfig, &procost, &prorows,
-								&prosupport, &parallel);
+								&prosupport, &parallel, namespaceId);
 
 	if (!language)
 	{
@@ -1513,6 +1606,90 @@ AlterFunction(ParseState *pstate, AlterFunctionStmt *stmt)
 	return address;
 }
 
+/*
+ * Updates proconfig search_path parameter for the functions and procedures
+ * created by an extension when ALTER EXTENSION SET SCHEMA is executed.
+ */
+void
+AlterProcSearchPathIfNeeded(Oid procOid, Oid old_nspOid, Oid new_nspOid)
+{
+	HeapTuple   tup;
+	Form_pg_proc procForm;
+	Relation    rel;
+	Datum       datum;
+	bool        isnull;
+	ArrayType  *array;
+	List	   *configNames;
+	List	   *configValues;
+	ListCell   *lc1,
+			   *lc2;
+
+	/* Advance command counter so that new tuple can be visible */
+	CommandCounterIncrement();
+
+	rel = table_open(ProcedureRelationId, RowExclusiveLock);
+	tup = SearchSysCacheCopy1(PROCOID, ObjectIdGetDatum(procOid));
+	if (!HeapTupleIsValid(tup))
+		elog(ERROR, "cache lookup failed for function %u", procOid);
+
+	procForm = (Form_pg_proc) GETSTRUCT(tup);
+
+	if (procForm->prokind == PROKIND_AGGREGATE ||
+		procForm->prokind == PROKIND_WINDOW)
+	{
+		table_close(rel, NoLock);
+		heap_freetuple(tup);
+		return;
+	}
+
+	/* extract existing proconfig setting */
+	datum = SysCacheGetAttr(PROCOID, tup, Anum_pg_proc_proconfig, &isnull);
+	array = isnull ? NULL : DatumGetArrayTypeP(datum);
+
+	if (!isnull)
+	{
+		TransformGUCArray(array, &configNames, &configValues);
+
+		forboth(lc1, configNames, lc2, configValues)
+		{
+			char       *name = lfirst(lc1);
+			char       *value = lfirst(lc2);
+			List	   *new_sp_set_items = NIL;
+			bool		need_update = false;
+
+			if (pg_strcasecmp(name, "search_path") == 0)
+				new_sp_set_items =  prepare_sp_set_items(value, old_nspOid, new_nspOid,
+														 &need_update);
+
+			if (need_update && new_sp_set_items)
+			{
+				Datum       repl_val[Natts_pg_proc];
+				bool        repl_null[Natts_pg_proc];
+				bool        repl_repl[Natts_pg_proc];
+
+				array = update_proconfig_value(array, new_sp_set_items);
+
+				/* update the tuple */
+				memset(repl_repl, false, sizeof(repl_repl));
+				repl_repl[Anum_pg_proc_proconfig - 1] = true;
+
+				repl_val[Anum_pg_proc_proconfig - 1] = PointerGetDatum(array);
+				repl_null[Anum_pg_proc_proconfig - 1] = false;
+
+				tup = heap_modify_tuple(tup, RelationGetDescr(rel),
+										repl_val, repl_null, repl_repl);
+
+				CatalogTupleUpdate(rel, &tup->t_self, tup);
+
+				InvokeObjectPostAlterHook(ProcedureRelationId, procOid, 0);
+
+			}
+		}
+	}
+
+	table_close(rel, NoLock);
+	heap_freetuple(tup);
+}
 
 /*
  * CREATE CAST
diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y
index 4d582950b7..dc08e62ed3 100644
--- a/src/backend/parser/gram.y
+++ b/src/backend/parser/gram.y
@@ -5200,6 +5200,19 @@ create_extension_opt_item:
 				{
 					$$ = makeDefElem("cascade", (Node *) makeBoolean(true), @1);
 				}
+			| SET IDENT
+				{
+					char *ident_name = $2;
+
+					if (strcmp(ident_name, "search_path") != 0)
+						ereport(ERROR,
+								(errcode(ERRCODE_SYNTAX_ERROR),
+								 errmsg("unrecognized option \"%s\"", $2),
+								 errhint("Only SEARCH_PATH is supported currently."),
+								 parser_errposition(@2)));
+
+					$$ = makeDefElem("search_path", (Node *) makeBoolean(true), @2);
+				}
 		;
 
 /*****************************************************************************
diff --git a/src/include/commands/defrem.h b/src/include/commands/defrem.h
index 29c511e319..29502746cd 100644
--- a/src/include/commands/defrem.h
+++ b/src/include/commands/defrem.h
@@ -53,6 +53,8 @@ extern Oid	ResolveOpClass(const List *opclass, Oid attrType,
 extern ObjectAddress CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt);
 extern void RemoveFunctionById(Oid funcOid);
 extern ObjectAddress AlterFunction(ParseState *pstate, AlterFunctionStmt *stmt);
+extern void AlterProcSearchPathIfNeeded(Oid procOid, Oid old_nspOid,
+										Oid new_nspOid);
 extern ObjectAddress CreateCast(CreateCastStmt *stmt);
 extern ObjectAddress CreateTransform(CreateTransformStmt *stmt);
 extern void IsThereFunctionInNamespace(const char *proname, int pronargs,
diff --git a/src/include/commands/extension.h b/src/include/commands/extension.h
index c6f3f867eb..b1ded703de 100644
--- a/src/include/commands/extension.h
+++ b/src/include/commands/extension.h
@@ -29,6 +29,8 @@
  */
 extern PGDLLIMPORT bool creating_extension;
 extern PGDLLIMPORT Oid CurrentExtensionObject;
+extern PGDLLIMPORT char *create_extension_search_path;
+extern PGDLLIMPORT bool create_extension_set_search_path;
 
 
 extern ObjectAddress CreateExtension(ParseState *pstate, CreateExtensionStmt *stmt);
-- 
2.17.1

