On Wed, Dec 27, 2017 at 09:27:40AM +0900, Michael Paquier wrote:
> On Tue, Dec 26, 2017 at 03:28:09PM -0500, Peter Eisentraut wrote:
>> On 12/22/17 03:10, Michael Paquier wrote:
>>> Second thoughts on 0002 as there is actually no need to move around
>>> errorMessage if the PGconn* pointer is saved in the SCRAM status data
>>> as both are linked. The attached simplifies the logic even more.
>>> 
>> 
>> That all looks pretty reasonable.
> 
> Thanks for the review. Don't you think that the the refactoring
> simplifications should be done first though? This would result in
> producing the patch set in reverse order. I'll be fine to produce them
> if need be.

Well, here is a patch set doing the reverse operation: refactoring does
first in 0001 and support for tls-server-end-point is in 0002. Hope this
helps.
--
Michael
From 551f9a037d6e38036998337f703758b41d2e1c72 Mon Sep 17 00:00:00 2001
From: Michael Paquier <mich...@paquier.xyz>
Date: Fri, 22 Dec 2017 17:04:19 +0900
Subject: [PATCH 1/2] Refactor channel binding code to fetch cbind_data only
 when necessary

As things stand now, channel binding data is fetched from OpenSSL and
saved into the SASL exchange context for any SSL connection attempted
for a SCRAM authentication, resulting in data fetched but not used if no
channel binding is used or if a different channel binding type is used
than what the data is here for.

Refactor the code in such a way that binding data is only fetched from
the SSL stack only when a specific channel binding is used for both the
frontend and the backend. In order to achieve that, save the libpq
connection context directly in the SCRAM exchange state, and add a
dependency to SSL in the low-level SCRAM routines.

This makes the interface in charge of initializing the SCRAM context
cleaner as all its data comes from either PGconn* (for frontend) or
Port* (for the backend).
---
 src/backend/libpq/auth-scram.c       |  34 +++-----
 src/backend/libpq/auth.c             |  19 +----
 src/include/libpq/scram.h            |   6 +-
 src/interfaces/libpq/fe-auth-scram.c | 159 +++++++++++++++++------------------
 src/interfaces/libpq/fe-auth.c       |  27 +-----
 src/interfaces/libpq/fe-auth.h       |  10 +--
 6 files changed, 104 insertions(+), 151 deletions(-)

diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index d52a763457..72973d3789 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -110,10 +110,8 @@ typedef struct
 
        const char *username;           /* username from startup packet */
 
+       Port       *port;
        char            cbind_flag;
-       bool            ssl_in_use;
-       const char *tls_finished_message;
-       size_t          tls_finished_len;
        char       *channel_binding_type;
 
        int                     iterations;
@@ -172,21 +170,15 @@ static char *scram_mock_salt(const char *username);
  * it will fail, as if an incorrect password was given.
  */
 void *
-pg_be_scram_init(const char *username,
-                                const char *shadow_pass,
-                                bool ssl_in_use,
-                                const char *tls_finished_message,
-                                size_t tls_finished_len)
+pg_be_scram_init(Port *port,
+                                const char *shadow_pass)
 {
        scram_state *state;
        bool            got_verifier;
 
        state = (scram_state *) palloc0(sizeof(scram_state));
+       state->port = port;
        state->state = SCRAM_AUTH_INIT;
-       state->username = username;
-       state->ssl_in_use = ssl_in_use;
-       state->tls_finished_message = tls_finished_message;
-       state->tls_finished_len = tls_finished_len;
        state->channel_binding_type = NULL;
 
        /*
@@ -209,7 +201,7 @@ pg_be_scram_init(const char *username,
                                 */
                                ereport(LOG,
                                                (errmsg("invalid SCRAM verifier 
for user \"%s\"",
-                                                               username)));
+                                                               
state->port->user_name)));
                                got_verifier = false;
                        }
                }
@@ -220,7 +212,7 @@ pg_be_scram_init(const char *username,
                         * authentication with an MD5 hash.)
                         */
                        state->logdetail = psprintf(_("User \"%s\" does not 
have a valid SCRAM verifier."),
-                                                                               
state->username);
+                                                                               
state->port->user_name);
                        got_verifier = false;
                }
        }
@@ -242,8 +234,8 @@ pg_be_scram_init(const char *username,
         */
        if (!got_verifier)
        {
-               mock_scram_verifier(username, &state->iterations, &state->salt,
-                                                       state->StoredKey, 
state->ServerKey);
+               mock_scram_verifier(state->port->user_name, &state->iterations,
+                                                       &state->salt, 
state->StoredKey, state->ServerKey);
                state->doomed = true;
        }
 
@@ -815,7 +807,7 @@ read_client_first_message(scram_state *state, char *input)
                         * it supports channel binding, which in this 
implementation is
                         * the case if a connection is using SSL.
                         */
-                       if (state->ssl_in_use)
+                       if (state->port->ssl_in_use)
                                ereport(ERROR,
                                                
(errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
                                                 errmsg("SCRAM channel binding 
negotiation error"),
@@ -839,7 +831,7 @@ read_client_first_message(scram_state *state, char *input)
                        {
                                char       *channel_binding_type;
 
-                               if (!state->ssl_in_use)
+                               if (!state->port->ssl_in_use)
                                {
                                        /*
                                         * Without SSL, we don't support 
channel binding.
@@ -1120,8 +1112,10 @@ read_client_final_message(scram_state *state, char 
*input)
                 */
                if (strcmp(state->channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
                {
-                       cbind_data = state->tls_finished_message;
-                       cbind_data_len = state->tls_finished_len;
+                       /* Fetch data from TLS finished message */
+#ifdef USE_SSL
+                       cbind_data = be_tls_get_peer_finished(state->port, 
&cbind_data_len);
+#endif
                }
                else
                {
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index b7f9bb1669..bd91e1cd18 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -873,8 +873,6 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
        int                     inputlen;
        int                     result;
        bool            initial;
-       char       *tls_finished = NULL;
-       size_t          tls_finished_len = 0;
 
        /*
         * SASL auth is not supported for protocol versions before 3, because it
@@ -915,17 +913,6 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
        sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs, p - sasl_mechs + 1);
        pfree(sasl_mechs);
 
-#ifdef USE_SSL
-
-       /*
-        * Get data for channel binding.
-        */
-       if (port->ssl_in_use)
-       {
-               tls_finished = be_tls_get_peer_finished(port, 
&tls_finished_len);
-       }
-#endif
-
        /*
         * Initialize the status tracker for message exchanges.
         *
@@ -937,11 +924,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char 
**logdetail)
         * This is because we don't want to reveal to an attacker what usernames
         * are valid, nor which users have a valid password.
         */
-       scram_opaq = pg_be_scram_init(port->user_name,
-                                                                 shadow_pass,
-                                                                 
port->ssl_in_use,
-                                                                 tls_finished,
-                                                                 
tls_finished_len);
+       scram_opaq = pg_be_scram_init(port, shadow_pass);
 
        /*
         * Loop through SASL message exchange.  This exchange can consist of
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index 2c245813d6..f404f57253 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -13,15 +13,15 @@
 #ifndef PG_SCRAM_H
 #define PG_SCRAM_H
 
+#include "libpq/libpq-be.h"
+
 /* Status codes for message exchange */
 #define SASL_EXCHANGE_CONTINUE         0
 #define SASL_EXCHANGE_SUCCESS          1
 #define SASL_EXCHANGE_FAILURE          2
 
 /* Routines dedicated to authentication */
-extern void *pg_be_scram_init(const char *username, const char *shadow_pass,
-                                bool ssl_in_use, const char 
*tls_finished_message,
-                                size_t tls_finished_len);
+extern void *pg_be_scram_init(Port *port, const char *shadow_pass);
 extern int pg_be_scram_exchange(void *opaq, char *input, int inputlen,
                                         char **output, int *outputlen, char 
**logdetail);
 
diff --git a/src/interfaces/libpq/fe-auth-scram.c 
b/src/interfaces/libpq/fe-auth-scram.c
index b8f7a6b5be..e8fc33c72f 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -42,13 +42,9 @@ typedef struct
        fe_scram_state_enum state;
 
        /* These are supplied by the user */
-       const char *username;
+       PGconn     *conn;
        char       *password;
-       bool            ssl_in_use;
-       char       *tls_finished_message;
-       size_t          tls_finished_len;
        char       *sasl_mechanism;
-       const char *channel_binding_type;
 
        /* We construct these */
        uint8           SaltedPassword[SCRAM_KEY_LEN];
@@ -68,14 +64,10 @@ typedef struct
        char            ServerSignature[SCRAM_KEY_LEN];
 } fe_scram_state;
 
-static bool read_server_first_message(fe_scram_state *state, char *input,
-                                                 PQExpBuffer errormessage);
-static bool read_server_final_message(fe_scram_state *state, char *input,
-                                                 PQExpBuffer errormessage);
-static char *build_client_first_message(fe_scram_state *state,
-                                                  PQExpBuffer errormessage);
-static char *build_client_final_message(fe_scram_state *state,
-                                                  PQExpBuffer errormessage);
+static bool read_server_first_message(fe_scram_state *state, char *input);
+static bool read_server_final_message(fe_scram_state *state, char *input);
+static char *build_client_first_message(fe_scram_state *state);
+static char *build_client_final_message(fe_scram_state *state);
 static bool verify_server_signature(fe_scram_state *state);
 static void calculate_client_proof(fe_scram_state *state,
                                           const char 
*client_final_message_without_proof,
@@ -89,13 +81,9 @@ static bool pg_frontend_random(char *dst, int len);
  * freed by pg_fe_scram_free().
  */
 void *
-pg_fe_scram_init(const char *username,
+pg_fe_scram_init(PGconn *conn,
                                 const char *password,
-                                bool ssl_in_use,
-                                const char *sasl_mechanism,
-                                const char *channel_binding_type,
-                                char *tls_finished_message,
-                                size_t tls_finished_len)
+                                const char *sasl_mechanism)
 {
        fe_scram_state *state;
        char       *prep_password;
@@ -107,13 +95,9 @@ pg_fe_scram_init(const char *username,
        if (!state)
                return NULL;
        memset(state, 0, sizeof(fe_scram_state));
+       state->conn = conn;
        state->state = FE_SCRAM_INIT;
-       state->username = username;
-       state->ssl_in_use = ssl_in_use;
-       state->tls_finished_message = tls_finished_message;
-       state->tls_finished_len = tls_finished_len;
        state->sasl_mechanism = strdup(sasl_mechanism);
-       state->channel_binding_type = channel_binding_type;
 
        if (!state->sasl_mechanism)
        {
@@ -154,10 +138,6 @@ pg_fe_scram_free(void *opaq)
 
        if (state->password)
                free(state->password);
-       if (state->tls_finished_message)
-               free(state->tls_finished_message);
-       if (state->sasl_mechanism)
-               free(state->sasl_mechanism);
 
        /* client messages */
        if (state->client_nonce)
@@ -188,9 +168,10 @@ pg_fe_scram_free(void *opaq)
 void
 pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
                                         char **output, int *outputlen,
-                                        bool *done, bool *success, PQExpBuffer 
errorMessage)
+                                        bool *done, bool *success)
 {
        fe_scram_state *state = (fe_scram_state *) opaq;
+       PGconn     *conn = state->conn;
 
        *done = false;
        *success = false;
@@ -205,13 +186,13 @@ pg_fe_scram_exchange(void *opaq, char *input, int 
inputlen,
        {
                if (inputlen == 0)
                {
-                       printfPQExpBuffer(errorMessage,
+                       printfPQExpBuffer(&conn->errorMessage,
                                                          
libpq_gettext("malformed SCRAM message (empty message)\n"));
                        goto error;
                }
                if (inputlen != strlen(input))
                {
-                       printfPQExpBuffer(errorMessage,
+                       printfPQExpBuffer(&conn->errorMessage,
                                                          
libpq_gettext("malformed SCRAM message (length mismatch)\n"));
                        goto error;
                }
@@ -221,7 +202,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
        {
                case FE_SCRAM_INIT:
                        /* Begin the SCRAM handshake, by sending client nonce */
-                       *output = build_client_first_message(state, 
errorMessage);
+                       *output = build_client_first_message(state);
                        if (*output == NULL)
                                goto error;
 
@@ -232,10 +213,10 @@ pg_fe_scram_exchange(void *opaq, char *input, int 
inputlen,
 
                case FE_SCRAM_NONCE_SENT:
                        /* Receive salt and server nonce, send response. */
-                       if (!read_server_first_message(state, input, 
errorMessage))
+                       if (!read_server_first_message(state, input))
                                goto error;
 
-                       *output = build_client_final_message(state, 
errorMessage);
+                       *output = build_client_final_message(state);
                        if (*output == NULL)
                                goto error;
 
@@ -246,7 +227,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
 
                case FE_SCRAM_PROOF_SENT:
                        /* Receive server signature */
-                       if (!read_server_final_message(state, input, 
errorMessage))
+                       if (!read_server_final_message(state, input))
                                goto error;
 
                        /*
@@ -260,7 +241,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
                        else
                        {
                                *success = false;
-                               printfPQExpBuffer(errorMessage,
+                               printfPQExpBuffer(&conn->errorMessage,
                                                                  
libpq_gettext("incorrect server signature\n"));
                        }
                        *done = true;
@@ -269,7 +250,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
 
                default:
                        /* shouldn't happen */
-                       printfPQExpBuffer(errorMessage,
+                       printfPQExpBuffer(&conn->errorMessage,
                                                          
libpq_gettext("invalid SCRAM exchange state\n"));
                        goto error;
        }
@@ -327,8 +308,9 @@ read_attr_value(char **input, char attr, PQExpBuffer 
errorMessage)
  * Build the first exchange message sent by the client.
  */
 static char *
-build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage)
+build_client_first_message(fe_scram_state *state)
 {
+       PGconn     *conn = state->conn;
        char            raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
        char       *result;
        int                     channel_info_len;
@@ -341,7 +323,7 @@ build_client_first_message(fe_scram_state *state, 
PQExpBuffer errormessage)
         */
        if (!pg_frontend_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
        {
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("could not 
generate nonce\n"));
                return NULL;
        }
@@ -349,7 +331,7 @@ build_client_first_message(fe_scram_state *state, 
PQExpBuffer errormessage)
        state->client_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
        if (state->client_nonce == NULL)
        {
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("out of 
memory\n"));
                return NULL;
        }
@@ -370,11 +352,11 @@ build_client_first_message(fe_scram_state *state, 
PQExpBuffer errormessage)
         */
        if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
        {
-               Assert(state->ssl_in_use);
-               appendPQExpBuffer(&buf, "p=%s", state->channel_binding_type);
+               Assert(conn->ssl_in_use);
+               appendPQExpBuffer(&buf, "p=%s", conn->scram_channel_binding);
        }
-       else if (state->channel_binding_type == NULL ||
-                        strlen(state->channel_binding_type) == 0)
+       else if (conn->scram_channel_binding == NULL ||
+                        strlen(conn->scram_channel_binding) == 0)
        {
                /*
                 * Client has chosen to not show to server that it supports 
channel
@@ -382,7 +364,7 @@ build_client_first_message(fe_scram_state *state, 
PQExpBuffer errormessage)
                 */
                appendPQExpBuffer(&buf, "n");
        }
-       else if (state->ssl_in_use)
+       else if (conn->ssl_in_use)
        {
                /*
                 * Client supports channel binding, but thinks the server does 
not.
@@ -423,7 +405,7 @@ build_client_first_message(fe_scram_state *state, 
PQExpBuffer errormessage)
 
 oom_error:
        termPQExpBuffer(&buf);
-       printfPQExpBuffer(errormessage,
+       printfPQExpBuffer(&conn->errorMessage,
                                          libpq_gettext("out of memory\n"));
        return NULL;
 }
@@ -432,9 +414,10 @@ oom_error:
  * Build the final exchange message sent from the client.
  */
 static char *
-build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage)
+build_client_final_message(fe_scram_state *state)
 {
        PQExpBufferData buf;
+       PGconn     *conn = state->conn;
        uint8           client_proof[SCRAM_KEY_LEN];
        char       *result;
 
@@ -450,22 +433,26 @@ build_client_final_message(fe_scram_state *state, 
PQExpBuffer errormessage)
         */
        if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
        {
-               char       *cbind_data;
-               size_t          cbind_data_len;
+               char       *cbind_data = NULL;
+               size_t          cbind_data_len = 0;
                size_t          cbind_header_len;
                char       *cbind_input;
                size_t          cbind_input_len;
 
-               if (strcmp(state->channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
+               if (strcmp(conn->scram_channel_binding, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
                {
-                       cbind_data = state->tls_finished_message;
-                       cbind_data_len = state->tls_finished_len;
+                       /* Fetch data from TLS finished message */
+#ifdef USE_SSL
+                       cbind_data = pgtls_get_finished(state->conn, 
&cbind_data_len);
+                       if (cbind_data == NULL)
+                               goto oom_error;
+#endif
                }
                else
                {
                        /* should not happen */
                        termPQExpBuffer(&buf);
-                       printfPQExpBuffer(errormessage,
+                       printfPQExpBuffer(&conn->errorMessage,
                                                          
libpq_gettext("invalid channel binding type\n"));
                        return NULL;
                }
@@ -473,37 +460,46 @@ build_client_final_message(fe_scram_state *state, 
PQExpBuffer errormessage)
                /* should not happen */
                if (cbind_data == NULL || cbind_data_len == 0)
                {
+                       if (cbind_data != NULL)
+                               free(cbind_data);
                        termPQExpBuffer(&buf);
-                       printfPQExpBuffer(errormessage,
+                       printfPQExpBuffer(&conn->errorMessage,
                                                          libpq_gettext("empty 
channel binding data for channel binding type \"%s\"\n"),
-                                                         
state->channel_binding_type);
+                                                         
conn->scram_channel_binding);
                        return NULL;
                }
 
                appendPQExpBuffer(&buf, "c=");
 
-               cbind_header_len = 4 + strlen(state->channel_binding_type); /* 
p=type,, */
+               /* p=type,, */
+               cbind_header_len = 4 + strlen(conn->scram_channel_binding);
                cbind_input_len = cbind_header_len + cbind_data_len;
                cbind_input = malloc(cbind_input_len);
                if (!cbind_input)
+               {
+                       free(cbind_data);
                        goto oom_error;
-               snprintf(cbind_input, cbind_input_len, "p=%s,,", 
state->channel_binding_type);
+               }
+               snprintf(cbind_input, cbind_input_len, "p=%s,,",
+                                conn->scram_channel_binding);
                memcpy(cbind_input + cbind_header_len, cbind_data, 
cbind_data_len);
 
                if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len)))
                {
+                       free(cbind_data);
                        free(cbind_input);
                        goto oom_error;
                }
                buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data 
+ buf.len);
                buf.data[buf.len] = '\0';
 
+               free(cbind_data);
                free(cbind_input);
        }
-       else if (state->channel_binding_type == NULL ||
-                        strlen(state->channel_binding_type) == 0)
+       else if (conn->scram_channel_binding == NULL ||
+                        strlen(conn->scram_channel_binding) == 0)
                appendPQExpBuffer(&buf, "c=biws");      /* base64 of "n,," */
-       else if (state->ssl_in_use)
+       else if (conn->ssl_in_use)
                appendPQExpBuffer(&buf, "c=eSws");      /* base64 of "y,," */
        else
                appendPQExpBuffer(&buf, "c=biws");      /* base64 of "n,," */
@@ -541,7 +537,7 @@ build_client_final_message(fe_scram_state *state, 
PQExpBuffer errormessage)
 
 oom_error:
        termPQExpBuffer(&buf);
-       printfPQExpBuffer(errormessage,
+       printfPQExpBuffer(&conn->errorMessage,
                                          libpq_gettext("out of memory\n"));
        return NULL;
 }
@@ -550,9 +546,9 @@ oom_error:
  * Read the first exchange message coming from the server.
  */
 static bool
-read_server_first_message(fe_scram_state *state, char *input,
-                                                 PQExpBuffer errormessage)
+read_server_first_message(fe_scram_state *state, char *input)
 {
+       PGconn     *conn = state->conn;
        char       *iterations_str;
        char       *endptr;
        char       *encoded_salt;
@@ -561,13 +557,14 @@ read_server_first_message(fe_scram_state *state, char 
*input,
        state->server_first_message = strdup(input);
        if (state->server_first_message == NULL)
        {
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("out of 
memory\n"));
                return false;
        }
 
        /* parse the message */
-       nonce = read_attr_value(&input, 'r', errormessage);
+       nonce = read_attr_value(&input, 'r',
+                                                       &conn->errorMessage);
        if (nonce == NULL)
        {
                /* read_attr_value() has generated an error string */
@@ -578,7 +575,7 @@ read_server_first_message(fe_scram_state *state, char 
*input,
        if (strlen(nonce) < strlen(state->client_nonce) ||
                memcmp(nonce, state->client_nonce, strlen(state->client_nonce)) 
!= 0)
        {
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("invalid SCRAM 
response (nonce mismatch)\n"));
                return false;
        }
@@ -586,12 +583,12 @@ read_server_first_message(fe_scram_state *state, char 
*input,
        state->nonce = strdup(nonce);
        if (state->nonce == NULL)
        {
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("out of 
memory\n"));
                return false;
        }
 
-       encoded_salt = read_attr_value(&input, 's', errormessage);
+       encoded_salt = read_attr_value(&input, 's', &conn->errorMessage);
        if (encoded_salt == NULL)
        {
                /* read_attr_value() has generated an error string */
@@ -600,7 +597,7 @@ read_server_first_message(fe_scram_state *state, char 
*input,
        state->salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
        if (state->salt == NULL)
        {
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("out of 
memory\n"));
                return false;
        }
@@ -608,7 +605,7 @@ read_server_first_message(fe_scram_state *state, char 
*input,
                                                                   
strlen(encoded_salt),
                                                                   state->salt);
 
-       iterations_str = read_attr_value(&input, 'i', errormessage);
+       iterations_str = read_attr_value(&input, 'i', &conn->errorMessage);
        if (iterations_str == NULL)
        {
                /* read_attr_value() has generated an error string */
@@ -617,13 +614,13 @@ read_server_first_message(fe_scram_state *state, char 
*input,
        state->iterations = strtol(iterations_str, &endptr, 10);
        if (*endptr != '\0' || state->iterations < 1)
        {
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("malformed 
SCRAM message (invalid iteration count)\n"));
                return false;
        }
 
        if (*input != '\0')
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("malformed 
SCRAM message (garbage at end of server-first-message)\n"));
 
        return true;
@@ -633,16 +630,16 @@ read_server_first_message(fe_scram_state *state, char 
*input,
  * Read the final exchange message coming from the server.
  */
 static bool
-read_server_final_message(fe_scram_state *state, char *input,
-                                                 PQExpBuffer errormessage)
+read_server_final_message(fe_scram_state *state, char *input)
 {
+       PGconn     *conn = state->conn;
        char       *encoded_server_signature;
        int                     server_signature_len;
 
        state->server_final_message = strdup(input);
        if (!state->server_final_message)
        {
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("out of 
memory\n"));
                return false;
        }
@@ -650,16 +647,18 @@ read_server_final_message(fe_scram_state *state, char 
*input,
        /* Check for error result. */
        if (*input == 'e')
        {
-               char       *errmsg = read_attr_value(&input, 'e', errormessage);
+               char       *errmsg = read_attr_value(&input, 'e',
+                                                                               
         &conn->errorMessage);
 
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("error received 
from server in SCRAM exchange: %s\n"),
                                                  errmsg);
                return false;
        }
 
        /* Parse the message. */
-       encoded_server_signature = read_attr_value(&input, 'v', errormessage);
+       encoded_server_signature = read_attr_value(&input, 'v',
+                                                                               
           &conn->errorMessage);
        if (encoded_server_signature == NULL)
        {
                /* read_attr_value() has generated an error message */
@@ -667,7 +666,7 @@ read_server_final_message(fe_scram_state *state, char 
*input,
        }
 
        if (*input != '\0')
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("malformed 
SCRAM message (garbage at end of server-final-message)\n"));
 
        server_signature_len = pg_b64_decode(encoded_server_signature,
@@ -675,7 +674,7 @@ read_server_final_message(fe_scram_state *state, char 
*input,
                                                                                
 state->ServerSignature);
        if (server_signature_len != SCRAM_KEY_LEN)
        {
-               printfPQExpBuffer(errormessage,
+               printfPQExpBuffer(&conn->errorMessage,
                                                  libpq_gettext("malformed 
SCRAM message (invalid server signature)\n"));
                return false;
        }
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 3340a9ad93..9c3524e553 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -491,8 +491,6 @@ pg_SASL_init(PGconn *conn, int payloadlen)
        bool            success;
        const char *selected_mechanism;
        PQExpBufferData mechanism_buf;
-       char       *tls_finished = NULL;
-       size_t          tls_finished_len = 0;
        char       *password;
 
        initPQExpBuffer(&mechanism_buf);
@@ -570,32 +568,15 @@ pg_SASL_init(PGconn *conn, int payloadlen)
                goto error;
        }
 
-#ifdef USE_SSL
-
-       /*
-        * Get data for channel binding.
-        */
-       if (strcmp(selected_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
-       {
-               tls_finished = pgtls_get_finished(conn, &tls_finished_len);
-               if (tls_finished == NULL)
-                       goto oom_error;
-       }
-#endif
-
        /*
         * Initialize the SASL state information with all the information 
gathered
         * during the initial exchange.
         *
         * Note: Only tls-unique is supported for the moment.
         */
-       conn->sasl_state = pg_fe_scram_init(conn->pguser,
+       conn->sasl_state = pg_fe_scram_init(conn,
                                                                                
password,
-                                                                               
conn->ssl_in_use,
-                                                                               
selected_mechanism,
-                                                                               
conn->scram_channel_binding,
-                                                                               
tls_finished,
-                                                                               
tls_finished_len);
+                                                                               
selected_mechanism);
        if (!conn->sasl_state)
                goto oom_error;
 
@@ -603,7 +584,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
        pg_fe_scram_exchange(conn->sasl_state,
                                                 NULL, -1,
                                                 &initialresponse, 
&initialresponselen,
-                                                &done, &success, 
&conn->errorMessage);
+                                                &done, &success);
 
        if (done && !success)
                goto error;
@@ -684,7 +665,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
        pg_fe_scram_exchange(conn->sasl_state,
                                                 challenge, payloadlen,
                                                 &output, &outputlen,
-                                                &done, &success, 
&conn->errorMessage);
+                                                &done, &success);
        free(challenge);                        /* don't need the input anymore 
*/
 
        if (final && !done)
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index db319ac071..1265d0d2f7 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -23,17 +23,13 @@ extern int  pg_fe_sendauth(AuthRequest areq, int 
payloadlen, PGconn *conn);
 extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
 
 /* Prototypes for functions in fe-auth-scram.c */
-extern void *pg_fe_scram_init(const char *username,
+extern void *pg_fe_scram_init(PGconn *conn,
                                 const char *password,
-                                bool ssl_in_use,
-                                const char *sasl_mechanism,
-                                const char *channel_binding_type,
-                                char *tls_finished_message,
-                                size_t tls_finished_len);
+                                const char *sasl_mechanism);
 extern void pg_fe_scram_free(void *opaq);
 extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
                                         char **output, int *outputlen,
-                                        bool *done, bool *success, PQExpBuffer 
errorMessage);
+                                        bool *done, bool *success);
 extern char *pg_fe_scram_build_verifier(const char *password);
 
 #endif                                                 /* FE_AUTH_H */
-- 
2.15.1

From db3c278177ded3a37431585bc85f60f646b976a8 Mon Sep 17 00:00:00 2001
From: Michael Paquier <mich...@paquier.xyz>
Date: Thu, 28 Dec 2017 16:12:54 +0900
Subject: [PATCH 2/2] Implement channel binding tls-server-end-point for SCRAM

As referenced in RFC 5929, this channel binding is not the default value
and uses a hash of the certificate as binding data. On the frontend,
this
can be resumed in getting the data from SSL_get_peer_certificate() and
on the backend SSL_get_certificate().

The hashing algorithm needs also to switch to SHA-256 if the signature
algorithm is MD5 or SHA-1, so let's be careful about that.
---
 doc/src/sgml/protocol.sgml               |  5 +-
 src/backend/libpq/auth-scram.c           | 21 +++++++--
 src/backend/libpq/be-secure-openssl.c    | 61 +++++++++++++++++++++++++
 src/include/common/scram-common.h        |  1 +
 src/include/libpq/libpq-be.h             |  1 +
 src/interfaces/libpq/fe-auth-scram.c     | 15 ++++++
 src/interfaces/libpq/fe-secure-openssl.c | 78 ++++++++++++++++++++++++++++++++
 src/interfaces/libpq/libpq-int.h         |  1 +
 src/test/ssl/t/002_scram.pl              |  5 +-
 9 files changed, 180 insertions(+), 8 deletions(-)

diff --git a/doc/src/sgml/protocol.sgml b/doc/src/sgml/protocol.sgml
index 8174e3defa..365f72b51d 100644
--- a/doc/src/sgml/protocol.sgml
+++ b/doc/src/sgml/protocol.sgml
@@ -1576,8 +1576,9 @@ the password is in.
   <para>
 <firstterm>Channel binding</firstterm> is supported in PostgreSQL builds with
 SSL support. The SASL mechanism name for SCRAM with channel binding
-is <literal>SCRAM-SHA-256-PLUS</literal>.  The only channel binding type
-supported at the moment is <literal>tls-unique</literal>, defined in RFC 5929.
+is <literal>SCRAM-SHA-256-PLUS</literal>.  Two channel binding types are
+supported at the moment: <literal>tls-unique</literal>, which is the default,
+and <literal>tls-server-end-point</literal>, both defined in RFC 5929.
   </para>
 
 <procedure>
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index 72973d3789..0a50f815ab 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -849,13 +849,15 @@ read_client_first_message(scram_state *state, char *input)
                                }
 
                                /*
-                                * Read value provided by client; only 
tls-unique is supported
-                                * for now.  (It is not safe to print the name 
of an
-                                * unsupported binding type in the error 
message.  Pranksters
-                                * could print arbitrary strings into the log 
that way.)
+                                * Read value provided by client; only 
tls-unique and
+                                * tls-server-end-point are supported for now.  
(It is
+                                * not safe to print the name of an unsupported 
binding
+                                * type in the error message.  Pranksters could 
print
+                                * arbitrary strings into the log that way.)
                                 */
                                channel_binding_type = read_attr_value(&input, 
'p');
-                               if (strcmp(channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) != 0)
+                               if (strcmp(channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_UNIQUE) != 0 &&
+                                       strcmp(channel_binding_type, 
SCRAM_CHANNEL_BINDING_TLS_ENDPOINT) != 0)
                                        ereport(ERROR,
                                                        
(errcode(ERRCODE_PROTOCOL_VIOLATION),
                                                         (errmsg("unsupported 
SCRAM channel-binding type"))));
@@ -1115,6 +1117,15 @@ read_client_final_message(scram_state *state, char 
*input)
                        /* Fetch data from TLS finished message */
 #ifdef USE_SSL
                        cbind_data = be_tls_get_peer_finished(state->port, 
&cbind_data_len);
+#endif
+               }
+               else if (strcmp(state->channel_binding_type,
+                                               
SCRAM_CHANNEL_BINDING_TLS_ENDPOINT) == 0)
+               {
+                       /* Fetch hash data of server's SSL certificate */
+#ifdef USE_SSL
+                       cbind_data = be_tls_get_certificate_hash(state->port,
+                                                                               
                         &cbind_data_len);
 #endif
                }
                else
diff --git a/src/backend/libpq/be-secure-openssl.c 
b/src/backend/libpq/be-secure-openssl.c
index 1e3e19f5e0..e3e8a535c8 100644
--- a/src/backend/libpq/be-secure-openssl.c
+++ b/src/backend/libpq/be-secure-openssl.c
@@ -1239,6 +1239,67 @@ be_tls_get_peer_finished(Port *port, size_t *len)
        return result;
 }
 
+/*
+ * Get the server certificate hash for authentication purposes. Per
+ * RFC 5929 and tls-server-end-point, the TLS server's certificate bytes
+ * need to be hashed with SHA-256 if its signature algorithm is MD5 or
+ * SHA-1 as per RFC 5929 (https://tools.ietf.org/html/rfc5929#section-4.1).
+ * If something else is used, the same hash as the signature algorithm is
+ * used. The result is a palloc'd hash of the server certificate with its
+ * size, and NULL if there is no certificate available.
+ */
+char *
+be_tls_get_certificate_hash(Port *port, size_t *len)
+{
+       char    *cert_hash = NULL;
+       X509    *server_cert;
+
+       *len = 0;
+       server_cert = SSL_get_certificate(port->ssl);
+
+       if (server_cert != NULL)
+       {
+               const EVP_MD   *algo_type = NULL;
+               char                    hash[EVP_MAX_MD_SIZE];  /* size for 
SHA-512 */
+               unsigned int    hash_size;
+               int                             algo_nid;
+
+               /*
+                * Get the signature algorithm of the certificate to determine 
the
+                * hash algorithm to use for the result.
+                */
+               if (!OBJ_find_sigid_algs(X509_get_signature_nid(server_cert),
+                                                                &algo_nid, 
NULL))
+                       elog(ERROR, "could not find signature algorithm");
+
+               switch (algo_nid)
+               {
+                       case NID_md5:
+                       case NID_sha1:
+                               algo_type = EVP_sha256();
+                               break;
+
+                       default:
+                               algo_type = EVP_get_digestbynid(algo_nid);
+                               if (algo_type == NULL)
+                                       elog(ERROR, "could not find digest for 
NID %s",
+                                                OBJ_nid2sn(algo_nid));
+                               break;
+               }
+
+               /* generate and save the certificate hash */
+               if (!X509_digest(server_cert, algo_type, (unsigned char *) hash,
+                                                &hash_size))
+                       elog(ERROR, "could not generate server certificate 
hash");
+
+               cert_hash = (char *) palloc(hash_size);
+               memcpy(cert_hash, hash, hash_size);
+               *len = hash_size;
+       }
+
+       return cert_hash;
+}
+
 /*
  * Convert an X509 subject name to a cstring.
  *
diff --git a/src/include/common/scram-common.h 
b/src/include/common/scram-common.h
index 857a60e71f..5aec5cadb8 100644
--- a/src/include/common/scram-common.h
+++ b/src/include/common/scram-common.h
@@ -21,6 +21,7 @@
 
 /* Channel binding types */
 #define SCRAM_CHANNEL_BINDING_TLS_UNIQUE    "tls-unique"
+#define SCRAM_CHANNEL_BINDING_TLS_ENDPOINT     "tls-server-end-point"
 
 /* Length of SCRAM keys (client and server) */
 #define SCRAM_KEY_LEN                          PG_SHA256_DIGEST_LENGTH
diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h
index 856e0439d5..cf9d8b7870 100644
--- a/src/include/libpq/libpq-be.h
+++ b/src/include/libpq/libpq-be.h
@@ -210,6 +210,7 @@ extern void be_tls_get_version(Port *port, char *ptr, 
size_t len);
 extern void be_tls_get_cipher(Port *port, char *ptr, size_t len);
 extern void be_tls_get_peerdn_name(Port *port, char *ptr, size_t len);
 extern char *be_tls_get_peer_finished(Port *port, size_t *len);
+extern char *be_tls_get_certificate_hash(Port *port, size_t *len);
 #endif
 
 extern ProtocolVersion FrontendProtocol;
diff --git a/src/interfaces/libpq/fe-auth-scram.c 
b/src/interfaces/libpq/fe-auth-scram.c
index e8fc33c72f..65817411b1 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -446,6 +446,21 @@ build_client_final_message(fe_scram_state *state)
                        cbind_data = pgtls_get_finished(state->conn, 
&cbind_data_len);
                        if (cbind_data == NULL)
                                goto oom_error;
+#endif
+               }
+               else if (strcmp(conn->scram_channel_binding,
+                                               
SCRAM_CHANNEL_BINDING_TLS_ENDPOINT) == 0)
+               {
+                       /* Fetch hash data of server's SSL certificate */
+#ifdef USE_SSL
+                       cbind_data =
+                               pgtls_get_peer_certificate_hash(state->conn,
+                                                                               
                &cbind_data_len);
+                       if (cbind_data == NULL)
+                       {
+                               /* error message is already set on error */
+                               return NULL;
+                       }
 #endif
                }
                else
diff --git a/src/interfaces/libpq/fe-secure-openssl.c 
b/src/interfaces/libpq/fe-secure-openssl.c
index 61d161b367..99077c3d9a 100644
--- a/src/interfaces/libpq/fe-secure-openssl.c
+++ b/src/interfaces/libpq/fe-secure-openssl.c
@@ -419,6 +419,84 @@ pgtls_get_finished(PGconn *conn, size_t *len)
        return result;
 }
 
+/*
+ *     Get the hash of the server certificate
+ *
+ * This information is useful for end-point channel binding, where the
+ * client certificate hash is used as a link, per RFC 5929. If the
+ * signature hash algorithm is MD5 or SHA-1, fall back to SHA-256,
+ * as per RFC 5929 (https://tools.ietf.org/html/rfc5929#section-4.1).
+ * NULL is sent back to the caller in the event of an error, with an
+ * error message for the caller to consume.
+ */
+char *
+pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len)
+{
+       char       *cert_hash = NULL;
+
+       *len = 0;
+
+       if (conn->peer)
+       {
+               X509               *peer_cert = conn->peer;
+               const EVP_MD   *algo_type = NULL;
+               char                    hash[EVP_MAX_MD_SIZE];  /* size for 
SHA-512 */
+               unsigned int    hash_size;
+               int                             algo_nid;
+
+               /*
+                * Get the signature algorithm of the certificate to determine 
the
+                * hash algorithm to use for the result.
+                */
+               if (!OBJ_find_sigid_algs(X509_get_signature_nid(peer_cert),
+                                                                &algo_nid, 
NULL))
+               {
+                       printfPQExpBuffer(&conn->errorMessage,
+                                                         libpq_gettext("could 
not find signature algorithm\n"));
+                       return NULL;
+               }
+
+               switch (algo_nid)
+               {
+                       case NID_md5:
+                       case NID_sha1:
+                               algo_type = EVP_sha256();
+                               break;
+
+                       default:
+                               algo_type = EVP_get_digestbynid(algo_nid);
+                               if (algo_type == NULL)
+                               {
+                                       printfPQExpBuffer(&conn->errorMessage,
+                                                                         
libpq_gettext("could not find digest for NID %s\n"),
+                                                                         
OBJ_nid2sn(algo_nid));
+                                       return NULL;
+                               }
+                               break;
+               }
+
+               if (!X509_digest(peer_cert, algo_type, (unsigned char *) hash,
+                                                &hash_size))
+               {
+                       printfPQExpBuffer(&conn->errorMessage,
+                                                         libpq_gettext("could 
not generate peer certificate hash\n"));
+                       return NULL;
+               }
+
+               /* save result */
+               cert_hash = (char *) malloc(hash_size);
+               if (cert_hash == NULL)
+               {
+                       printfPQExpBuffer(&conn->errorMessage,
+                                                         libpq_gettext("out of 
memory\n"));
+                       return NULL;
+               }
+               memcpy(cert_hash, hash, hash_size);
+               *len = hash_size;
+       }
+
+       return cert_hash;
+}
 
 /* ------------------------------------------------------------ */
 /*                                             OpenSSL specific code           
                        */
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index f6c1023f37..756c4d61e1 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -672,6 +672,7 @@ extern ssize_t pgtls_read(PGconn *conn, void *ptr, size_t 
len);
 extern bool pgtls_read_pending(PGconn *conn);
 extern ssize_t pgtls_write(PGconn *conn, const void *ptr, size_t len);
 extern char *pgtls_get_finished(PGconn *conn, size_t *len);
+extern char *pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len);
 
 /*
  * this is so that we can check if a connection is non-blocking internally
diff --git a/src/test/ssl/t/002_scram.pl b/src/test/ssl/t/002_scram.pl
index 324b4888d4..3f425e00f0 100644
--- a/src/test/ssl/t/002_scram.pl
+++ b/src/test/ssl/t/002_scram.pl
@@ -4,7 +4,7 @@ use strict;
 use warnings;
 use PostgresNode;
 use TestLib;
-use Test::More tests => 4;
+use Test::More tests => 5;
 use ServerSetup;
 use File::Copy;
 
@@ -45,6 +45,9 @@ test_connect_ok($common_connstr,
 test_connect_ok($common_connstr,
        "scram_channel_binding=''",
        "SCRAM authentication without channel binding");
+test_connect_ok($common_connstr,
+       "scram_channel_binding=tls-server-end-point",
+       "SCRAM authentication with tls-server-end-point as channel binding");
 test_connect_fails($common_connstr,
        "scram_channel_binding=not-exists",
        "SCRAM authentication with invalid channel binding");
-- 
2.15.1

Attachment: signature.asc
Description: PGP signature

Reply via email to