From 80e19af369d730722e531f9ea3d6e32ec71558e5 Mon Sep 17 00:00:00 2001
From: Ashutosh Sharma <asharma@microsoft.com>
Date: Wed, 24 Jul 2024 11:48:15 +0000
Subject: [PATCH] Introduce new control file parameter 'protected' to define
 implicit search_path for extension functions.

When enabled, this parameter defines the implicit search_path for
functions and procedures created by extensions if no explicit
search_path is specified. It includes $extension_schema, pg_temp, and
function_schema (if different from the extension's schema). Here
$extension_schema is a special name that dynamically resolves to all
schemas on which the extension depends. This resolution occurs at the
time function or procedure execution.
---
 doc/src/sgml/extend.sgml            | 23 +++++++--
 src/backend/catalog/namespace.c     | 32 +++++++++++++
 src/backend/catalog/pg_depend.c     | 43 +++++++++++++++++
 src/backend/commands/extension.c    | 65 ++++++++++++++++++++-----
 src/backend/commands/functioncmds.c | 73 ++++++++++++++++++++++++++++-
 src/backend/utils/fmgr/fmgr.c       | 16 +++++++
 src/include/catalog/dependency.h    |  1 +
 src/include/commands/extension.h    |  5 ++
 8 files changed, 241 insertions(+), 17 deletions(-)

diff --git a/doc/src/sgml/extend.sgml b/doc/src/sgml/extend.sgml
index 218940ee5c..a319119a02 100644
--- a/doc/src/sgml/extend.sgml
+++ b/doc/src/sgml/extend.sgml
@@ -822,6 +822,21 @@ RETURNS anycompatible AS ...
        </para>
       </listitem>
      </varlistentry>
+
+    <varlistentry id="extend-extensions-files-protected">
+      <term><varname>protected</varname> (<type>boolean</type>)</term>
+      <listitem>
+       <para>
+        This parameter, if set to true (which is not the default), defines the
+        implicit search_path for functions and procedures created by the
+        extension. It sets the <varname>search_path</varname> to
+        <literal>$extension_schema</literal>, <literal>pg_temp</literal>, where
+        <literal>$extension_schema</literal> is a special name that dynamically
+        resolves to all schemas on which the extension depends. This resolution
+        occurs at the time of function or procedure execution.
+       </para>
+      </listitem>
+     </varlistentry>
     </variablelist>
 
     <para>
@@ -1288,10 +1303,10 @@ SELECT * FROM pg_extension_update_paths('<replaceable>extension_name</replaceabl
           PostgreSQL contained no such defect. -->
      <para>
       If you cannot set the <varname>search_path</varname> to contain only
-      secure schemas, assume that each unqualified name could resolve to an
-      object that a malicious user has defined.  Beware of constructs that
-      depend on <varname>search_path</varname> implicitly; for
-      example, <token>IN</token>
+      secure schemas, or mark the extension as protected, then assume that each
+      unqualified name could resolve to an object that a malicious user has
+      defined.  Beware of constructs that depend on
+      <varname>search_path</varname> implicitly; for example, <token>IN</token>
       and <literal>CASE <replaceable>expression</replaceable> WHEN</literal>
       always select an operator using the search path.  In their place, use
       <literal>OPERATOR(<replaceable>schema</replaceable>.=) ANY</literal>
diff --git a/src/backend/catalog/namespace.c b/src/backend/catalog/namespace.c
index 43b707699d..05fca3354c 100644
--- a/src/backend/catalog/namespace.c
+++ b/src/backend/catalog/namespace.c
@@ -42,6 +42,7 @@
 #include "catalog/pg_ts_template.h"
 #include "catalog/pg_type.h"
 #include "commands/dbcommands.h"
+#include "commands/extension.h"
 #include "common/hashfn_unstable.h"
 #include "funcapi.h"
 #include "mb/pg_wchar.h"
@@ -4152,6 +4153,37 @@ preprocessNamespacePath(const char *searchPath, Oid roleid,
 					*temp_missing = true;
 			}
 		}
+		else if (strcmp(curname, "$extension_schema") == 0)
+		{
+			/*
+			 * $extension_schema --- substitute namespace on which the extension
+			 * depends, if executing functions or procedures related to an
+			 * extension that has search_path set in its proconfig to
+			 * $extension_schema; otherwise, skip.
+			 */
+			Oid			extOid = GetCurrentExtensionId();
+			List	   *extList;
+			ListCell   *lc;
+
+			if (!OidIsValid(extOid))
+				continue;
+
+			extList = getExtensionsOfExtension(extOid);
+			extList = lappend_oid(extList, extOid);
+
+			foreach(lc, extList)
+			{
+				extOid = lfirst_oid(lc);
+
+				namespaceId = get_extension_schema(extOid);
+				if (OidIsValid(namespaceId) &&
+					object_aclcheck(NamespaceRelationId, namespaceId, roleid,
+									ACL_USAGE) == ACLCHECK_OK)
+					oidlist = lappend_oid(oidlist, namespaceId);
+			}
+
+			list_free(extList);
+		}
 		else
 		{
 			/* normal namespace reference */
diff --git a/src/backend/catalog/pg_depend.c b/src/backend/catalog/pg_depend.c
index cfd7ef51df..8a7f071c00 100644
--- a/src/backend/catalog/pg_depend.c
+++ b/src/backend/catalog/pg_depend.c
@@ -814,6 +814,49 @@ getAutoExtensionsOfObject(Oid classId, Oid objectId)
 	return result;
 }
 
+/*
+ * Return (possibly NIL) list of extensions that the given extension depends on
+ * in DEPENDENCY_NORMAL mode.
+ */
+List *
+getExtensionsOfExtension(Oid objectId)
+{
+	List	   *result = NIL;
+	Relation	depRel;
+	ScanKeyData key[2];
+	SysScanDesc scan;
+	HeapTuple	tup;
+
+	depRel = table_open(DependRelationId, AccessShareLock);
+
+	ScanKeyInit(&key[0],
+				Anum_pg_depend_classid,
+				BTEqualStrategyNumber, F_OIDEQ,
+				ObjectIdGetDatum(ExtensionRelationId));
+	ScanKeyInit(&key[1],
+				Anum_pg_depend_objid,
+				BTEqualStrategyNumber, F_OIDEQ,
+				ObjectIdGetDatum(objectId));
+
+	scan = systable_beginscan(depRel, DependDependerIndexId, true,
+							  NULL, 2, key);
+
+	while (HeapTupleIsValid((tup = systable_getnext(scan))))
+	{
+		Form_pg_depend depform = (Form_pg_depend) GETSTRUCT(tup);
+
+		if (depform->refclassid == ExtensionRelationId &&
+			depform->deptype == DEPENDENCY_NORMAL)
+			result = lappend_oid(result, depform->refobjid);
+	}
+
+	systable_endscan(scan);
+
+	table_close(depRel, AccessShareLock);
+
+	return result;
+}
+
 /*
  * Detect whether a sequence is marked as "owned" by a column
  *
diff --git a/src/backend/commands/extension.c b/src/backend/commands/extension.c
index 1643c8c69a..2b4f52d8be 100644
--- a/src/backend/commands/extension.c
+++ b/src/backend/commands/extension.c
@@ -70,6 +70,8 @@
 /* Globally visible state variables */
 bool		creating_extension = false;
 Oid			CurrentExtensionObject = InvalidOid;
+bool		create_extension_set_search_path = false;
+Oid			CurrentExtensionId = InvalidOid;
 
 /*
  * Internal data structure to hold the results of parsing a control file
@@ -86,6 +88,8 @@ typedef struct ExtensionControlFile
 	bool		relocatable;	/* is ALTER EXTENSION SET SCHEMA supported? */
 	bool		superuser;		/* must be superuser to install? */
 	bool		trusted;		/* allow becoming superuser on the fly? */
+	bool		protected;		/* should we protect extension by setting implicit
+								 * search_path for functions and procedures? */
 	int			encoding;		/* encoding of the script file, or -1 */
 	List	   *requires;		/* names of prerequisite extensions */
 	List	   *no_relocate;	/* names of prerequisite extensions that
@@ -117,7 +121,8 @@ static Oid	get_required_extension(char *reqExtensionName,
 								   char *origSchemaName,
 								   bool cascade,
 								   List *parents,
-								   bool is_create);
+								   bool is_create,
+								   bool set_search_path);
 static void get_available_versions_for_extension(ExtensionControlFile *pcontrol,
 												 Tuplestorestate *tupstore,
 												 TupleDesc tupdesc);
@@ -128,12 +133,30 @@ static void ApplyExtensionUpdates(Oid extensionOid,
 								  List *updateVersions,
 								  char *origSchemaName,
 								  bool cascade,
-								  bool is_create);
+								  bool is_create,
+								  bool set_search_path);
 static void ExecAlterExtensionContentsRecurse(AlterExtensionContentsStmt *stmt,
 											  ObjectAddress extension,
 											  ObjectAddress object);
 static char *read_whole_file(const char *filename, int *length);
 
+/*
+ * SetCurrentExtensionId - Set the current extension Oid.
+ */
+void
+SetCurrentExtensionId(Oid extensionOid)
+{
+	CurrentExtensionId = extensionOid;
+}
+
+/*
+ * GetCurrentExtensionId - Get the current extension Oid.
+ */
+Oid
+GetCurrentExtensionId()
+{
+	return CurrentExtensionId;
+}
 
 /*
  * get_extension_oid - given an extension name, look up the OID
@@ -585,6 +608,14 @@ parse_extension_control_file(ExtensionControlFile *control,
 						 errmsg("parameter \"%s\" requires a Boolean value",
 								item->name)));
 		}
+		else if (strcmp(item->name, "protected") == 0)
+		{
+			if (!parse_bool(item->value, &control->protected))
+				ereport(ERROR,
+						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+						 errmsg("parameter \"%s\" requires a Boolean value",
+								item->name)));
+		}
 		else if (strcmp(item->name, "encoding") == 0)
 		{
 			control->encoding = pg_valid_server_encoding(item->value);
@@ -871,7 +902,8 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control,
 						 const char *from_version,
 						 const char *version,
 						 List *requiredSchemas,
-						 const char *schemaName, Oid schemaOid)
+						 const char *schemaName, Oid schemaOid,
+						 bool set_search_path)
 {
 	bool		switch_to_superuser = false;
 	char	   *filename;
@@ -992,6 +1024,7 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control,
 	 */
 	creating_extension = true;
 	CurrentExtensionObject = extensionOid;
+	create_extension_set_search_path = set_search_path;
 	PG_TRY();
 	{
 		char	   *c_sql = read_extension_script_file(control, filename);
@@ -1116,6 +1149,7 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control,
 	{
 		creating_extension = false;
 		CurrentExtensionObject = InvalidOid;
+		create_extension_set_search_path = false;
 	}
 	PG_END_TRY();
 
@@ -1475,6 +1509,7 @@ CreateExtensionInternal(char *extensionName,
 	Oid			extensionOid;
 	ObjectAddress address;
 	ListCell   *lc;
+	bool		set_search_path = false;
 
 	/*
 	 * Read the primary control file.  Note we assume that it does not contain
@@ -1542,6 +1577,10 @@ CreateExtensionInternal(char *extensionName,
 	 */
 	control = read_extension_aux_control_file(pcontrol, versionName);
 
+	/* Check if this extension requires protection */
+	if (control->protected)
+		set_search_path = true;
+
 	/*
 	 * Determine the target schema to install the extension into
 	 */
@@ -1648,7 +1687,8 @@ CreateExtensionInternal(char *extensionName,
 										origSchemaName,
 										cascade,
 										parents,
-										is_create);
+										is_create,
+										set_search_path);
 		reqschema = get_extension_schema(reqext);
 		requiredExtensions = lappend_oid(requiredExtensions, reqext);
 		requiredSchemas = lappend_oid(requiredSchemas, reqschema);
@@ -1677,7 +1717,7 @@ CreateExtensionInternal(char *extensionName,
 	execute_extension_script(extensionOid, control,
 							 NULL, versionName,
 							 requiredSchemas,
-							 schemaName, schemaOid);
+							 schemaName, schemaOid, set_search_path);
 
 	/*
 	 * If additional update scripts have to be executed, apply the updates as
@@ -1685,7 +1725,7 @@ CreateExtensionInternal(char *extensionName,
 	 */
 	ApplyExtensionUpdates(extensionOid, pcontrol,
 						  versionName, updateVersions,
-						  origSchemaName, cascade, is_create);
+						  origSchemaName, cascade, is_create, set_search_path);
 
 	return address;
 }
@@ -1699,7 +1739,8 @@ get_required_extension(char *reqExtensionName,
 					   char *origSchemaName,
 					   bool cascade,
 					   List *parents,
-					   bool is_create)
+					   bool is_create,
+					   bool set_search_path)
 {
 	Oid			reqExtensionOid;
 
@@ -3115,7 +3156,7 @@ ExecAlterExtensionStmt(ParseState *pstate, AlterExtensionStmt *stmt)
 	 */
 	ApplyExtensionUpdates(extensionOid, control,
 						  oldVersionName, updateVersions,
-						  NULL, false, false);
+						  NULL, false, false, false);
 
 	ObjectAddressSet(address, ExtensionRelationId, extensionOid);
 
@@ -3137,7 +3178,8 @@ ApplyExtensionUpdates(Oid extensionOid,
 					  List *updateVersions,
 					  char *origSchemaName,
 					  bool cascade,
-					  bool is_create)
+					  bool is_create,
+					  bool set_search_path)
 {
 	const char *oldVersionName = initialVersion;
 	ListCell   *lcv;
@@ -3232,7 +3274,8 @@ ApplyExtensionUpdates(Oid extensionOid,
 											origSchemaName,
 											cascade,
 											NIL,
-											is_create);
+											is_create,
+											set_search_path);
 			reqschema = get_extension_schema(reqext);
 			requiredExtensions = lappend_oid(requiredExtensions, reqext);
 			requiredSchemas = lappend_oid(requiredSchemas, reqschema);
@@ -3269,7 +3312,7 @@ ApplyExtensionUpdates(Oid extensionOid,
 		execute_extension_script(extensionOid, control,
 								 oldVersionName, versionName,
 								 requiredSchemas,
-								 schemaName, schemaOid);
+								 schemaName, schemaOid, set_search_path);
 
 		/*
 		 * Update prior-version name and loop around.  Since
diff --git a/src/backend/commands/functioncmds.c b/src/backend/commands/functioncmds.c
index 6593fd7d81..79764f2996 100644
--- a/src/backend/commands/functioncmds.c
+++ b/src/backend/commands/functioncmds.c
@@ -52,6 +52,7 @@
 #include "executor/functions.h"
 #include "funcapi.h"
 #include "miscadmin.h"
+#include "nodes/makefuncs.h"
 #include "nodes/nodeFuncs.h"
 #include "optimizer/optimizer.h"
 #include "parser/analyze.h"
@@ -71,6 +72,7 @@
 #include "utils/snapmgr.h"
 #include "utils/syscache.h"
 #include "utils/typcache.h"
+#include "utils/varlena.h"
 
 /*
  *	 Examine the RETURNS clause of the CREATE FUNCTION statement
@@ -705,6 +707,25 @@ interpret_func_support(DefElem *defel)
 	return procOid;
 }
 
+/*
+ * Returns true if search_path is set in set_items list.
+ */
+static bool
+IsSearchPathSet(List *set_items)
+{
+	ListCell   *l;
+
+	foreach(l, set_items)
+	{
+		VariableSetStmt *sstmt = lfirst_node(VariableSetStmt, l);
+
+		if (pg_strcasecmp(sstmt->name, "search_path") == 0 &&
+			sstmt->kind == VAR_SET_VALUE)
+			return true;
+	}
+
+	return false;
+}
 
 /*
  * Dissect the list of options assembled in gram.y into function
@@ -726,7 +747,8 @@ compute_function_attributes(ParseState *pstate,
 							float4 *procost,
 							float4 *prorows,
 							Oid *prosupport,
-							char *parallel_p)
+							char *parallel_p,
+							Oid namespaceId)
 {
 	ListCell   *option;
 	DefElem    *as_item = NULL;
@@ -813,6 +835,53 @@ compute_function_attributes(ParseState *pstate,
 		*security_definer = boolVal(security_item->arg);
 	if (leakproof_item)
 		*leakproof_p = boolVal(leakproof_item->arg);
+
+	/*
+	 * If "create_extension_set_search_path" is enabled, it indicates that the
+	 * user has set "protected" flag inside the extension control file.
+	 * Therefore, we must ensure that the function(s) created by an extension
+	 * have their search_path set to trusted schema(s), which includes the
+	 * schema where the function is being created and the search_path set by the
+	 * extension. See execute_extension_script() for details on search_path set
+	 * by the extension.
+	 */
+	if (creating_extension && create_extension_set_search_path)
+	{
+		/* If the search_path is already set, there is nothing to do. */
+		if (!set_items || !IsSearchPathSet(set_items))
+		{
+			StringInfoData sp_string;
+			VariableSetStmt *sp_node = makeNode(VariableSetStmt);
+			List	 *schemaList;
+			ListCell *lc;
+
+			sp_node->kind = VAR_SET_VALUE;
+			sp_node->name = "search_path";
+
+			initStringInfo(&sp_string);
+
+			if (namespaceId != get_extension_schema(CurrentExtensionObject))
+			{
+				appendStringInfoString(&sp_string, get_namespace_name(namespaceId));
+				appendStringInfoString(&sp_string, ", ");
+			}
+			appendStringInfoString(&sp_string, "$extension_schema, pg_temp");
+
+			(void) SplitIdentifierString(sp_string.data, ',', &schemaList);
+
+			foreach(lc, schemaList)
+			{
+				char *schema_name = lfirst(lc);
+
+				sp_node->args = lappend(sp_node->args,
+										makeStringConst(pstrdup(schema_name), -1));
+			}
+
+			set_items = lappend(set_items, sp_node);
+			pfree(sp_string.data);
+		}
+	}
+
 	if (set_items)
 		*proconfig = update_proconfig_value(NULL, set_items);
 	if (cost_item)
@@ -1079,7 +1148,7 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt)
 								&isWindowFunc, &volatility,
 								&isStrict, &security, &isLeakProof,
 								&proconfig, &procost, &prorows,
-								&prosupport, &parallel);
+								&prosupport, &parallel, namespaceId);
 
 	if (!language)
 	{
diff --git a/src/backend/utils/fmgr/fmgr.c b/src/backend/utils/fmgr/fmgr.c
index e48a86be54..e2211c82f3 100644
--- a/src/backend/utils/fmgr/fmgr.c
+++ b/src/backend/utils/fmgr/fmgr.c
@@ -16,6 +16,8 @@
 #include "postgres.h"
 
 #include "access/detoast.h"
+#include "commands/extension.h"
+#include "catalog/dependency.h"
 #include "catalog/pg_language.h"
 #include "catalog/pg_proc.h"
 #include "catalog/pg_type.h"
@@ -641,6 +643,15 @@ fmgr_security_definer(PG_FUNCTION_ARGS)
 			   *lc3;
 	volatile int save_nestlevel;
 	PgStat_FunctionCallUsage fcusage;
+	Oid			extensionOid = InvalidOid;
+
+	/*
+	 * Let's check if this is an extension created function. If it is, we'll set
+	 * the CurrentExtensionId before calling it, so that preprocessNamespacePath
+	 * can handle $extension_schema correctly.
+	 */
+	extensionOid = getExtensionOfObject(ProcedureRelationId,
+										fcinfo->flinfo->fn_oid);
 
 	if (!fcinfo->flinfo->fn_extra)
 	{
@@ -737,6 +748,9 @@ fmgr_security_definer(PG_FUNCTION_ARGS)
 	 */
 	save_flinfo = fcinfo->flinfo;
 
+	if (OidIsValid(extensionOid))
+		SetCurrentExtensionId(extensionOid);
+
 	PG_TRY();
 	{
 		fcinfo->flinfo = &fcache->flinfo;
@@ -758,6 +772,7 @@ fmgr_security_definer(PG_FUNCTION_ARGS)
 	PG_CATCH();
 	{
 		fcinfo->flinfo = save_flinfo;
+		SetCurrentExtensionId(InvalidOid);
 		if (fmgr_hook)
 			(*fmgr_hook) (FHET_ABORT, &fcache->flinfo, &fcache->arg);
 		PG_RE_THROW();
@@ -765,6 +780,7 @@ fmgr_security_definer(PG_FUNCTION_ARGS)
 	PG_END_TRY();
 
 	fcinfo->flinfo = save_flinfo;
+	SetCurrentExtensionId(InvalidOid);
 
 	if (fcache->configNames != NIL)
 		AtEOXact_GUC(true, save_nestlevel);
diff --git a/src/include/catalog/dependency.h b/src/include/catalog/dependency.h
index 6908ca7180..1055c2f784 100644
--- a/src/include/catalog/dependency.h
+++ b/src/include/catalog/dependency.h
@@ -174,6 +174,7 @@ extern long changeDependenciesOn(Oid refClassId, Oid oldRefObjectId,
 
 extern Oid	getExtensionOfObject(Oid classId, Oid objectId);
 extern List *getAutoExtensionsOfObject(Oid classId, Oid objectId);
+extern List *getExtensionsOfExtension(Oid objectId);
 
 extern bool sequenceIsOwned(Oid seqId, char deptype, Oid *tableId, int32 *colId);
 extern List *getOwnedSequences(Oid relid);
diff --git a/src/include/commands/extension.h b/src/include/commands/extension.h
index c6f3f867eb..9512e8109c 100644
--- a/src/include/commands/extension.h
+++ b/src/include/commands/extension.h
@@ -29,6 +29,8 @@
  */
 extern PGDLLIMPORT bool creating_extension;
 extern PGDLLIMPORT Oid CurrentExtensionObject;
+extern PGDLLIMPORT bool create_extension_set_search_path;
+extern PGDLLIMPORT Oid CurrentExtensionId;
 
 
 extern ObjectAddress CreateExtension(ParseState *pstate, CreateExtensionStmt *stmt);
@@ -53,4 +55,7 @@ extern bool extension_file_exists(const char *extensionName);
 extern ObjectAddress AlterExtensionNamespace(const char *extensionName, const char *newschema,
 											 Oid *oldschema);
 
+extern void SetCurrentExtensionId(Oid extensionOid);
+extern Oid GetCurrentExtensionId(void);
+
 #endif							/* EXTENSION_H */
-- 
2.17.1

