Hi all,

This is a follow-up of the work done in b69aba7 for cryptohashes, but
this time for HMAC.  The main issue here is related to SCRAM, where we
have a lot of code paths that have no idea about what kind of failure
is happening when an error happens, and this exists since v10 where
SCRAM has been introduced, for some of them, frontend and backend
included.  \password is one example.

The set of errors improved here would only trigger in scenarios that
are unlikely going to happen, like an OOM or an internal OpenSSL
error.  It would be possible to create a HMAC from a MD5, which would
cause an error when compiling with OpenSSL and FIPS enabled, but the
only callers of the pg_hmac_* routines involve SHA-256 in core through
SCRAM, so I don't see much a point in backpatching any of the things
proposed here.

The attached patch creates a new routine call pg_hmac_error() that one
can use to grab details about the error that happened, in the same
fashion as what has been done for cryptohashes.  The logic is not that
complicated, but note that the fallback HMAC implementation relies
itself on cryptohashes, so there are cases where we need to look at
the error from pg_cryptohash_error() and store it in the HMAC private
context.

Thoughts?
--
Michael
From 687dd48a8150fae4597b126d68f6758b52ff67cb Mon Sep 17 00:00:00 2001
From: Michael Paquier <mich...@paquier.xyz>
Date: Tue, 11 Jan 2022 13:47:06 +0900
Subject: [PATCH] Improve HMAC error handling

---
 src/include/common/hmac.h            |  1 +
 src/include/common/scram-common.h    | 14 +++--
 src/backend/libpq/auth-scram.c       | 22 ++++---
 src/common/hmac.c                    | 64 ++++++++++++++++++++
 src/common/hmac_openssl.c            | 90 ++++++++++++++++++++++++++++
 src/common/scram-common.c            | 47 +++++++++++----
 src/interfaces/libpq/fe-auth-scram.c | 63 +++++++++++++------
 src/interfaces/libpq/fe-auth.c       | 17 ++++--
 src/interfaces/libpq/fe-auth.h       |  3 +-
 9 files changed, 271 insertions(+), 50 deletions(-)

diff --git a/src/include/common/hmac.h b/src/include/common/hmac.h
index cf7aa17be4..c18783fe11 100644
--- a/src/include/common/hmac.h
+++ b/src/include/common/hmac.h
@@ -25,5 +25,6 @@ extern int	pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len);
 extern int	pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len);
 extern int	pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len);
 extern void pg_hmac_free(pg_hmac_ctx *ctx);
+extern const char *pg_hmac_error(pg_hmac_ctx *ctx);
 
 #endif							/* PG_HMAC_H */
diff --git a/src/include/common/scram-common.h b/src/include/common/scram-common.h
index d53b4fa7f5..d1f840c11c 100644
--- a/src/include/common/scram-common.h
+++ b/src/include/common/scram-common.h
@@ -47,12 +47,16 @@
 #define SCRAM_DEFAULT_ITERATIONS	4096
 
 extern int	scram_SaltedPassword(const char *password, const char *salt,
-								 int saltlen, int iterations, uint8 *result);
-extern int	scram_H(const uint8 *str, int len, uint8 *result);
-extern int	scram_ClientKey(const uint8 *salted_password, uint8 *result);
-extern int	scram_ServerKey(const uint8 *salted_password, uint8 *result);
+								 int saltlen, int iterations, uint8 *result,
+								 const char **errstr);
+extern int	scram_H(const uint8 *str, int len, uint8 *result,
+					const char **errstr);
+extern int	scram_ClientKey(const uint8 *salted_password, uint8 *result,
+							const char **errstr);
+extern int	scram_ServerKey(const uint8 *salted_password, uint8 *result,
+							const char **errstr);
 
 extern char *scram_build_secret(const char *salt, int saltlen, int iterations,
-								const char *password);
+								const char *password, const char **errstr);
 
 #endif							/* SCRAM_COMMON_H */
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index 7c9dee70ce..ee7f52218a 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -465,6 +465,7 @@ pg_be_scram_build_secret(const char *password)
 	pg_saslprep_rc rc;
 	char		saltbuf[SCRAM_DEFAULT_SALT_LEN];
 	char	   *result;
+	const char *errstr = NULL;
 
 	/*
 	 * Normalize the password with SASLprep.  If that doesn't work, because
@@ -482,7 +483,8 @@ pg_be_scram_build_secret(const char *password)
 				 errmsg("could not generate random salt")));
 
 	result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
-								SCRAM_DEFAULT_ITERATIONS, password);
+								SCRAM_DEFAULT_ITERATIONS, password,
+								&errstr);
 
 	if (prep_password)
 		pfree(prep_password);
@@ -509,6 +511,7 @@ scram_verify_plain_password(const char *username, const char *password,
 	uint8		computed_key[SCRAM_KEY_LEN];
 	char	   *prep_password;
 	pg_saslprep_rc rc;
+	const char *errstr = NULL;
 
 	if (!parse_scram_secret(secret, &iterations, &encoded_salt,
 							stored_key, server_key))
@@ -539,10 +542,10 @@ scram_verify_plain_password(const char *username, const char *password,
 
 	/* Compute Server Key based on the user-supplied plaintext password */
 	if (scram_SaltedPassword(password, salt, saltlen, iterations,
-							 salted_password) < 0 ||
-		scram_ServerKey(salted_password, computed_key) < 0)
+							 salted_password, &errstr) < 0 ||
+		scram_ServerKey(salted_password, computed_key, &errstr) < 0)
 	{
-		elog(ERROR, "could not compute server key");
+		elog(ERROR, "could not compute server key: %s", errstr);
 	}
 
 	if (prep_password)
@@ -1113,6 +1116,7 @@ verify_client_proof(scram_state *state)
 	uint8		client_StoredKey[SCRAM_KEY_LEN];
 	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
 	int			i;
+	const char *errstr = NULL;
 
 	/*
 	 * Calculate ClientSignature.  Note that we don't log directly a failure
@@ -1133,7 +1137,8 @@ verify_client_proof(scram_state *state)
 					   strlen(state->client_final_message_without_proof)) < 0 ||
 		pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
 	{
-		elog(ERROR, "could not calculate client signature");
+		elog(ERROR, "could not calculate client signature: %s",
+			 pg_hmac_error(ctx));
 	}
 
 	pg_hmac_free(ctx);
@@ -1143,8 +1148,8 @@ verify_client_proof(scram_state *state)
 		ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
 
 	/* Hash it one more time, and compare with StoredKey */
-	if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey) < 0)
-		elog(ERROR, "could not hash stored key");
+	if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey, &errstr) < 0)
+		elog(ERROR, "could not hash stored key: %s", errstr);
 
 	if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
 		return false;
@@ -1389,7 +1394,8 @@ build_server_final_message(scram_state *state)
 					   strlen(state->client_final_message_without_proof)) < 0 ||
 		pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0)
 	{
-		elog(ERROR, "could not calculate server signature");
+		elog(ERROR, "could not calculate server signature: %s",
+			 pg_hmac_error(ctx));
 	}
 
 	pg_hmac_free(ctx);
diff --git a/src/common/hmac.c b/src/common/hmac.c
index 6e46dc28a1..8a09dad585 100644
--- a/src/common/hmac.c
+++ b/src/common/hmac.c
@@ -38,6 +38,14 @@
 #define FREE(ptr) free(ptr)
 #endif
 
+/* Set of error states */
+typedef enum pg_hmac_errno
+{
+	PG_HMAC_ERROR_NONE = 0,
+	PG_HMAC_ERROR_OOM,
+	PG_HMAC_ERROR_INTERNAL
+} pg_hmac_errno;
+
 /*
  * Internal structure for pg_hmac_ctx->data with this implementation.
  */
@@ -45,6 +53,8 @@ struct pg_hmac_ctx
 {
 	pg_cryptohash_ctx *hash;
 	pg_cryptohash_type type;
+	pg_hmac_errno error;
+	const char *errreason;
 	int			block_size;
 	int			digest_size;
 
@@ -75,6 +85,8 @@ pg_hmac_create(pg_cryptohash_type type)
 		return NULL;
 	memset(ctx, 0, sizeof(pg_hmac_ctx));
 	ctx->type = type;
+	ctx->error = PG_HMAC_ERROR_NONE;
+	ctx->errreason = NULL;
 
 	/*
 	 * Initialize the context data.  This requires to know the digest and
@@ -152,12 +164,16 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
 		/* temporary buffer for one-time shrink */
 		shrinkbuf = ALLOC(digest_size);
 		if (shrinkbuf == NULL)
+		{
+			ctx->error = PG_HMAC_ERROR_OOM;
 			return -1;
+		}
 		memset(shrinkbuf, 0, digest_size);
 
 		hash_ctx = pg_cryptohash_create(ctx->type);
 		if (hash_ctx == NULL)
 		{
+			ctx->error = PG_HMAC_ERROR_OOM;
 			FREE(shrinkbuf);
 			return -1;
 		}
@@ -166,6 +182,8 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
 			pg_cryptohash_update(hash_ctx, key, len) < 0 ||
 			pg_cryptohash_final(hash_ctx, shrinkbuf, digest_size) < 0)
 		{
+			ctx->error = PG_HMAC_ERROR_INTERNAL;
+			ctx->errreason = pg_cryptohash_error(hash_ctx);
 			pg_cryptohash_free(hash_ctx);
 			FREE(shrinkbuf);
 			return -1;
@@ -186,6 +204,8 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
 	if (pg_cryptohash_init(ctx->hash) < 0 ||
 		pg_cryptohash_update(ctx->hash, ctx->k_ipad, ctx->block_size) < 0)
 	{
+		ctx->error = PG_HMAC_ERROR_INTERNAL;
+		ctx->errreason = pg_cryptohash_error(ctx->hash);
 		if (shrinkbuf)
 			FREE(shrinkbuf);
 		return -1;
@@ -208,7 +228,11 @@ pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len)
 		return -1;
 
 	if (pg_cryptohash_update(ctx->hash, data, len) < 0)
+	{
+		ctx->error = PG_HMAC_ERROR_INTERNAL;
+		ctx->errreason = pg_cryptohash_error(ctx->hash);
 		return -1;
+	}
 
 	return 0;
 }
@@ -228,11 +252,16 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
 
 	h = ALLOC(ctx->digest_size);
 	if (h == NULL)
+	{
+		ctx->error = PG_HMAC_ERROR_OOM;
 		return -1;
+	}
 	memset(h, 0, ctx->digest_size);
 
 	if (pg_cryptohash_final(ctx->hash, h, ctx->digest_size) < 0)
 	{
+		ctx->error = PG_HMAC_ERROR_INTERNAL;
+		ctx->errreason = pg_cryptohash_error(ctx->hash);
 		FREE(h);
 		return -1;
 	}
@@ -243,6 +272,8 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
 		pg_cryptohash_update(ctx->hash, h, ctx->digest_size) < 0 ||
 		pg_cryptohash_final(ctx->hash, dest, len) < 0)
 	{
+		ctx->error = PG_HMAC_ERROR_INTERNAL;
+		ctx->errreason = pg_cryptohash_error(ctx->hash);
 		FREE(h);
 		return -1;
 	}
@@ -266,3 +297,36 @@ pg_hmac_free(pg_hmac_ctx *ctx)
 	explicit_bzero(ctx, sizeof(pg_hmac_ctx));
 	FREE(ctx);
 }
+
+/*
+ * pg_hmac_error
+ *
+ * Returns a static string providing errors about an error that happened
+ * during a HMAC computation.
+ */
+const char *
+pg_hmac_error(pg_hmac_ctx *ctx)
+{
+	if (ctx == NULL)
+		return _("out of memory");
+
+	/*
+	 * If a reason is provided, rely on it, else fallback to any error code
+	 * set.
+	 */
+	if (ctx->errreason)
+		return ctx->errreason;
+
+	switch (ctx->error)
+	{
+		case PG_HMAC_ERROR_NONE:
+			return _("success");
+		case PG_HMAC_ERROR_INTERNAL:
+			return _("internal error");
+		case PG_HMAC_ERROR_OOM:
+			return _("out of memory");
+	}
+
+	Assert(false);				/* cannot be reached */
+	return _("success");
+}
diff --git a/src/common/hmac_openssl.c b/src/common/hmac_openssl.c
index d2cb5474bb..47489ac704 100644
--- a/src/common/hmac_openssl.c
+++ b/src/common/hmac_openssl.c
@@ -20,6 +20,8 @@
 #include "postgres_fe.h"
 #endif
 
+
+#include <openssl/err.h>
 #include <openssl/hmac.h>
 
 #include "common/hmac.h"
@@ -50,6 +52,14 @@
 #define FREE(ptr) free(ptr)
 #endif							/* FRONTEND */
 
+/* Set of error states */
+typedef enum pg_hmac_errno
+{
+	PG_HMAC_ERROR_NONE = 0,
+	PG_HMAC_ERROR_DEST_LEN,
+	PG_HMAC_ERROR_OPENSSL
+} pg_hmac_errno;
+
 /*
  * Internal structure for pg_hmac_ctx->data with this implementation.
  */
@@ -57,12 +67,27 @@ struct pg_hmac_ctx
 {
 	HMAC_CTX   *hmacctx;
 	pg_cryptohash_type type;
+	pg_hmac_errno error;
+	const char *errreason;
 
 #ifndef FRONTEND
 	ResourceOwner resowner;
 #endif
 };
 
+static const char *
+SSLerrmessage(unsigned long ecode)
+{
+	if (ecode == 0)
+		return NULL;
+
+	/*
+	 * This may return NULL, but we would fall back to a default error path if
+	 * that were the case.
+	 */
+	return ERR_reason_error_string(ecode);
+}
+
 /*
  * pg_hmac_create
  *
@@ -80,6 +105,8 @@ pg_hmac_create(pg_cryptohash_type type)
 	memset(ctx, 0, sizeof(pg_hmac_ctx));
 
 	ctx->type = type;
+	ctx->error = PG_HMAC_ERROR_NONE;
+	ctx->errreason = NULL;
 
 	/*
 	 * Initialization takes care of assigning the correct type for OpenSSL.
@@ -154,7 +181,11 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
 
 	/* OpenSSL internals return 1 on success, 0 on failure */
 	if (status <= 0)
+	{
+		ctx->errreason = SSLerrmessage(ERR_get_error());
+		ctx->error = PG_HMAC_ERROR_OPENSSL;
 		return -1;
+	}
 
 	return 0;
 }
@@ -176,7 +207,11 @@ pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len)
 
 	/* OpenSSL internals return 1 on success, 0 on failure */
 	if (status <= 0)
+	{
+		ctx->errreason = SSLerrmessage(ERR_get_error());
+		ctx->error = PG_HMAC_ERROR_OPENSSL;
 		return -1;
+	}
 	return 0;
 }
 
@@ -198,27 +233,45 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
 	{
 		case PG_MD5:
 			if (len < MD5_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 		case PG_SHA1:
 			if (len < SHA1_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 		case PG_SHA224:
 			if (len < PG_SHA224_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 		case PG_SHA256:
 			if (len < PG_SHA256_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 		case PG_SHA384:
 			if (len < PG_SHA384_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 		case PG_SHA512:
 			if (len < PG_SHA512_DIGEST_LENGTH)
+			{
+				ctx->error = PG_HMAC_ERROR_DEST_LEN;
 				return -1;
+			}
 			break;
 	}
 
@@ -226,7 +279,11 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
 
 	/* OpenSSL internals return 1 on success, 0 on failure */
 	if (status <= 0)
+	{
+		ctx->errreason = SSLerrmessage(ERR_get_error());
+		ctx->error = PG_HMAC_ERROR_OPENSSL;
 		return -1;
+	}
 	return 0;
 }
 
@@ -254,3 +311,36 @@ pg_hmac_free(pg_hmac_ctx *ctx)
 	explicit_bzero(ctx, sizeof(pg_hmac_ctx));
 	FREE(ctx);
 }
+
+/*
+ * pg_hmac_error
+ *
+ * Returns a static string providing errors about an error that happened
+ * during a HMAC computation.
+ */
+const char *
+pg_hmac_error(pg_hmac_ctx *ctx)
+{
+	if (ctx == NULL)
+		return _("out of memory");
+
+	/*
+	 * If a reason is provided, rely on it, else fallback to any error code
+	 * set.
+	 */
+	if (ctx->errreason)
+		return ctx->errreason;
+
+	switch (ctx->error)
+	{
+		case PG_HMAC_ERROR_NONE:
+			return _("success");
+		case PG_HMAC_ERROR_DEST_LEN:
+			return _("destination buffer too small");
+		case PG_HMAC_ERROR_OPENSSL:
+			return _("OpenSSL failure");
+	}
+
+	Assert(false);				/* cannot be reached */
+	return _("success");
+}
diff --git a/src/common/scram-common.c b/src/common/scram-common.c
index 23b68b14da..5f90397c66 100644
--- a/src/common/scram-common.c
+++ b/src/common/scram-common.c
@@ -33,7 +33,7 @@
 int
 scram_SaltedPassword(const char *password,
 					 const char *salt, int saltlen, int iterations,
-					 uint8 *result)
+					 uint8 *result, const char **errstr)
 {
 	int			password_len = strlen(password);
 	uint32		one = pg_hton32(1);
@@ -58,6 +58,7 @@ scram_SaltedPassword(const char *password,
 		pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
 		pg_hmac_final(hmac_ctx, Ui_prev, sizeof(Ui_prev)) < 0)
 	{
+		*errstr = pg_hmac_error(hmac_ctx);
 		pg_hmac_free(hmac_ctx);
 		return -1;
 	}
@@ -71,6 +72,7 @@ scram_SaltedPassword(const char *password,
 			pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, SCRAM_KEY_LEN) < 0 ||
 			pg_hmac_final(hmac_ctx, Ui, sizeof(Ui)) < 0)
 		{
+			*errstr = pg_hmac_error(hmac_ctx);
 			pg_hmac_free(hmac_ctx);
 			return -1;
 		}
@@ -90,18 +92,22 @@ scram_SaltedPassword(const char *password,
  * not included in the hash).  Returns 0 on success, -1 on failure.
  */
 int
-scram_H(const uint8 *input, int len, uint8 *result)
+scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
 {
 	pg_cryptohash_ctx *ctx;
 
 	ctx = pg_cryptohash_create(PG_SHA256);
 	if (ctx == NULL)
+	{
+		*errstr = pg_cryptohash_error(NULL);	/* returns OOM */
 		return -1;
+	}
 
 	if (pg_cryptohash_init(ctx) < 0 ||
 		pg_cryptohash_update(ctx, input, len) < 0 ||
 		pg_cryptohash_final(ctx, result, SCRAM_KEY_LEN) < 0)
 	{
+		*errstr = pg_cryptohash_error(ctx);
 		pg_cryptohash_free(ctx);
 		return -1;
 	}
@@ -114,7 +120,8 @@ scram_H(const uint8 *input, int len, uint8 *result)
  * Calculate ClientKey.  Returns 0 on success, -1 on failure.
  */
 int
-scram_ClientKey(const uint8 *salted_password, uint8 *result)
+scram_ClientKey(const uint8 *salted_password, uint8 *result,
+				const char **errstr)
 {
 	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
 
@@ -125,6 +132,7 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result)
 		pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
 		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
 	{
+		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
 		return -1;
 	}
@@ -137,17 +145,22 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result)
  * Calculate ServerKey.  Returns 0 on success, -1 on failure.
  */
 int
-scram_ServerKey(const uint8 *salted_password, uint8 *result)
+scram_ServerKey(const uint8 *salted_password, uint8 *result,
+				const char **errstr)
 {
 	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
 
 	if (ctx == NULL)
+	{
+		*errstr = pg_hmac_error(NULL);	/* returns OOM */
 		return -1;
+	}
 
 	if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
 		pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
 		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
 	{
+		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
 		return -1;
 	}
@@ -167,7 +180,7 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result)
  */
 char *
 scram_build_secret(const char *salt, int saltlen, int iterations,
-				   const char *password)
+				   const char *password, const char **errstr)
 {
 	uint8		salted_password[SCRAM_KEY_LEN];
 	uint8		stored_key[SCRAM_KEY_LEN];
@@ -185,15 +198,17 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 
 	/* Calculate StoredKey and ServerKey */
 	if (scram_SaltedPassword(password, salt, saltlen, iterations,
-							 salted_password) < 0 ||
-		scram_ClientKey(salted_password, stored_key) < 0 ||
-		scram_H(stored_key, SCRAM_KEY_LEN, stored_key) < 0 ||
-		scram_ServerKey(salted_password, server_key) < 0)
+							 salted_password, errstr) < 0 ||
+		scram_ClientKey(salted_password, stored_key, errstr) < 0 ||
+		scram_H(stored_key, SCRAM_KEY_LEN, stored_key, errstr) < 0 ||
+		scram_ServerKey(salted_password, server_key, errstr) < 0)
 	{
+		/* errstr is filled already here */
 #ifdef FRONTEND
 		return NULL;
 #else
-		elog(ERROR, "could not calculate stored key and server key");
+		elog(ERROR, "could not calculate stored key and server key: %s",
+			 *errstr);
 #endif
 	}
 
@@ -215,7 +230,10 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 #ifdef FRONTEND
 	result = malloc(maxlen);
 	if (!result)
+	{
+		*errstr = _("out of memory");
 		return NULL;
+	}
 #else
 	result = palloc(maxlen);
 #endif
@@ -226,11 +244,12 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
 	if (encoded_result < 0)
 	{
+		*errstr = _("could not encode salt");
 #ifdef FRONTEND
 		free(result);
 		return NULL;
 #else
-		elog(ERROR, "could not encode salt");
+		elog(ERROR, "%s", *errstr);
 #endif
 	}
 	p += encoded_result;
@@ -241,11 +260,12 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 								   encoded_stored_len);
 	if (encoded_result < 0)
 	{
+		*errstr = _("could not encode stored key");
 #ifdef FRONTEND
 		free(result);
 		return NULL;
 #else
-		elog(ERROR, "could not encode stored key");
+		elog(ERROR, "%s", *errstr);
 #endif
 	}
 
@@ -257,11 +277,12 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 								   encoded_server_len);
 	if (encoded_result < 0)
 	{
+		*errstr = _("could not encode server key");
 #ifdef FRONTEND
 		free(result);
 		return NULL;
 #else
-		elog(ERROR, "could not encode server key");
+		elog(ERROR, "%s", *errstr);
 #endif
 	}
 
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index cc41440c4e..96b44c5207 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -80,10 +80,11 @@ static bool read_server_first_message(fe_scram_state *state, char *input);
 static bool read_server_final_message(fe_scram_state *state, char *input);
 static char *build_client_first_message(fe_scram_state *state);
 static char *build_client_final_message(fe_scram_state *state);
-static bool verify_server_signature(fe_scram_state *state, bool *match);
+static bool verify_server_signature(fe_scram_state *state, bool *match,
+									const char **errstr);
 static bool calculate_client_proof(fe_scram_state *state,
 								   const char *client_final_message_without_proof,
-								   uint8 *result);
+								   uint8 *result, const char **errstr);
 
 /*
  * Initialize SCRAM exchange status.
@@ -211,6 +212,7 @@ scram_exchange(void *opaq, char *input, int inputlen,
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 	PGconn	   *conn = state->conn;
+	const char *errstr = NULL;
 
 	*done = false;
 	*success = false;
@@ -273,10 +275,10 @@ scram_exchange(void *opaq, char *input, int inputlen,
 			 * Verify server signature, to make sure we're talking to the
 			 * genuine server.
 			 */
-			if (!verify_server_signature(state, success))
+			if (!verify_server_signature(state, success, &errstr))
 			{
-				appendPQExpBufferStr(&conn->errorMessage,
-									 libpq_gettext("could not verify server signature\n"));
+				appendPQExpBuffer(&conn->errorMessage,
+								  libpq_gettext("could not verify server signature: %s\n"), errstr);
 				goto error;
 			}
 
@@ -469,6 +471,7 @@ build_client_final_message(fe_scram_state *state)
 	uint8		client_proof[SCRAM_KEY_LEN];
 	char	   *result;
 	int			encoded_len;
+	const char *errstr = NULL;
 
 	initPQExpBuffer(&buf);
 
@@ -572,11 +575,12 @@ build_client_final_message(fe_scram_state *state)
 	/* Append proof to it, to form client-final-message. */
 	if (!calculate_client_proof(state,
 								state->client_final_message_without_proof,
-								client_proof))
+								client_proof, &errstr))
 	{
 		termPQExpBuffer(&buf);
-		appendPQExpBufferStr(&conn->errorMessage,
-							 libpq_gettext("could not calculate client proof\n"));
+		appendPQExpBuffer(&conn->errorMessage,
+						  libpq_gettext("could not calculate client proof: %s\n"),
+						  errstr);
 		return NULL;
 	}
 
@@ -787,7 +791,7 @@ read_server_final_message(fe_scram_state *state, char *input)
 static bool
 calculate_client_proof(fe_scram_state *state,
 					   const char *client_final_message_without_proof,
-					   uint8 *result)
+					   uint8 *result, const char **errstr)
 {
 	uint8		StoredKey[SCRAM_KEY_LEN];
 	uint8		ClientKey[SCRAM_KEY_LEN];
@@ -797,17 +801,27 @@ calculate_client_proof(fe_scram_state *state,
 
 	ctx = pg_hmac_create(PG_SHA256);
 	if (ctx == NULL)
+	{
+		*errstr = pg_hmac_error(NULL);	/* returns OOM */
 		return false;
+	}
 
 	/*
 	 * Calculate SaltedPassword, and store it in 'state' so that we can reuse
 	 * it later in verify_server_signature.
 	 */
 	if (scram_SaltedPassword(state->password, state->salt, state->saltlen,
-							 state->iterations, state->SaltedPassword) < 0 ||
-		scram_ClientKey(state->SaltedPassword, ClientKey) < 0 ||
-		scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey) < 0 ||
-		pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 ||
+							 state->iterations, state->SaltedPassword,
+							 errstr) < 0 ||
+		scram_ClientKey(state->SaltedPassword, ClientKey, errstr) < 0 ||
+		scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey, errstr) < 0)
+	{
+		/* errstr is already filled here */
+		pg_hmac_free(ctx);
+		return false;
+	}
+
+	if (pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -821,6 +835,7 @@ calculate_client_proof(fe_scram_state *state,
 					   strlen(client_final_message_without_proof)) < 0 ||
 		pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
 	{
+		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
 		return false;
 	}
@@ -839,7 +854,8 @@ calculate_client_proof(fe_scram_state *state,
  * false for a processing error.
  */
 static bool
-verify_server_signature(fe_scram_state *state, bool *match)
+verify_server_signature(fe_scram_state *state, bool *match,
+						const char **errstr)
 {
 	uint8		expected_ServerSignature[SCRAM_KEY_LEN];
 	uint8		ServerKey[SCRAM_KEY_LEN];
@@ -847,11 +863,20 @@ verify_server_signature(fe_scram_state *state, bool *match)
 
 	ctx = pg_hmac_create(PG_SHA256);
 	if (ctx == NULL)
+	{
+		*errstr = pg_hmac_error(NULL);	/* returns OOM */
 		return false;
+	}
+
+	if (scram_ServerKey(state->SaltedPassword, ServerKey, errstr) < 0)
+	{
+		/* errstr is filled already */
+		pg_hmac_free(ctx);
+		return false;
+	}
 
-	if (scram_ServerKey(state->SaltedPassword, ServerKey) < 0 ||
 	/* calculate ServerSignature */
-		pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -866,6 +891,7 @@ verify_server_signature(fe_scram_state *state, bool *match)
 		pg_hmac_final(ctx, expected_ServerSignature,
 					  sizeof(expected_ServerSignature)) < 0)
 	{
+		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
 		return false;
 	}
@@ -885,7 +911,7 @@ verify_server_signature(fe_scram_state *state, bool *match)
  * Build a new SCRAM secret.
  */
 char *
-pg_fe_scram_build_secret(const char *password)
+pg_fe_scram_build_secret(const char *password, const char **errstr)
 {
 	char	   *prep_password;
 	pg_saslprep_rc rc;
@@ -912,7 +938,8 @@ pg_fe_scram_build_secret(const char *password)
 	}
 
 	result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
-								SCRAM_DEFAULT_ITERATIONS, password);
+								SCRAM_DEFAULT_ITERATIONS, password,
+								errstr);
 
 	if (prep_password)
 		free(prep_password);
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 2e6b2e8f04..83d2a5e24d 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -1289,7 +1289,15 @@ PQencryptPasswordConn(PGconn *conn, const char *passwd, const char *user,
 	 */
 	if (strcmp(algorithm, "scram-sha-256") == 0)
 	{
-		crypt_pwd = pg_fe_scram_build_secret(passwd);
+		const char *errstr = NULL;
+
+		crypt_pwd = pg_fe_scram_build_secret(passwd, &errstr);
+		if (!crypt_pwd)
+		{
+			appendPQExpBuffer(&conn->errorMessage,
+							  libpq_gettext("could not encrypt password: %s\n"),
+							  errstr);
+		}
 	}
 	else if (strcmp(algorithm, "md5") == 0)
 	{
@@ -1307,6 +1315,9 @@ PQencryptPasswordConn(PGconn *conn, const char *passwd, const char *user,
 				crypt_pwd = NULL;
 			}
 		}
+		else
+			appendPQExpBufferStr(&conn->errorMessage,
+								 libpq_gettext("out of memory\n"));
 	}
 	else
 	{
@@ -1316,9 +1327,5 @@ PQencryptPasswordConn(PGconn *conn, const char *passwd, const char *user,
 		return NULL;
 	}
 
-	if (!crypt_pwd)
-		appendPQExpBufferStr(&conn->errorMessage,
-							 libpq_gettext("out of memory\n"));
-
 	return crypt_pwd;
 }
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 16d5e1da0f..4cd0f82de1 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -24,6 +24,7 @@ extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
 
 /* Mechanisms in fe-auth-scram.c */
 extern const pg_fe_sasl_mech pg_scram_mech;
-extern char *pg_fe_scram_build_secret(const char *password);
+extern char *pg_fe_scram_build_secret(const char *password,
+									  const char **errstr);
 
 #endif							/* FE_AUTH_H */
-- 
2.34.1

Attachment: signature.asc
Description: PGP signature

Reply via email to