From c958582ec89ec065101f61a9b61df4cd800cb3e9 Mon Sep 17 00:00:00 2001
From: Dave Cramer <davecramer@gmail.com>
Date: Sun, 17 Jul 2022 21:40:32 -0400
Subject: [PATCH] add format_binary

fix error

working version of setting binary oids

rename binary_formats, comments

added error checking in check_format_binary

allocate the correct amount of memory in permanent storage
---
 src/backend/tcop/postgres.c      |  2 +
 src/backend/tcop/pquery.c        | 54 ++++++++++++++++---
 src/backend/utils/init/globals.c |  1 +
 src/backend/utils/misc/guc.c     | 91 ++++++++++++++++++++++++++++++++
 src/include/miscadmin.h          |  1 +
 5 files changed, 142 insertions(+), 7 deletions(-)

diff --git a/src/backend/tcop/postgres.c b/src/backend/tcop/postgres.c
index 7bec4e4ff5..8a47bc2733 100644
--- a/src/backend/tcop/postgres.c
+++ b/src/backend/tcop/postgres.c
@@ -97,6 +97,8 @@ int			PostAuthDelay = 0;
 /* Time between checks that the client is still connected. */
 int			client_connection_check_interval = 0;
 
+Oid			*binary_format_oids = NULL;
+
 /* ----------------
  *		private typedefs etc
  * ----------------
diff --git a/src/backend/tcop/pquery.c b/src/backend/tcop/pquery.c
index 5aa5a350f3..e11f55b36e 100644
--- a/src/backend/tcop/pquery.c
+++ b/src/backend/tcop/pquery.c
@@ -58,6 +58,8 @@ static uint64 DoPortalRunFetch(Portal portal,
 							   long count,
 							   DestReceiver *dest);
 static void DoPortalRewind(Portal portal);
+static int findOid( Oid *binary_oids, Oid oid);
+extern Oid  *binary_format_oids;
 
 
 /*
@@ -644,18 +646,56 @@ PortalSetResultFormat(Portal portal, int nFormats, int16 *formats)
 	}
 	else if (nFormats > 0)
 	{
-		/* single format specified, use for all columns */
-		int16		format1 = formats[0];
+		// The client has requested binary formats for some types
+		if ( binary_format_oids != NULL )
+		{
+			Oid targetOid;
 
-		for (i = 0; i < natts; i++)
-			portal->formats[i] = format1;
+			for (i = 0; i < natts; i++){
+				targetOid = portal->tupDesc->attrs[i].atttypid;
+				portal->formats[i] = findOid(binary_format_oids, targetOid);
+			}
+		}
+		else
+		{
+			/* single format specified, use for all columns */
+			int16		format1 = formats[0];
+
+			for (i = 0; i < natts; i++)
+				portal->formats[i] = format1;
+		}
 	}
 	else
 	{
-		/* use default format for all columns */
-		for (i = 0; i < natts; i++)
-			portal->formats[i] = 0;
+		if ( binary_format_oids != NULL )
+		{
+			Oid targetOid;
+
+			for (i = 0; i < natts; i++){
+				targetOid = portal->tupDesc->attrs[i].atttypid;
+				portal->formats[i] = findOid(binary_format_oids, targetOid);
+			}
+		}
+		else {
+			/* use default format for all columns */
+			for (i = 0; i < natts; i++)
+				portal->formats[i] = 0;
+		}
+	}
+}
+
+/*
+* Linear search through the array of oids.
+* I don't expect this to ever be a large array
+*/
+static int findOid( Oid *binary_oids, Oid oid)
+{
+	Oid *tmp = binary_oids;
+	while (tmp && *tmp != InvalidOid)
+	{
+		if (*tmp++ == oid) return 1;
 	}
+	return 0;
 }
 
 /*
diff --git a/src/backend/utils/init/globals.c b/src/backend/utils/init/globals.c
index 1a5d29ac9b..31188462ec 100644
--- a/src/backend/utils/init/globals.c
+++ b/src/backend/utils/init/globals.c
@@ -123,6 +123,7 @@ int			IntervalStyle = INTSTYLE_POSTGRES;
 bool		enableFsync = true;
 bool		allowSystemTableMods = false;
 int			work_mem = 4096;
+char         *format_binary = NULL;
 double		hash_mem_multiplier = 2.0;
 int			maintenance_work_mem = 65536;
 int			max_parallel_maintenance_workers = 2;
diff --git a/src/backend/utils/misc/guc.c b/src/backend/utils/misc/guc.c
index c336698ad5..beb91c5664 100644
--- a/src/backend/utils/misc/guc.c
+++ b/src/backend/utils/misc/guc.c
@@ -144,6 +144,7 @@ extern char *temp_tablespaces;
 extern bool ignore_checksum_failure;
 extern bool ignore_invalid_pages;
 extern bool synchronize_seqscans;
+extern Oid  *binary_format_oids;
 
 #ifdef TRACE_SYNCSCAN
 extern bool trace_syncscan;
@@ -243,6 +244,8 @@ static bool check_recovery_target_lsn(char **newval, void **extra, GucSource sou
 static void assign_recovery_target_lsn(const char *newval, void *extra);
 static bool check_primary_slot_name(char **newval, void **extra, GucSource source);
 static bool check_default_with_oids(bool *newval, void **extra, GucSource source);
+static bool check_format_binary(char **newval, void **extra, GucSource source);
+static void assign_format_binary(const char*newval, void *extra);
 
 /*
  * Track whether there were any deferred checks for custom resource managers
@@ -4295,6 +4298,17 @@ static struct config_string ConfigureNamesString[] =
 		"",
 		NULL, NULL, NULL
 	},
+	{
+		{"format_binary", PGC_USERSET, CLIENT_CONN_STATEMENT,
+			gettext_noop("Sets the type Oid's to be returned in binary format"),
+			gettext_noop("Set by the client to indicate which types are to be "
+						 "returned in binary format. "),
+			GUC_NOT_IN_SAMPLE | GUC_DISALLOW_IN_FILE
+		},
+		&format_binary,
+		"",
+		check_format_binary, assign_format_binary, NULL
+	},
 
 	{
 		{"search_path", PGC_USERSET, CLIENT_CONN_STATEMENT,
@@ -13333,3 +13347,80 @@ check_default_with_oids(bool *newval, void **extra, GucSource source)
 
 	return true;
 }
+
+static bool
+check_format_binary( char **newval, void **extra, GucSource source)
+{
+	// sanity check
+	if (*newval == NULL)
+		return false;
+	
+	if (strcmp(*newval,"") == 0)
+		return true;
+
+	char *tmp = palloc(strlen(*newval));
+	strcpy(tmp, *newval);
+	char *token = strtok(tmp, ",");
+
+	while(token != NULL)
+	{
+		Oid candidate = atooid(token);
+		if (candidate > OID_MAX) 
+			GUC_check_errdetail("OID out of range found in %s, %s", *newval, token);
+
+		// atooid will return 0 aka InvalidOid if it can't convert the string or 0 
+		// if it's really 0	
+		if (candidate == InvalidOid)
+		{	
+			if (errno == EINVAL)
+				GUC_check_errdetail("%s has invalid characters at %s",
+					*newval, token);
+			else
+				GUC_check_errdetail("InvalidOid (0) found in %s", *newval);
+			return false;
+		}
+		else
+			token = strtok(NULL, ",");
+	}
+	return true;
+}
+
+static void
+assign_format_binary(const char *newval, void *extra)
+{
+	// check for errors or nothing to do
+	if (newval == NULL || strcmp(newval, "") == 0)
+		return;
+
+	char *tmp = palloc(strlen(newval));
+	strcpy(tmp, newval);
+
+	/* Must save OID list in permanent storage. */
+	MemoryContext oldcxt = MemoryContextSwitchTo(TopMemoryContext);
+
+	// unlikely to have more than 16
+	int length = 16;
+	// +1 for the InvalidOid marker at the end
+	Oid *tmpOids = palloc(sizeof(Oid)*(length+1));
+	int i = 0;
+	
+	char *token = strtok(tmp, ",");
+
+	while(token != NULL)
+	{
+		tmpOids[i++] = atooid(token);
+		if (i > length)
+		{
+			length += 16;
+			tmpOids = repalloc(tmpOids, sizeof(Oid)*(length+1));
+		}
+		token = strtok(NULL, ",");
+	}
+	tmpOids[i] = InvalidOid;
+	binary_format_oids = tmpOids;
+	MemoryContextSwitchTo(oldcxt);
+
+}
+
+#include "guc-file.c"
+
diff --git a/src/include/miscadmin.h b/src/include/miscadmin.h
index 65cf4ba50f..0f9fc3562c 100644
--- a/src/include/miscadmin.h
+++ b/src/include/miscadmin.h
@@ -275,6 +275,7 @@ extern PGDLLIMPORT int64 VacuumPageDirty;
 
 extern PGDLLIMPORT int VacuumCostBalance;
 extern PGDLLIMPORT bool VacuumCostActive;
+extern PGDLLIMPORT char *format_binary;
 
 
 /* in tcop/postgres.c */
-- 
2.32.1 (Apple Git-133)

