On Wed, Jan 12, 2022 at 12:56:17PM +0900, Michael Paquier wrote:
> Attached is a rebased patch for the HMAC portions, with a couple of
> fixes I noticed while going through this stuff again (mostly around
> SASLprep and pg_fe_scram_build_secret), and a fix for a conflict
> coming from 9cb5518.  psql's \password is wrong to assume that the
> only error that can happen for scran-sha-256 is an OOM, but we'll get
> there.

With an attachment, that's even better.  (Thanks, Daniel.)
--
Michael
From a6bcfefa9a8dd98bdc6f0e105f7b55dc8739c49e Mon Sep 17 00:00:00 2001
From: Michael Paquier <mich...@paquier.xyz>
Date: Wed, 12 Jan 2022 12:46:27 +0900
Subject: [PATCH v2] Improve HMAC error handling

---
 src/include/common/hmac.h            |  1 +
 src/include/common/scram-common.h    | 14 +++--
 src/backend/libpq/auth-scram.c       | 22 ++++---
 src/common/hmac.c                    | 64 ++++++++++++++++++++
 src/common/hmac_openssl.c            | 90 ++++++++++++++++++++++++++++
 src/common/scram-common.c            | 47 +++++++++++----
 src/interfaces/libpq/fe-auth-scram.c | 67 +++++++++++++++------
 src/interfaces/libpq/fe-auth.c       | 10 ++--
 src/interfaces/libpq/fe-auth.h       |  3 +-
 9 files changed, 269 insertions(+), 49 deletions(-)

diff --git a/src/include/common/hmac.h b/src/include/common/hmac.h
index cf7aa17be4..c18783fe11 100644
--- a/src/include/common/hmac.h
+++ b/src/include/common/hmac.h
@@ -25,5 +25,6 @@ extern int	pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len);
 extern int	pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len);
 extern int	pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len);
 extern void pg_hmac_free(pg_hmac_ctx *ctx);
+extern const char *pg_hmac_error(pg_hmac_ctx *ctx);
 
 #endif							/* PG_HMAC_H */
diff --git a/src/include/common/scram-common.h b/src/include/common/scram-common.h
index d53b4fa7f5..d1f840c11c 100644
--- a/src/include/common/scram-common.h
+++ b/src/include/common/scram-common.h
@@ -47,12 +47,16 @@
 #define SCRAM_DEFAULT_ITERATIONS	4096
 
 extern int	scram_SaltedPassword(const char *password, const char *salt,
-								 int saltlen, int iterations, uint8 *result);
-extern int	scram_H(const uint8 *str, int len, uint8 *result);
-extern int	scram_ClientKey(const uint8 *salted_password, uint8 *result);
-extern int	scram_ServerKey(const uint8 *salted_password, uint8 *result);
+								 int saltlen, int iterations, uint8 *result,
+								 const char **errstr);
+extern int	scram_H(const uint8 *str, int len, uint8 *result,
+					const char **errstr);
+extern int	scram_ClientKey(const uint8 *salted_password, uint8 *result,
+							const char **errstr);
+extern int	scram_ServerKey(const uint8 *salted_password, uint8 *result,
+							const char **errstr);
 
 extern char *scram_build_secret(const char *salt, int saltlen, int iterations,
-								const char *password);
+								const char *password, const char **errstr);
 
 #endif							/* SCRAM_COMMON_H */
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index 7c9dee70ce..ee7f52218a 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -465,6 +465,7 @@ pg_be_scram_build_secret(const char *password)
 	pg_saslprep_rc rc;
 	char		saltbuf[SCRAM_DEFAULT_SALT_LEN];
 	char	   *result;
+	const char *errstr = NULL;
 
 	/*
 	 * Normalize the password with SASLprep.  If that doesn't work, because
@@ -482,7 +483,8 @@ pg_be_scram_build_secret(const char *password)
 				 errmsg("could not generate random salt")));
 
 	result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
-								SCRAM_DEFAULT_ITERATIONS, password);
+								SCRAM_DEFAULT_ITERATIONS, password,
+								&errstr);
 
 	if (prep_password)
 		pfree(prep_password);
@@ -509,6 +511,7 @@ scram_verify_plain_password(const char *username, const char *password,
 	uint8		computed_key[SCRAM_KEY_LEN];
 	char	   *prep_password;
 	pg_saslprep_rc rc;
+	const char *errstr = NULL;
 
 	if (!parse_scram_secret(secret, &iterations, &encoded_salt,
 							stored_key, server_key))
@@ -539,10 +542,10 @@ scram_verify_plain_password(const char *username, const char *password,
 
 	/* Compute Server Key based on the user-supplied plaintext password */
 	if (scram_SaltedPassword(password, salt, saltlen, iterations,
-							 salted_password) < 0 ||
-		scram_ServerKey(salted_password, computed_key) < 0)
+							 salted_password, &errstr) < 0 ||
+		scram_ServerKey(salted_password, computed_key, &errstr) < 0)
 	{
-		elog(ERROR, "could not compute server key");
+		elog(ERROR, "could not compute server key: %s", errstr);
 	}
 
 	if (prep_password)
@@ -1113,6 +1116,7 @@ verify_client_proof(scram_state *state)
 	uint8		client_StoredKey[SCRAM_KEY_LEN];
 	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
 	int			i;
+	const char *errstr = NULL;
 
 	/*
 	 * Calculate ClientSignature.  Note that we don't log directly a failure
@@ -1133,7 +1137,8 @@ verify_client_proof(scram_state *state)
 					   strlen(state->client_final_message_without_proof)) < 0 ||
 		pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
 	{
-		elog(ERROR, "could not calculate client signature");
+		elog(ERROR, "could not calculate client signature: %s",
+			 pg_hmac_error(ctx));
 	}
 
 	pg_hmac_free(ctx);
@@ -1143,8 +1148,8 @@ verify_client_proof(scram_state *state)
 		ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
 
 	/* Hash it one more time, and compare with StoredKey */
-	if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey) < 0)
-		elog(ERROR, "could not hash stored key");
+	if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey, &errstr) < 0)
+		elog(ERROR, "could not hash stored key: %s", errstr);
 
 	if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
 		return false;
@@ -1389,7 +1394,8 @@ build_server_final_message(scram_state *state)
 					   strlen(state->client_final_message_without_proof)) < 0 ||
 		pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0)
 	{
-		elog(ERROR, "could not calculate server signature");
+		elog(ERROR, "could not calculate server signature: %s",
+			 pg_hmac_error(ctx));
 	}
 
 	pg_hmac_free(ctx);
diff --git a/src/common/hmac.c b/src/common/hmac.c
index 6e46dc28a1..592f9b20a3 100644
--- a/src/common/hmac.c
+++ b/src/common/hmac.c
@@ -38,6 +38,14 @@
 #define FREE(ptr) free(ptr)
 #endif
 
+/* Set of error states */
+typedef enum pg_hmac_errno
+{
+	PG_HMAC_ERROR_NONE = 0,
+	PG_HMAC_ERROR_OOM,
+	PG_HMAC_ERROR_INTERNAL
+} pg_hmac_errno;
+
 /*
  * Internal structure for pg_hmac_ctx->data with this implementation.
  */
@@ -45,6 +53,8 @@ struct pg_hmac_ctx
 {
 	pg_cryptohash_ctx *hash;
 	pg_cryptohash_type type;
+	pg_hmac_errno error;
+	const char *errreason;
 	int			block_size;
 	int			digest_size;
 
@@ -75,6 +85,8 @@ pg_hmac_create(pg_cryptohash_type type)
 		return NULL;
 	memset(ctx, 0, sizeof(pg_hmac_ctx));
 	ctx->type = type;
+	ctx->error = PG_HMAC_ERROR_NONE;
+	ctx->errreason = NULL;
 
 	/*
 	 * Initialize the context data.  This requires to know the digest and
@@ -152,12 +164,16 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
 		/* temporary buffer for one-time shrink */
 		shrinkbuf = ALLOC(digest_size);
 		if (shrinkbuf == NULL)
+		{
+			ctx->error = PG_HMAC_ERROR_OOM;
 			return -1;
+		}
 		memset(shrinkbuf, 0, digest_size);
 
 		hash_ctx = pg_cryptohash_create(ctx->type);
 		if (hash_ctx == NULL)
 		{
+			ctx->error = PG_HMAC_ERROR_OOM;
 			FREE(shrinkbuf);
 			return -1;
 		}
@@ -166,6 +182,8 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
 			pg_cryptohash_update(hash_ctx, key, len) < 0 ||
 			pg_cryptohash_final(hash_ctx, shrinkbuf, digest_size) < 0)
 		{
+			ctx->error = PG_HMAC_ERROR_INTERNAL;
+			ctx->errreason = pg_cryptohash_error(hash_ctx);
 			pg_cryptohash_free(hash_ctx);
 			FREE(shrinkbuf);
 			return -1;
@@ -186,6 +204,8 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
 	if (pg_cryptohash_init(ctx->hash) < 0 ||
 		pg_cryptohash_update(ctx->hash, ctx->k_ipad, ctx->block_size) < 0)
 	{
+		ctx->error = PG_HMAC_ERROR_INTERNAL;
+		ctx->errreason = pg_cryptohash_error(ctx->hash);
 		if (shrinkbuf)
 			FREE(shrinkbuf);
 		return -1;
@@ -208,7 +228,11 @@ pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len)
 		return -1;
 
 	if (pg_cryptohash_update(ctx->hash, data, len) < 0)
+	{
+		ctx->error = PG_HMAC_ERROR_INTERNAL;
+		ctx->errreason = pg_cryptohash_error(ctx->hash);
 		return -1;
+	}
 
 	return 0;
 }
@@ -228,11 +252,16 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
 
 	h = ALLOC(ctx->digest_size);
 	if (h == NULL)
+	{
+		ctx->error = PG_HMAC_ERROR_OOM;
 		return -1;
+	}
 	memset(h, 0, ctx->digest_size);
 
 	if (pg_cryptohash_final(ctx->hash, h, ctx->digest_size) < 0)
 	{
+		ctx->error = PG_HMAC_ERROR_INTERNAL;
+		ctx->errreason = pg_cryptohash_error(ctx->hash);
 		FREE(h);
 		return -1;
 	}
@@ -243,6 +272,8 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
 		pg_cryptohash_update(ctx->hash, h, ctx->digest_size) < 0 ||
 		pg_cryptohash_final(ctx->hash, dest, len) < 0)
 	{
+		ctx->error = PG_HMAC_ERROR_INTERNAL;
+		ctx->errreason = pg_cryptohash_error(ctx->hash);
 		FREE(h);
 		return -1;
 	}
@@ -266,3 +297,36 @@ pg_hmac_free(pg_hmac_ctx *ctx)
 	explicit_bzero(ctx, sizeof(pg_hmac_ctx));
 	FREE(ctx);
 }
+
+/*
+ * pg_hmac_error
+ *
+ * Returns a static string providing details about an error that happened
+ * during a HMAC computation.
+ */
+const char *
+pg_hmac_error(pg_hmac_ctx *ctx)
+{
+	if (ctx == NULL)
+		return _("out of memory");
+
+	/*
+	 * If a reason is provided, rely on it, else fallback to any error code
+	 * set.
+	 */
+	if (ctx->errreason)
+		return ctx->errreason;
+
+	switch (ctx->error)
+	{
+		case PG_HMAC_ERROR_NONE:
+			return _("success");
+		case PG_HMAC_ERROR_INTERNAL:
+			return _("internal error");
+		case PG_HMAC_ERROR_OOM:
+			return _("out of memory");
+	}
+
+	Assert(false);				/* cannot be reached */
+	return _("success");
+}
diff --git a/src/common/hmac_openssl.c b/src/common/hmac_openssl.c
index d2cb5474bb..c352f9db9e 100644
--- a/src/common/hmac_openssl.c
+++ b/src/common/hmac_openssl.c
@@ -20,6 +20,8 @@
 #include "postgres_fe.h"
 #endif
 
+
+#include <openssl/err.h>
 #include <openssl/hmac.h>
 
 #include "common/hmac.h"
@@ -50,6 +52,14 @@
 #define FREE(ptr) free(ptr)
 #endif							/* FRONTEND */
 
+/* Set of error states */
+typedef enum pg_hmac_errno
+{
+	PG_HMAC_ERROR_NONE = 0,
+	PG_HMAC_ERROR_DEST_LEN,
+	PG_HMAC_ERROR_OPENSSL
+} pg_hmac_errno;
+
 /*
  * Internal structure for pg_hmac_ctx->data with this implementation.
  */
@@ -57,12 +67,27 @@ struct pg_hmac_ctx
 {
 	HMAC_CTX   *hmacctx;
 	pg_cryptohash_type type;
+	pg_hmac_errno error;
+	const char *errreason;
 
 #ifndef FRONTEND
 	ResourceOwner resowner;
 #endif
 };
 
+static const char *
+SSLerrmessage(unsigned long ecode)
+{
+	if (ecode == 0)
+		return NULL;
+
+	/*
+	 * This may return NULL, but we would fall back to a default error path if
+	 * that were the case.
+	 */
+	return ERR_reason_error_string(ecode);
+}
+
 /*
  * pg_hmac_create
  *
@@ -80,6 +105,8 @@ pg_hmac_create(pg_cryptohash_type type)
 	memset(ctx, 0, sizeof(pg_hmac_ctx));
 
 	ctx->type = type;
+	ctx->error = PG_HMAC_ERROR_NONE;
+	ctx->errreason = NULL;
 
 	/*
 	 * Initialization takes care of assigning the correct type for OpenSSL.
@@ -154,7 +181,11 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
 
 	/* OpenSSL internals return 1 on success, 0 on failure */
 	if (status <= 0)
+	{
+		ctx->errreason = SSLerrmessage(ERR_get_error());
+		ctx->error = PG_HMAC_ERROR_OPENSSL;
 		return -1;
+	}
 
 	return 0;
 }
@@ -176,7 +207,11 @@ pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len)
 
 	/* OpenSSL internals return 1 on success, 0 on failure */
 	if (status <= 0)
+	{
+		ctx->errreason = SSLerrmessage(ERR_get_error());
+		ctx->error = PG_HMAC_ERROR_OPENSSL;
 		return -1;
+	}
 	return 0;
 }
 
@@ -198,27 +233,45 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
 	{
 		case PG_MD5:
 			if (len < MD5_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 		case PG_SHA1:
 			if (len < SHA1_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 		case PG_SHA224:
 			if (len < PG_SHA224_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 		case PG_SHA256:
 			if (len < PG_SHA256_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 		case PG_SHA384:
 			if (len < PG_SHA384_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 		case PG_SHA512:
 			if (len < PG_SHA512_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 	}
 
@@ -226,7 +279,11 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
 
 	/* OpenSSL internals return 1 on success, 0 on failure */
 	if (status <= 0)
+	{
+		ctx->errreason = SSLerrmessage(ERR_get_error());
+		ctx->error = PG_HMAC_ERROR_OPENSSL;
 		return -1;
+	}
 	return 0;
 }
 
@@ -254,3 +311,36 @@ pg_hmac_free(pg_hmac_ctx *ctx)
 	explicit_bzero(ctx, sizeof(pg_hmac_ctx));
 	FREE(ctx);
 }
+
+/*
+ * pg_hmac_error
+ *
+ * Returns a static string providing details about an error that happened
+ * during a HMAC computation.
+ */
+const char *
+pg_hmac_error(pg_hmac_ctx *ctx)
+{
+	if (ctx == NULL)
+		return _("out of memory");
+
+	/*
+	 * If a reason is provided, rely on it, else fallback to any error code
+	 * set.
+	 */
+	if (ctx->errreason)
+		return ctx->errreason;
+
+	switch (ctx->error)
+	{
+		case PG_HMAC_ERROR_NONE:
+			return _("success");
+		case PG_HMAC_ERROR_DEST_LEN:
+			return _("destination buffer too small");
+		case PG_HMAC_ERROR_OPENSSL:
+			return _("OpenSSL failure");
+	}
+
+	Assert(false);				/* cannot be reached */
+	return _("success");
+}
diff --git a/src/common/scram-common.c b/src/common/scram-common.c
index 23b68b14da..5f90397c66 100644
--- a/src/common/scram-common.c
+++ b/src/common/scram-common.c
@@ -33,7 +33,7 @@
 int
 scram_SaltedPassword(const char *password,
 					 const char *salt, int saltlen, int iterations,
-					 uint8 *result)
+					 uint8 *result, const char **errstr)
 {
 	int			password_len = strlen(password);
 	uint32		one = pg_hton32(1);
@@ -58,6 +58,7 @@ scram_SaltedPassword(const char *password,
 		pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
 		pg_hmac_final(hmac_ctx, Ui_prev, sizeof(Ui_prev)) < 0)
 	{
+		*errstr = pg_hmac_error(hmac_ctx);
 		pg_hmac_free(hmac_ctx);
 		return -1;
 	}
@@ -71,6 +72,7 @@ scram_SaltedPassword(const char *password,
 			pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, SCRAM_KEY_LEN) < 0 ||
 			pg_hmac_final(hmac_ctx, Ui, sizeof(Ui)) < 0)
 		{
+			*errstr = pg_hmac_error(hmac_ctx);
 			pg_hmac_free(hmac_ctx);
 			return -1;
 		}
@@ -90,18 +92,22 @@ scram_SaltedPassword(const char *password,
  * not included in the hash).  Returns 0 on success, -1 on failure.
  */
 int
-scram_H(const uint8 *input, int len, uint8 *result)
+scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
 {
 	pg_cryptohash_ctx *ctx;
 
 	ctx = pg_cryptohash_create(PG_SHA256);
 	if (ctx == NULL)
+	{
+		*errstr = pg_cryptohash_error(NULL);	/* returns OOM */
 		return -1;
+	}
 
 	if (pg_cryptohash_init(ctx) < 0 ||
 		pg_cryptohash_update(ctx, input, len) < 0 ||
 		pg_cryptohash_final(ctx, result, SCRAM_KEY_LEN) < 0)
 	{
+		*errstr = pg_cryptohash_error(ctx);
 		pg_cryptohash_free(ctx);
 		return -1;
 	}
@@ -114,7 +120,8 @@ scram_H(const uint8 *input, int len, uint8 *result)
  * Calculate ClientKey.  Returns 0 on success, -1 on failure.
  */
 int
-scram_ClientKey(const uint8 *salted_password, uint8 *result)
+scram_ClientKey(const uint8 *salted_password, uint8 *result,
+				const char **errstr)
 {
 	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
 
@@ -125,6 +132,7 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result)
 		pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
 		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
 	{
+		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
 		return -1;
 	}
@@ -137,17 +145,22 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result)
  * Calculate ServerKey.  Returns 0 on success, -1 on failure.
  */
 int
-scram_ServerKey(const uint8 *salted_password, uint8 *result)
+scram_ServerKey(const uint8 *salted_password, uint8 *result,
+				const char **errstr)
 {
 	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
 
 	if (ctx == NULL)
+	{
+		*errstr = pg_hmac_error(NULL);	/* returns OOM */
 		return -1;
+	}
 
 	if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
 		pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
 		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
 	{
+		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
 		return -1;
 	}
@@ -167,7 +180,7 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result)
  */
 char *
 scram_build_secret(const char *salt, int saltlen, int iterations,
-				   const char *password)
+				   const char *password, const char **errstr)
 {
 	uint8		salted_password[SCRAM_KEY_LEN];
 	uint8		stored_key[SCRAM_KEY_LEN];
@@ -185,15 +198,17 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 
 	/* Calculate StoredKey and ServerKey */
 	if (scram_SaltedPassword(password, salt, saltlen, iterations,
-							 salted_password) < 0 ||
-		scram_ClientKey(salted_password, stored_key) < 0 ||
-		scram_H(stored_key, SCRAM_KEY_LEN, stored_key) < 0 ||
-		scram_ServerKey(salted_password, server_key) < 0)
+							 salted_password, errstr) < 0 ||
+		scram_ClientKey(salted_password, stored_key, errstr) < 0 ||
+		scram_H(stored_key, SCRAM_KEY_LEN, stored_key, errstr) < 0 ||
+		scram_ServerKey(salted_password, server_key, errstr) < 0)
 	{
+		/* errstr is filled already here */
 #ifdef FRONTEND
 		return NULL;
 #else
-		elog(ERROR, "could not calculate stored key and server key");
+		elog(ERROR, "could not calculate stored key and server key: %s",
+			 *errstr);
 #endif
 	}
 
@@ -215,7 +230,10 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 #ifdef FRONTEND
 	result = malloc(maxlen);
 	if (!result)
+	{
+		*errstr = _("out of memory");
 		return NULL;
+	}
 #else
 	result = palloc(maxlen);
 #endif
@@ -226,11 +244,12 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
 	if (encoded_result < 0)
 	{
+		*errstr = _("could not encode salt");
 #ifdef FRONTEND
 		free(result);
 		return NULL;
 #else
-		elog(ERROR, "could not encode salt");
+		elog(ERROR, "%s", *errstr);
 #endif
 	}
 	p += encoded_result;
@@ -241,11 +260,12 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 								   encoded_stored_len);
 	if (encoded_result < 0)
 	{
+		*errstr = _("could not encode stored key");
 #ifdef FRONTEND
 		free(result);
 		return NULL;
 #else
-		elog(ERROR, "could not encode stored key");
+		elog(ERROR, "%s", *errstr);
 #endif
 	}
 
@@ -257,11 +277,12 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 								   encoded_server_len);
 	if (encoded_result < 0)
 	{
+		*errstr = _("could not encode server key");
 #ifdef FRONTEND
 		free(result);
 		return NULL;
 #else
-		elog(ERROR, "could not encode server key");
+		elog(ERROR, "%s", *errstr);
 #endif
 	}
 
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index cc41440c4e..b173d7d502 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -80,10 +80,11 @@ static bool read_server_first_message(fe_scram_state *state, char *input);
 static bool read_server_final_message(fe_scram_state *state, char *input);
 static char *build_client_first_message(fe_scram_state *state);
 static char *build_client_final_message(fe_scram_state *state);
-static bool verify_server_signature(fe_scram_state *state, bool *match);
+static bool verify_server_signature(fe_scram_state *state, bool *match,
+									const char **errstr);
 static bool calculate_client_proof(fe_scram_state *state,
 								   const char *client_final_message_without_proof,
-								   uint8 *result);
+								   uint8 *result, const char **errstr);
 
 /*
  * Initialize SCRAM exchange status.
@@ -211,6 +212,7 @@ scram_exchange(void *opaq, char *input, int inputlen,
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 	PGconn	   *conn = state->conn;
+	const char *errstr = NULL;
 
 	*done = false;
 	*success = false;
@@ -273,10 +275,10 @@ scram_exchange(void *opaq, char *input, int inputlen,
 			 * Verify server signature, to make sure we're talking to the
 			 * genuine server.
 			 */
-			if (!verify_server_signature(state, success))
+			if (!verify_server_signature(state, success, &errstr))
 			{
-				appendPQExpBufferStr(&conn->errorMessage,
-									 libpq_gettext("could not verify server signature\n"));
+				appendPQExpBuffer(&conn->errorMessage,
+								  libpq_gettext("could not verify server signature: %s\n"), errstr);
 				goto error;
 			}
 
@@ -469,6 +471,7 @@ build_client_final_message(fe_scram_state *state)
 	uint8		client_proof[SCRAM_KEY_LEN];
 	char	   *result;
 	int			encoded_len;
+	const char *errstr = NULL;
 
 	initPQExpBuffer(&buf);
 
@@ -572,11 +575,12 @@ build_client_final_message(fe_scram_state *state)
 	/* Append proof to it, to form client-final-message. */
 	if (!calculate_client_proof(state,
 								state->client_final_message_without_proof,
-								client_proof))
+								client_proof, &errstr))
 	{
 		termPQExpBuffer(&buf);
-		appendPQExpBufferStr(&conn->errorMessage,
-							 libpq_gettext("could not calculate client proof\n"));
+		appendPQExpBuffer(&conn->errorMessage,
+						  libpq_gettext("could not calculate client proof: %s\n"),
+						  errstr);
 		return NULL;
 	}
 
@@ -787,7 +791,7 @@ read_server_final_message(fe_scram_state *state, char *input)
 static bool
 calculate_client_proof(fe_scram_state *state,
 					   const char *client_final_message_without_proof,
-					   uint8 *result)
+					   uint8 *result, const char **errstr)
 {
 	uint8		StoredKey[SCRAM_KEY_LEN];
 	uint8		ClientKey[SCRAM_KEY_LEN];
@@ -797,17 +801,27 @@ calculate_client_proof(fe_scram_state *state,
 
 	ctx = pg_hmac_create(PG_SHA256);
 	if (ctx == NULL)
+	{
+		*errstr = pg_hmac_error(NULL);	/* returns OOM */
 		return false;
+	}
 
 	/*
 	 * Calculate SaltedPassword, and store it in 'state' so that we can reuse
 	 * it later in verify_server_signature.
 	 */
 	if (scram_SaltedPassword(state->password, state->salt, state->saltlen,
-							 state->iterations, state->SaltedPassword) < 0 ||
-		scram_ClientKey(state->SaltedPassword, ClientKey) < 0 ||
-		scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey) < 0 ||
-		pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 ||
+							 state->iterations, state->SaltedPassword,
+							 errstr) < 0 ||
+		scram_ClientKey(state->SaltedPassword, ClientKey, errstr) < 0 ||
+		scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey, errstr) < 0)
+	{
+		/* errstr is already filled here */
+		pg_hmac_free(ctx);
+		return false;
+	}
+
+	if (pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -821,6 +835,7 @@ calculate_client_proof(fe_scram_state *state,
 					   strlen(client_final_message_without_proof)) < 0 ||
 		pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
 	{
+		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
 		return false;
 	}
@@ -839,7 +854,8 @@ calculate_client_proof(fe_scram_state *state,
  * false for a processing error.
  */
 static bool
-verify_server_signature(fe_scram_state *state, bool *match)
+verify_server_signature(fe_scram_state *state, bool *match,
+						const char **errstr)
 {
 	uint8		expected_ServerSignature[SCRAM_KEY_LEN];
 	uint8		ServerKey[SCRAM_KEY_LEN];
@@ -847,11 +863,20 @@ verify_server_signature(fe_scram_state *state, bool *match)
 
 	ctx = pg_hmac_create(PG_SHA256);
 	if (ctx == NULL)
+	{
+		*errstr = pg_hmac_error(NULL);	/* returns OOM */
 		return false;
+	}
+
+	if (scram_ServerKey(state->SaltedPassword, ServerKey, errstr) < 0)
+	{
+		/* errstr is filled already */
+		pg_hmac_free(ctx);
+		return false;
+	}
 
-	if (scram_ServerKey(state->SaltedPassword, ServerKey) < 0 ||
 	/* calculate ServerSignature */
-		pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -866,6 +891,7 @@ verify_server_signature(fe_scram_state *state, bool *match)
 		pg_hmac_final(ctx, expected_ServerSignature,
 					  sizeof(expected_ServerSignature)) < 0)
 	{
+		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
 		return false;
 	}
@@ -885,7 +911,7 @@ verify_server_signature(fe_scram_state *state, bool *match)
  * Build a new SCRAM secret.
  */
 char *
-pg_fe_scram_build_secret(const char *password)
+pg_fe_scram_build_secret(const char *password, const char **errstr)
 {
 	char	   *prep_password;
 	pg_saslprep_rc rc;
@@ -899,20 +925,25 @@ pg_fe_scram_build_secret(const char *password)
 	 */
 	rc = pg_saslprep(password, &prep_password);
 	if (rc == SASLPREP_OOM)
+	{
+		*errstr = _("out of memory");
 		return NULL;
+	}
 	if (rc == SASLPREP_SUCCESS)
 		password = (const char *) prep_password;
 
 	/* Generate a random salt */
 	if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
 	{
+		*errstr = _("failed to generate random salt");
 		if (prep_password)
 			free(prep_password);
 		return NULL;
 	}
 
 	result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
-								SCRAM_DEFAULT_ITERATIONS, password);
+								SCRAM_DEFAULT_ITERATIONS, password,
+								errstr);
 
 	if (prep_password)
 		free(prep_password);
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 2edc3f48e2..f8f4111fef 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -1293,11 +1293,13 @@ PQencryptPasswordConn(PGconn *conn, const char *passwd, const char *user,
 	 */
 	if (strcmp(algorithm, "scram-sha-256") == 0)
 	{
-		crypt_pwd = pg_fe_scram_build_secret(passwd);
-		/* We assume the only possible failure is OOM */
+		const char *errstr = NULL;
+
+		crypt_pwd = pg_fe_scram_build_secret(passwd, &errstr);
 		if (!crypt_pwd)
-			appendPQExpBufferStr(&conn->errorMessage,
-								 libpq_gettext("out of memory\n"));
+			appendPQExpBuffer(&conn->errorMessage,
+							  libpq_gettext("could not encrypt password: %s\n"),
+							  errstr);
 	}
 	else if (strcmp(algorithm, "md5") == 0)
 	{
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index f22b3fe648..049a8bb1a1 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -25,6 +25,7 @@ extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
 
 /* Mechanisms in fe-auth-scram.c */
 extern const pg_fe_sasl_mech pg_scram_mech;
-extern char *pg_fe_scram_build_secret(const char *password);
+extern char *pg_fe_scram_build_secret(const char *password,
+									  const char **errstr);
 
 #endif							/* FE_AUTH_H */
-- 
2.34.1

Attachment: signature.asc
Description: PGP signature

Reply via email to