On Tue, Aug 16, 2022 at 2:02 AM Drouvot, Bertrand <bdrou...@amazon.com> wrote:
> On 8/14/22 11:57 AM, Michael Paquier wrote:
> >    One thing was itching me about the serialization and
> > deserialization logic though: could it be more readable if we used an
> > intermediate structure to store the length of the serialized strings?
> > We use this approach in other areas, like for the snapshot data in
> > snapmgr.c.  This would handle the case of an empty and NULL string, by
> > storing -1 as length for NULL and >= 0 for the string length if there
> > is something set, while making the addition of more fields a
> > no-brainer.
>
> I think that's a good idea and I think that would be more readable (as
> compare to storing a "hint" in the first byte).

Sounds good. v3, attached, should make the requested changes:
- declare `struct ClientConnectionInfo`
- use an intermediate serialization struct
- switch to length-"prefixing" for the string

I do like the way this reads compared to before.

Thanks,
--Jacob
commit 753c46352adc967a903a60ea65a3068252d685e6
Author: Jacob Champion <jchamp...@timescale.com>
Date:   Tue Aug 16 09:14:58 2022 -0700

    squash! Allow parallel workers to read authn_id
    
    Per review,
    - add an intermediate struct for serialization,
    - switch to length-prefixing for the authn_id string, and
    - make sure `struct ClientConnectionInfo` is declared for use elsewhere.

diff --git a/src/backend/utils/init/miscinit.c 
b/src/backend/utils/init/miscinit.c
index 155ba92c67..58772d0a4a 100644
--- a/src/backend/utils/init/miscinit.c
+++ b/src/backend/utils/init/miscinit.c
@@ -943,19 +943,29 @@ GetUserNameFromId(Oid roleid, bool noerr)
 
 ClientConnectionInfo MyClientConnectionInfo;
 
+/*
+ * Intermediate representation of ClientConnectionInfo for easier 
serialization.
+ * Variable-length fields are allocated right after this header.
+ */
+typedef struct SerializedClientConnectionInfo
+{
+       int32           authn_id_len; /* strlen(authn_id), or -1 if NULL */
+       UserAuth        auth_method;
+} SerializedClientConnectionInfo;
+
 /*
  * Calculate the space needed to serialize MyClientConnectionInfo.
  */
 Size
 EstimateClientConnectionInfoSpace(void)
 {
-       Size            size = 1;
+       Size            size = 0;
+
+       size = add_size(size, sizeof(SerializedClientConnectionInfo));
 
        if (MyClientConnectionInfo.authn_id)
                size = add_size(size, strlen(MyClientConnectionInfo.authn_id) + 
1);
 
-       size = add_size(size, sizeof(UserAuth));
-
        return size;
 }
 
@@ -965,32 +975,29 @@ EstimateClientConnectionInfoSpace(void)
 void
 SerializeClientConnectionInfo(Size maxsize, char *start_address)
 {
-       /*
-        * First byte is an indication of whether or not authn_id has been set 
to
-        * non-NULL, to differentiate that case from the empty string.
-        */
-       Assert(maxsize > 0);
-       start_address[0] = MyClientConnectionInfo.authn_id ? 1 : 0;
-       start_address++;
-       maxsize--;
+       SerializedClientConnectionInfo serialized = {0};
+
+       serialized.authn_id_len = -1;
+       serialized.auth_method = MyClientConnectionInfo.auth_method;
 
        if (MyClientConnectionInfo.authn_id)
-       {
-               Size len;
+               serialized.authn_id_len = 
strlen(MyClientConnectionInfo.authn_id);
 
-               len = strlcpy(start_address, MyClientConnectionInfo.authn_id, 
maxsize) + 1;
-               Assert(len <= maxsize);
-               maxsize -= len;
-               start_address += len;
-       }
+       /* Copy serialized representation to buffer */
+       Assert(maxsize >= sizeof(serialized));
+       memcpy(start_address, &serialized, sizeof(serialized));
 
-       {
-               UserAuth           *auth_method = (UserAuth*) start_address;
+       maxsize -= sizeof(serialized);
+       start_address += sizeof(serialized);
 
-               Assert(sizeof(*auth_method) <= maxsize);
-               *auth_method = MyClientConnectionInfo.auth_method;
-               maxsize -= sizeof(*auth_method);
-               start_address += sizeof(*auth_method);
+       /* Copy authn_id into the space after the struct. */
+       if (serialized.authn_id_len >= 0)
+       {
+               Assert(maxsize >= (serialized.authn_id_len + 1));
+               memcpy(start_address,
+                          MyClientConnectionInfo.authn_id,
+                          /* include the NULL terminator to ease 
deserialization */
+                          serialized.authn_id_len + 1);
        }
 }
 
@@ -1000,25 +1007,19 @@ SerializeClientConnectionInfo(Size maxsize, char 
*start_address)
 void
 RestoreClientConnectionInfo(char *conninfo)
 {
-       if (conninfo[0] == 0)
-       {
-               MyClientConnectionInfo.authn_id = NULL;
-               conninfo++;
-       }
-       else
-       {
-               conninfo++;
-               MyClientConnectionInfo.authn_id = 
MemoryContextStrdup(TopMemoryContext,
-                                                                               
                                          conninfo);
-               conninfo += strlen(conninfo) + 1;
-       }
+       SerializedClientConnectionInfo serialized;
+       char       *authn_id;
 
-       {
-               UserAuth           *auth_method = (UserAuth*) conninfo;
+       memcpy(&serialized, conninfo, sizeof(serialized));
+       authn_id = conninfo + sizeof(serialized);
 
-               MyClientConnectionInfo.auth_method = *auth_method;
-               conninfo += sizeof(*auth_method);
-       }
+       /* Copy the fields back into place. */
+       MyClientConnectionInfo.authn_id = NULL;
+       MyClientConnectionInfo.auth_method = serialized.auth_method;
+
+       if (serialized.authn_id_len >= 0)
+               MyClientConnectionInfo.authn_id = 
MemoryContextStrdup(TopMemoryContext,
+                                                                               
                                          authn_id);
 }
 
 
diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h
index 0643733765..84a6bdea6f 100644
--- a/src/include/libpq/libpq-be.h
+++ b/src/include/libpq/libpq-be.h
@@ -107,7 +107,7 @@ typedef struct
  * If you add a struct member here, remember to also handle serialization in
  * SerializeClientConnectionInfo() et al.
  */
-typedef struct
+typedef struct ClientConnectionInfo
 {
        /*
         * Authenticated identity.  The meaning of this identifier is dependent 
on
From 2eea3ef097bbeee5323f78c827e56b42480b5c81 Mon Sep 17 00:00:00 2001
From: Jacob Champion <pchamp...@vmware.com>
Date: Wed, 23 Mar 2022 15:07:05 -0700
Subject: [PATCH v3 1/3] Allow parallel workers to read authn_id

Move authn_id into a new global, MyClientConnectionInfo, which is
intended to hold all the client information that needs to be shared
between the backend and any parallel workers. MyClientConnectionInfo is
serialized and restored using a new parallel key.

Additionally, make a copy of hba->auth_method in ClientConnectionInfo
when set_authn_id() is called, for use by SYSTEM_USER.
---
 src/backend/access/transam/parallel.c | 19 +++++-
 src/backend/libpq/auth.c              | 25 ++++----
 src/backend/utils/init/miscinit.c     | 91 +++++++++++++++++++++++++++
 src/include/libpq/libpq-be.h          | 45 +++++++++----
 src/include/miscadmin.h               |  4 ++
 5 files changed, 159 insertions(+), 25 deletions(-)

diff --git a/src/backend/access/transam/parallel.c b/src/backend/access/transam/parallel.c
index df0cd77558..bc93101ff7 100644
--- a/src/backend/access/transam/parallel.c
+++ b/src/backend/access/transam/parallel.c
@@ -76,6 +76,7 @@
 #define PARALLEL_KEY_REINDEX_STATE			UINT64CONST(0xFFFFFFFFFFFF000C)
 #define PARALLEL_KEY_RELMAPPER_STATE		UINT64CONST(0xFFFFFFFFFFFF000D)
 #define PARALLEL_KEY_UNCOMMITTEDENUMS		UINT64CONST(0xFFFFFFFFFFFF000E)
+#define PARALLEL_KEY_CLIENTCONNINFO			UINT64CONST(0xFFFFFFFFFFFF000F)
 
 /* Fixed-size parallel state. */
 typedef struct FixedParallelState
@@ -212,6 +213,7 @@ InitializeParallelDSM(ParallelContext *pcxt)
 	Size		reindexlen = 0;
 	Size		relmapperlen = 0;
 	Size		uncommittedenumslen = 0;
+	Size		clientconninfolen = 0;
 	Size		segsize = 0;
 	int			i;
 	FixedParallelState *fps;
@@ -272,8 +274,10 @@ InitializeParallelDSM(ParallelContext *pcxt)
 		shm_toc_estimate_chunk(&pcxt->estimator, relmapperlen);
 		uncommittedenumslen = EstimateUncommittedEnumsSpace();
 		shm_toc_estimate_chunk(&pcxt->estimator, uncommittedenumslen);
+		clientconninfolen = EstimateClientConnectionInfoSpace();
+		shm_toc_estimate_chunk(&pcxt->estimator, clientconninfolen);
 		/* If you add more chunks here, you probably need to add keys. */
-		shm_toc_estimate_keys(&pcxt->estimator, 11);
+		shm_toc_estimate_keys(&pcxt->estimator, 12);
 
 		/* Estimate space need for error queues. */
 		StaticAssertStmt(BUFFERALIGN(PARALLEL_ERROR_QUEUE_SIZE) ==
@@ -352,6 +356,7 @@ InitializeParallelDSM(ParallelContext *pcxt)
 		char	   *session_dsm_handle_space;
 		char	   *entrypointstate;
 		char	   *uncommittedenumsspace;
+		char	   *clientconninfospace;
 		Size		lnamelen;
 
 		/* Serialize shared libraries we have loaded. */
@@ -422,6 +427,12 @@ InitializeParallelDSM(ParallelContext *pcxt)
 		shm_toc_insert(pcxt->toc, PARALLEL_KEY_UNCOMMITTEDENUMS,
 					   uncommittedenumsspace);
 
+		/* Serialize our ClientConnectionInfo. */
+		clientconninfospace = shm_toc_allocate(pcxt->toc, clientconninfolen);
+		SerializeClientConnectionInfo(clientconninfolen, clientconninfospace);
+		shm_toc_insert(pcxt->toc, PARALLEL_KEY_CLIENTCONNINFO,
+					   clientconninfospace);
+
 		/* Allocate space for worker information. */
 		pcxt->worker = palloc0(sizeof(ParallelWorkerInfo) * pcxt->nworkers);
 
@@ -1270,6 +1281,7 @@ ParallelWorkerMain(Datum main_arg)
 	char	   *reindexspace;
 	char	   *relmapperspace;
 	char	   *uncommittedenumsspace;
+	char	   *clientconninfospace;
 	StringInfoData msgbuf;
 	char	   *session_dsm_handle_space;
 	Snapshot	tsnapshot;
@@ -1479,6 +1491,11 @@ ParallelWorkerMain(Datum main_arg)
 										   false);
 	RestoreUncommittedEnums(uncommittedenumsspace);
 
+	/* Restore the ClientConnectionInfo. */
+	clientconninfospace = shm_toc_lookup(toc, PARALLEL_KEY_CLIENTCONNINFO,
+										 false);
+	RestoreClientConnectionInfo(clientconninfospace);
+
 	/* Attach to the leader's serializable transaction, if SERIALIZABLE. */
 	AttachSerializableXact(fps->serializable_xact_handle);
 
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 2d9ab7edce..9113f04189 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -333,24 +333,24 @@ auth_failed(Port *port, int status, const char *logdetail)
 
 
 /*
- * Sets the authenticated identity for the current user.  The provided string
- * will be copied into the TopMemoryContext.  The ID will be logged if
- * log_connections is enabled.
+ * Sets the authenticated identity for the current user. The provided string
+ * will be stored into MyClientConnectionInfo, alongside the current HBA method
+ * in use. The ID will be logged if log_connections is enabled.
  *
  * Auth methods should call this routine exactly once, as soon as the user is
  * successfully authenticated, even if they have reasons to know that
  * authorization will fail later.
  *
  * The provided string will be copied into TopMemoryContext, to match the
- * lifetime of the Port, so it is safe to pass a string that is managed by an
- * external library.
+ * lifetime of MyClientConnectionInfo, so it is safe to pass a string that is
+ * managed by an external library.
  */
 static void
 set_authn_id(Port *port, const char *id)
 {
 	Assert(id);
 
-	if (port->authn_id)
+	if (MyClientConnectionInfo.authn_id)
 	{
 		/*
 		 * An existing authn_id should never be overwritten; that means two
@@ -361,18 +361,20 @@ set_authn_id(Port *port, const char *id)
 		ereport(FATAL,
 				(errmsg("authentication identifier set more than once"),
 				 errdetail_log("previous identifier: \"%s\"; new identifier: \"%s\"",
-							   port->authn_id, id)));
+							   MyClientConnectionInfo.authn_id, id)));
 	}
 
-	port->authn_id = MemoryContextStrdup(TopMemoryContext, id);
+	MyClientConnectionInfo.authn_id = MemoryContextStrdup(TopMemoryContext, id);
+	MyClientConnectionInfo.auth_method = port->hba->auth_method;
 
 	if (Log_connections)
 	{
 		ereport(LOG,
 				errmsg("connection authenticated: identity=\"%s\" method=%s "
 					   "(%s:%d)",
-					   port->authn_id, hba_authname(port->hba->auth_method), HbaFileName,
-					   port->hba->linenumber));
+					   MyClientConnectionInfo.authn_id,
+					   hba_authname(MyClientConnectionInfo.auth_method),
+					   HbaFileName, port->hba->linenumber));
 	}
 }
 
@@ -1908,7 +1910,8 @@ auth_peer(hbaPort *port)
 	 */
 	set_authn_id(port, pw->pw_name);
 
-	ret = check_usermap(port->hba->usermap, port->user_name, port->authn_id, false);
+	ret = check_usermap(port->hba->usermap, port->user_name,
+						MyClientConnectionInfo.authn_id, false);
 
 	return ret;
 #else
diff --git a/src/backend/utils/init/miscinit.c b/src/backend/utils/init/miscinit.c
index eb43b2c5e5..58772d0a4a 100644
--- a/src/backend/utils/init/miscinit.c
+++ b/src/backend/utils/init/miscinit.c
@@ -931,6 +931,97 @@ GetUserNameFromId(Oid roleid, bool noerr)
 	return result;
 }
 
+/* ------------------------------------------------------------------------
+ *				Parallel connection state
+ *
+ * ClientConnectionInfo contains pieces of information about the client that
+ * need to be synced to parallel workers when they initialize. Over time, this
+ * list will probably grow, and may subsume some of the "user state" variables
+ * above.
+ *-------------------------------------------------------------------------
+ */
+
+ClientConnectionInfo MyClientConnectionInfo;
+
+/*
+ * Intermediate representation of ClientConnectionInfo for easier serialization.
+ * Variable-length fields are allocated right after this header.
+ */
+typedef struct SerializedClientConnectionInfo
+{
+	int32		authn_id_len; /* strlen(authn_id), or -1 if NULL */
+	UserAuth	auth_method;
+} SerializedClientConnectionInfo;
+
+/*
+ * Calculate the space needed to serialize MyClientConnectionInfo.
+ */
+Size
+EstimateClientConnectionInfoSpace(void)
+{
+	Size		size = 0;
+
+	size = add_size(size, sizeof(SerializedClientConnectionInfo));
+
+	if (MyClientConnectionInfo.authn_id)
+		size = add_size(size, strlen(MyClientConnectionInfo.authn_id) + 1);
+
+	return size;
+}
+
+/*
+ * Serialize MyClientConnectionInfo for use by parallel workers.
+ */
+void
+SerializeClientConnectionInfo(Size maxsize, char *start_address)
+{
+	SerializedClientConnectionInfo serialized = {0};
+
+	serialized.authn_id_len = -1;
+	serialized.auth_method = MyClientConnectionInfo.auth_method;
+
+	if (MyClientConnectionInfo.authn_id)
+		serialized.authn_id_len = strlen(MyClientConnectionInfo.authn_id);
+
+	/* Copy serialized representation to buffer */
+	Assert(maxsize >= sizeof(serialized));
+	memcpy(start_address, &serialized, sizeof(serialized));
+
+	maxsize -= sizeof(serialized);
+	start_address += sizeof(serialized);
+
+	/* Copy authn_id into the space after the struct. */
+	if (serialized.authn_id_len >= 0)
+	{
+		Assert(maxsize >= (serialized.authn_id_len + 1));
+		memcpy(start_address,
+			   MyClientConnectionInfo.authn_id,
+			   /* include the NULL terminator to ease deserialization */
+			   serialized.authn_id_len + 1);
+	}
+}
+
+/*
+ * Restore MyClientConnectionInfo from its serialized representation.
+ */
+void
+RestoreClientConnectionInfo(char *conninfo)
+{
+	SerializedClientConnectionInfo serialized;
+	char	   *authn_id;
+
+	memcpy(&serialized, conninfo, sizeof(serialized));
+	authn_id = conninfo + sizeof(serialized);
+
+	/* Copy the fields back into place. */
+	MyClientConnectionInfo.authn_id = NULL;
+	MyClientConnectionInfo.auth_method = serialized.auth_method;
+
+	if (serialized.authn_id_len >= 0)
+		MyClientConnectionInfo.authn_id = MemoryContextStrdup(TopMemoryContext,
+															  authn_id);
+}
+
 
 /*-------------------------------------------------------------------------
  *				Interlock-file support
diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h
index 90c20da22b..84a6bdea6f 100644
--- a/src/include/libpq/libpq-be.h
+++ b/src/include/libpq/libpq-be.h
@@ -98,6 +98,37 @@ typedef struct
 } pg_gssinfo;
 #endif
 
+/*
+ * Fields describing the client connection, that also need to be copied over to
+ * parallel workers, go into the ClientConnectionInfo rather than Port. The same
+ * rules apply for allocations here as for Port (must be malloc'd or palloc'd in
+ * TopMemoryContext).
+ *
+ * If you add a struct member here, remember to also handle serialization in
+ * SerializeClientConnectionInfo() et al.
+ */
+typedef struct ClientConnectionInfo
+{
+	/*
+	 * Authenticated identity.  The meaning of this identifier is dependent on
+	 * auth_method; it is the identity (if any) that the user presented
+	 * during the authentication cycle, before they were assigned a database
+	 * role.  (It is effectively the "SYSTEM-USERNAME" of a pg_ident usermap
+	 * -- though the exact string in use may be different, depending on pg_hba
+	 * options.)
+	 *
+	 * authn_id is NULL if the user has not actually been authenticated, for
+	 * example if the "trust" auth method is in use.
+	 */
+	const char *authn_id;
+
+	/*
+	 * The HBA method that determined the above authn_id. This only has meaning
+	 * if authn_id is not NULL; otherwise it's undefined.
+	 */
+	UserAuth	auth_method;
+} ClientConnectionInfo;
+
 /*
  * This is used by the postmaster in its communication with frontends.  It
  * contains all state information needed during this communication before the
@@ -158,19 +189,6 @@ typedef struct Port
 	 */
 	HbaLine    *hba;
 
-	/*
-	 * Authenticated identity.  The meaning of this identifier is dependent on
-	 * hba->auth_method; it is the identity (if any) that the user presented
-	 * during the authentication cycle, before they were assigned a database
-	 * role.  (It is effectively the "SYSTEM-USERNAME" of a pg_ident usermap
-	 * -- though the exact string in use may be different, depending on pg_hba
-	 * options.)
-	 *
-	 * authn_id is NULL if the user has not actually been authenticated, for
-	 * example if the "trust" auth method is in use.
-	 */
-	const char *authn_id;
-
 	/*
 	 * TCP keepalive and user timeout settings.
 	 *
@@ -327,6 +345,7 @@ extern ssize_t be_gssapi_write(Port *port, void *ptr, size_t len);
 #endif							/* ENABLE_GSS */
 
 extern PGDLLIMPORT ProtocolVersion FrontendProtocol;
+extern PGDLLIMPORT ClientConnectionInfo MyClientConnectionInfo;
 
 /* TCP keepalives configuration. These are no-ops on an AF_UNIX socket. */
 
diff --git a/src/include/miscadmin.h b/src/include/miscadmin.h
index 067b729d5a..3e9297e399 100644
--- a/src/include/miscadmin.h
+++ b/src/include/miscadmin.h
@@ -481,6 +481,10 @@ extern bool has_rolreplication(Oid roleid);
 typedef void (*shmem_request_hook_type) (void);
 extern PGDLLIMPORT shmem_request_hook_type shmem_request_hook;
 
+extern Size EstimateClientConnectionInfoSpace(void);
+extern void SerializeClientConnectionInfo(Size maxsize, char *start_address);
+extern void RestoreClientConnectionInfo(char *procinfo);
+
 /* in executor/nodeHash.c */
 extern size_t get_hash_memory_limit(void);
 
-- 
2.25.1

Reply via email to