From 32ddafea48d639e5fc5f82d81ff306979c300539 Mon Sep 17 00:00:00 2001
From: Nikita Malakhov <n.malakhov@postgrespro.ru>
Date: Mon, 30 Oct 2023 06:26:56 +0300
Subject: [PATCH] WIP patch - introducing 'proerrsafe' attribute into PG_PROC

Attribute 'proerrsafe' added to PG_PROC relation, to differ procedures
which have error-safe behavior (for use in SQL/JSON patches), with
support functions and passing from create/alter procedure clauses.
---
 src/backend/catalog/pg_aggregate.c  |  1 +
 src/backend/catalog/pg_proc.c       |  2 ++
 src/backend/commands/functioncmds.c | 31 +++++++++++++++++++---
 src/backend/commands/typecmds.c     |  4 +++
 src/backend/utils/cache/lsyscache.c | 40 +++++++++++++++++++++++++++++
 src/include/catalog/pg_proc.h       |  4 +++
 src/include/utils/lsyscache.h       |  4 +++
 7 files changed, 82 insertions(+), 4 deletions(-)

diff --git a/src/backend/catalog/pg_aggregate.c b/src/backend/catalog/pg_aggregate.c
index ebc4454743..4beb2e8c41 100644
--- a/src/backend/catalog/pg_aggregate.c
+++ b/src/backend/catalog/pg_aggregate.c
@@ -628,6 +628,7 @@ AggregateCreate(const char *aggName,
 									 * definable for agg) */
 							 false, /* isLeakProof */
 							 false, /* isStrict (not needed for agg) */
+							 false, /* isErrorSafe */
 							 PROVOLATILE_IMMUTABLE, /* volatility (not needed
 													 * for agg) */
 							 proparallel,
diff --git a/src/backend/catalog/pg_proc.c b/src/backend/catalog/pg_proc.c
index b5fd364003..4681d01c21 100644
--- a/src/backend/catalog/pg_proc.c
+++ b/src/backend/catalog/pg_proc.c
@@ -84,6 +84,7 @@ ProcedureCreate(const char *procedureName,
 				bool security_definer,
 				bool isLeakProof,
 				bool isStrict,
+				bool isErrorSafe,
 				char volatility,
 				char parallel,
 				oidvector *parameterTypes,
@@ -311,6 +312,7 @@ ProcedureCreate(const char *procedureName,
 	values[Anum_pg_proc_prosecdef - 1] = BoolGetDatum(security_definer);
 	values[Anum_pg_proc_proleakproof - 1] = BoolGetDatum(isLeakProof);
 	values[Anum_pg_proc_proisstrict - 1] = BoolGetDatum(isStrict);
+	values[Anum_pg_proc_proerrsafe - 1] = BoolGetDatum(isErrorSafe);
 	values[Anum_pg_proc_proretset - 1] = BoolGetDatum(returnsSet);
 	values[Anum_pg_proc_provolatile - 1] = CharGetDatum(volatility);
 	values[Anum_pg_proc_proparallel - 1] = CharGetDatum(parallel);
diff --git a/src/backend/commands/functioncmds.c b/src/backend/commands/functioncmds.c
index 7ba6a86ebe..3037ca9db3 100644
--- a/src/backend/commands/functioncmds.c
+++ b/src/backend/commands/functioncmds.c
@@ -513,7 +513,8 @@ compute_common_attribute(ParseState *pstate,
 						 DefElem **cost_item,
 						 DefElem **rows_item,
 						 DefElem **support_item,
-						 DefElem **parallel_item)
+						 DefElem **parallel_item,
+						 DefElem **errsafe_item)
 {
 	if (strcmp(defel->defname, "volatility") == 0)
 	{
@@ -589,6 +590,15 @@ compute_common_attribute(ParseState *pstate,
 
 		*parallel_item = defel;
 	}
+	else if (strcmp(defel->defname, "errorsafe") == 0)
+	{
+		if (is_procedure)
+			goto procedure_error;
+		if (*errsafe_item)
+			errorConflictingDefElem(defel, pstate);
+
+		*errsafe_item = defel;
+	}
 	else
 		return false;
 
@@ -727,6 +737,7 @@ compute_function_attributes(ParseState *pstate,
 							bool *strict_p,
 							bool *security_definer,
 							bool *leakproof_p,
+							bool *errsafe_p,
 							ArrayType **proconfig,
 							float4 *procost,
 							float4 *prorows,
@@ -747,6 +758,7 @@ compute_function_attributes(ParseState *pstate,
 	DefElem    *rows_item = NULL;
 	DefElem    *support_item = NULL;
 	DefElem    *parallel_item = NULL;
+	DefElem    *errsafe_item = NULL;
 
 	foreach(option, options)
 	{
@@ -792,7 +804,8 @@ compute_function_attributes(ParseState *pstate,
 										  &cost_item,
 										  &rows_item,
 										  &support_item,
-										  &parallel_item))
+										  &parallel_item,
+										  &errsafe_item))
 		{
 			/* recognized common option */
 			continue;
@@ -814,6 +827,8 @@ compute_function_attributes(ParseState *pstate,
 		*volatility_p = interpret_func_volatility(volatility_item);
 	if (strict_item)
 		*strict_p = boolVal(strict_item->arg);
+	if (errsafe_item)
+		*errsafe_p = boolVal(errsafe_item->arg);
 	if (security_item)
 		*security_definer = boolVal(security_item->arg);
 	if (leakproof_item)
@@ -1041,7 +1056,8 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt)
 	bool		isWindowFunc,
 				isStrict,
 				security,
-				isLeakProof;
+				isLeakProof,
+				isErrorSafe;
 	char		volatility;
 	ArrayType  *proconfig;
 	float4		procost;
@@ -1067,6 +1083,7 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt)
 	language = NULL;
 	isWindowFunc = false;
 	isStrict = false;
+	isErrorSafe = false;
 	security = false;
 	isLeakProof = false;
 	volatility = PROVOLATILE_VOLATILE;
@@ -1083,6 +1100,7 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt)
 								&as_clause, &language, &transformDefElem,
 								&isWindowFunc, &volatility,
 								&isStrict, &security, &isLeakProof,
+								&isErrorSafe,
 								&proconfig, &procost, &prorows,
 								&prosupport, &parallel);
 
@@ -1274,6 +1292,7 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt)
 						   security,
 						   isLeakProof,
 						   isStrict,
+							isErrorSafe,
 						   volatility,
 						   parallel,
 						   parameterTypes,
@@ -1362,6 +1381,7 @@ AlterFunction(ParseState *pstate, AlterFunctionStmt *stmt)
 	DefElem    *rows_item = NULL;
 	DefElem    *support_item = NULL;
 	DefElem    *parallel_item = NULL;
+	DefElem    *errsafe_item = NULL;
 	ObjectAddress address;
 
 	rel = table_open(ProcedureRelationId, RowExclusiveLock);
@@ -1405,7 +1425,8 @@ AlterFunction(ParseState *pstate, AlterFunctionStmt *stmt)
 									 &cost_item,
 									 &rows_item,
 									 &support_item,
-									 &parallel_item) == false)
+									 &parallel_item,
+									 &errsafe_item) == false)
 			elog(ERROR, "option \"%s\" not recognized", defel->defname);
 	}
 
@@ -1413,6 +1434,8 @@ AlterFunction(ParseState *pstate, AlterFunctionStmt *stmt)
 		procForm->provolatile = interpret_func_volatility(volatility_item);
 	if (strict_item)
 		procForm->proisstrict = boolVal(strict_item->arg);
+	if (errsafe_item)
+		procForm->proerrsafe = boolVal(errsafe_item->arg);
 	if (security_def_item)
 		procForm->prosecdef = boolVal(security_def_item->arg);
 	if (leakproof_item)
diff --git a/src/backend/commands/typecmds.c b/src/backend/commands/typecmds.c
index 5e97606793..0768162671 100644
--- a/src/backend/commands/typecmds.c
+++ b/src/backend/commands/typecmds.c
@@ -1766,6 +1766,7 @@ makeRangeConstructors(const char *name, Oid namespace,
 								 false, /* security_definer */
 								 false, /* leakproof */
 								 false, /* isStrict */
+								 false, /* isErrorSafe */
 								 PROVOLATILE_IMMUTABLE, /* volatility */
 								 PROPARALLEL_SAFE,	/* parallel safety */
 								 constructorArgTypesVector, /* parameterTypes */
@@ -1831,6 +1832,7 @@ makeMultirangeConstructors(const char *name, Oid namespace,
 							 false, /* security_definer */
 							 false, /* leakproof */
 							 true,	/* isStrict */
+							 false, /* isErrorSafe */
 							 PROVOLATILE_IMMUTABLE, /* volatility */
 							 PROPARALLEL_SAFE,	/* parallel safety */
 							 argtypes,	/* parameterTypes */
@@ -1875,6 +1877,7 @@ makeMultirangeConstructors(const char *name, Oid namespace,
 							 false, /* security_definer */
 							 false, /* leakproof */
 							 true,	/* isStrict */
+							 false, /* isErrorSafe */
 							 PROVOLATILE_IMMUTABLE, /* volatility */
 							 PROPARALLEL_SAFE,	/* parallel safety */
 							 argtypes,	/* parameterTypes */
@@ -1913,6 +1916,7 @@ makeMultirangeConstructors(const char *name, Oid namespace,
 							 false, /* security_definer */
 							 false, /* leakproof */
 							 true,	/* isStrict */
+							 false, /* isErrorSafe */
 							 PROVOLATILE_IMMUTABLE, /* volatility */
 							 PROPARALLEL_SAFE,	/* parallel safety */
 							 argtypes,	/* parameterTypes */
diff --git a/src/backend/utils/cache/lsyscache.c b/src/backend/utils/cache/lsyscache.c
index fc6d267e44..3fef811759 100644
--- a/src/backend/utils/cache/lsyscache.c
+++ b/src/backend/utils/cache/lsyscache.c
@@ -2494,6 +2494,46 @@ get_typdefault(Oid typid)
 	return expr;
 }
 
+bool
+procIsErrorSafe(Oid funcid)
+{
+	HeapTuple	proctup;
+	Form_pg_proc procform;
+	bool res = false;
+
+	proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(funcid));
+	if (!HeapTupleIsValid(proctup))
+		elog(ERROR, "cache lookup failed for function %u", funcid);
+	procform = (Form_pg_proc) GETSTRUCT(proctup);
+
+	if (procform->proerrsafe)
+		res = true;
+
+	ReleaseSysCache(proctup);
+	return res;
+}
+
+/*
+ * isTypeinErrorSafe
+ *		Check if type input is error safe
+  */
+bool
+isTypeinErrorSafe(Oid typid)
+{
+	HeapTuple	tup;
+	Form_pg_type typTup;
+	bool res = false;
+
+	tup = SearchSysCache1(TYPEOID, ObjectIdGetDatum(typid));
+	if (!HeapTupleIsValid(tup))
+		elog(ERROR, "cache lookup failed for type %u", typid);
+	typTup = (Form_pg_type) GETSTRUCT(tup);
+	res = procIsErrorSafe((Oid) typTup->typinput);
+	ReleaseSysCache(tup);
+
+	return res;
+}
+
 /*
  * getBaseType
  *		If the given type is a domain, return its base type;
diff --git a/src/include/catalog/pg_proc.h b/src/include/catalog/pg_proc.h
index fdb39d4001..dbae644c23 100644
--- a/src/include/catalog/pg_proc.h
+++ b/src/include/catalog/pg_proc.h
@@ -70,6 +70,9 @@ CATALOG(pg_proc,1255,ProcedureRelationId) BKI_BOOTSTRAP BKI_ROWTYPE_OID(81,Proce
 	/* returns a set? */
 	bool		proretset BKI_DEFAULT(f);
 
+	/* is procedure error safe? */
+	bool		proerrsafe BKI_DEFAULT(f);
+
 	/* see PROVOLATILE_ categories below */
 	char		provolatile BKI_DEFAULT(i);
 
@@ -200,6 +203,7 @@ extern ObjectAddress ProcedureCreate(const char *procedureName,
 									 bool security_definer,
 									 bool isLeakProof,
 									 bool isStrict,
+									 bool isErrorSafe,
 									 char volatility,
 									 char parallel,
 									 oidvector *parameterTypes,
diff --git a/src/include/utils/lsyscache.h b/src/include/utils/lsyscache.h
index f5fdbfe116..e384fe6b5f 100644
--- a/src/include/utils/lsyscache.h
+++ b/src/include/utils/lsyscache.h
@@ -65,6 +65,8 @@ typedef struct AttStatsSlot
 typedef int32 (*get_attavgwidth_hook_type) (Oid relid, AttrNumber attnum);
 extern PGDLLIMPORT get_attavgwidth_hook_type get_attavgwidth_hook;
 
+extern bool procIsErrorSafe(Oid funcid);
+
 extern bool op_in_opfamily(Oid opno, Oid opfamily);
 extern int	get_op_opfamily_strategy(Oid opno, Oid opfamily);
 extern Oid	get_op_opfamily_sortfamily(Oid opno, Oid opfamily);
@@ -182,6 +184,8 @@ extern bool type_is_collatable(Oid typid);
 extern RegProcedure get_typsubscript(Oid typid, Oid *typelemp);
 extern const struct SubscriptRoutines *getSubscriptingRoutines(Oid typid,
 															   Oid *typelemp);
+
+extern bool isTypeinErrorSafe(Oid typid);
 extern Oid	getBaseType(Oid typid);
 extern Oid	getBaseTypeAndTypmod(Oid typid, int32 *typmod);
 extern int32 get_typavgwidth(Oid typid, int32 typmod);
-- 
2.25.1

