PFA a simple patch that implements support for the PROXY protocol.

This is a protocol common and very light weight in proxies and load
balancers (haproxy is one common example, but also for example the AWS
cloud load balancers).  Basically this protocol prefixes the normal
connection with a header and a specification of what the original host
was, allowing the server to unwrap that and get the correct client
address instead of just the proxy ip address. It is a one-way protocol
in that there is no response from the server, it's just purely a
prefix of the IP information.

Using this when PostgreSQL is behind a proxy allows us to keep using
pg_hba.conf rules based on the original ip address, as well as track
the original address in log messages and pg_stat_activity etc.

The implementation adds a parameter named proxy_servers which lists
the ips or ip+cidr mask to be trusted. Since a proxy can decide what
the origin is, and this is used for security decisions, it's very
important to not just trust any server, only those that are
intentionally used. By default, no servers are listed, and thus the
protocol is disabled.

When specified, and the connection on the normal port has the proxy
prefix on it, and the connection comes in from one of the addresses
listed as valid proxy servers, we will replace the actual IP address
of the client with the one specified in the proxy packet.

Currently there is no information about the proxy server in the
pg_stat_activity view, it's only available as a log message. But maybe
it should go in pg_stat_activity as well? Or in a separate
pg_stat_proxy view?

(In passing, I note that pq_discardbytes were in pqcomm.h, yet listed
as static in pqcomm.c -- but now made non-static)

-- 
 Magnus Hagander
 Me: https://www.hagander.net/
 Work: https://www.redpill-linpro.com/
diff --git a/doc/src/sgml/client-auth.sgml b/doc/src/sgml/client-auth.sgml
index b420486a0a..d4f6fad5b0 100644
--- a/doc/src/sgml/client-auth.sgml
+++ b/doc/src/sgml/client-auth.sgml
@@ -353,6 +353,15 @@ hostnogssenc  <replaceable>database</replaceable>  <replaceable>user</replaceabl
        the client's host name instead of the IP address in the log.
       </para>
 
+      <para>
+       If <xref linkend="guc-proxy-servers"/> is enabled and the
+       connection is made through a proxy server using the PROXY
+       protocol, the actual IP address of the client will be used
+       for matching. If a connection is made through a proxy server
+       not using the PROXY protocol, the IP address of the
+       proxy server will be used.
+      </para>
+
       <para>
        These fields do not apply to <literal>local</literal> records.
       </para>
diff --git a/doc/src/sgml/config.sgml b/doc/src/sgml/config.sgml
index b5718fc136..fc7de25378 100644
--- a/doc/src/sgml/config.sgml
+++ b/doc/src/sgml/config.sgml
@@ -682,6 +682,30 @@ include_dir 'conf.d'
       </listitem>
      </varlistentry>
 
+     <varlistentry id="guc-proxy-servers" xreflabel="proxy_servers">
+      <term><varname>proxy_servers</varname> (<type>string</type>)
+      <indexterm>
+       <primary><varname>proxy_servers</varname> configuration parameter</primary>
+      </indexterm>
+      </term>
+      <listitem>
+       <para>
+        A comma separated list of one or more host names or cidr specifications
+        of proxy servers to trust. If a connection using the PROXY protocol is made
+        from one of these IP addresses, <productname>PostgreSQL</productname> will
+        read the client IP address from the PROXY header and consider that the
+        address of the client, instead of listing all connections as coming from
+        the proxy server.
+       </para>
+       <para>
+        If a proxy connection is made from an IP address not covered by this
+        list, the connection will be rejected. By default no proxy is trusted
+        and all proxy connections will be rejected.  This parameter can only
+        be set at server start.
+       </para>
+      </listitem>
+     </varlistentry>
+
      <varlistentry id="guc-max-connections" xreflabel="max_connections">
       <term><varname>max_connections</varname> (<type>integer</type>)
       <indexterm>
diff --git a/src/backend/libpq/pqcomm.c b/src/backend/libpq/pqcomm.c
index 27a298f110..9163761cc2 100644
--- a/src/backend/libpq/pqcomm.c
+++ b/src/backend/libpq/pqcomm.c
@@ -53,6 +53,7 @@
  *		pq_getmessage	- get a message with length word from connection
  *		pq_getbyte		- get next byte from connection
  *		pq_peekbyte		- peek at next byte from connection
+ *		pq_peekbytes    - peek at a known number of bytes from connection
  *		pq_putbytes		- send bytes to connection (not flushed until pq_flush)
  *		pq_flush		- flush pending output
  *		pq_flush_if_writable - flush pending output if writable without blocking
@@ -1039,6 +1040,27 @@ pq_peekbyte(void)
 	return (unsigned char) PqRecvBuffer[PqRecvPointer];
 }
 
+
+/* --------------------------------
+ *		pq_peekbytes		- peek at a known number of bytes from connection.
+ *							  Note! Does NOT wait for more data to arrive.
+ *
+ *		returns 0 if OK, EOF if trouble
+ * --------------------------------
+ */
+int
+pq_peekbytes(char *s, size_t len)
+{
+	Assert(PqCommReadingMsg);
+
+	if (PqRecvLength - PqRecvPointer < len)
+		return EOF;
+
+	memcpy(s, PqRecvBuffer + PqRecvPointer, len);
+
+	return 0;
+}
+
 /* --------------------------------
  *		pq_getbyte_if_available - get a single byte from connection,
  *			if available
@@ -1135,7 +1157,7 @@ pq_getbytes(char *s, size_t len)
  *		returns 0 if OK, EOF if trouble
  * --------------------------------
  */
-static int
+int
 pq_discardbytes(size_t len)
 {
 	size_t		amount;
diff --git a/src/backend/postmaster/postmaster.c b/src/backend/postmaster/postmaster.c
index 3f1ce135a8..0473129cb4 100644
--- a/src/backend/postmaster/postmaster.c
+++ b/src/backend/postmaster/postmaster.c
@@ -102,6 +102,7 @@
 #include "common/string.h"
 #include "lib/ilist.h"
 #include "libpq/auth.h"
+#include "libpq/ifaddr.h"
 #include "libpq/libpq.h"
 #include "libpq/pqformat.h"
 #include "libpq/pqsignal.h"
@@ -204,6 +205,10 @@ char	   *Unix_socket_directories;
 /* The TCP listen address(es) */
 char	   *ListenAddresses;
 
+/* Trusted proxy servers */
+char	   *TrustedProxyServersString = NULL;
+struct sockaddr_storage *TrustedProxyServers = NULL;
+
 /*
  * ReservedBackends is the number of backends reserved for superuser use.
  * This number is taken out of the pool size given by MaxConnections so
@@ -1911,6 +1916,203 @@ initMasks(fd_set *rmask)
 	return maxsock + 1;
 }
 
+static int
+UnwrapProxyConnection(Port *port)
+{
+	char		proxyver;
+	uint16		proxyaddrlen;
+	SockAddr	raddr_save;
+	int			i;
+	bool		allowed = false;
+
+	/*
+	 * These structs are from the PROXY protocol docs at
+	 * http://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
+	 */
+	union
+	{
+		struct
+		{						/* for TCP/UDP over IPv4, len = 12 */
+			uint32		src_addr;
+			uint32		dst_addr;
+			uint16		src_port;
+			uint16		dst_port;
+		}			ip4;
+		struct
+		{						/* for TCP/UDP over IPv6, len = 36 */
+			uint8		src_addr[16];
+			uint8		dst_addr[16];
+			uint16		src_port;
+			uint16		dst_port;
+		}			ip6;
+	}			proxyaddr;
+	struct
+	{
+		uint8		sig[12];	/* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */
+		uint8		ver_cmd;	/* protocol version and command */
+		uint8		fam;		/* protocol family and address */
+		uint16		len;		/* number of following bytes part of the
+								 * header */
+	}			proxyheader;
+
+	/* Store a copy of the original address, for logging */
+	memcpy(&raddr_save, &port->raddr, port->raddr.salen);
+
+	pq_startmsgread();
+
+	/* Peek at the very first byte just to trigger a read */
+	if (pq_peekbyte() == EOF)
+	{
+		ereport(COMMERROR,
+				(errcode(ERRCODE_PROTOCOL_VIOLATION),
+				 errmsg("incomplete startup packet")));
+		return STATUS_ERROR;
+	}
+
+	/*
+	 * PROXY requests always start with: \x0D \x0A \x0D \x0A \x00 \x0D \x0A
+	 * \x51 \x55 \x49 \x54 \x0A
+	 */
+
+	if (pq_peekbytes((char *) &proxyheader, sizeof(proxyheader)) != 0)
+	{
+		/*
+		 * Not enough bytes to be a proxy header, so fall through to normal
+		 * processing
+		 */
+		pq_endmsgread();
+		return STATUS_OK;
+	}
+
+	if (memcmp(proxyheader.sig, "\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a", sizeof(proxyheader.sig)) != 0)
+	{
+		/*
+		 * Data is there but it wasn't a proxy header. Also fall through to
+		 * normal processing
+		 */
+		pq_endmsgread();
+		return STATUS_OK;
+	}
+
+	/* Header is valid. Verify that the proxy is actually authorized! */
+	for (i = 0; i < *((int *) TrustedProxyServers); i += 2)
+	{
+		if (raddr_save.addr.ss_family == TrustedProxyServers[i + 1].ss_family &&
+			pg_range_sockaddr(&raddr_save.addr,
+							  &TrustedProxyServers[i + 1],
+							  &TrustedProxyServers[i + 2]))
+		{
+			allowed = true;
+			break;
+		}
+	}
+	if (!allowed)
+	{
+		ereport(COMMERROR,
+				(errcode(ERRCODE_PROTOCOL_VIOLATION),
+				 errmsg("proxy connection from unauthorized server")));
+		return STATUS_ERROR;
+	}
+
+	/*
+	 * This is a valid proxy header, so unwrap it. First, skip past the header
+	 * itself
+	 */
+	pq_discardbytes(sizeof(proxyheader));
+
+	/* Proxy version is in the high 4 bits of the first byte */
+	proxyver = (proxyheader.ver_cmd & 0xF0) >> 4;
+	if (proxyver != 2)
+	{
+		ereport(COMMERROR,
+				(errcode(ERRCODE_PROTOCOL_VIOLATION),
+				 errmsg("invalid proxy protocol version: %x", proxyver)));
+		return STATUS_ERROR;
+	}
+
+	proxyaddrlen = pg_ntoh16(proxyheader.len);
+
+	if (proxyaddrlen > sizeof(proxyaddr))
+	{
+		ereport(COMMERROR,
+				(errcode(ERRCODE_PROTOCOL_VIOLATION),
+				 errmsg("oversized proxy packet")));
+		return STATUS_ERROR;
+	}
+	if (pq_getbytes((char *) &proxyaddr, proxyaddrlen) == EOF)
+	{
+		ereport(COMMERROR,
+				(errcode(ERRCODE_PROTOCOL_VIOLATION),
+				 errmsg("incomplete proxy packet")));
+		return STATUS_ERROR;
+	}
+
+	/* Lower 4 bits hold type of connection */
+	if (proxyheader.fam == 0)
+	{
+		/* LOCAL connection, so we ignore the address included */
+	}
+	else if (proxyheader.fam == 0x11)
+	{
+		/* TCPv4 */
+		port->raddr.addr.ss_family = AF_INET;
+		port->raddr.salen = sizeof(struct sockaddr_in);
+		((struct sockaddr_in *) &port->raddr.addr)->sin_addr.s_addr = proxyaddr.ip4.src_addr;
+		((struct sockaddr_in *) &port->raddr.addr)->sin_port = proxyaddr.ip4.src_port;
+	}
+	else if (proxyheader.fam == 0x21)
+	{
+		/* TCPv6 */
+		port->raddr.addr.ss_family = AF_INET6;
+		port->raddr.salen = sizeof(struct sockaddr_in6);
+		memcpy(&((struct sockaddr_in6 *) &port->raddr.addr)->sin6_addr, proxyaddr.ip6.src_addr, 16);
+		((struct sockaddr_in6 *) &port->raddr.addr)->sin6_port = proxyaddr.ip6.src_port;
+	}
+	else
+	{
+		ereport(COMMERROR,
+				(errcode(ERRCODE_PROTOCOL_VIOLATION),
+				 errmsg("invalid proxy protocol connection type: %x", proxyheader.fam)));
+		return STATUS_ERROR;
+	}
+
+	/* If there is any more header data present, skip past it */
+	if (proxyaddrlen > sizeof(proxyaddr))
+		pq_discardbytes(proxyaddrlen - sizeof(proxyaddr));
+
+
+	pq_endmsgread();
+
+	/*
+	 * Log what we've done if connection logging is enabled. We log the proxy
+	 * connection here, and let the normal connection logging mechanism log
+	 * the unwrapped connection.
+	 */
+	if (Log_connections)
+	{
+		char		remote_host[NI_MAXHOST];
+		char		remote_port[NI_MAXSERV];
+		int			ret;
+
+		remote_host[0] = '\0';
+		remote_port[0] = '\0';
+		if ((ret = pg_getnameinfo_all(&raddr_save.addr, raddr_save.salen,
+									  remote_host, sizeof(remote_host),
+									  remote_port, sizeof(remote_port),
+									  (log_hostname ? 0 : NI_NUMERICHOST) | NI_NUMERICSERV)) != 0)
+			ereport(WARNING,
+					(errmsg_internal("pg_getnameinfo_all() failed: %s",
+									 gai_strerror(ret))));
+
+		ereport(LOG,
+				(errmsg("proxy connection from: host=%s port=%s",
+						remote_host,
+						remote_port)));
+
+	}
+
+	return STATUS_OK;
+}
 
 /*
  * Read a client's startup packet and do something according to it.
@@ -4344,6 +4546,33 @@ BackendInitialize(Port *port)
 	InitializeTimeouts();		/* establishes SIGALRM handler */
 	PG_SETMASK(&StartupBlockSig);
 
+	/*
+	 * Ready to begin client interaction.  We will give up and _exit(1) after
+	 * a time delay, so that a broken client can't hog a connection
+	 * indefinitely.  PreAuthDelay and any DNS interactions above don't count
+	 * against the time limit.
+	 *
+	 * Note: AuthenticationTimeout is applied here while waiting for the
+	 * startup packet, and then again in InitPostgres for the duration of any
+	 * authentication operations.  So a hostile client could tie up the
+	 * process for nearly twice AuthenticationTimeout before we kick him off.
+	 *
+	 * Note: because PostgresMain will call InitializeTimeouts again, the
+	 * registration of STARTUP_PACKET_TIMEOUT will be lost.  This is okay
+	 * since we never use it again after this function.
+	 */
+	RegisterTimeout(STARTUP_PACKET_TIMEOUT, StartupPacketTimeoutHandler);
+	enable_timeout_after(STARTUP_PACKET_TIMEOUT, AuthenticationTimeout * 1000);
+
+	/* Check if this is a proxy connection and if so unwrap the proxying */
+	if (TrustedProxyServers)
+	{
+		if (UnwrapProxyConnection(port) != STATUS_OK)
+			proc_exit(0);
+	}
+
+	disable_timeout(STARTUP_PACKET_TIMEOUT, false);
+
 	/*
 	 * Get the remote host name and port for logging and status display.
 	 */
@@ -4395,28 +4624,11 @@ BackendInitialize(Port *port)
 		strspn(remote_host, "0123456789ABCDEFabcdef:") < strlen(remote_host))
 		port->remote_hostname = strdup(remote_host);
 
-	/*
-	 * Ready to begin client interaction.  We will give up and _exit(1) after
-	 * a time delay, so that a broken client can't hog a connection
-	 * indefinitely.  PreAuthDelay and any DNS interactions above don't count
-	 * against the time limit.
-	 *
-	 * Note: AuthenticationTimeout is applied here while waiting for the
-	 * startup packet, and then again in InitPostgres for the duration of any
-	 * authentication operations.  So a hostile client could tie up the
-	 * process for nearly twice AuthenticationTimeout before we kick him off.
-	 *
-	 * Note: because PostgresMain will call InitializeTimeouts again, the
-	 * registration of STARTUP_PACKET_TIMEOUT will be lost.  This is okay
-	 * since we never use it again after this function.
-	 */
-	RegisterTimeout(STARTUP_PACKET_TIMEOUT, StartupPacketTimeoutHandler);
-	enable_timeout_after(STARTUP_PACKET_TIMEOUT, AuthenticationTimeout * 1000);
-
 	/*
 	 * Receive the startup packet (which might turn out to be a cancel request
 	 * packet).
 	 */
+	enable_timeout_after(STARTUP_PACKET_TIMEOUT, AuthenticationTimeout * 1000);
 	status = ProcessStartupPacket(port, false, false);
 
 	/*
diff --git a/src/backend/utils/misc/guc.c b/src/backend/utils/misc/guc.c
index d626731723..381067b737 100644
--- a/src/backend/utils/misc/guc.c
+++ b/src/backend/utils/misc/guc.c
@@ -46,10 +46,12 @@
 #include "commands/user.h"
 #include "commands/vacuum.h"
 #include "commands/variable.h"
+#include "common/ip.h"
 #include "common/string.h"
 #include "funcapi.h"
 #include "jit/jit.h"
 #include "libpq/auth.h"
+#include "libpq/ifaddr.h"
 #include "libpq/libpq.h"
 #include "libpq/pqformat.h"
 #include "miscadmin.h"
@@ -227,6 +229,8 @@ static bool check_recovery_target_lsn(char **newval, void **extra, GucSource sou
 static void assign_recovery_target_lsn(const char *newval, void *extra);
 static bool check_primary_slot_name(char **newval, void **extra, GucSource source);
 static bool check_default_with_oids(bool *newval, void **extra, GucSource source);
+static bool check_proxy_servers(char **newval, void **extra, GucSource source);
+static void assign_proxy_servers(const char *newval, void *extra);
 
 /* Private functions in guc-file.l that need to be called from guc.c */
 static ConfigVariable *ProcessConfigFileInternal(GucContext context,
@@ -4241,6 +4245,17 @@ static struct config_string ConfigureNamesString[] =
 		NULL, NULL, NULL
 	},
 
+	{
+		{"proxy_servers", PGC_POSTMASTER, CONN_AUTH_SETTINGS,
+			gettext_noop("Sets the addresses for trusted proxy servers."),
+			NULL,
+			GUC_LIST_INPUT
+		},
+		&TrustedProxyServersString,
+		"",
+		check_proxy_servers, assign_proxy_servers, NULL
+	},
+
 	{
 		/*
 		 * Can't be set by ALTER SYSTEM as it can lead to recursive definition
@@ -12228,4 +12243,108 @@ check_default_with_oids(bool *newval, void **extra, GucSource source)
 	return true;
 }
 
+static bool
+check_proxy_servers(char **newval, void **extra, GucSource source)
+{
+	char	   *rawstring;
+	List	   *elemlist;
+	ListCell   *l;
+	struct sockaddr_storage *myextra;
+
+	/* Special case when it's empty */
+	if (**newval == '\0')
+	{
+		*extra = NULL;
+		return true;
+	}
+
+	/* Need a modifiable copy of string */
+	rawstring = pstrdup(*newval);
+
+	/* Parse string into list of identifiers */
+	if (!SplitIdentifierString(rawstring, ',', &elemlist))
+	{
+		/* syntax error in list */
+		GUC_check_errdetail("List syntax is invalid.");
+		pfree(rawstring);
+		list_free(elemlist);
+		return false;
+	}
+
+	if (list_length(elemlist) == 0)
+	{
+		/* If it had only whitespace */
+		pfree(rawstring);
+		list_free(elemlist);
+
+		*extra = NULL;
+		return true;
+	}
+
+	/*
+	 * We store the result in an array of sockaddr_storage. The first entry is
+	 * just an overloaded int which holds the size of the array.
+	 */
+	myextra = (struct sockaddr_storage *) guc_malloc(ERROR, sizeof(struct sockaddr_storage) * (list_length(elemlist) * 2 + 1));
+	*((int *) &myextra[0]) = list_length(elemlist);
+
+	foreach(l, elemlist)
+	{
+		char	   *tok = (char *) lfirst(l);
+		char	   *netmasktok = NULL;
+		int			ret;
+		struct addrinfo *gai_result;
+		struct addrinfo hints;
+
+		netmasktok = strchr(tok, '/');
+		if (netmasktok)
+		{
+			*netmasktok = '\0';
+			netmasktok++;
+		}
+
+		memset((char *) &hints, 0, sizeof(hints));
+		hints.ai_flags = AI_NUMERICHOST;
+		hints.ai_family = AF_UNSPEC;
+
+		ret = pg_getaddrinfo_all(tok, NULL, &hints, &gai_result);
+		if (ret != 0 || gai_result == NULL)
+		{
+			GUC_check_errdetail("Invalid IP addrress %s", tok);
+			pfree(rawstring);
+			list_free(elemlist);
+			free(myextra);
+			return false;
+		}
+
+		memcpy((char *) &myextra[foreach_current_index(l) * 2 + 1], gai_result->ai_addr, gai_result->ai_addrlen);
+		pg_freeaddrinfo_all(hints.ai_family, gai_result);
+
+		/* A NULL netmasktok means the fully set hostmask */
+		if (pg_sockaddr_cidr_mask(&myextra[foreach_current_index(l) * 2 + 2], netmasktok, myextra[foreach_current_index(l) * 2 + 1].ss_family) != 0)
+		{
+			if (netmasktok)
+				GUC_check_errdetail("Invalid netmask %s", netmasktok);
+			else
+				GUC_check_errdetail("Could not create netmask");
+			pfree(rawstring);
+			list_free(elemlist);
+			free(myextra);
+			return false;
+		}
+	}
+
+	pfree(rawstring);
+	list_free(elemlist);
+	*extra = (void *) myextra;
+
+	return true;
+}
+
+static void
+assign_proxy_servers(const char *newval, void *extra)
+{
+	TrustedProxyServers = (struct sockaddr_storage *) extra;
+}
+
 #include "guc-file.c"
diff --git a/src/backend/utils/misc/postgresql.conf.sample b/src/backend/utils/misc/postgresql.conf.sample
index ee06528bb0..aa7ac35f67 100644
--- a/src/backend/utils/misc/postgresql.conf.sample
+++ b/src/backend/utils/misc/postgresql.conf.sample
@@ -61,6 +61,8 @@
 					# defaults to 'localhost'; use '*' for all
 					# (change requires restart)
 #port = 5432				# (change requires restart)
+#proxy_servers = ''			# what IP/netmasks of proxy servers to trust
+					# (change requires restart)
 #max_connections = 100			# (change requires restart)
 #superuser_reserved_connections = 3	# (change requires restart)
 #unix_socket_directories = '/tmp'	# comma-separated list of directories
diff --git a/src/include/libpq/libpq.h b/src/include/libpq/libpq.h
index e4e5c21565..6125b93e86 100644
--- a/src/include/libpq/libpq.h
+++ b/src/include/libpq/libpq.h
@@ -74,6 +74,8 @@ extern bool pq_is_reading_msg(void);
 extern int	pq_getmessage(StringInfo s, int maxlen);
 extern int	pq_getbyte(void);
 extern int	pq_peekbyte(void);
+extern int	pq_peekbytes(char *s, size_t len);
+extern int	pq_discardbytes(size_t len);
 extern int	pq_getbyte_if_available(unsigned char *c);
 extern int	pq_putbytes(const char *s, size_t len);
 
diff --git a/src/include/postmaster/postmaster.h b/src/include/postmaster/postmaster.h
index cfa59c4dc0..38e7644371 100644
--- a/src/include/postmaster/postmaster.h
+++ b/src/include/postmaster/postmaster.h
@@ -21,6 +21,8 @@ extern int	Unix_socket_permissions;
 extern char *Unix_socket_group;
 extern char *Unix_socket_directories;
 extern char *ListenAddresses;
+extern char *TrustedProxyServersString;
+extern struct sockaddr_storage *TrustedProxyServers;
 extern bool ClientAuthInProgress;
 extern int	PreAuthDelay;
 extern int	AuthenticationTimeout;

Reply via email to