From 11f190270ec3cf8c51e58bb02de671c7a9d966e2 Mon Sep 17 00:00:00 2001
From: Samay Sharma <smilingsamay@gmail.com>
Date: Tue, 15 Feb 2022 22:23:29 -0800
Subject: [PATCH v3 1/4] Add support for custom authentication methods

Currently, PostgreSQL supports only a set of pre-defined authentication
methods. This patch adds support for 2 hooks which allow users to add
their custom authentication methods by defining a check function and an
error function. Users can then use these methods by using a new "custom"
keyword in pg_hba.conf and specifying the authentication provider they
want to use.
---
 src/backend/libpq/auth.c | 108 ++++++++++++++++++++++++++++++++-------
 src/backend/libpq/hba.c  |  44 ++++++++++++++++
 src/include/libpq/auth.h |  37 ++++++++++++++
 src/include/libpq/hba.h  |   2 +
 4 files changed, 172 insertions(+), 19 deletions(-)

diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index efc53f3135..375ee33892 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -47,8 +47,6 @@
  *----------------------------------------------------------------
  */
 static void auth_failed(Port *port, int status, const char *logdetail);
-static char *recv_password_packet(Port *port);
-static void set_authn_id(Port *port, const char *id);
 
 
 /*----------------------------------------------------------------
@@ -206,22 +204,11 @@ static int	pg_SSPI_make_upn(char *accountname,
 static int	CheckRADIUSAuth(Port *port);
 static int	PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd);
 
-
-/*
- * Maximum accepted size of GSS and SSPI authentication tokens.
- * We also use this as a limit on ordinary password packet lengths.
- *
- * Kerberos tickets are usually quite small, but the TGTs issued by Windows
- * domain controllers include an authorization field known as the Privilege
- * Attribute Certificate (PAC), which contains the user's Windows permissions
- * (group memberships etc.). The PAC is copied into all tickets obtained on
- * the basis of this TGT (even those issued by Unix realms which the Windows
- * realm trusts), and can be several kB in size. The maximum token size
- * accepted by Windows systems is determined by the MaxAuthToken Windows
- * registry setting. Microsoft recommends that it is not set higher than
- * 65535 bytes, so that seems like a reasonable limit for us as well.
+/*----------------------------------------------------------------
+ * Custom Authentication
+ *----------------------------------------------------------------
  */
-#define PG_MAX_AUTH_TOKEN_LENGTH	65535
+static List *custom_auth_providers = NIL;
 
 /*----------------------------------------------------------------
  * Global authentication functions
@@ -311,6 +298,15 @@ auth_failed(Port *port, int status, const char *logdetail)
 		case uaRADIUS:
 			errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
 			break;
+		case uaCustom:
+			{
+				CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider);
+				if (provider->auth_error_hook)
+					errstr = provider->auth_error_hook(port);
+				else
+					errstr = gettext_noop("Custom authentication failed for user \"%s\"");
+				break;
+			}
 		default:
 			errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method");
 			break;
@@ -345,7 +341,7 @@ auth_failed(Port *port, int status, const char *logdetail)
  * lifetime of the Port, so it is safe to pass a string that is managed by an
  * external library.
  */
-static void
+void
 set_authn_id(Port *port, const char *id)
 {
 	Assert(id);
@@ -630,6 +626,13 @@ ClientAuthentication(Port *port)
 		case uaTrust:
 			status = STATUS_OK;
 			break;
+		case uaCustom:
+			{
+				CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider);
+				if (provider->auth_check_hook)
+					status = provider->auth_check_hook(port);
+				break;
+			}
 	}
 
 	if ((status == STATUS_OK && port->hba->clientcert == clientCertFull)
@@ -689,7 +692,7 @@ sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, int extrale
  *
  * Returns NULL if couldn't get password, else palloc'd string.
  */
-static char *
+char *
 recv_password_packet(Port *port)
 {
 	StringInfoData buf;
@@ -3343,3 +3346,70 @@ PerformRadiusTransaction(const char *server, const char *secret, const char *por
 		}
 	}							/* while (true) */
 }
+
+/*----------------------------------------------------------------
+ * Custom authentication
+ *----------------------------------------------------------------
+ */
+
+/*
+ * RegisterAuthProvider registers a custom authentication provider to be
+ * used for authentication. It validates the inputs and adds the provider
+ * name and it's hooks to a list of loaded providers. The right provider's
+ * hooks can then be called based on the provider name specified in
+ * pg_hba.conf.
+ *
+ * This function should be called in _PG_init() by any extension looking to
+ * add a custom authentication method.
+ */
+void RegisterAuthProvider(const char *provider_name,
+		CustomAuthenticationCheck_hook_type AuthenticationCheckFunction,
+		CustomAuthenticationError_hook_type AuthenticationErrorFunction)
+{
+	CustomAuthProvider *provider = NULL;
+	MemoryContext old_context;
+
+	if (provider_name == NULL)
+	{
+		ereport(ERROR,
+				(errmsg("cannot register authentication provider without name")));
+	}
+
+	if (AuthenticationCheckFunction == NULL)
+	{
+		ereport(ERROR,
+				(errmsg("cannot register authentication provider without a check function")));
+	}
+
+	/*
+	 * Allocate in top memory context as we need to read this whenever
+	 * we parse pg_hba.conf
+	 */
+	old_context = MemoryContextSwitchTo(TopMemoryContext);
+	provider = palloc(sizeof(CustomAuthProvider));
+	provider->name = MemoryContextStrdup(TopMemoryContext,provider_name);
+	provider->auth_check_hook = AuthenticationCheckFunction;
+	provider->auth_error_hook = AuthenticationErrorFunction;
+	custom_auth_providers = lappend(custom_auth_providers, provider);
+	MemoryContextSwitchTo(old_context);
+}
+
+/*
+ * Returns the authentication provider (which includes it's
+ * callback functions) based on name specified.
+ */
+CustomAuthProvider *get_provider_by_name(const char *name)
+{
+	ListCell *lc;
+
+	foreach(lc, custom_auth_providers)
+	{
+		CustomAuthProvider *provider = (CustomAuthProvider *) lfirst(lc);
+		if (strcmp(provider->name,name) == 0)
+		{
+			return provider;
+		}
+	}
+
+	return NULL;
+}
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index 90953c38f3..9f15252789 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -31,6 +31,7 @@
 #include "common/ip.h"
 #include "common/string.h"
 #include "funcapi.h"
+#include "libpq/auth.h"
 #include "libpq/ifaddr.h"
 #include "libpq/libpq.h"
 #include "miscadmin.h"
@@ -134,6 +135,7 @@ static const char *const UserAuthName[] =
 	"ldap",
 	"cert",
 	"radius",
+	"custom",
 	"peer"
 };
 
@@ -1399,6 +1401,8 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
 #endif
 	else if (strcmp(token->string, "radius") == 0)
 		parsedline->auth_method = uaRADIUS;
+	else if (strcmp(token->string, "custom") == 0)
+		parsedline->auth_method = uaCustom;
 	else
 	{
 		ereport(elevel,
@@ -1691,6 +1695,14 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
 		parsedline->clientcert = clientCertFull;
 	}
 
+	/*
+	 * Ensure that the provider name is specified for custom authentication method.
+	 */
+	if (parsedline->auth_method == uaCustom)
+	{
+		MANDATORY_AUTH_ARG(parsedline->custom_provider, "provider", "custom");
+	}
+
 	return parsedline;
 }
 
@@ -2102,6 +2114,31 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
 		hbaline->radiusidentifiers = parsed_identifiers;
 		hbaline->radiusidentifiers_s = pstrdup(val);
 	}
+	else if (strcmp(name, "provider") == 0)
+	{
+		REQUIRE_AUTH_OPTION(uaCustom, "provider", "custom");
+
+		/*
+		 * Verify that the provider mentioned is loaded via shared_preload_libraries.
+		 */
+
+		if (get_provider_by_name(val) == NULL)
+		{
+			ereport(elevel,
+					(errcode(ERRCODE_CONFIG_FILE_ERROR),
+					 errmsg("cannot use authentication provider %s",val),
+					 errhint("Load authentication provider via shared_preload_libraries."),
+					 errcontext("line %d of configuration file \"%s\"",
+							line_num, HbaFileName)));
+			*err_msg = psprintf("cannot use authentication provider %s", val);
+
+			return false;
+		}
+		else
+		{
+			hbaline->custom_provider = pstrdup(val);
+		}
+	}
 	else
 	{
 		ereport(elevel,
@@ -2442,6 +2479,13 @@ gethba_options(HbaLine *hba)
 				CStringGetTextDatum(psprintf("radiusports=%s", hba->radiusports_s));
 	}
 
+	if (hba->auth_method == uaCustom)
+	{
+		if (hba->custom_provider)
+			options[noptions++] =
+				CStringGetTextDatum(psprintf("provider=%s",hba->custom_provider));
+	}
+
 	/* If you add more options, consider increasing MAX_HBA_OPTIONS. */
 	Assert(noptions <= MAX_HBA_OPTIONS);
 
diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 6d7ee1acb9..7aff98d919 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -23,9 +23,46 @@ extern char *pg_krb_realm;
 extern void ClientAuthentication(Port *port);
 extern void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
 							int extralen);
+extern void set_authn_id(Port *port, const char *id);
+extern char *recv_password_packet(Port *port);
 
 /* Hook for plugins to get control in ClientAuthentication() */
+typedef int (*CustomAuthenticationCheck_hook_type) (Port *);
 typedef void (*ClientAuthentication_hook_type) (Port *, int);
 extern PGDLLIMPORT ClientAuthentication_hook_type ClientAuthentication_hook;
 
+/* Hook for plugins to report error messages in auth_failed() */
+typedef const char * (*CustomAuthenticationError_hook_type) (Port *);
+
+extern void RegisterAuthProvider
+		(const char *provider_name,
+		 CustomAuthenticationCheck_hook_type CustomAuthenticationCheck_hook,
+		 CustomAuthenticationError_hook_type CustomAuthenticationError_hook);
+
+/* Declarations for custom authentication providers */
+typedef struct CustomAuthProvider
+{
+	const char *name;
+	CustomAuthenticationCheck_hook_type auth_check_hook;
+	CustomAuthenticationError_hook_type auth_error_hook;
+} CustomAuthProvider;
+
+extern CustomAuthProvider *get_provider_by_name(const char *name);
+
+/*
+ * Maximum accepted size of GSS and SSPI authentication tokens.
+ * We also use this as a limit on ordinary password packet lengths.
+ *
+ * Kerberos tickets are usually quite small, but the TGTs issued by Windows
+ * domain controllers include an authorization field known as the Privilege
+ * Attribute Certificate (PAC), which contains the user's Windows permissions
+ * (group memberships etc.). The PAC is copied into all tickets obtained on
+ * the basis of this TGT (even those issued by Unix realms which the Windows
+ * realm trusts), and can be several kB in size. The maximum token size
+ * accepted by Windows systems is determined by the MaxAuthToken Windows
+ * registry setting. Microsoft recommends that it is not set higher than
+ * 65535 bytes, so that seems like a reasonable limit for us as well.
+ */
+#define PG_MAX_AUTH_TOKEN_LENGTH	65535
+
 #endif							/* AUTH_H */
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index 8d9f3821b1..48490c44ed 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -38,6 +38,7 @@ typedef enum UserAuth
 	uaLDAP,
 	uaCert,
 	uaRADIUS,
+	uaCustom,
 	uaPeer
 #define USER_AUTH_LAST uaPeer	/* Must be last value of this enum */
 } UserAuth;
@@ -120,6 +121,7 @@ typedef struct HbaLine
 	char	   *radiusidentifiers_s;
 	List	   *radiusports;
 	char	   *radiusports_s;
+	char	   *custom_provider;
 } HbaLine;
 
 typedef struct IdentLine
-- 
2.34.1

