From e4a9c66fb34270bff62eaded37042bc5521c905c Mon Sep 17 00:00:00 2001
From: Jelte Fennema <jelte.fennema@microsoft.com>
Date: Wed, 22 Jun 2022 09:39:13 +0200
Subject: [PATCH] Support load balancing in libpq

Load balancing connections across multiple read replicas is a pretty
common way of scaling out read queries. There are two main ways of doing
so, both with their own advantages and disadvantages:
1. Load balancing at the client level
2. Load balancing by connecting to an intermediary load balancer

Option 1 has been supported by JDBC (Java) for 8 years and Npgsql (C#)
merged support about a year ago. This patch adds the same functionality
to libpq. The way it's implemented is the same as the implementation of
JDBC, and contains two levels of load balancing:
1. The given hosts are randomly shuffled, before resolving them
    one-by-one.
2. Once a host its addresses get resolved, those addresses are shuffled,
    before trying to connect to them one-by-one.
---
 doc/src/sgml/libpq.sgml           |  17 +++
 src/include/libpq/pqcomm.h        |   6 +
 src/interfaces/libpq/fe-connect.c | 204 +++++++++++++++++++++++++-----
 src/interfaces/libpq/libpq-int.h  |   7 +-
 4 files changed, 202 insertions(+), 32 deletions(-)

diff --git a/doc/src/sgml/libpq.sgml b/doc/src/sgml/libpq.sgml
index 37ec3cb4e5..149cf25854 100644
--- a/doc/src/sgml/libpq.sgml
+++ b/doc/src/sgml/libpq.sgml
@@ -1317,6 +1317,23 @@ postgresql://%2Fvar%2Flib%2Fpostgresql/dbname
       </listitem>
      </varlistentry>
 
+     <varlistentry id="libpq-load-balance-hosts" xreflabel="load_balance_hosts">
+      <term><literal>load_balance_hosts</literal></term>
+      <listitem>
+       <para>
+        Controls whether the client load balances connections across hosts and
+        adresses. The default value is 0, meaning off, this means that hosts are
+        tried in order they are provided and addresses are tried in the order
+        they are received from DNS or a hosts file. If this value is set to 1,
+        meaning on, the hosts and address that are tried in random order.
+        Subsequent queries once connected will still be sent to the same server.
+        Setting this to 1, is mostly useful when opening multiple connections at
+        the same time, possibly from different machines. This way connections
+        can be load balanced across multiple Postgres servers.
+       </para>
+      </listitem>
+     </varlistentry>
+
      <varlistentry id="libpq-keepalives" xreflabel="keepalives">
       <term><literal>keepalives</literal></term>
       <listitem>
diff --git a/src/include/libpq/pqcomm.h b/src/include/libpq/pqcomm.h
index b418283d5f..f67b334887 100644
--- a/src/include/libpq/pqcomm.h
+++ b/src/include/libpq/pqcomm.h
@@ -65,6 +65,12 @@ typedef struct
 	socklen_t	salen;
 } SockAddr;
 
+typedef struct
+{
+	int			family;
+	SockAddr	addr;
+}			AddrInfo;
+
 /* Configure the UNIX socket location for the well known port. */
 
 #define UNIXSOCK_PATH(path, port, sockdir) \
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 6e936bbff3..6dcbbecc39 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -23,6 +23,7 @@
 
 #include "common/ip.h"
 #include "common/link-canary.h"
+#include "common/pg_prng.h"
 #include "common/scram-common.h"
 #include "common/string.h"
 #include "fe-auth.h"
@@ -244,6 +245,10 @@ static const internalPQconninfoOption PQconninfoOptions[] = {
 		"Fallback-Application-Name", "", 64,
 	offsetof(struct pg_conn, fbappname)},
 
+	{"load_balance_hosts", NULL, NULL, NULL,
+		"Load-Balance", "", 0,	/* should be just '0' or '1' */
+	offsetof(struct pg_conn, loadbalance)},
+
 	{"keepalives", NULL, NULL, NULL,
 		"TCP-Keepalives", "", 1,	/* should be just '0' or '1' */
 	offsetof(struct pg_conn, keepalives)},
@@ -367,6 +372,9 @@ static const PQEnvironmentOption EnvironmentOptions[] =
 	}
 };
 
+static bool libpq_prng_initialized = false;
+static pg_prng_state libpq_prng_state;
+
 /* The connection URI must start with either of the following designators: */
 static const char uri_designator[] = "postgresql://";
 static const char short_uri_designator[] = "postgres://";
@@ -382,6 +390,7 @@ static bool fillPGconn(PGconn *conn, PQconninfoOption *connOptions);
 static void freePGconn(PGconn *conn);
 static void closePGconn(PGconn *conn);
 static void release_conn_addrinfo(PGconn *conn);
+static bool store_conn_addrinfo(PGconn *conn, struct addrinfo *addrlist);
 static void sendTerminateConn(PGconn *conn);
 static PQconninfoOption *conninfo_init(PQExpBuffer errorMessage);
 static PQconninfoOption *parse_connection_string(const char *conninfo,
@@ -427,6 +436,7 @@ static void pgpassfileWarning(PGconn *conn);
 static void default_threadlock(int acquire);
 static bool sslVerifyProtocolVersion(const char *version);
 static bool sslVerifyProtocolRange(const char *min, const char *max);
+static int	loadBalance(PGconn *conn);
 
 
 /* global variable because fe-auth.c needs to access it */
@@ -1015,6 +1025,41 @@ parse_comma_separated_list(char **startptr, bool *more)
 	return p;
 }
 
+static void
+libpq_prng_init()
+{
+	if (libpq_prng_initialized)
+	{
+		return;
+	}
+
+	/*
+	 * Set a different global seed in every process.  We want something
+	 * unpredictable, so if possible, use high-quality random bits for the
+	 * seed.  Otherwise, fall back to a seed based on timestamp and PID.
+	 */
+	if (unlikely(!pg_prng_strong_seed(&libpq_prng_state)))
+	{
+		uint64		rseed;
+		time_t		now = time(NULL);
+
+		/*
+		 * Since PIDs and timestamps tend to change more frequently in their
+		 * least significant bits, shift the timestamp left to allow a larger
+		 * total number of seeds in a given time period.  Since that would
+		 * leave only 20 bits of the timestamp that cycle every ~1 second,
+		 * also mix in some higher bits.
+		 */
+		rseed = ((uint64) getpid()) ^
+			((uint64) now << 12) ^
+			((uint64) now >> 20);
+
+		pg_prng_seed(&libpq_prng_state, rseed);
+	}
+	libpq_prng_initialized = true;
+}
+
+
 /*
  *		connectOptions2
  *
@@ -1027,6 +1072,7 @@ static bool
 connectOptions2(PGconn *conn)
 {
 	int			i;
+	int			loadbalancehosts = loadBalance(conn);
 
 	/*
 	 * Allocate memory for details about each host to which we might possibly
@@ -1178,6 +1224,31 @@ connectOptions2(PGconn *conn)
 			return false;
 		}
 	}
+	if (loadbalancehosts < 0)
+	{
+		appendPQExpBufferStr(&conn->errorMessage,
+							 libpq_gettext("loadbalance parameter must be an integer\n"));
+		return false;
+	}
+
+	if (loadbalancehosts)
+	{
+		/*
+		 * Shuffle connhost with a Durstenfeld/Knuth version of the
+		 * Fisher-Yates shuffle. Source:
+		 * https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
+		 */
+		libpq_prng_init();
+		for (i = conn->nconnhost - 1; i > 0; i--)
+		{
+			int			j = pg_prng_double(&libpq_prng_state) * (i + 1);
+			pg_conn_host temp = conn->connhost[j];
+
+			conn->connhost[j] = conn->connhost[i];
+			conn->connhost[i] = temp;
+		}
+	}
+
 
 	/*
 	 * If user name was not given, fetch it.  (Most likely, the fetch will
@@ -1772,6 +1843,27 @@ connectFailureMessage(PGconn *conn, int errorno)
 							 libpq_gettext("\tIs the server running on that host and accepting TCP/IP connections?\n"));
 }
 
+/*
+ * Should we load balance across hosts? Returns 1 if yes, 0 if no, and -1 if
+ * conn->loadbalance is set to a value which is not parseable as an integer.
+ */
+static int
+loadBalance(PGconn *conn)
+{
+	char	   *ep;
+	int			val;
+
+	if (conn->loadbalance == NULL)
+	{
+		return 0;
+	}
+	val = strtol(conn->loadbalance, &ep, 10);
+	if (*ep)
+		return -1;
+	return val != 0 ? 1 : 0;
+}
+
+
 /*
  * Should we use keepalives?  Returns 1 if yes, 0 if no, and -1 if
  * conn->keepalives is set to a value which is not parseable as an
@@ -2129,7 +2221,7 @@ connectDBComplete(PGconn *conn)
 	time_t		finish_time = ((time_t) -1);
 	int			timeout = 0;
 	int			last_whichhost = -2;	/* certainly different from whichhost */
-	struct addrinfo *last_addr_cur = NULL;
+	int			last_whichaddr = -2;	/* certainly different from whichaddr */
 
 	if (conn == NULL || conn->status == CONNECTION_BAD)
 		return 0;
@@ -2173,11 +2265,11 @@ connectDBComplete(PGconn *conn)
 		if (flag != PGRES_POLLING_OK &&
 			timeout > 0 &&
 			(conn->whichhost != last_whichhost ||
-			 conn->addr_cur != last_addr_cur))
+			 conn->whichaddr != last_whichaddr))
 		{
 			finish_time = time(NULL) + timeout;
 			last_whichhost = conn->whichhost;
-			last_addr_cur = conn->addr_cur;
+			last_whichaddr = conn->whichaddr;
 		}
 
 		/*
@@ -2325,9 +2417,9 @@ keep_going:						/* We will come back to here until there is
 	/* Time to advance to next address, or next host if no more addresses? */
 	if (conn->try_next_addr)
 	{
-		if (conn->addr_cur && conn->addr_cur->ai_next)
+		if (conn->whichaddr < conn->naddr)
 		{
-			conn->addr_cur = conn->addr_cur->ai_next;
+			conn->whichaddr++;
 			reset_connection_state_machine = true;
 		}
 		else
@@ -2340,6 +2432,7 @@ keep_going:						/* We will come back to here until there is
 	{
 		pg_conn_host *ch;
 		struct addrinfo hint;
+		struct addrinfo *addrlist;
 		int			thisport;
 		int			ret;
 		char		portstr[MAXPGPATH];
@@ -2380,7 +2473,6 @@ keep_going:						/* We will come back to here until there is
 		/* Initialize hint structure */
 		MemSet(&hint, 0, sizeof(hint));
 		hint.ai_socktype = SOCK_STREAM;
-		conn->addrlist_family = hint.ai_family = AF_UNSPEC;
 
 		/* Figure out the port number we're going to use. */
 		if (ch->port == NULL || ch->port[0] == '\0')
@@ -2405,8 +2497,8 @@ keep_going:						/* We will come back to here until there is
 		{
 			case CHT_HOST_NAME:
 				ret = pg_getaddrinfo_all(ch->host, portstr, &hint,
-										 &conn->addrlist);
-				if (ret || !conn->addrlist)
+										 &addrlist);
+				if (ret || !addrlist)
 				{
 					appendPQExpBuffer(&conn->errorMessage,
 									  libpq_gettext("could not translate host name \"%s\" to address: %s\n"),
@@ -2418,8 +2510,8 @@ keep_going:						/* We will come back to here until there is
 			case CHT_HOST_ADDRESS:
 				hint.ai_flags = AI_NUMERICHOST;
 				ret = pg_getaddrinfo_all(ch->hostaddr, portstr, &hint,
-										 &conn->addrlist);
-				if (ret || !conn->addrlist)
+										 &addrlist);
+				if (ret || !addrlist)
 				{
 					appendPQExpBuffer(&conn->errorMessage,
 									  libpq_gettext("could not parse network address \"%s\": %s\n"),
@@ -2430,7 +2522,6 @@ keep_going:						/* We will come back to here until there is
 
 			case CHT_UNIX_SOCKET:
 #ifdef HAVE_UNIX_SOCKETS
-				conn->addrlist_family = hint.ai_family = AF_UNIX;
 				UNIXSOCK_PATH(portstr, thisport, ch->host);
 				if (strlen(portstr) >= UNIXSOCK_PATH_BUFLEN)
 				{
@@ -2446,8 +2537,8 @@ keep_going:						/* We will come back to here until there is
 				 * name as a Unix-domain socket path.
 				 */
 				ret = pg_getaddrinfo_all(NULL, portstr, &hint,
-										 &conn->addrlist);
-				if (ret || !conn->addrlist)
+										 &addrlist);
+				if (ret || !addrlist)
 				{
 					appendPQExpBuffer(&conn->errorMessage,
 									  libpq_gettext("could not translate Unix-domain socket path \"%s\" to address: %s\n"),
@@ -2460,8 +2551,15 @@ keep_going:						/* We will come back to here until there is
 				break;
 		}
 
+		if (!store_conn_addrinfo(conn, addrlist))
+		{
+			appendPQExpBufferStr(&conn->errorMessage,
+								 libpq_gettext("out of memory\n"));
+			goto error_return;
+		}
+		pg_freeaddrinfo_all(hint.ai_family, addrlist);
+
 		/* OK, scan this addrlist for a working server address */
-		conn->addr_cur = conn->addrlist;
 		reset_connection_state_machine = true;
 		conn->try_next_host = false;
 	}
@@ -2518,30 +2616,29 @@ keep_going:						/* We will come back to here until there is
 			{
 				/*
 				 * Try to initiate a connection to one of the addresses
-				 * returned by pg_getaddrinfo_all().  conn->addr_cur is the
+				 * returned by pg_getaddrinfo_all().  conn->whichaddr is the
 				 * next one to try.
 				 *
 				 * The extra level of braces here is historical.  It's not
 				 * worth reindenting this whole switch case to remove 'em.
 				 */
 				{
-					struct addrinfo *addr_cur = conn->addr_cur;
 					char		host_addr[NI_MAXHOST];
+					AddrInfo   *addr_cur;
 
 					/*
 					 * Advance to next possible host, if we've tried all of
 					 * the addresses for the current host.
 					 */
-					if (addr_cur == NULL)
+					if (conn->whichaddr == conn->naddr)
 					{
 						conn->try_next_host = true;
 						goto keep_going;
 					}
+					addr_cur = &conn->addr[conn->whichaddr];
 
 					/* Remember current address for possible use later */
-					memcpy(&conn->raddr.addr, addr_cur->ai_addr,
-						   addr_cur->ai_addrlen);
-					conn->raddr.salen = addr_cur->ai_addrlen;
+					memcpy(&conn->raddr, &addr_cur->addr, sizeof(SockAddr));
 
 					/*
 					 * Set connip, too.  Note we purposely ignore strdup
@@ -2557,7 +2654,7 @@ keep_going:						/* We will come back to here until there is
 						conn->connip = strdup(host_addr);
 
 					/* Try to create the socket */
-					conn->sock = socket(addr_cur->ai_family, SOCK_STREAM, 0);
+					conn->sock = socket(addr_cur->family, SOCK_STREAM, 0);
 					if (conn->sock == PGINVALID_SOCKET)
 					{
 						int			errorno = SOCK_ERRNO;
@@ -2568,7 +2665,7 @@ keep_going:						/* We will come back to here until there is
 						 * cases where the address list includes both IPv4 and
 						 * IPv6 but kernel only accepts one family.
 						 */
-						if (addr_cur->ai_next != NULL ||
+						if (conn->whichaddr < conn->naddr ||
 							conn->whichhost + 1 < conn->nconnhost)
 						{
 							conn->try_next_addr = true;
@@ -2595,7 +2692,7 @@ keep_going:						/* We will come back to here until there is
 					 * TCP sockets, nonblock mode, close-on-exec.  Try the
 					 * next address if any of this fails.
 					 */
-					if (addr_cur->ai_family != AF_UNIX)
+					if (addr_cur->family != AF_UNIX)
 					{
 						if (!connectNoDelay(conn))
 						{
@@ -2624,7 +2721,7 @@ keep_going:						/* We will come back to here until there is
 					}
 #endif							/* F_SETFD */
 
-					if (addr_cur->ai_family != AF_UNIX)
+					if (addr_cur->family != AF_UNIX)
 					{
 #ifndef WIN32
 						int			on = 1;
@@ -2718,8 +2815,8 @@ keep_going:						/* We will come back to here until there is
 					 * Start/make connection.  This should not block, since we
 					 * are in nonblock mode.  If it does, well, too bad.
 					 */
-					if (connect(conn->sock, addr_cur->ai_addr,
-								addr_cur->ai_addrlen) < 0)
+					if (connect(conn->sock, (struct sockaddr *) &addr_cur->addr.addr,
+								addr_cur->addr.salen) < 0)
 					{
 						if (SOCK_ERRNO == EINPROGRESS ||
 #ifdef WIN32
@@ -4172,6 +4269,54 @@ freePGconn(PGconn *conn)
 	free(conn);
 }
 
+
+static bool
+store_conn_addrinfo(PGconn *conn, struct addrinfo *addrlist)
+{
+	struct addrinfo *ai = addrlist;
+
+	conn->whichaddr = 0;
+	conn->naddr = 0;
+	while (ai)
+	{
+		ai = ai->ai_next;
+		conn->naddr++;
+	}
+	conn->addr = calloc(conn->naddr, sizeof(AddrInfo));
+	if (conn->addr == NULL)
+	{
+		return false;
+	}
+	ai = addrlist;
+	for (int i = 0; i < conn->naddr; i++)
+	{
+		conn->addr[i].family = ai->ai_family;
+
+		memcpy(&conn->addr[i].addr.addr, ai->ai_addr,
+			   ai->ai_addrlen);
+		conn->addr[i].addr.salen = ai->ai_addrlen;
+		ai = ai->ai_next;
+	}
+	if (loadBalance(conn))
+	{
+		/*
+		 * Shuffle addr with a Durstenfeld/Knuth version of the Fisher-Yates
+		 * shuffle. Source:
+		 * https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
+		 */
+		libpq_prng_init();
+		for (int i = conn->naddr - 1; i > 0; i--)
+		{
+			int			j = pg_prng_double(&libpq_prng_state) * (i + 1);
+			AddrInfo	temp = conn->addr[j];
+
+			conn->addr[j] = conn->addr[i];
+			conn->addr[i] = temp;
+		}
+	}
+	return true;
+}
+
 /*
  * release_conn_addrinfo
  *	 - Free any addrinfo list in the PGconn.
@@ -4179,11 +4324,10 @@ freePGconn(PGconn *conn)
 static void
 release_conn_addrinfo(PGconn *conn)
 {
-	if (conn->addrlist)
+	if (conn->addr)
 	{
-		pg_freeaddrinfo_all(conn->addrlist_family, conn->addrlist);
-		conn->addrlist = NULL;
-		conn->addr_cur = NULL;	/* for safety */
+		free(conn->addr);
+		conn->addr = NULL;
 	}
 }
 
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 3db6a17db4..3f60bfa2a0 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -370,6 +370,7 @@ struct pg_conn
 	char	   *pgpassfile;		/* path to a file containing password(s) */
 	char	   *channel_binding;	/* channel binding mode
 									 * (require,prefer,disable) */
+	char	   *loadbalance;	/* load balance over hosts */
 	char	   *keepalives;		/* use TCP keepalives? */
 	char	   *keepalives_idle;	/* time between TCP keepalives */
 	char	   *keepalives_interval;	/* time between TCP keepalive
@@ -458,8 +459,10 @@ struct pg_conn
 	PGTargetServerType target_server_type;	/* desired session properties */
 	bool		try_next_addr;	/* time to advance to next address/host? */
 	bool		try_next_host;	/* time to advance to next connhost[]? */
-	struct addrinfo *addrlist;	/* list of addresses for current connhost */
-	struct addrinfo *addr_cur;	/* the one currently being tried */
+	int			naddr;			/* # of addrs returned by getaddrinfo */
+	int			whichaddr;		/* the addr currently being tried */
+	AddrInfo   *addr;			/* the array of addresses for the currently
+								 * tried host */
 	int			addrlist_family;	/* needed to know how to free addrlist */
 	bool		send_appname;	/* okay to send application_name? */
 
-- 
2.17.1

