From 02c667cc960a8bf09bf581e7b29d2167915a7ac9 Mon Sep 17 00:00:00 2001
From: Matheus Alcantara <mths.dev@pm.me>
Date: Mon, 20 Jan 2025 15:25:51 -0300
Subject: [PATCH v10 1/3] dblink: refactor get connection routines

Refactor dblink_get_conn and dblink_connect to move the logic of
actually opening the connection to the new connect_pg_server function
which them can be re-used on both functions.

This is a pre-work for a next commit that will add support for scram
pass-through authentication to dblink which will be able to implement
most of the logic into the connect_pg_server function which now already
have all necessary data information.

Note that some pfree(rconn) calls has been removed on this commit, but
the allocation of rconn was also moved after connect_pg_server which was
the only function that could force the pfree(rconn) call.
---
 contrib/dblink/dblink.c | 208 ++++++++++++++++++----------------------
 1 file changed, 94 insertions(+), 114 deletions(-)

diff --git a/contrib/dblink/dblink.c b/contrib/dblink/dblink.c
index 58c1a6221c8..4be8cfdbf74 100644
--- a/contrib/dblink/dblink.c
+++ b/contrib/dblink/dblink.c
@@ -114,10 +114,10 @@ static Relation get_rel_from_relname(text *relname_text, LOCKMODE lockmode, AclM
 static char *generate_relation_name(Relation rel);
 static void dblink_connstr_check(const char *connstr);
 static bool dblink_connstr_has_pw(const char *connstr);
-static void dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr);
+static void dblink_security_check(PGconn *conn, const char *connstr);
 static void dblink_res_error(PGconn *conn, const char *conname, PGresult *res,
 							 bool fail, const char *fmt,...) pg_attribute_printf(5, 6);
-static char *get_connect_string(const char *servername);
+static char *get_connect_string(ForeignServer *foreign_server);
 static char *escape_param_str(const char *str);
 static void validate_pkattnums(Relation rel,
 							   int2vector *pkattnums_arg, int32 pknumatts_arg,
@@ -126,6 +126,7 @@ static bool is_valid_dblink_option(const PQconninfoOption *options,
 								   const char *option, Oid context);
 static int	applyRemoteGucs(PGconn *conn);
 static void restoreLocalGucs(int nestlevel);
+static PGconn *connect_pg_server(char *connstr_or_srvname, uint32 wait_event_info);
 
 /* Global */
 static remoteConn *pconn = NULL;
@@ -199,33 +200,11 @@ dblink_get_conn(char *conname_or_str,
 	}
 	else
 	{
-		const char *connstr;
-
-		connstr = get_connect_string(conname_or_str);
-		if (connstr == NULL)
-			connstr = conname_or_str;
-		dblink_connstr_check(connstr);
-
 		/* first time, allocate or get the custom wait event */
 		if (dblink_we_get_conn == 0)
 			dblink_we_get_conn = WaitEventExtensionNew("DblinkGetConnect");
 
-		/* OK to make connection */
-		conn = libpqsrv_connect(connstr, dblink_we_get_conn);
-
-		if (PQstatus(conn) == CONNECTION_BAD)
-		{
-			char	   *msg = pchomp(PQerrorMessage(conn));
-
-			libpqsrv_disconnect(conn);
-			ereport(ERROR,
-					(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
-					 errmsg("could not establish connection"),
-					 errdetail_internal("%s", msg)));
-		}
-		dblink_security_check(conn, rconn, connstr);
-		if (PQclientEncoding(conn) != GetDatabaseEncoding())
-			PQsetClientEncoding(conn, GetDatabaseEncodingName());
+		conn = connect_pg_server(conname_or_str, dblink_we_get_conn);
 		freeconn = true;
 		conname = NULL;
 	}
@@ -270,9 +249,7 @@ Datum
 dblink_connect(PG_FUNCTION_ARGS)
 {
 	char	   *conname_or_str = NULL;
-	char	   *connstr = NULL;
 	char	   *connname = NULL;
-	char	   *msg;
 	PGconn	   *conn = NULL;
 	remoteConn *rconn = NULL;
 
@@ -286,53 +263,19 @@ dblink_connect(PG_FUNCTION_ARGS)
 	else if (PG_NARGS() == 1)
 		conname_or_str = text_to_cstring(PG_GETARG_TEXT_PP(0));
 
-	if (connname)
-	{
-		rconn = (remoteConn *) MemoryContextAlloc(TopMemoryContext,
-												  sizeof(remoteConn));
-		rconn->conn = NULL;
-		rconn->openCursorCount = 0;
-		rconn->newXactForCursor = false;
-	}
-
-	/* first check for valid foreign data server */
-	connstr = get_connect_string(conname_or_str);
-	if (connstr == NULL)
-		connstr = conname_or_str;
-
-	/* check password in connection string if not superuser */
-	dblink_connstr_check(connstr);
-
 	/* first time, allocate or get the custom wait event */
 	if (dblink_we_connect == 0)
 		dblink_we_connect = WaitEventExtensionNew("DblinkConnect");
 
-	/* OK to make connection */
-	conn = libpqsrv_connect(connstr, dblink_we_connect);
-
-	if (PQstatus(conn) == CONNECTION_BAD)
-	{
-		msg = pchomp(PQerrorMessage(conn));
-		libpqsrv_disconnect(conn);
-		if (rconn)
-			pfree(rconn);
-
-		ereport(ERROR,
-				(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
-				 errmsg("could not establish connection"),
-				 errdetail_internal("%s", msg)));
-	}
-
-	/* check password actually used if not superuser */
-	dblink_security_check(conn, rconn, connstr);
-
-	/* attempt to set client encoding to match server encoding, if needed */
-	if (PQclientEncoding(conn) != GetDatabaseEncoding())
-		PQsetClientEncoding(conn, GetDatabaseEncodingName());
+	conn = connect_pg_server(conname_or_str, dblink_we_connect);
 
 	if (connname)
 	{
+		rconn = (remoteConn *) MemoryContextAlloc(TopMemoryContext,
+												  sizeof(remoteConn));
 		rconn->conn = conn;
+		rconn->openCursorCount = 0;
+		rconn->newXactForCursor = false;
 		createNewConnection(connname, rconn);
 	}
 	else
@@ -2602,7 +2545,7 @@ deleteConnection(const char *name)
  * used to connect and then make sure that they came from the user.
  */
 static void
-dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr)
+dblink_security_check(PGconn *conn, const char *connstr)
 {
 	/* Superuser bypasses security check */
 	if (superuser())
@@ -2620,8 +2563,6 @@ dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr)
 
 	/* Otherwise, fail out */
 	libpqsrv_disconnect(conn);
-	if (rconn)
-		pfree(rconn);
 
 	ereport(ERROR,
 			(errcode(ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED),
@@ -2782,15 +2723,16 @@ dblink_res_error(PGconn *conn, const char *conname, PGresult *res,
  * Obtain connection string for a foreign server
  */
 static char *
-get_connect_string(const char *servername)
+get_connect_string(ForeignServer *foreign_server)
 {
-	ForeignServer *foreign_server = NULL;
 	UserMapping *user_mapping;
 	ListCell   *cell;
 	StringInfoData buf;
 	ForeignDataWrapper *fdw;
 	AclResult	aclresult;
-	char	   *srvname;
+	Oid			serverid = foreign_server->serverid;
+	Oid			fdwid = foreign_server->fdwid;
+	Oid			userid = GetUserId();
 
 	static const PQconninfoOption *options = NULL;
 
@@ -2813,57 +2755,43 @@ get_connect_string(const char *servername)
 					 errdetail("Could not get libpq's default connection options.")));
 	}
 
-	/* first gather the server connstr options */
-	srvname = pstrdup(servername);
-	truncate_identifier(srvname, strlen(srvname), false);
-	foreign_server = GetForeignServerByName(srvname, true);
-
-	if (foreign_server)
-	{
-		Oid			serverid = foreign_server->serverid;
-		Oid			fdwid = foreign_server->fdwid;
-		Oid			userid = GetUserId();
-
-		user_mapping = GetUserMapping(userid, serverid);
-		fdw = GetForeignDataWrapper(fdwid);
-
-		/* Check permissions, user must have usage on the server. */
-		aclresult = object_aclcheck(ForeignServerRelationId, serverid, userid, ACL_USAGE);
-		if (aclresult != ACLCHECK_OK)
-			aclcheck_error(aclresult, OBJECT_FOREIGN_SERVER, foreign_server->servername);
+	user_mapping = GetUserMapping(userid, serverid);
+	fdw = GetForeignDataWrapper(fdwid);
 
-		foreach(cell, fdw->options)
-		{
-			DefElem    *def = lfirst(cell);
+	/* Check permissions, user must have usage on the server. */
+	aclresult = object_aclcheck(ForeignServerRelationId, serverid, userid, ACL_USAGE);
+	if (aclresult != ACLCHECK_OK)
+		aclcheck_error(aclresult, OBJECT_FOREIGN_SERVER, foreign_server->servername);
 
-			if (is_valid_dblink_option(options, def->defname, ForeignDataWrapperRelationId))
-				appendStringInfo(&buf, "%s='%s' ", def->defname,
-								 escape_param_str(strVal(def->arg)));
-		}
+	foreach(cell, fdw->options)
+	{
+		DefElem    *def = lfirst(cell);
 
-		foreach(cell, foreign_server->options)
-		{
-			DefElem    *def = lfirst(cell);
+		if (is_valid_dblink_option(options, def->defname, ForeignDataWrapperRelationId))
+			appendStringInfo(&buf, "%s='%s' ", def->defname,
+							 escape_param_str(strVal(def->arg)));
+	}
 
-			if (is_valid_dblink_option(options, def->defname, ForeignServerRelationId))
-				appendStringInfo(&buf, "%s='%s' ", def->defname,
-								 escape_param_str(strVal(def->arg)));
-		}
+	foreach(cell, foreign_server->options)
+	{
+		DefElem    *def = lfirst(cell);
 
-		foreach(cell, user_mapping->options)
-		{
+		if (is_valid_dblink_option(options, def->defname, ForeignServerRelationId))
+			appendStringInfo(&buf, "%s='%s' ", def->defname,
+							 escape_param_str(strVal(def->arg)));
+	}
 
-			DefElem    *def = lfirst(cell);
+	foreach(cell, user_mapping->options)
+	{
 
-			if (is_valid_dblink_option(options, def->defname, UserMappingRelationId))
-				appendStringInfo(&buf, "%s='%s' ", def->defname,
-								 escape_param_str(strVal(def->arg)));
-		}
+		DefElem    *def = lfirst(cell);
 
-		return buf.data;
+		if (is_valid_dblink_option(options, def->defname, UserMappingRelationId))
+			appendStringInfo(&buf, "%s='%s' ", def->defname,
+							 escape_param_str(strVal(def->arg)));
 	}
-	else
-		return NULL;
+
+	return buf.data;
 }
 
 /*
@@ -3085,3 +3013,55 @@ restoreLocalGucs(int nestlevel)
 	if (nestlevel > 0)
 		AtEOXact_GUC(true, nestlevel);
 }
+
+/*
+ * Connect to remote server. If connstr_or_srvname maps to a foreign server,
+ * the associated properties and user mapping properties is also used to open
+ * the connection. Otherwise a connection will be open using the raw
+ * connstr_or_srvname value.
+ */
+static PGconn *
+connect_pg_server(char *connstr_or_srvname, uint32 wait_event_info)
+{
+	PGconn	   *conn;
+	ForeignServer *foreign_server = NULL;
+	const char *connstr;
+	char	   *srvname;
+
+	/* first gather the server connstr options */
+	srvname = pstrdup(connstr_or_srvname);
+	truncate_identifier(srvname, strlen(srvname), false);
+	foreign_server = GetForeignServerByName(srvname, true);
+
+	if (foreign_server)
+		connstr = get_connect_string(foreign_server);
+	else
+		connstr = connstr_or_srvname;
+
+	/* Verify the set of connection parameters. */
+	dblink_connstr_check(connstr);
+
+	/* OK to make connection */
+	conn = libpqsrv_connect(connstr, wait_event_info);
+
+	if (PQstatus(conn) == CONNECTION_BAD)
+	{
+		char	   *msg = pchomp(PQerrorMessage(conn));
+
+		libpqsrv_disconnect(conn);
+
+		ereport(ERROR,
+				(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
+				 errmsg("could not establish connection"),
+				 errdetail_internal("%s", msg)));
+	}
+
+	/* Perform post-connection security checks. */
+	dblink_security_check(conn, connstr);
+
+	/* attempt to set client encoding to match server encoding, if needed */
+	if (PQclientEncoding(conn) != GetDatabaseEncoding())
+		PQsetClientEncoding(conn, GetDatabaseEncodingName());
+
+	return conn;
+}
-- 
2.39.5 (Apple Git-154)

