From 8875db29dc6e1504efbe14d7cdf010ec441b3c53 Mon Sep 17 00:00:00 2001
From: Russell Foster <russell.foster.coding@gmail.com>
Date: Tue, 13 Oct 2020 09:02:47 -0400
Subject: [PATCH] * Add support for Windows groups in SSPI authentication

---
 src/backend/libpq/auth.c |  85 ++++++++++++++++------------------
 src/backend/libpq/hba.c  | 118 +++++++++++++++++++++++++++++++++++++++++++++--
 src/include/libpq/hba.h  |   2 +-
 3 files changed, 156 insertions(+), 49 deletions(-)

diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 36565df4fc..562c2ec278 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -1281,7 +1281,7 @@ pg_GSS_checkauth(Port *port)
 	}
 
 	ret = check_usermap(port->hba->usermap, port->user_name, gbuf.value,
-						pg_krb_caseins_users);
+						pg_krb_caseins_users, NULL);
 
 	gss_release_buffer(&lmin_s, &gbuf);
 
@@ -1321,6 +1321,7 @@ pg_SSPI_error(int severity, const char *errmsg, SECURITY_STATUS r)
 static int
 pg_SSPI_recvauth(Port *port)
 {
+	int			retval = STATUS_ERROR;
 	int			mtype;
 	StringInfoData buf;
 	SECURITY_STATUS r;
@@ -1398,7 +1399,7 @@ pg_SSPI_recvauth(Port *port)
 						(errcode(ERRCODE_PROTOCOL_VIOLATION),
 						 errmsg("expected SSPI response, got message type %d",
 								mtype)));
-			return STATUS_ERROR;
+			return retval;
 		}
 
 		/* Get the actual SSPI token */
@@ -1413,7 +1414,7 @@ pg_SSPI_recvauth(Port *port)
 				free(sspictx);
 			}
 			FreeCredentialsHandle(&sspicred);
-			return STATUS_ERROR;
+			return retval;
 		}
 
 		/* Map to SSPI style buffer */
@@ -1563,8 +1564,6 @@ pg_SSPI_recvauth(Port *port)
 				(errmsg_internal("could not get token information: error code %lu",
 								 GetLastError())));
 
-	CloseHandle(token);
-
 	if (!LookupAccountSid(NULL, tokenuser->User.Sid, accountname, &accountnamesize,
 						  domainname, &domainnamesize, &accountnameuse))
 		ereport(ERROR,
@@ -1573,51 +1572,47 @@ pg_SSPI_recvauth(Port *port)
 
 	free(tokenuser);
 
-	if (!port->hba->compat_realm)
-	{
-		int			status = pg_SSPI_make_upn(accountname, sizeof(accountname),
-											  domainname, sizeof(domainname),
-											  port->hba->upn_username);
-
-		if (status != STATUS_OK)
-			/* Error already reported from pg_SSPI_make_upn */
-			return status;
-	}
-
-	/*
-	 * Compare realm/domain if requested. In SSPI, always compare case
-	 * insensitive.
-	 */
-	if (port->hba->krb_realm && strlen(port->hba->krb_realm))
+	if (port->hba->compat_realm ||
+		(pg_SSPI_make_upn(accountname, sizeof(accountname), domainname,
+						  sizeof(domainname), port->hba->upn_username) == STATUS_OK))
 	{
-		if (pg_strcasecmp(port->hba->krb_realm, domainname) != 0)
+		/*
+		 * Compare realm/domain if requested. In SSPI, always compare case
+		 * insensitive.
+		 */
+		if (port->hba->krb_realm && strlen(port->hba->krb_realm) &&
+			(pg_strcasecmp(port->hba->krb_realm, domainname) != 0))
 		{
 			elog(DEBUG2,
-				 "SSPI domain (%s) and configured domain (%s) don't match",
-				 domainname, port->hba->krb_realm);
+					"SSPI domain (%s) and configured domain (%s) don't match",
+					domainname, port->hba->krb_realm);
+		}
+		else
+		{
+			/*
+			 * We have the username (without domain/realm) in accountname, compare to
+			 * the supplied value. In SSPI, always compare case insensitive.
+			 *
+			 * If set to include realm, append it in <username>@<realm> format.
+			 */
+			if (port->hba->include_realm)
+			{
+				char	   *namebuf;
 
-			return STATUS_ERROR;
+				namebuf = psprintf("%s@%s", accountname, domainname);
+				retval = check_usermap(port->hba->usermap, port->user_name, namebuf, true, token);
+				pfree(namebuf);
+			}
+			else
+			{
+				retval = check_usermap(port->hba->usermap, port->user_name, accountname, true, token);
+			}
 		}
 	}
 
-	/*
-	 * We have the username (without domain/realm) in accountname, compare to
-	 * the supplied value. In SSPI, always compare case insensitive.
-	 *
-	 * If set to include realm, append it in <username>@<realm> format.
-	 */
-	if (port->hba->include_realm)
-	{
-		char	   *namebuf;
-		int			retval;
+	CloseHandle(token);
 
-		namebuf = psprintf("%s@%s", accountname, domainname);
-		retval = check_usermap(port->hba->usermap, port->user_name, namebuf, true);
-		pfree(namebuf);
-		return retval;
-	}
-	else
-		return check_usermap(port->hba->usermap, port->user_name, accountname, true);
+	return retval;
 }
 
 /*
@@ -1972,7 +1967,7 @@ ident_inet_done:
 
 	if (ident_return)
 		/* Success! Check the usermap */
-		return check_usermap(port->hba->usermap, port->user_name, ident_user, false);
+		return check_usermap(port->hba->usermap, port->user_name, ident_user, false, NULL);
 	return STATUS_ERROR;
 }
 
@@ -2031,7 +2026,7 @@ auth_peer(hbaPort *port)
 	/* Make a copy of static getpw*() result area. */
 	peer_user = pstrdup(pw->pw_name);
 
-	ret = check_usermap(port->hba->usermap, port->user_name, peer_user, false);
+	ret = check_usermap(port->hba->usermap, port->user_name, peer_user, false, NULL);
 
 	pfree(peer_user);
 
@@ -2883,7 +2878,7 @@ CheckCertAuth(Port *port)
 	}
 
 	/* Just pass the certificate cn to the usermap check */
-	status_check_usermap = check_usermap(port->hba->usermap, port->user_name, port->peer_cn, false);
+	status_check_usermap = check_usermap(port->hba->usermap, port->user_name, port->peer_cn, false, NULL);
 	if (status_check_usermap != STATUS_OK)
 	{
 		/*
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index 4c86fb6087..b5a38cf26e 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -2797,6 +2797,101 @@ parse_ident_line(TokenizedLine *tok_line)
 	return parsedline;
 }
 
+#ifdef ENABLE_SSPI
+
+/*
+ * Get the sid for an account name
+ */
+static PSID
+lookup_account_name(LPCTSTR accountName, LPDWORD accountSidSize, LPDWORD domainNameCharCount, LPDWORD lastError)
+{
+	PSID			accountSid;
+	LPCTSTR			domainName;
+	SID_NAME_USE	sidType;
+	BOOL			lookupResult;
+
+	accountSid = malloc(*accountSidSize);
+
+	if (accountSid == NULL)
+	{
+		ereport(ERROR, (errcode(ERRCODE_OUT_OF_MEMORY), errmsg("out of memory")));
+	}
+
+	domainName = malloc(*domainNameCharCount * sizeof(TCHAR));
+
+	if (domainName == NULL)
+	{
+		ereport(ERROR, (errcode(ERRCODE_OUT_OF_MEMORY), errmsg("out of memory")));
+	}
+
+	lookupResult = LookupAccountName(NULL, accountName, accountSid, accountSidSize, domainName, domainNameCharCount, &sidType);
+	*lastError = GetLastError();
+
+	free(domainName);
+
+	if (!lookupResult && (accountSid != NULL))
+	{
+		free(accountSid);
+
+		accountSid = NULL;
+	}
+
+	return accountSid;
+}
+
+/*
+ * Check if the user (sspiToken) is a member of the specified group
+ */
+static bool
+sspi_user_is_in_group(HANDLE sspiToken, LPCTSTR groupName)
+{
+	BOOL		isMember = FALSE;
+	DWORD		groupSidSize = 1024;
+	DWORD		domainNameCharCount = 1024;
+	PSID		groupSid;
+	DWORD		lastError;
+
+	// try a default buffer size. if this doesn't work, use the returned buffer size
+	groupSid = lookup_account_name(groupName, &groupSidSize, &domainNameCharCount, &lastError);
+
+	if (groupSid == NULL)
+	{
+		if (lastError == 122)
+		{
+			elog(DEBUG2, "larger buffer required to get sid, groupSidSize=%lu, domainNameCharCount=%lu", groupSidSize, domainNameCharCount);
+
+			groupSid = lookup_account_name(groupName, &groupSidSize, &domainNameCharCount, &lastError);
+
+			if (groupSid == NULL)
+			{
+				elog(DEBUG2, "could not get sid on second attempt: error=%lu", lastError);
+			}
+		}
+		else
+		{
+			elog(DEBUG2, "could not get sid on first attempt: error=%lu", lastError);
+		}
+	}
+
+	if (groupSid != NULL)
+	{
+		if (CheckTokenMembership(sspiToken, groupSid, &isMember))
+		{
+			elog(DEBUG4, "check group membership groupName=%s, isMember=%i", groupName, isMember);
+		}
+		else
+		{
+			elog(DEBUG2, "could not check group membership: error=%lu", lastError);
+		}
+
+		free(groupSid);
+	}
+
+	return (isMember == TRUE);
+}
+
+#endif		/* ENABLE_SSPI */
+
 /*
  *	Process one line from the parsed ident config lines.
  *
@@ -2806,7 +2901,8 @@ parse_ident_line(TokenizedLine *tok_line)
 static void
 check_ident_usermap(IdentLine *identLine, const char *usermap_name,
 					const char *pg_role, const char *ident_user,
-					bool case_insensitive, bool *found_p, bool *error_p)
+					bool case_insensitive, bool *found_p, bool *error_p,
+					void *sspi_token)
 {
 	*found_p = false;
 	*error_p = false;
@@ -2906,6 +3002,21 @@ check_ident_usermap(IdentLine *identLine, const char *usermap_name,
 
 		return;
 	}
+#ifdef ENABLE_SSPI
+	else if (identLine->ident_user[0] == '+')
+	{
+		if (case_insensitive)
+		{
+			if (pg_strcasecmp(identLine->pg_role, pg_role) == 0)
+				*found_p = sspi_user_is_in_group(sspi_token, identLine->ident_user + 1);
+		}
+		else
+		{
+			if (strcmp(identLine->pg_role, pg_role) == 0)
+				*found_p = sspi_user_is_in_group(sspi_token, identLine->ident_user + 1);
+		}
+	}
+#endif   /* ENABLE_SSPI */
 	else
 	{
 		/* Not regular expression, so make complete match */
@@ -2942,7 +3053,8 @@ int
 check_usermap(const char *usermap_name,
 			  const char *pg_role,
 			  const char *auth_user,
-			  bool case_insensitive)
+			  bool case_insensitive,
+			  void *sspi_token)
 {
 	bool		found_entry = false,
 				error = false;
@@ -2972,7 +3084,7 @@ check_usermap(const char *usermap_name,
 		{
 			check_ident_usermap(lfirst(line_cell), usermap_name,
 								pg_role, auth_user, case_insensitive,
-								&found_entry, &error);
+								&found_entry, &error, sspi_token);
 			if (found_entry || error)
 				break;
 		}
diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h
index d638479d88..5faf4deca7 100644
--- a/src/include/libpq/hba.h
+++ b/src/include/libpq/hba.h
@@ -128,7 +128,7 @@ extern bool load_ident(void);
 extern void hba_getauthmethod(hbaPort *port);
 extern int	check_usermap(const char *usermap_name,
 						  const char *pg_role, const char *auth_user,
-						  bool case_sensitive);
+						  bool case_sensitive, void *sspi_token);
 extern bool pg_isblank(const char c);
 
 #endif							/* HBA_H */
-- 
2.16.1.windows.4

