From 56245d7acd25d1c75a3ae2196271bc7932c4dadd Mon Sep 17 00:00:00 2001
From: Ajit Awekar <ajit.awekar@enterprisedb.com>
Date: Tue, 16 Jun 2026 14:13:49 +0530
Subject: [PATCH 3/3] Add OAuth token expiry to credential validation

Register a method-specific validator for OAuth-authenticated sessions
(CVT_OAUTH) that asks the loaded OAuth validator module whether the bearer
token has expired.

This requires a new optional callback, expire_cb, in OAuthValidatorCallbacks.
To add it without breaking existing modules, the validator ABI magic is
versioned: PG_OAUTH_VALIDATOR_MAGIC_V1 (the original layout, without
expire_cb) and PG_OAUTH_VALIDATOR_MAGIC_V2 (which adds expire_cb).  The server
accepts both; expire_cb is only consulted for V2 modules that provide it, and
a token is assumed still valid otherwise.
---
 doc/src/sgml/oauth-validators.sgml            | 32 +++++++++++++++
 src/backend/libpq/auth-oauth.c                | 41 ++++++++++++++++---
 src/backend/libpq/auth-validate-methods.c     | 21 ++++++++++
 src/include/libpq/oauth.h                     | 10 ++++-
 .../modules/oauth_validator/fail_validator.c  |  1 +
 .../modules/oauth_validator/magic_validator.c |  1 +
 src/test/modules/oauth_validator/validator.c  |  3 +-
 7 files changed, 102 insertions(+), 7 deletions(-)

diff --git a/doc/src/sgml/oauth-validators.sgml b/doc/src/sgml/oauth-validators.sgml
index d69b6cf98ad..723ba8d8b1b 100644
--- a/doc/src/sgml/oauth-validators.sgml
+++ b/doc/src/sgml/oauth-validators.sgml
@@ -326,6 +326,7 @@ typedef struct OAuthValidatorCallbacks
     ValidatorStartupCB startup_cb;
     ValidatorShutdownCB shutdown_cb;
     ValidatorValidateCB validate_cb;
+    ValidatorExpireCB expire_cb;    /* Optional: check token expiration */
 } OAuthValidatorCallbacks;
 
 typedef const OAuthValidatorCallbacks *(*OAuthValidatorModuleInit) (void);
@@ -334,6 +335,15 @@ typedef const OAuthValidatorCallbacks *(*OAuthValidatorModuleInit) (void);
    Only the <function>validate_cb</function> callback is required, the others
    are optional.
   </para>
+  <para>
+   The <literal>magic</literal> field identifies the ABI version of the module.
+   The server supports both <literal>PG_OAUTH_VALIDATOR_MAGIC_V1</literal> (the
+   original version without <function>expire_cb</function>) and
+   <literal>PG_OAUTH_VALIDATOR_MAGIC_V2</literal> (which adds
+   <function>expire_cb</function>).  New modules should use
+   <literal>PG_OAUTH_VALIDATOR_MAGIC</literal>, which always refers to the
+   latest version.
+  </para>
  </sect1>
 
  <sect1 id="oauth-validator-callbacks">
@@ -443,6 +453,28 @@ typedef void (*ValidatorShutdownCB) (ValidatorModuleState *state);
    </para>
   </sect2>
 
+  <sect2 id="oauth-validator-callback-expire">
+   <title>Expire Callback</title>
+   <para>
+    The <function>expire_cb</function> callback is an optional callback that
+    can be used to check whether the OAuth token has expired. This is called
+    during credential validation to verify that the token is still valid.
+<programlisting>
+typedef bool (*ValidatorExpireCB) (const ValidatorModuleState *state);
+</programlisting>
+    The callback should return <literal>true</literal> if the token is still
+    valid, or <literal>false</literal> if the token has expired. If this
+    callback is not provided (set to NULL), the server assumes the token
+    remains valid.
+   </para>
+   <para>
+    This callback was added in <literal>PG_OAUTH_VALIDATOR_MAGIC_V2</literal>.
+    Modules compiled against the older <literal>PG_OAUTH_VALIDATOR_MAGIC_V1</literal>
+    do not have this field, and the server will not attempt to call it for
+    such modules.
+   </para>
+  </sect2>
+
  </sect1>
 
  <sect1 id="oauth-validator-hba">
diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c
index b769931ca4f..0b77b515df8 100644
--- a/src/backend/libpq/auth-oauth.c
+++ b/src/backend/libpq/auth-oauth.c
@@ -45,6 +45,7 @@ static bool check_validator_hba_options(Port *port, const char **logdetail);
 
 static ValidatorModuleState *validator_module_state;
 static const OAuthValidatorCallbacks *ValidatorCallbacks;
+static int ValidatorABIVersion;		/* tracks V1 vs V2 module ABI */
 
 static MemoryContext ValidatorMemoryContext;
 static List *ValidatorOptions;
@@ -801,13 +802,22 @@ load_validator_library(const char *libname)
 	 * Check the magic number, to protect against break-glass scenarios where
 	 * the ABI must change within a major version. load_external_function()
 	 * already checks for compatibility across major versions.
+	 *
+	 * We accept both V1 and V2 magic numbers for backward compatibility.
+	 * V1 modules don't have the expire_cb field, so we track the version
+	 * to avoid accessing non-existent struct members.
 	 */
-	if (ValidatorCallbacks->magic != PG_OAUTH_VALIDATOR_MAGIC)
+	if (ValidatorCallbacks->magic == PG_OAUTH_VALIDATOR_MAGIC_V2)
+		ValidatorABIVersion = 2;
+	else if (ValidatorCallbacks->magic == PG_OAUTH_VALIDATOR_MAGIC_V1)
+		ValidatorABIVersion = 1;
+	else
 		ereport(ERROR,
-				errmsg("OAuth validator module \"%s\": magic number mismatch",
-					   libname),
-				errdetail("Server has magic number 0x%08X, module has 0x%08X.",
-						  PG_OAUTH_VALIDATOR_MAGIC, ValidatorCallbacks->magic));
+				errmsg("%s module \"%s\": magic number mismatch",
+					   "OAuth validator", libname),
+				errdetail("Server expects magic number 0x%08X or 0x%08X, module has 0x%08X.",
+						  PG_OAUTH_VALIDATOR_MAGIC_V2, PG_OAUTH_VALIDATOR_MAGIC_V1,
+						  ValidatorCallbacks->magic));
 
 	/*
 	 * Make sure all required callbacks are present in the ValidatorCallbacks
@@ -1134,3 +1144,24 @@ GetOAuthHBAOption(const ValidatorModuleState *state, const char *optname)
 
 	return ret;
 }
+
+/*
+ * Check if an OAuth token has expired.
+ * This is called from credential validation to check token validity.
+ */
+bool
+CheckOAuthValidatorExpiration(void)
+{
+	/*
+	 * Delegate to validator's expire_cb if available.  Only V2+ modules have
+	 * the expire_cb field, so we must check the ABI version before accessing
+	 * it to maintain backward compatibility with V1 modules.
+	 */
+	if (ValidatorCallbacks != NULL &&
+		ValidatorABIVersion >= 2 &&
+		ValidatorCallbacks->expire_cb != NULL)
+		return ValidatorCallbacks->expire_cb(validator_module_state);
+
+	/* V1 module or no expire_cb, assume token is valid */
+	return true;
+}
diff --git a/src/backend/libpq/auth-validate-methods.c b/src/backend/libpq/auth-validate-methods.c
index 180d37263ad..0b74057fffd 100644
--- a/src/backend/libpq/auth-validate-methods.c
+++ b/src/backend/libpq/auth-validate-methods.c
@@ -23,11 +23,13 @@
 #include "libpq/auth-validate-methods.h"
 #include "libpq/auth-validate.h"
 #include "libpq/libpq-be.h"
+#include "libpq/oauth.h"
 #include "miscadmin.h"
 #include "utils/syscache.h"
 #include "utils/timestamp.h"
 
 /* Function declarations for internal use */
+static bool validate_oauth_credentials(void);
 static bool validate_cert_credentials(void);
 
 /*
@@ -42,6 +44,7 @@ InitializeValidationMethods(void)
 	 * session by ValidateRoleValidity(), so password-based methods need no
 	 * separate validator of their own.
 	 */
+	RegisterCredentialValidator(CVT_OAUTH, validate_oauth_credentials);
 	RegisterCredentialValidator(CVT_CERT, validate_cert_credentials);
 }
 
@@ -85,6 +88,24 @@ ValidateRoleValidity(void)
 	return result;
 }
 
+/*
+ * Check if an OAuth token has expired.
+ *
+ * Returns true if the token is still valid, false if it has expired.
+ *
+ * Calls wrapper CheckOAuthValidatorExpiration() function
+ * to verify that the token hasn't expired.
+ */
+static bool
+validate_oauth_credentials(void)
+{
+	/* Call the validator's expire_cb to check token expiration */
+	if (!CheckOAuthValidatorExpiration())
+		return false;
+
+	return true;
+}
+
 /*
  * Validate TLS client certificate credentials.
  *
diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h
index 86f463a284e..bc483148ed9 100644
--- a/src/include/libpq/oauth.h
+++ b/src/include/libpq/oauth.h
@@ -78,6 +78,7 @@ typedef void (*ValidatorShutdownCB) (ValidatorModuleState *state);
 typedef bool (*ValidatorValidateCB) (const ValidatorModuleState *state,
 									 const char *token, const char *role,
 									 ValidatorModuleResult *result);
+typedef bool (*ValidatorExpireCB) (const ValidatorModuleState *state);
 
 /*
  * Identifies the compiled ABI version of the validator module. Since the server
@@ -85,7 +86,9 @@ typedef bool (*ValidatorValidateCB) (const ValidatorModuleState *state,
  * versions, this is reserved for emergency use within a stable release line.
  * May it never need to change.
  */
-#define PG_OAUTH_VALIDATOR_MAGIC 0x20250220
+#define PG_OAUTH_VALIDATOR_MAGIC_V1 0x20250220
+#define PG_OAUTH_VALIDATOR_MAGIC_V2 0x20260326
+#define PG_OAUTH_VALIDATOR_MAGIC PG_OAUTH_VALIDATOR_MAGIC_V2
 
 typedef struct OAuthValidatorCallbacks
 {
@@ -94,6 +97,7 @@ typedef struct OAuthValidatorCallbacks
 	ValidatorStartupCB startup_cb;
 	ValidatorShutdownCB shutdown_cb;
 	ValidatorValidateCB validate_cb;
+	ValidatorExpireCB expire_cb;  /* Optional: Check token expiration */
 } OAuthValidatorCallbacks;
 
 /*
@@ -121,4 +125,8 @@ extern PGDLLIMPORT const pg_be_sasl_mech pg_be_oauth_mech;
 extern bool check_oauth_validator(HbaLine *hbaline, int elevel, char **err_msg);
 extern bool valid_oauth_hba_option_name(const char *name);
 
+/*
+ * Check OAuth token expiration using validator's expire_cb if available.
+ */
+extern bool CheckOAuthValidatorExpiration(void);
 #endif							/* PG_OAUTH_H */
diff --git a/src/test/modules/oauth_validator/fail_validator.c b/src/test/modules/oauth_validator/fail_validator.c
index 3de0470a541..8754e1e8f85 100644
--- a/src/test/modules/oauth_validator/fail_validator.c
+++ b/src/test/modules/oauth_validator/fail_validator.c
@@ -29,6 +29,7 @@ static const OAuthValidatorCallbacks validator_callbacks = {
 	PG_OAUTH_VALIDATOR_MAGIC,
 
 	.validate_cb = fail_token,
+	.expire_cb = NULL,
 };
 
 const OAuthValidatorCallbacks *
diff --git a/src/test/modules/oauth_validator/magic_validator.c b/src/test/modules/oauth_validator/magic_validator.c
index 550da41d11b..6e4d72fde30 100644
--- a/src/test/modules/oauth_validator/magic_validator.c
+++ b/src/test/modules/oauth_validator/magic_validator.c
@@ -30,6 +30,7 @@ static const OAuthValidatorCallbacks validator_callbacks = {
 	0xdeadbeef,
 
 	.validate_cb = validate_token,
+	.expire_cb = NULL,
 };
 
 const OAuthValidatorCallbacks *
diff --git a/src/test/modules/oauth_validator/validator.c b/src/test/modules/oauth_validator/validator.c
index 85fb4c08bf2..7a1625cef8a 100644
--- a/src/test/modules/oauth_validator/validator.c
+++ b/src/test/modules/oauth_validator/validator.c
@@ -34,7 +34,8 @@ static const OAuthValidatorCallbacks validator_callbacks = {
 
 	.startup_cb = validator_startup,
 	.shutdown_cb = validator_shutdown,
-	.validate_cb = validate_token
+	.validate_cb = validate_token,
+	.expire_cb = NULL,			/* Optional: not implemented */
 };
 
 /* GUCs */
-- 
2.52.0

