From 2d8148c3cf62e0e155553fb65efaf5b4c508f8f9 Mon Sep 17 00:00:00 2001
From: Michael Paquier <michael@paquier.xyz>
Date: Tue, 20 Jun 2017 11:41:10 +0900
Subject: [PATCH 3/4] Add connection parameters "saslname" and
 "saslchannelbinding"

Those parameters can be used to respectively enforce the value of the
SASL mechanism name and the channel binding name sent to server during
a SASL message exchange.

A set of tests dedicated to SASL and channel binding is added as well
to the SSL test suite, which is handy to check the validity of a patch.
---
 doc/src/sgml/libpq.sgml              | 24 ++++++++++++++++
 src/backend/libpq/auth-scram.c       | 41 +++++++++++++++++++++-------
 src/interfaces/libpq/fe-auth-scram.c | 53 ++++++++++++++++++++++++++++++++----
 src/interfaces/libpq/fe-auth.c       |  8 ++++++
 src/interfaces/libpq/fe-auth.h       |  4 +--
 src/interfaces/libpq/fe-connect.c    | 13 +++++++++
 src/interfaces/libpq/libpq-int.h     |  2 ++
 src/test/ssl/ServerSetup.pm          | 19 +++++++++++--
 src/test/ssl/t/001_ssltests.pl       |  2 +-
 src/test/ssl/t/002_sasl.pl           | 52 +++++++++++++++++++++++++++++++++++
 10 files changed, 197 insertions(+), 21 deletions(-)
 create mode 100644 src/test/ssl/t/002_sasl.pl

diff --git a/doc/src/sgml/libpq.sgml b/doc/src/sgml/libpq.sgml
index 096a8be605..cfcf6ee7c2 100644
--- a/doc/src/sgml/libpq.sgml
+++ b/doc/src/sgml/libpq.sgml
@@ -1220,6 +1220,30 @@ postgresql://%2Fvar%2Flib%2Fpostgresql/dbname
       </listitem>
      </varlistentry>
 
+     <varlistentry id="libpq-saslname" xreflabel="saslname">
+      <term><literal>saslname</literal></term>
+      <listitem>
+       <para>
+        Controls the name of the SASL mechanism name sent to server when doing
+        a message exchange for a SASL authentication. The list of SASL
+        mechanisms supported by server are listed in
+        <xref linkend="sasl-authentication">.
+       </para>
+      </listitem>
+     </varlistentry>
+
+     <varlistentry id="libpq-saslchannelbinding" xreflabel="saslchannelbinding">
+      <term><literal>saslchannelbinding</literal></term>
+      <listitem>
+       <para>
+        Controls the name of the channel binding name sent to server when doing
+        a message exchange for a SASL authentication. The list of channel
+        binding names supported by server are listed in
+        <xref linkend="sasl-authentication">.
+       </para>
+      </listitem>
+     </varlistentry>
+     
      <varlistentry id="libpq-connect-sslmode" xreflabel="sslmode">
       <term><literal>sslmode</literal></term>
       <listitem>
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index c90fa2981a..d5371a2708 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -113,6 +113,7 @@ typedef struct
 	bool		ssl_in_use;
 	char	   *tls_finish_message;
 	int			tls_finish_len;
+	char	   *channel_binding;
 
 	int			iterations;
 	char	   *salt;			/* base64-encoded */
@@ -185,6 +186,7 @@ pg_be_scram_init(const char *username,
 	state->ssl_in_use = ssl_in_use;
 	state->tls_finish_message = tls_finish_message;
 	state->tls_finish_len = tls_finish_len;
+	state->channel_binding = NULL;
 
 	/*
 	 * Parse the stored password verifier.
@@ -853,6 +855,9 @@ read_client_first_message(scram_state *state, char *input)
 					ereport(ERROR,
 						(errcode(ERRCODE_PROTOCOL_VIOLATION),
 						 (errmsg("unexpected SCRAM channel-binding type"))));
+
+				/* Save the name for handling of subsequent messages */
+				state->channel_binding = pstrdup(channel_name);
 #else
 				/*
 				 * Client requires channel binding.  We don't support it.
@@ -1106,20 +1111,36 @@ read_client_final_message(scram_state *state, char *input)
 #ifdef USE_SSL
 	if (state->ssl_in_use)
 	{
-		char	   *enc_tls_message;
-		int			enc_tls_len;
+		char	   *b64_message, *raw_data;
+		int			b64_message_len, raw_data_len;
+
+		/* Fetch data for each channel binding type */
+		if (strcmp(state->channel_binding, SCRAM_CHANNEL_TLS_UNIQUE) == 0)
+		{
+			raw_data = state->tls_finish_message;
+			raw_data_len = state->tls_finish_len;
+		}
+		else
+		{
+			/* should not happen */
+			elog(ERROR, "invalid channel binding type");
+		}
+
+		/* should not happen, but better safe than sorry */
+		if (raw_data == NULL)
+			elog(ERROR, "empty binding data for channel name \"%s\"",
+				 state->channel_binding);
 
-		enc_tls_message = palloc(pg_b64_enc_len(state->tls_finish_len) + 1);
-		enc_tls_len = pg_b64_encode(state->tls_finish_message,
-									state->tls_finish_len,
-									enc_tls_message);
-		enc_tls_message[enc_tls_len] = '\0';
+		b64_message = palloc(pg_b64_enc_len(raw_data_len) + 1);
+		b64_message_len = pg_b64_encode(raw_data, raw_data_len,
+										b64_message);
+		b64_message[b64_message_len] = '\0';
 
 		/*
-		 * Compare the value sent by the client with the TLS finish message
-		 * expected by the server.
+		 * Compare the value sent by the client with the value expected by
+		 * the server.
 		 */
-		if (strcmp(channel_binding, enc_tls_message) != 0)
+		if (strcmp(channel_binding, b64_message) != 0)
 			ereport(ERROR,
 					(errcode(ERRCODE_PROTOCOL_VIOLATION),
 					 (errmsg("no match for SCRAM channel-binding attribute in client-final-message"))));
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index e37a0cce27..5b8522391d 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -48,6 +48,8 @@ typedef struct
 	bool		ssl_in_use;
 	char	   *tls_finish_message;
 	int			tls_finish_len;
+	/* enforceable user parameters */
+	char	   *saslchannelbinding;	/* name of channel binding to use */
 
 	/* We construct these */
 	uint8		SaltedPassword[SCRAM_KEY_LEN];
@@ -88,6 +90,7 @@ void *
 pg_fe_scram_init(const char *username,
 				 const char *password,
 				 bool ssl_in_use,
+				 char *saslchannelbinding,
 				 char *tls_finish_message,
 				 int tls_finish_len)
 {
@@ -105,6 +108,15 @@ pg_fe_scram_init(const char *username,
 	state->tls_finish_message = tls_finish_message;
 	state->tls_finish_len = tls_finish_len;
 
+	/*
+	 * If user has specified a channel binding to use, enforce the
+	 * channel binding sent to it.  The default is "tls-unique".
+	 */
+	if (saslchannelbinding && strlen(saslchannelbinding) > 0)
+		state->saslchannelbinding = strdup(saslchannelbinding);
+	else
+		state->saslchannelbinding = strdup(SCRAM_CHANNEL_TLS_UNIQUE);
+
 	/* Normalize the password with SASLprep, if possible */
 	rc = pg_saslprep(password, &prep_password);
 	if (rc == SASLPREP_OOM)
@@ -136,6 +148,10 @@ pg_fe_scram_free(void *opaq)
 
 	if (state->password)
 		free(state->password);
+	if (state->tls_finish_message)
+		free(state->tls_finish_message);
+	if (state->saslchannelbinding)
+		free(state->saslchannelbinding);
 
 	/* client messages */
 	if (state->client_nonce)
@@ -353,7 +369,7 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage)
 	 */
 #ifdef USE_SSL
 	if (state->ssl_in_use)
-		appendPQExpBuffer(&buf, "p=%s", SCRAM_CHANNEL_TLS_UNIQUE);
+		appendPQExpBuffer(&buf, "p=%s", state->saslchannelbinding);
 	else
 		appendPQExpBuffer(&buf, "y");
 #else
@@ -412,12 +428,39 @@ build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage)
 #ifdef USE_SSL
 	if (state->ssl_in_use)
 	{
+		char	   *raw_data;
+		int			raw_data_len;
+
+		if (strcmp(state->saslchannelbinding, SCRAM_CHANNEL_TLS_UNIQUE) == 0)
+		{
+			raw_data = state->tls_finish_message;
+			raw_data_len = state->tls_finish_len;
+		}
+		else
+		{
+			/* should not happen */
+			termPQExpBuffer(&buf);
+			printfPQExpBuffer(errormessage,
+							  libpq_gettext("incorrect channel binding name\n"));
+			return NULL;
+		}
+
+		/* should not happen, but better safe than sorry */
+		if (raw_data == NULL)
+		{
+			/* should not happen */
+			termPQExpBuffer(&buf);
+			printfPQExpBuffer(errormessage,
+							  libpq_gettext("empty binding data for channel name \"%s\"\n"),
+							  state->saslchannelbinding);
+			return NULL;
+		}
+
 		appendPQExpBuffer(&buf, "c=");
-		if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(state->tls_finish_len)))
+
+		if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(raw_data_len)))
 			goto oom_error;
-		buf.len += pg_b64_encode(state->tls_finish_message,
-								 state->tls_finish_len,
-								 buf.data + buf.len);
+		buf.len += pg_b64_encode(raw_data, raw_data_len, buf.data + buf.len);
 		buf.data[buf.len] = '\0';
 	}
 	else
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 05f04be10a..4b26018f65 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -564,6 +564,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 			conn->sasl_state = pg_fe_scram_init(conn->pguser,
 												password,
 												conn->ssl_in_use,
+												conn->saslchannelbinding,
 												tls_finish,
 												tls_finish_len);
 			if (!conn->sasl_state)
@@ -589,6 +590,13 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 		}
 	}
 
+	/*
+	 * If user has asked for a specific mechanism name, enforce the chosen
+	 * name to it.
+	 */
+	if (conn->saslname && strlen(conn->saslname) > 0)
+		selected_mechanism = conn->saslname;
+
 	if (!selected_mechanism)
 	{
 		printfPQExpBuffer(&conn->errorMessage,
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index ee3dd7f96c..3c699959e0 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -24,8 +24,8 @@ extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
 
 /* Prototypes for functions in fe-auth-scram.c */
 extern void *pg_fe_scram_init(const char *username, const char *password,
-					 bool ssl_in_use, char *tls_finish_message,
-					 int tls_finish_len);
+					 bool ssl_in_use, char *saslchannelbinding,
+					 char *tls_finish_message, int tls_finish_len);
 extern void pg_fe_scram_free(void *opaq);
 extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
 					 char **output, int *outputlen,
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index c580d91135..27139579e9 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -262,6 +262,15 @@ static const internalPQconninfoOption PQconninfoOptions[] = {
 		"TCP-Keepalives-Count", "", 10, /* strlen(INT32_MAX) == 10 */
 	offsetof(struct pg_conn, keepalives_count)},
 
+	/* Set of options proper to SASL */
+	{"saslname", NULL, NULL, NULL,
+		"SASL-Name", "", 21,	/* maximum name size per IANA == 21 */
+	offsetof(struct pg_conn, saslname)},
+
+	{"saslchannelbinding", NULL, NULL, NULL,
+		"SASL-Channel", "", 22,	/* sizeof("tls-unique-for-telnet") == 22 */
+	offsetof(struct pg_conn, saslchannelbinding)},
+
 	/*
 	 * ssl options are allowed even without client SSL support because the
 	 * client can still handle SSL modes "disable" and "allow". Other
@@ -3470,6 +3479,10 @@ freePGconn(PGconn *conn)
 		free(conn->keepalives_interval);
 	if (conn->keepalives_count)
 		free(conn->keepalives_count);
+	if (conn->saslname)
+		free(conn->saslname);
+	if (conn->saslchannelbinding)
+		free(conn->saslchannelbinding);
 	if (conn->sslmode)
 		free(conn->sslmode);
 	if (conn->sslcert)
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 0eb8b60c95..6d500aa5db 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -349,6 +349,8 @@ struct pg_conn
 										 * retransmits */
 	char	   *keepalives_count;	/* maximum number of TCP keepalive
 									 * retransmits */
+	char	   *saslname;			/* SASL mechanism name */
+	char	   *saslchannelbinding;	/* channel binding used in SASL */
 	char	   *sslmode;		/* SSL mode (require,prefer,allow,disable) */
 	char	   *sslcompression; /* SSL compression (0 or 1) */
 	char	   *sslkey;			/* client key filename */
diff --git a/src/test/ssl/ServerSetup.pm b/src/test/ssl/ServerSetup.pm
index ad2e036602..b71969ac75 100644
--- a/src/test/ssl/ServerSetup.pm
+++ b/src/test/ssl/ServerSetup.pm
@@ -91,6 +91,9 @@ sub configure_test_server_for_ssl
 {
 	my $node       = $_[0];
 	my $serverhost = $_[1];
+	my $sslmethod  = $_[2];
+	my $passwd     = $_[3];
+	my $passwdhash = $_[4];
 
 	my $pgdata = $node->data_dir;
 
@@ -100,6 +103,15 @@ sub configure_test_server_for_ssl
 	$node->psql('postgres', "CREATE DATABASE trustdb");
 	$node->psql('postgres', "CREATE DATABASE certdb");
 
+	# Update password of each user as needed.
+	if (defined($passwd))
+	{
+		$node->psql('postgres',
+"SET password_encryption='$passwdhash'; ALTER USER ssltestuser PASSWORD '$passwd';");
+		$node->psql('postgres',
+"SET password_encryption='$passwdhash'; ALTER USER anotheruser PASSWORD '$passwd';");
+	}
+
 	# enable logging etc.
 	open my $conf, '>>', "$pgdata/postgresql.conf";
 	print $conf "fsync=off\n";
@@ -129,7 +141,7 @@ sub configure_test_server_for_ssl
 	$node->restart;
 
 	# Change pg_hba after restart because hostssl requires ssl=on
-	configure_hba_for_ssl($node, $serverhost);
+	configure_hba_for_ssl($node, $serverhost, $sslmethod);
 }
 
 # Change the configuration to use given server cert file, and reload
@@ -159,6 +171,7 @@ sub configure_hba_for_ssl
 {
 	my $node       = $_[0];
 	my $serverhost = $_[1];
+	my $sslmethod  = $_[2];
 	my $pgdata     = $node->data_dir;
 
   # Only accept SSL connections from localhost. Our tests don't depend on this
@@ -169,9 +182,9 @@ sub configure_hba_for_ssl
 	print $hba
 "# TYPE  DATABASE        USER            ADDRESS                 METHOD\n";
 	print $hba
-"hostssl trustdb         ssltestuser     $serverhost/32            trust\n";
+"hostssl trustdb         ssltestuser     $serverhost/32          $sslmethod\n";
 	print $hba
-"hostssl trustdb         ssltestuser     ::1/128                 trust\n";
+"hostssl trustdb         ssltestuser     ::1/128                 $sslmethod\n";
 	print $hba
 "hostssl certdb          ssltestuser     $serverhost/32            cert\n";
 	print $hba
diff --git a/src/test/ssl/t/001_ssltests.pl b/src/test/ssl/t/001_ssltests.pl
index 890e3051a2..e690c1fa15 100644
--- a/src/test/ssl/t/001_ssltests.pl
+++ b/src/test/ssl/t/001_ssltests.pl
@@ -32,7 +32,7 @@ $node->init;
 $ENV{PGHOST} = $node->host;
 $ENV{PGPORT} = $node->port;
 $node->start;
-configure_test_server_for_ssl($node, $SERVERHOSTADDR);
+configure_test_server_for_ssl($node, $SERVERHOSTADDR, "trust");
 switch_server_cert($node, 'server-cn-only');
 
 ### Part 1. Run client-side tests.
diff --git a/src/test/ssl/t/002_sasl.pl b/src/test/ssl/t/002_sasl.pl
new file mode 100644
index 0000000000..a625f0d473
--- /dev/null
+++ b/src/test/ssl/t/002_sasl.pl
@@ -0,0 +1,52 @@
+use strict;
+use warnings;
+use PostgresNode;
+use TestLib;
+use Test::More tests => 6;
+use ServerSetup;
+use File::Copy;
+
+# test combinations of SASL authentication for SCRAM mechanism:
+# - SCRAM-SHA-256 and SCRAM-SHA-256-PLUS
+# - Channel bindings
+
+# This is the hostname used to connect to the server.
+my $SERVERHOSTADDR = '127.0.0.1';
+
+# Allocation of base connection string shared among multiple tests.
+my $common_connstr;
+
+#### Part 0. Set up the server.
+
+note "setting up data directory";
+my $node = get_new_node('master');
+$node->init;
+
+# PGHOST is enforced here to set up the node, subsequent connections
+# will use a dedicated connection string.
+$ENV{PGHOST} = $node->host;
+$ENV{PGPORT} = $node->port;
+$node->start;
+
+# Configure server for SSL connections, with password handling.
+configure_test_server_for_ssl($node, $SERVERHOSTADDR, "scram-sha-256",
+							  "pass", "scram-sha-256");
+switch_server_cert($node, 'server-cn-only');
+$ENV{PGPASSWORD} = "pass";
+$common_connstr =
+"user=ssltestuser dbname=trustdb sslmode=require hostaddr=$SERVERHOSTADDR";
+
+# Tests with default channel binding and SASL mechanism names.
+# tls-unique is used here
+test_connect_ok($common_connstr, "saslname=SCRAM-SHA-256-PLUS");
+test_connect_fails($common_connstr, "saslname=not-exists");
+# Downgrade attack.
+test_connect_fails($common_connstr, "saslname=SCRAM-SHA-256");
+test_connect_fails($common_connstr,
+		"saslname=SCRAM-SHA-256 saslchannelbinding=tls-unique");
+
+# Channel bindings
+test_connect_ok($common_connstr,
+		"saslname=SCRAM-SHA-256-PLUS saslchannelbinding=tls-unique");
+test_connect_fails($common_connstr,
+		"saslname=SCRAM-SHA-256-PLUS saslchannelbinding=not-exists");
-- 
2.14.1

