From 413a66bf030365ac192f566f5e9ef81c3cf80cf2 Mon Sep 17 00:00:00 2001
From: Samay Sharma <smilingsamay@gmail.com>
Date: Mon, 14 Mar 2022 14:54:08 -0700
Subject: [PATCH v3 4/4] Add support for "map" and custom auth options

This commit allows extensions to now specify, validate and use
custom options for their custom auth methods. This is done by
exposing a validation function hook which can be defined by
extensions. The valid options are then stored as key / value
pairs which can be used while checking authentication. We also
allow custom auth providers to use the "map" option to use
usermaps.

The test module was updated to use custom options and new tests
were added.
---
 src/backend/libpq/auth.c                      |  4 +-
 src/backend/libpq/hba.c                       | 76 +++++++++++++++----
 src/include/libpq/auth.h                      | 17 +++--
 src/include/libpq/hba.h                       |  8 ++
 .../test_auth_provider/t/001_custom_auth.pl   | 22 ++++++
 .../test_auth_provider/test_auth_provider.c   | 50 +++++++++++-
 6 files changed, 157 insertions(+), 20 deletions(-)

diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 375ee33892..4a8a63922a 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -3364,7 +3364,8 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
  */
 void RegisterAuthProvider(const char *provider_name,
 		CustomAuthenticationCheck_hook_type AuthenticationCheckFunction,
-		CustomAuthenticationError_hook_type AuthenticationErrorFunction)
+		CustomAuthenticationError_hook_type AuthenticationErrorFunction,
+		CustomAuthenticationValidateOptions_hook_type AuthenticationOptionsFunction)
 {
 	CustomAuthProvider *provider = NULL;
 	MemoryContext old_context;
@@ -3390,6 +3391,7 @@ void RegisterAuthProvider(const char *provider_name,
 	provider->name = MemoryContextStrdup(TopMemoryContext,provider_name);
 	provider->auth_check_hook = AuthenticationCheckFunction;
 	provider->auth_error_hook = AuthenticationErrorFunction;
+	provider->auth_options_hook = AuthenticationOptionsFunction;
 	custom_auth_providers = lappend(custom_auth_providers, provider);
 	MemoryContextSwitchTo(old_context);
 }
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index 9f15252789..42cb1ce51d 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -1729,8 +1729,9 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 			hbaline->auth_method != uaPeer &&
 			hbaline->auth_method != uaGSS &&
 			hbaline->auth_method != uaSSPI &&
-			hbaline->auth_method != uaCert)
-			INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, and cert"));
+			hbaline->auth_method != uaCert &&
+			hbaline->auth_method != uaCustom)
+			INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert and custom"));
 		hbaline->usermap = pstrdup(val);
 	}
 	else if (strcmp(name, "clientcert") == 0)
@@ -2121,7 +2122,6 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 		/*
 		 * Verify that the provider mentioned is loaded via shared_preload_libraries.
 		 */
-
 		if (get_provider_by_name(val) == NULL)
 		{
 			ereport(elevel,
@@ -2129,7 +2129,7 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 					 errmsg("cannot use authentication provider %s",val),
 					 errhint("Load authentication provider via shared_preload_libraries."),
 					 errcontext("line %d of configuration file \"%s\"",
-							line_num, HbaFileName)));
+								line_num, HbaFileName)));
 			*err_msg = psprintf("cannot use authentication provider %s", val);
 
 			return false;
@@ -2141,15 +2141,55 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 	}
 	else
 	{
-		ereport(elevel,
-				(errcode(ERRCODE_CONFIG_FILE_ERROR),
-				 errmsg("unrecognized authentication option name: \"%s\"",
-						name),
-				 errcontext("line %d of configuration file \"%s\"",
-							line_num, HbaFileName)));
-		*err_msg = psprintf("unrecognized authentication option name: \"%s\"",
-							name);
-		return false;
+		/*
+		 * Allow custom providers to validate their options if they have an
+		 * option validation function defined.
+		 */
+		if (hbaline->auth_method == uaCustom && (hbaline->custom_provider != NULL))
+		{
+			bool valid_option = false;
+			CustomAuthProvider *provider = get_provider_by_name(hbaline->custom_provider);
+			if (provider->auth_options_hook)
+			{
+				valid_option = provider->auth_options_hook(name, val, hbaline, err_msg);
+				if (valid_option)
+				{
+					CustomOption *option = palloc(sizeof(CustomOption));
+					option->name = pstrdup(name);
+					option->value = pstrdup(val);
+					hbaline->custom_auth_options = lappend(hbaline->custom_auth_options,
+														   option);
+				}
+			}
+			else
+			{
+				*err_msg = psprintf("unrecognized authentication option name: \"%s\"",
+									name);
+			}
+
+			/* Report the error returned by the provider as it is */
+			if (!valid_option)
+			{
+				ereport(elevel,
+						(errcode(ERRCODE_CONFIG_FILE_ERROR),
+						 errmsg("%s", *err_msg),
+						 errcontext("line %d of configuration file \"%s\"",
+									line_num, HbaFileName)));
+				return false;
+			}
+		}
+		else
+		{
+			ereport(elevel,
+					(errcode(ERRCODE_CONFIG_FILE_ERROR),
+					 errmsg("unrecognized authentication option name: \"%s\"",
+							name),
+					 errcontext("line %d of configuration file \"%s\"",
+								line_num, HbaFileName)));
+			*err_msg = psprintf("unrecognized authentication option name: \"%s\"",
+								name);
+			return false;
+		}
 	}
 	return true;
 }
@@ -2484,6 +2524,16 @@ gethba_options(HbaLine *hba)
 		if (hba->custom_provider)
 			options[noptions++] =
 				CStringGetTextDatum(psprintf("provider=%s",hba->custom_provider));
+		if (hba->custom_auth_options)
+		{
+			ListCell *lc;
+			foreach(lc, hba->custom_auth_options)
+			{
+				CustomOption *option = (CustomOption *)lfirst(lc);
+				options[noptions++] =
+					CStringGetTextDatum(psprintf("%s=%s",option->name, option->value));
+			}
+		}
 	}
 
 	/* If you add more options, consider increasing MAX_HBA_OPTIONS. */
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 7aff98d919..cbdc63b4df 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -31,22 +31,29 @@ typedef int (*CustomAuthenticationCheck_hook_type) (Port *);
 typedef void (*ClientAuthentication_hook_type) (Port *, int);
 extern PGDLLIMPORT ClientAuthentication_hook_type ClientAuthentication_hook;
 
+/* Declarations for custom authentication providers */
+
 /* Hook for plugins to report error messages in auth_failed() */
 typedef const char * (*CustomAuthenticationError_hook_type) (Port *);
 
-extern void RegisterAuthProvider
-		(const char *provider_name,
-		 CustomAuthenticationCheck_hook_type CustomAuthenticationCheck_hook,
-		 CustomAuthenticationError_hook_type CustomAuthenticationError_hook);
+/* Hook for plugins to validate custom authentication options */
+typedef bool (*CustomAuthenticationValidateOptions_hook_type)
+			 (char *, char *, HbaLine *, char **);
 
-/* Declarations for custom authentication providers */
 typedef struct CustomAuthProvider
 {
 	const char *name;
 	CustomAuthenticationCheck_hook_type auth_check_hook;
 	CustomAuthenticationError_hook_type auth_error_hook;
+	CustomAuthenticationValidateOptions_hook_type auth_options_hook;
 } CustomAuthProvider;
 
+extern void RegisterAuthProvider
+		(const char *provider_name,
+		 CustomAuthenticationCheck_hook_type CustomAuthenticationCheck_hook,
+		 CustomAuthenticationError_hook_type CustomAuthenticationError_hook,
+		 CustomAuthenticationValidateOptions_hook_type CustomAuthenticationOptions_hook);
+
 extern CustomAuthProvider *get_provider_by_name(const char *name);
 
 /*
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index 48490c44ed..31a00c4b71 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -78,6 +78,13 @@ typedef enum ClientCertName
 	clientCertDN
 } ClientCertName;
 
+/* Struct for custom options defined by custom auth plugins */
+typedef struct CustomOption
+{
+	char	*name;
+	char	*value;
+}CustomOption;
+
 typedef struct HbaLine
 {
 	int			linenumber;
@@ -122,6 +129,7 @@ typedef struct HbaLine
 	List	   *radiusports;
 	char	   *radiusports_s;
 	char	   *custom_provider;
+	List	   *custom_auth_options;
 } HbaLine;
 
 typedef struct IdentLine
diff --git a/src/test/modules/test_auth_provider/t/001_custom_auth.pl b/src/test/modules/test_auth_provider/t/001_custom_auth.pl
index 3b7472dc7f..e964c2f723 100644
--- a/src/test/modules/test_auth_provider/t/001_custom_auth.pl
+++ b/src/test/modules/test_auth_provider/t/001_custom_auth.pl
@@ -109,6 +109,28 @@ test_hba_reload($node, 'custom', 1);
 # Test that correct provider name allows reload to succeed.
 test_hba_reload($node, 'custom provider=test', 0);
 
+# Tests for custom auth options
+
+# Test that a custom option doesn't work without a provider.
+test_hba_reload($node, 'custom allow=bob', 1);
+
+# Test that options other than allowed ones are not accepted.
+test_hba_reload($node, 'custom provider=test wrong=true', 1);
+
+# Test that only valid values are accepted for allowed options.
+test_hba_reload($node, 'custom provider=test allow=wrong', 1);
+
+# Test that setting allow option for a user doesn't look at the password.
+test_hba_reload($node, 'custom provider=test allow=bob', 0);
+$ENV{"PGPASSWORD"} = 'bad123';
+test_role($node, 'bob', 'custom', 0, log_like => [qr/connection authorized: user=bob/]);
+
+# Password is still checked for other users.
+test_role($node, 'alice', 'custom', 2, log_unlike => [qr/connection authorized:/]);
+
+# Reset the password for future tests.
+$ENV{"PGPASSWORD"} = 'bob123';
+
 # Custom auth modules require mentioning extension in shared_preload_libraries.
 
 # Remove extension from shared_preload_libraries and try to restart.
diff --git a/src/test/modules/test_auth_provider/test_auth_provider.c b/src/test/modules/test_auth_provider/test_auth_provider.c
index 7c4b1f3500..5ac425f5b6 100644
--- a/src/test/modules/test_auth_provider/test_auth_provider.c
+++ b/src/test/modules/test_auth_provider/test_auth_provider.c
@@ -39,7 +39,27 @@ static int TestAuthenticationCheck(Port *port)
 	int result = STATUS_ERROR;
 	char *real_pass;
 	const char *logdetail = NULL;
+	ListCell *lc;
 
+	/*
+	 * If user's name is in the the "allow" list, do not request password
+	 * for them and allow them to authenticate.
+	 */
+	foreach(lc,port->hba->custom_auth_options)
+	{
+		CustomOption *option = (CustomOption *) lfirst(lc);
+		if (strcmp(option->name, "allow") == 0 &&
+			strcmp(option->value, port->user_name) == 0)
+		{
+			set_authn_id(port, port->user_name);
+			return STATUS_OK;
+		}
+	}
+
+	/*
+	 * Encrypt the password and validate that it's the same as the one
+	 * returned by the client.
+	 */
 	real_pass = get_encrypted_password_for_user(port->user_name);
 	if (real_pass)
 	{
@@ -79,8 +99,36 @@ static const char *TestAuthenticationError(Port *port)
 	return error_message;
 }
 
+/*
+ * Returns if the options passed are supported by the extension
+ * and are valid. Currently only "allow" is supported.
+ */
+static bool TestAuthenticationOptions(char *name, char *val, HbaLine *hbaline, char **err_msg)
+{
+	/* Validate that an actual user is in the "allow" list. */
+	if (strcmp(name,"allow") == 0)
+	{
+		for (int i=0;i<3;i++)
+		{
+			if (strcmp(val,credentials[0][i]) == 0)
+			{
+				return true;
+			}
+		}
+
+		*err_msg = psprintf("\"%s\" is not valid value for option \"%s\"", val, name);
+		return false;
+	}
+	else
+	{
+		*err_msg = psprintf("option \"%s\" not recognized by \"%s\" provider", val, hbaline->custom_provider);
+		return false;
+	}
+}
+
 void
 _PG_init(void)
 {
-	RegisterAuthProvider("test", TestAuthenticationCheck, TestAuthenticationError);
+	RegisterAuthProvider("test", TestAuthenticationCheck,
+						 TestAuthenticationError,TestAuthenticationOptions);
 }
-- 
2.34.1

