This is an automated email from the ASF dual-hosted git repository.
BewareMyPower pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pulsar-client-cpp.git
The following commit(s) were added to refs/heads/main by this push:
new 3ea3ff1 [improve][client] Implement tls_client_auth for AuthOauth2
(#575)
3ea3ff1 is described below
commit 3ea3ff19ca3c57b0cf185e30e8d115c48437da91
Author: Hideaki Oguni <[email protected]>
AuthorDate: Tue May 26 21:25:52 2026 +0900
[improve][client] Implement tls_client_auth for AuthOauth2 (#575)
---
include/pulsar/Authentication.h | 20 ++-
lib/auth/AuthOauth2.cc | 352 +++++++++++++++++++++++++-----------
lib/auth/AuthOauth2.h | 30 ++++
tests/AuthPluginTest.cc | 382 +++++++++++++++++++++++++++++++++++++++-
4 files changed, 673 insertions(+), 111 deletions(-)
diff --git a/include/pulsar/Authentication.h b/include/pulsar/Authentication.h
index 34b70cd..a6a02b8 100644
--- a/include/pulsar/Authentication.h
+++ b/include/pulsar/Authentication.h
@@ -515,11 +515,22 @@ typedef std::shared_ptr<CachedToken> CachedTokenPtr;
* Passed in parameter would be like:
* ```
* "type": "client_credentials",
+ * "tokenEndpointAuthMethod": "client_secret_post",
* "issuer_url": "https://accounts.google.com",
* "client_id": "d9ZyX97q1ef8Cr81WHVC4hFQ64vSlDK3",
* "client_secret": "on1uJ...k6F6R",
* "audience": "https://broker.example.com"
* ```
+ *
+ * For `tokenEndpointAuthMethod = "tls_client_auth"`:
+ * ```
+ * "type": "client_credentials",
+ * "tokenEndpointAuthMethod": "tls_client_auth",
+ * "issuer_url": "https://accounts.google.com",
+ * "client_id": "d9ZyX97q1ef8Cr81WHVC4hFQ64vSlDK3",
+ * "tls_cert_file": "/path/to/cert.pem",
+ * "tls_key_file": "/path/to/key.pem"
+ * ```
* If passed in as std::string, it should be in Json format.
*/
class PULSAR_PUBLIC AuthOauth2 : public Authentication {
@@ -530,7 +541,14 @@ class PULSAR_PUBLIC AuthOauth2 : public Authentication {
/**
* Create an AuthOauth2 with a ParamMap
*
- * The required parameter keys are “issuer_url”, “private_key”, and
“audience”
+ * For `tokenEndpointAuthMethod = "client_secret_post"` (default), the
required parameter
+ * keys are “issuer_url”, “private_key”, and “audience”.
+ * Optional keys: `scope`, `tls_cert_file`, `tls_key_file`.
+ *
+ * For `tokenEndpointAuthMethod = "tls_client_auth"`, the required
parameter keys are
+ * `issuer_url`, `tls_cert_file`, and `tls_key_file`.
+ * Optional keys: `client_id`, `audience`, `scope`. If `client_id` is
omitted, the client
+ * uses `pulsar-client`.
*
* @param parameters the key-value to create OAuth 2.0 client credentials
* @see
http://pulsar.apache.org/docs/en/security-oauth2/#client-credentials
diff --git a/lib/auth/AuthOauth2.cc b/lib/auth/AuthOauth2.cc
index 9573496..94d9fc6 100644
--- a/lib/auth/AuthOauth2.cc
+++ b/lib/auth/AuthOauth2.cc
@@ -20,6 +20,7 @@
#include <boost/property_tree/json_parser.hpp>
#include <boost/property_tree/ptree.hpp>
+#include <cstdint>
#include <sstream>
#include <stdexcept>
@@ -31,6 +32,36 @@ DECLARE_LOG_OBJECT()
namespace pulsar {
+const std::string TlsClientAuthFlow::DEFAULT_CLIENT_ID = "pulsar-client";
+namespace {
+enum class OAuth2TokenEndpointAuthMethod : std::uint8_t
+{
+ ClientSecretPost,
+ TlsClientAuth,
+ Unknown,
+};
+
+OAuth2TokenEndpointAuthMethod parseTokenEndpointAuthMethod(const std::string&
authMethod) {
+ if (authMethod == "tls_client_auth") {
+ return OAuth2TokenEndpointAuthMethod::TlsClientAuth;
+ }
+ if (authMethod == "client_secret_post") {
+ return OAuth2TokenEndpointAuthMethod::ClientSecretPost;
+ }
+ return OAuth2TokenEndpointAuthMethod::Unknown;
+}
+
+std::string toFlowName(OAuth2TokenEndpointAuthMethod authMethod) {
+ switch (authMethod) {
+ case OAuth2TokenEndpointAuthMethod::TlsClientAuth:
+ return "TlsClientAuthFlow";
+ case OAuth2TokenEndpointAuthMethod::ClientSecretPost:
+ default:
+ return "ClientCredentialFlow";
+ }
+}
+} // namespace
+
// AuthDataOauth2
AuthDataOauth2::AuthDataOauth2(const std::string& accessToken) { accessToken_
= accessToken; }
@@ -111,6 +142,8 @@ bool Oauth2CachedToken::isExpired() { return expiresAt_ <
Clock::now(); }
Oauth2Flow::Oauth2Flow() {}
Oauth2Flow::~Oauth2Flow() {}
+static std::string buildClientCredentialsBody(CurlWrapper& curl, const
ParamMap& params);
+
KeyFile KeyFile::fromParamMap(ParamMap& params) {
const auto it = params.find("private_key");
if (it == params.cend()) {
@@ -199,80 +232,186 @@ KeyFile KeyFile::fromBase64(const std::string& encoded) {
}
}
-ClientCredentialFlow::ClientCredentialFlow(ParamMap& params)
- : issuerUrl_(params["issuer_url"]),
- keyFile_(KeyFile::fromParamMap(params)),
- audience_(params["audience"]),
- scope_(params["scope"]) {}
+static std::string getWellKnownUrl(const std::string& issuerUrl) {
+ std::string wellKnownUrl = issuerUrl;
+ if (!wellKnownUrl.empty() && wellKnownUrl.back() == '/') {
+ wellKnownUrl.pop_back();
+ }
+ wellKnownUrl.append("/.well-known/openid-configuration");
+ return wellKnownUrl;
+}
-std::string ClientCredentialFlow::getTokenEndPoint() const { return
tokenEndPoint_; }
+static std::unique_ptr<CurlWrapper::TlsContext> createTlsContext(const
std::string& tlsTrustCertsFilePath,
+ const
std::string& tlsCertFilePath,
+ const
std::string& tlsKeyFilePath) {
+ const bool hasTrustCerts = !tlsTrustCertsFilePath.empty();
+ const bool hasClientCertPair = !tlsCertFilePath.empty() &&
!tlsKeyFilePath.empty();
-void ClientCredentialFlow::initialize() {
- if (issuerUrl_.empty()) {
- LOG_ERROR("Failed to initialize ClientCredentialFlow: issuer_url is
not set");
- return;
+ if (!tlsCertFilePath.empty() != !tlsKeyFilePath.empty()) {
+ LOG_WARN("Ignore incomplete mTLS settings: both tls_cert_file and
tls_key_file are required");
}
- if (!keyFile_.isValid()) {
- return;
+ if (!hasTrustCerts && !hasClientCertPair) {
+ return nullptr;
}
- // set URL: well-know endpoint
- std::string wellKnownUrl = issuerUrl_;
- if (wellKnownUrl.back() == '/') {
- wellKnownUrl.pop_back();
+ auto tlsContext = std::unique_ptr<CurlWrapper::TlsContext>(new
CurlWrapper::TlsContext);
+ if (hasTrustCerts) {
+ tlsContext->trustCertsFilePath = tlsTrustCertsFilePath;
}
- wellKnownUrl.append("/.well-known/openid-configuration");
+ if (hasClientCertPair) {
+ tlsContext->certPath = tlsCertFilePath;
+ tlsContext->keyPath = tlsKeyFilePath;
+ }
+ return tlsContext;
+}
+static std::string fetchTokenEndpoint(const std::string& issuerUrl,
+ const CurlWrapper::TlsContext*
tlsContext) {
+ const auto wellKnownUrl = getWellKnownUrl(issuerUrl);
CurlWrapper curl;
if (!curl.init()) {
LOG_ERROR("Failed to initialize curl");
- return;
- }
- std::unique_ptr<CurlWrapper::TlsContext> tlsContext;
- if (!tlsTrustCertsFilePath_.empty()) {
- tlsContext.reset(new CurlWrapper::TlsContext);
- tlsContext->trustCertsFilePath = tlsTrustCertsFilePath_;
+ return "";
}
- auto result = curl.get(wellKnownUrl, "Accept: application/json", {},
tlsContext.get());
+ auto result = curl.get(wellKnownUrl, "Accept: application/json", {},
tlsContext);
if (!result.error.empty()) {
- LOG_ERROR("Failed to get the well-known configuration " << issuerUrl_
<< ": " << result.error);
- return;
+ LOG_ERROR("Failed to get the well-known configuration " << issuerUrl
<< ": " << result.error);
+ return "";
}
const auto res = result.code;
- const auto response_code = result.responseCode;
+ const auto responseCode = result.responseCode;
const auto& responseData = result.responseData;
const auto& errorBuffer = result.serverError;
switch (res) {
case CURLE_OK:
- LOG_DEBUG("Received well-known configuration data " << issuerUrl_
<< " code " << response_code);
- if (response_code == 200) {
+ LOG_DEBUG("Received well-known configuration data " << issuerUrl
<< " code " << responseCode);
+ if (responseCode == 200) {
boost::property_tree::ptree root;
std::stringstream stream;
stream << responseData;
try {
boost::property_tree::read_json(stream, root);
+ return root.get<std::string>("token_endpoint");
} catch (boost::property_tree::json_parser_error& e) {
LOG_ERROR("Failed to parse well-known configuration data
response: "
<< e.what() << "\nInput Json = " <<
responseData);
+ return "";
+ }
+ } else {
+ LOG_ERROR("Response failed for getting the well-known
configuration "
+ << issuerUrl << ". response Code " << responseCode);
+ }
+ break;
+ default:
+ LOG_ERROR("Response failed for getting the well-known
configuration "
+ << issuerUrl << ". Error Code " << res << ": " <<
errorBuffer);
+ break;
+ }
+ return "";
+}
+
+static Oauth2TokenResultPtr fetchOauth2Token(const std::string& tokenEndpoint,
const ParamMap& params,
+ const CurlWrapper::TlsContext*
tlsContext,
+ OAuth2TokenEndpointAuthMethod
authMethod) {
+ Oauth2TokenResultPtr resultPtr = Oauth2TokenResultPtr(new
Oauth2TokenResult());
+ if (tokenEndpoint.empty()) {
+ return resultPtr;
+ }
+
+ CurlWrapper curl;
+ if (!curl.init()) {
+ LOG_ERROR("Failed to initialize curl");
+ return resultPtr;
+ }
+
+ auto postData = buildClientCredentialsBody(curl, params);
+ if (postData.empty()) {
+ return resultPtr;
+ }
+ LOG_DEBUG("Generate URL encoded body for " << toFlowName(authMethod) << ":
" << postData);
+
+ CurlWrapper::Options options;
+ options.postFields = std::move(postData);
+ auto result =
+ curl.get(tokenEndpoint, "Content-Type:
application/x-www-form-urlencoded", options, tlsContext);
+ if (!result.error.empty()) {
+ LOG_ERROR("Failed to fetch OAuth2 token from " << tokenEndpoint << ":
" << result.error);
+ return resultPtr;
+ }
+
+ const auto res = result.code;
+ const auto responseCode = result.responseCode;
+ const auto& responseData = result.responseData;
+ const auto& errorBuffer = result.serverError;
+
+ switch (res) {
+ case CURLE_OK:
+ LOG_DEBUG("Response received for token endpoint " << tokenEndpoint
<< " code " << responseCode);
+ if (responseCode == 200) {
+ boost::property_tree::ptree root;
+ std::stringstream stream;
+ stream << responseData;
+ try {
+ boost::property_tree::read_json(stream, root);
+ } catch (boost::property_tree::json_parser_error& e) {
+ LOG_ERROR("Failed to parse json of Oauth2 response: " <<
e.what() << "\nInput Json = "
+ <<
responseData);
break;
}
- this->tokenEndPoint_ = root.get<std::string>("token_endpoint");
+
resultPtr->setAccessToken(root.get<std::string>("access_token", ""));
+ resultPtr->setExpiresIn(
+ root.get<uint32_t>("expires_in",
Oauth2TokenResult::undefined_expiration));
+
resultPtr->setRefreshToken(root.get<std::string>("refresh_token", ""));
+ resultPtr->setIdToken(root.get<std::string>("id_token", ""));
- LOG_DEBUG("Get token endpoint: " << this->tokenEndPoint_);
+ if (!resultPtr->getAccessToken().empty()) {
+ LOG_DEBUG("access_token: " << resultPtr->getAccessToken()
+ << " expires_in: " <<
resultPtr->getExpiresIn());
+ } else {
+ LOG_ERROR("Response doesn't contain access_token, the
response is: " << responseData);
+ }
} else {
- LOG_ERROR("Response failed for getting the well-known
configuration "
- << issuerUrl_ << ". response Code " <<
response_code);
+ LOG_ERROR("Response failed for token endpoint " <<
tokenEndpoint << ". response Code "
+ <<
responseCode);
}
break;
default:
- LOG_ERROR("Response failed for getting the well-known
configuration "
- << issuerUrl_ << ". Error Code " << res << ": " <<
errorBuffer);
+ LOG_ERROR("Response failed for token endpoint " << tokenEndpoint
<< ". ErrorCode " << res << ": "
+ << errorBuffer);
break;
}
+
+ return resultPtr;
+}
+
+ClientCredentialFlow::ClientCredentialFlow(ParamMap& params)
+ : issuerUrl_(params["issuer_url"]),
+ keyFile_(KeyFile::fromParamMap(params)),
+ audience_(params["audience"]),
+ scope_(params["scope"]),
+ tlsCertFilePath_(params["tls_cert_file"]),
+ tlsKeyFilePath_(params["tls_key_file"]) {}
+
+std::string ClientCredentialFlow::getTokenEndPoint() const { return
tokenEndPoint_; }
+
+void ClientCredentialFlow::initialize() {
+ if (issuerUrl_.empty()) {
+ LOG_ERROR("Failed to initialize ClientCredentialFlow: issuer_url is
not set");
+ return;
+ }
+ if (!keyFile_.isValid()) {
+ return;
+ }
+
+ const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_,
tlsCertFilePath_, tlsKeyFilePath_);
+ this->tokenEndPoint_ = fetchTokenEndpoint(issuerUrl_, tlsContext.get());
+ if (!this->tokenEndPoint_.empty()) {
+ LOG_DEBUG("Get token endpoint: " << this->tokenEndPoint_);
+ }
}
void ClientCredentialFlow::close() {}
@@ -324,85 +463,90 @@ static std::string
buildClientCredentialsBody(CurlWrapper& curl, const ParamMap&
Oauth2TokenResultPtr ClientCredentialFlow::authenticate() {
std::call_once(initializeOnce_, &ClientCredentialFlow::initialize, this);
- Oauth2TokenResultPtr resultPtr = Oauth2TokenResultPtr(new
Oauth2TokenResult());
- if (tokenEndPoint_.empty()) {
- return resultPtr;
+ const auto params = generateParamMap();
+ const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_,
tlsCertFilePath_, tlsKeyFilePath_);
+ return fetchOauth2Token(tokenEndPoint_, params, tlsContext.get(),
+ OAuth2TokenEndpointAuthMethod::ClientSecretPost);
+}
+
+TlsClientAuthFlow::TlsClientAuthFlow(ParamMap& params)
+ : issuerUrl_(params["issuer_url"]),
+ clientId_(params["client_id"].empty() ? DEFAULT_CLIENT_ID :
params["client_id"]),
+ audience_(params["audience"]),
+ scope_(params["scope"]),
+ tlsCertFilePath_(params["tls_cert_file"]),
+ tlsKeyFilePath_(params["tls_key_file"]) {}
+
+std::string TlsClientAuthFlow::getTokenEndPoint() const { return
tokenEndPoint_; }
+
+void TlsClientAuthFlow::initialize() {
+ if (issuerUrl_.empty()) {
+ LOG_ERROR("Failed to initialize TlsClientAuthFlow: issuer_url is not
set");
+ return;
+ }
+ if (tlsCertFilePath_.empty() || tlsKeyFilePath_.empty()) {
+ LOG_ERROR("Failed to initialize TlsClientAuthFlow: tls_cert_file or
tls_key_file is not set");
+ return;
}
- CurlWrapper curl;
- if (!curl.init()) {
- LOG_ERROR("Failed to initialize curl");
- return resultPtr;
+ const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_,
tlsCertFilePath_, tlsKeyFilePath_);
+ if (!tlsContext || tlsContext->certPath.empty() ||
tlsContext->keyPath.empty()) {
+ LOG_ERROR("Failed to initialize TlsClientAuthFlow: tls_cert_file or
tls_key_file is not set");
+ return;
}
- auto postData = buildClientCredentialsBody(curl, generateParamMap());
- if (postData.empty()) {
- return resultPtr;
+ this->tokenEndPoint_ = fetchTokenEndpoint(issuerUrl_, tlsContext.get());
+ if (!this->tokenEndPoint_.empty()) {
+ LOG_DEBUG("Get token endpoint: " << this->tokenEndPoint_);
}
- LOG_DEBUG("Generate URL encoded body for ClientCredentialFlow: " <<
postData);
+}
+void TlsClientAuthFlow::close() {}
- CurlWrapper::Options options;
- options.postFields = std::move(postData);
- std::unique_ptr<CurlWrapper::TlsContext> tlsContext;
- if (!tlsTrustCertsFilePath_.empty()) {
- tlsContext.reset(new CurlWrapper::TlsContext);
- tlsContext->trustCertsFilePath = tlsTrustCertsFilePath_;
+ParamMap TlsClientAuthFlow::generateParamMap() const {
+ ParamMap params;
+ params.emplace("grant_type", "client_credentials");
+ params.emplace("client_id", clientId_);
+ if (!audience_.empty()) {
+ params.emplace("audience", audience_);
}
- auto result = curl.get(tokenEndPoint_, "Content-Type:
application/x-www-form-urlencoded", options,
- tlsContext.get());
- if (!result.error.empty()) {
- LOG_ERROR("Failed to get the well-known configuration " << issuerUrl_
<< ": " << result.error);
- return resultPtr;
+ if (!scope_.empty()) {
+ params.emplace("scope", scope_);
}
- const auto res = result.code;
- const auto response_code = result.responseCode;
- const auto& responseData = result.responseData;
- const auto& errorBuffer = result.serverError;
+ return params;
+}
- switch (res) {
- case CURLE_OK:
- LOG_DEBUG("Response received for issuerurl " << issuerUrl_ << "
code " << response_code);
- if (response_code == 200) {
- boost::property_tree::ptree root;
- std::stringstream stream;
- stream << responseData;
- try {
- boost::property_tree::read_json(stream, root);
- } catch (boost::property_tree::json_parser_error& e) {
- LOG_ERROR("Failed to parse json of Oauth2 response: "
- << e.what() << "\nInput Json = " << responseData
<< " passedin: " << postData);
- break;
- }
+Oauth2TokenResultPtr TlsClientAuthFlow::authenticate() {
+ std::call_once(initializeOnce_, &TlsClientAuthFlow::initialize, this);
+ const auto params = generateParamMap();
+ const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_,
tlsCertFilePath_, tlsKeyFilePath_);
+ if (!tlsContext || tlsContext->certPath.empty() ||
tlsContext->keyPath.empty()) {
+ Oauth2TokenResultPtr resultPtr = Oauth2TokenResultPtr(new
Oauth2TokenResult());
+ return resultPtr;
+ }
+ return fetchOauth2Token(tokenEndPoint_, params, tlsContext.get(),
+ OAuth2TokenEndpointAuthMethod::TlsClientAuth);
+}
-
resultPtr->setAccessToken(root.get<std::string>("access_token", ""));
- resultPtr->setExpiresIn(
- root.get<uint32_t>("expires_in",
Oauth2TokenResult::undefined_expiration));
-
resultPtr->setRefreshToken(root.get<std::string>("refresh_token", ""));
- resultPtr->setIdToken(root.get<std::string>("id_token", ""));
+// AuthOauth2
- if (!resultPtr->getAccessToken().empty()) {
- LOG_DEBUG("access_token: " << resultPtr->getAccessToken()
- << " expires_in: " <<
resultPtr->getExpiresIn());
- } else {
- LOG_ERROR("Response doesn't contain access_token, the
response is: " << responseData);
- }
- } else {
- LOG_ERROR("Response failed for issuerurl " << issuerUrl_ << ".
response Code "
- << response_code <<
" passedin: " << postData);
- }
+AuthOauth2::AuthOauth2(ParamMap& params) {
+ std::string tokenEndpointAuthMethodName =
params["tokenEndpointAuthMethod"];
+ if (tokenEndpointAuthMethodName.empty()) {
+ tokenEndpointAuthMethodName = "client_secret_post";
+ }
+ const auto tokenEndpointAuthMethod =
parseTokenEndpointAuthMethod(tokenEndpointAuthMethodName);
+ switch (tokenEndpointAuthMethod) {
+ case OAuth2TokenEndpointAuthMethod::TlsClientAuth:
+ flowPtr_ = FlowPtr(new TlsClientAuthFlow(params));
break;
- default:
- LOG_ERROR("Response failed for issuerurl " << issuerUrl_ << ".
ErrorCode " << res << ": "
- << errorBuffer << "
passedin: " << postData);
+ case OAuth2TokenEndpointAuthMethod::ClientSecretPost:
+ flowPtr_ = FlowPtr(new ClientCredentialFlow(params));
break;
+ case OAuth2TokenEndpointAuthMethod::Unknown:
+ default:
+ throw std::invalid_argument("Unknown tokenEndpointAuthMethod: " +
tokenEndpointAuthMethodName);
}
-
- return resultPtr;
}
-// AuthOauth2
-
-AuthOauth2::AuthOauth2(ParamMap& params) : flowPtr_(new
ClientCredentialFlow(params)) {}
-
AuthOauth2::~AuthOauth2() {}
ParamMap parseJsonAuthParamsString(const std::string& authParamsString) {
@@ -436,11 +580,13 @@ const std::string AuthOauth2::getAuthMethodName() const {
return "token"; }
Result AuthOauth2::getAuthData(AuthenticationDataPtr& authDataContent) {
auto initialAuthData =
std::dynamic_pointer_cast<InitialAuthData>(authDataContent);
if (initialAuthData) {
- auto flowPtr =
std::dynamic_pointer_cast<ClientCredentialFlow>(flowPtr_);
- if (!flowPtr_) {
- throw std::invalid_argument("AuthOauth2::flowPtr_ is not a
ClientCredentialFlow");
+ if (auto clientCredentialFlow =
std::dynamic_pointer_cast<ClientCredentialFlow>(flowPtr_)) {
+
clientCredentialFlow->setTlsTrustCertsFilePath(initialAuthData->tlsTrustCertsFilePath_);
+ } else if (auto tlsClientAuthFlow =
std::dynamic_pointer_cast<TlsClientAuthFlow>(flowPtr_)) {
+
tlsClientAuthFlow->setTlsTrustCertsFilePath(initialAuthData->tlsTrustCertsFilePath_);
+ } else {
+ throw std::invalid_argument("AuthOauth2::flowPtr_ is not an OAuth2
flow implementation");
}
-
flowPtr->setTlsTrustCertsFilePath(initialAuthData->tlsTrustCertsFilePath_);
}
if (cachedTokenPtr_ == nullptr || cachedTokenPtr_->isExpired()) {
diff --git a/lib/auth/AuthOauth2.h b/lib/auth/AuthOauth2.h
index 035ad08..b402f37 100644
--- a/lib/auth/AuthOauth2.h
+++ b/lib/auth/AuthOauth2.h
@@ -71,6 +71,36 @@ class ClientCredentialFlow : public Oauth2Flow {
const KeyFile keyFile_;
const std::string audience_;
const std::string scope_;
+ const std::string tlsCertFilePath_;
+ const std::string tlsKeyFilePath_;
+ std::string tlsTrustCertsFilePath_;
+ std::once_flag initializeOnce_;
+};
+
+class TlsClientAuthFlow : public Oauth2Flow {
+ public:
+ static const std::string DEFAULT_CLIENT_ID;
+
+ TlsClientAuthFlow(ParamMap& params);
+ void initialize();
+ Oauth2TokenResultPtr authenticate();
+ void close();
+
+ ParamMap generateParamMap() const;
+ std::string getTokenEndPoint() const;
+
+ void setTlsTrustCertsFilePath(const std::string& tlsTrustCertsFilePath) {
+ tlsTrustCertsFilePath_ = tlsTrustCertsFilePath;
+ }
+
+ private:
+ std::string tokenEndPoint_;
+ const std::string issuerUrl_;
+ const std::string clientId_;
+ const std::string audience_;
+ const std::string scope_;
+ const std::string tlsCertFilePath_;
+ const std::string tlsKeyFilePath_;
std::string tlsTrustCertsFilePath_;
std::once_flag initializeOnce_;
};
diff --git a/tests/AuthPluginTest.cc b/tests/AuthPluginTest.cc
index 6c6b898..7ab151c 100644
--- a/tests/AuthPluginTest.cc
+++ b/tests/AuthPluginTest.cc
@@ -22,11 +22,17 @@
#include <array>
#include <boost/algorithm/string.hpp>
+#include <chrono>
+#include <future>
+#include <mutex>
#include <sstream>
+#include <stdexcept>
#ifdef USE_ASIO
#include <asio.hpp>
+#include <asio/ssl.hpp>
#else
#include <boost/asio.hpp>
+#include <boost/asio/ssl.hpp>
#endif
#include <thread>
@@ -36,6 +42,7 @@
#include "lib/LogUtils.h"
#include "lib/Utils.h"
#include "lib/auth/AuthOauth2.h"
+#include "lib/auth/InitialAuthData.h"
DECLARE_LOG_OBJECT()
using namespace pulsar;
@@ -59,6 +66,8 @@ static const std::string mimServiceUrlTls =
"pulsar+ssl://localhost:6653";
static const std::string mimServiceUrlHttps = "https://localhost:8444";
static const std::string mimCaPath = TEST_CONF_DIR
"/hn-verification/cacert.pem";
+static const std::string brokerPublicKeyPath = TEST_CONF_DIR
"/broker-cert.pem";
+static const std::string brokerPrivateKeyPath = TEST_CONF_DIR
"/broker-key.pem";
static void sendCallBackTls(Result r, const MessageId& msgId) {
ASSERT_EQ(r, ResultOk);
@@ -324,14 +333,12 @@ static std::vector<std::string> split(const std::string&
s, char separator) {
return tokens;
}
-namespace testAthenz {
-std::string principalToken;
-
// ASIO::ip::tcp::iostream could call a virtual function during destruction,
so the clang-tidy will fail by
// clang-analyzer-optin.cplusplus.VirtualCall. Here we write a simple stream
to read lines from socket.
+template <typename Stream>
class SocketStream {
public:
- SocketStream(ASIO::ip::tcp::socket& socket) : socket_(socket) {}
+ explicit SocketStream(Stream& stream) : stream_(stream) {}
bool getline(std::string& line) {
auto pos = buffer_.find('\n', bufferPos_);
@@ -343,7 +350,7 @@ class SocketStream {
std::array<char, 1024> buffer;
ASIO_ERROR error;
- auto length = socket_.read_some(ASIO::buffer(buffer.data(),
buffer.size()), error);
+ auto length = stream_.read_some(ASIO::buffer(buffer.data(),
buffer.size()), error);
if (error == ASIO::error::eof) {
return false;
} else if (error) {
@@ -362,12 +369,29 @@ class SocketStream {
return true;
}
+ bool readBytes(size_t size, std::string& out) {
+ while (buffer_.size() - bufferPos_ < size) {
+ std::array<char, 1024> buffer;
+ ASIO_ERROR error;
+ auto length = stream_.read_some(ASIO::buffer(buffer.data(),
buffer.size()), error);
+ if (error == ASIO::error::eof) return false;
+ if (error) return false;
+ buffer_.append(buffer.data(), length);
+ }
+ out.assign(buffer_.data() + bufferPos_, size);
+ bufferPos_ += size;
+ return true;
+ }
+
private:
- ASIO::ip::tcp::socket& socket_;
+ Stream& stream_;
std::string buffer_;
size_t bufferPos_{0};
};
+namespace testAthenz {
+std::string principalToken;
+
void mockZTS(Latch& latch, int port) {
LOG_INFO("-- MockZTS started");
ASIO::io_context io;
@@ -380,7 +404,7 @@ void mockZTS(Latch& latch, int port) {
LOG_INFO("-- MockZTS got connection");
std::string headerLine;
- SocketStream stream(socket);
+ SocketStream<ASIO::ip::tcp::socket> stream(socket);
while (stream.getline(headerLine)) {
if (headerLine.empty()) {
continue;
@@ -518,6 +542,142 @@ TEST(AuthPluginTest, testAuthFactoryAthenz) {
}
}
+namespace testOauth2Tls {
+static const auto mockServerTimeout = std::chrono::seconds(10);
+
+class MockOauth2Server {
+ public:
+ MockOauth2Server(const std::string& responseBody, const std::string&
responseContentType, int listenPort,
+ bool requireClientCert = true)
+ : responseBody_(responseBody),
+ responseContentType_(responseContentType),
+ acceptor_(io_, ASIO::ip::tcp::endpoint(ASIO::ip::tcp::v4(),
static_cast<uint16_t>(listenPort))),
+ sslCtx_(ASIO::ssl::context::sslv23) {
+ sslCtx_.set_options(ASIO::ssl::context::default_workarounds |
ASIO::ssl::context::no_sslv2 |
+ ASIO::ssl::context::no_sslv3);
+ sslCtx_.use_certificate_chain_file(brokerPublicKeyPath);
+ sslCtx_.use_private_key_file(brokerPrivateKeyPath,
ASIO::ssl::context::pem);
+ sslCtx_.load_verify_file(caPath);
+ sslCtx_.set_verify_mode(requireClientCert
+ ? (ASIO::ssl::verify_peer |
ASIO::ssl::verify_fail_if_no_peer_cert)
+ : ASIO::ssl::verify_none);
+ }
+
+ const std::string& request() const { return request_; }
+
+ bool mockServe() {
+ ASIO_ERROR error;
+ auto socket = std::make_shared<ASIO::ip::tcp::socket>(io_);
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ activeSocket_ = socket;
+ }
+
+ acceptor_.accept(*socket, error);
+ if (error) {
+ clearActiveSocket();
+ return false;
+ }
+
+ ASIO::ssl::stream<ASIO::ip::tcp::socket&> sslStream(*socket, sslCtx_);
+ sslStream.handshake(ASIO::ssl::stream_base::server, error);
+ if (error || !readRequest(sslStream)) {
+ clearActiveSocket();
+ return false;
+ }
+
+ const std::string response = "HTTP/1.1 200 OK\r\nContent-Type: " +
responseContentType_ +
+ "\r\nContent-Length: " +
std::to_string(responseBody_.size()) +
+ "\r\nConnection: close\r\n\r\n" +
responseBody_;
+ ASIO::write(sslStream, ASIO::buffer(response.data(), response.size()),
error);
+ clearActiveSocket();
+ if (error) return false;
+ return true;
+ }
+
+ void stop() {
+ ASIO_ERROR error;
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (acceptor_.is_open()) {
+ acceptor_.close(error);
+ }
+ if (activeSocket_ && activeSocket_->is_open()) {
+ activeSocket_->cancel(error);
+ activeSocket_->shutdown(ASIO::ip::tcp::socket::shutdown_both,
error);
+ activeSocket_->close(error);
+ }
+ }
+ io_.stop();
+ }
+
+ private:
+ void clearActiveSocket() {
+ std::lock_guard<std::mutex> lock(mutex_);
+ activeSocket_.reset();
+ }
+
+ bool readRequest(ASIO::ssl::stream<ASIO::ip::tcp::socket&>& sslStream) {
+ SocketStream<ASIO::ssl::stream<ASIO::ip::tcp::socket&>>
stream(sslStream);
+ request_.clear();
+ int contentLength = 0;
+ const std::string prefix = "Content-Length:";
+ std::string headerLine;
+ while (stream.getline(headerLine)) {
+ if (headerLine.empty()) {
+ continue;
+ }
+ request_.append(headerLine).append("\n");
+ if (headerLine.rfind(prefix, 0) == 0) {
+ contentLength = std::stoi(headerLine.substr(prefix.size()));
+ }
+ if (headerLine == "\r") {
+ break;
+ }
+ }
+ if (headerLine != "\r") return false;
+
+ if (contentLength > 0) {
+ std::string body;
+ if (!stream.readBytes(static_cast<size_t>(contentLength), body))
return false;
+ request_ += body;
+ }
+ return true;
+ }
+
+ const std::string responseBody_;
+ const std::string responseContentType_;
+ std::string request_;
+
+ ASIO::io_context io_;
+ ASIO::ip::tcp::acceptor acceptor_;
+ ASIO::ssl::context sslCtx_;
+ std::shared_ptr<ASIO::ip::tcp::socket> activeSocket_;
+ std::mutex mutex_;
+};
+
+static bool awaitMockServeResult(std::future<bool>& future, MockOauth2Server&
server, std::thread& thread,
+ const char* serverName) {
+ if (future.wait_for(mockServerTimeout) != std::future_status::ready) {
+ server.stop();
+ if (thread.joinable()) {
+ thread.join();
+ }
+ ADD_FAILURE() << serverName << " did not complete within "
+ <<
std::chrono::duration_cast<std::chrono::seconds>(mockServerTimeout).count()
+ << " seconds";
+ return false;
+ }
+
+ const bool result = future.get();
+ if (thread.joinable()) {
+ thread.join();
+ }
+ return result;
+}
+
+} // namespace testOauth2Tls
+
TEST(AuthPluginTest, testOauth2) {
// test success get token from oauth2 server.
pulsar::AuthenticationDataPtr data;
@@ -584,11 +744,15 @@ TEST(AuthPluginTest, testOauth2RequestBody) {
params["client_id"] = "Xd23RHsUnvUlP7wchjNYOaIfazgeHd9x";
params["client_secret"] =
"rT7ps7WY8uhdVuBTKWZkttwLdQotmdEliaM5rLfmgNibvqziZ-g07ZH52N_poGAb";
params["audience"] = "https://dev-kt-aa9ne.us.auth0.com/api/v2/";
+ params["tls_cert_file"] = "/path/to/cert.pem";
+ params["tls_key_file"] = "/path/to/key.pem";
auto createExpectedResult = [&] {
auto paramsCopy = params;
paramsCopy.emplace("grant_type", "client_credentials");
paramsCopy.erase("issuer_url");
+ paramsCopy.erase("tls_cert_file");
+ paramsCopy.erase("tls_key_file");
return paramsCopy;
};
@@ -668,6 +832,210 @@ TEST(AuthPluginTest, testOauth2Failure) {
client5.close();
}
+TEST(AuthPluginTest, testOauth2TlsClientAuth) {
+ const int tokenServerPort = 58081;
+ const int wellKnownServerPort = 58082;
+ const std::string tokenBody =
R"({"access_token":"mockToken","expires_in":3600,"token_type":"Bearer"})";
+ std::unique_ptr<testOauth2Tls::MockOauth2Server> tokenServer;
+ try {
+ tokenServer =
+ std::make_unique<testOauth2Tls::MockOauth2Server>(tokenBody,
"application/json", tokenServerPort);
+ } catch (const std::exception& e) {
+ FAIL() << "Failed to bind local mock token server: " << e.what();
+ }
+
+ std::promise<bool> tokenPromise;
+ auto tokenFuture = tokenPromise.get_future();
+ std::thread tokenThread(
+ [&tokenServer, &tokenPromise]() {
tokenPromise.set_value(tokenServer->mockServe()); });
+
+ std::ostringstream wellKnownBody;
+ wellKnownBody << R"({"token_endpoint":"https://localhost:)" <<
tokenServerPort << R"(/oauth/token"})";
+ std::unique_ptr<testOauth2Tls::MockOauth2Server> wellKnownServer;
+ try {
+ wellKnownServer = std::make_unique<testOauth2Tls::MockOauth2Server>(
+ wellKnownBody.str(), "application/json", wellKnownServerPort,
false);
+ } catch (const std::exception& e) {
+ tokenThread.join();
+ FAIL() << "Failed to bind local mock well-known server: " << e.what();
+ }
+
+ std::promise<bool> wellKnownPromise;
+ auto wellKnownFuture = wellKnownPromise.get_future();
+ std::thread wellKnownThread([&wellKnownServer, &wellKnownPromise]() {
+ wellKnownPromise.set_value(wellKnownServer->mockServe());
+ });
+
+ ParamMap params;
+ params["tokenEndpointAuthMethod"] = "tls_client_auth";
+ params["issuer_url"] = "https://localhost:" +
std::to_string(wellKnownServerPort);
+ params["client_id"] = "test-client";
+ params["tls_cert_file"] = clientPublicKeyPath;
+ params["tls_key_file"] = clientPrivateKeyPath;
+
+ AuthenticationDataPtr data =
+
std::static_pointer_cast<AuthenticationDataProvider>(std::make_shared<InitialAuthData>(caPath));
+ AuthenticationPtr auth = AuthOauth2::create(params);
+ ASSERT_EQ(auth->getAuthData(data), ResultOk);
+ ASSERT_TRUE(data->hasDataFromCommand());
+ ASSERT_EQ(data->getCommandData(), "mockToken");
+
+ ASSERT_TRUE(testOauth2Tls::awaitMockServeResult(wellKnownFuture,
*wellKnownServer, wellKnownThread,
+ "Well-known mock server"));
+ ASSERT_TRUE(
+ testOauth2Tls::awaitMockServeResult(tokenFuture, *tokenServer,
tokenThread, "Token mock server"));
+ ASSERT_NE(wellKnownServer->request().find("GET
/.well-known/openid-configuration "), std::string::npos);
+ ASSERT_NE(tokenServer->request().find("POST /oauth/token "),
std::string::npos);
+ ASSERT_NE(tokenServer->request().find("grant_type=client_credentials"),
std::string::npos);
+}
+
+TEST(AuthPluginTest, testOauth2TlsClientAuthWrongCert) {
+ const int tokenServerPort = 58083;
+ const int wellKnownServerPort = 58084;
+ const std::string tokenBody =
R"({"access_token":"mockToken","expires_in":3600,"token_type":"Bearer"})";
+
+ std::unique_ptr<testOauth2Tls::MockOauth2Server> tokenServer;
+ try {
+ tokenServer =
+ std::make_unique<testOauth2Tls::MockOauth2Server>(tokenBody,
"application/json", tokenServerPort);
+ } catch (const std::exception& e) {
+ FAIL() << "Failed to bind local mock token server: " << e.what();
+ }
+
+ std::promise<bool> tokenPromise;
+ auto tokenFuture = tokenPromise.get_future();
+ std::thread tokenThread(
+ [&tokenServer, &tokenPromise]() {
tokenPromise.set_value(tokenServer->mockServe()); });
+
+ std::ostringstream wellKnownBody;
+ wellKnownBody << R"({"token_endpoint":"https://localhost:)" <<
tokenServerPort << R"(/oauth/token"})";
+ std::unique_ptr<testOauth2Tls::MockOauth2Server> wellKnownServer;
+ try {
+ wellKnownServer = std::make_unique<testOauth2Tls::MockOauth2Server>(
+ wellKnownBody.str(), "application/json", wellKnownServerPort,
false);
+ } catch (const std::exception& e) {
+ tokenThread.join();
+ FAIL() << "Failed to bind local mock well-known server: " << e.what();
+ }
+
+ std::promise<bool> wellKnownPromise;
+ auto wellKnownFuture = wellKnownPromise.get_future();
+ std::thread wellKnownThread([&wellKnownServer, &wellKnownPromise]() {
+ wellKnownPromise.set_value(wellKnownServer->mockServe());
+ });
+
+ ParamMap params;
+ params["tokenEndpointAuthMethod"] = "tls_client_auth";
+ params["issuer_url"] = "https://localhost:" +
std::to_string(wellKnownServerPort);
+ params["client_id"] = "test-client";
+ // set wrong cert and key
+ params["tls_cert_file"] = TEST_CONF_DIR "/hn-verification/broker-cert.pem";
+ params["tls_key_file"] = TEST_CONF_DIR "/hn-verification/broker-key.pem";
+
+ AuthenticationDataPtr data =
+
std::static_pointer_cast<AuthenticationDataProvider>(std::make_shared<InitialAuthData>(caPath));
+ AuthenticationPtr auth = AuthOauth2::create(params);
+ ASSERT_EQ(auth->getAuthData(data), ResultAuthenticationError);
+
+ ASSERT_TRUE(testOauth2Tls::awaitMockServeResult(wellKnownFuture,
*wellKnownServer, wellKnownThread,
+ "Well-known mock server"));
+ ASSERT_FALSE(
+ testOauth2Tls::awaitMockServeResult(tokenFuture, *tokenServer,
tokenThread, "Token mock server"));
+ ASSERT_NE(wellKnownServer->request().find("GET
/.well-known/openid-configuration "), std::string::npos);
+}
+
+TEST(AuthPluginTest, testOauth2TlsClientAuthRequestBody) {
+ ParamMap params;
+ params["tokenEndpointAuthMethod"] = "tls_client_auth";
+ params["issuer_url"] = "https://dev-kt-aa9ne.us.auth0.com";
+ params["client_id"] = "Xd23RHsUnvUlP7wchjNYOaIfazgeHd9x";
+ params["audience"] = "https://dev-kt-aa9ne.us.auth0.com/api/v2/";
+ params["tls_cert_file"] = "/path/to/cert.pem";
+ params["tls_key_file"] = "/path/to/key.pem";
+
+ auto createExpectedResult = [&] {
+ auto paramsCopy = params;
+ paramsCopy.emplace("grant_type", "client_credentials");
+ paramsCopy.erase("tokenEndpointAuthMethod");
+ paramsCopy.erase("issuer_url");
+ paramsCopy.erase("tls_cert_file");
+ paramsCopy.erase("tls_key_file");
+ return paramsCopy;
+ };
+
+ const auto expectedResult1 = createExpectedResult();
+ TlsClientAuthFlow flow1(params);
+ ASSERT_EQ(flow1.generateParamMap(), expectedResult1);
+
+ params["scope"] = "test-scope";
+ const auto expectedResult2 = createExpectedResult();
+ TlsClientAuthFlow flow2(params);
+ ASSERT_EQ(flow2.generateParamMap(), expectedResult2);
+
+ params.erase("client_id");
+ auto expectedResult3 = expectedResult2;
+ expectedResult3["client_id"] = TlsClientAuthFlow::DEFAULT_CLIENT_ID;
+ TlsClientAuthFlow flow3(params);
+ ASSERT_EQ(flow3.generateParamMap(), expectedResult3);
+
+ params.erase("audience");
+ auto expectedResult4 = expectedResult3;
+ expectedResult4.erase("audience");
+ TlsClientAuthFlow flow4(params);
+ ASSERT_EQ(flow4.generateParamMap(), expectedResult4);
+}
+
+TEST(AuthPluginTest, testOauth2TlsClientAuthFailure) {
+ ParamMap params;
+ auto getAuthDataResult = [&]() -> Result {
+ AuthenticationDataPtr data =
+
std::static_pointer_cast<AuthenticationDataProvider>(std::make_shared<InitialAuthData>(caPath));
+ AuthenticationPtr auth = AuthOauth2::create(params);
+ return auth->getAuthData(data);
+ };
+
+ params["tokenEndpointAuthMethod"] = "tls_client_auth";
+ params["tls_cert_file"] = clientPublicKeyPath;
+ params["tls_key_file"] = clientPrivateKeyPath;
+
+ // No issuer_url
+ params.erase("issuer_url");
+ ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError);
+
+ // Invalid issuer_url
+ params["issuer_url"] = "hello";
+ ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError);
+
+ // No cert and key
+ params["issuer_url"] = "https://localhost:58086";
+ params.erase("tls_cert_file");
+ params.erase("tls_key_file");
+ ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError);
+
+ // Only cert
+ params["tls_cert_file"] = clientPublicKeyPath;
+ params.erase("tls_key_file");
+ ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError);
+
+ // Invalid cert and key
+ params["tls_cert_file"] = TEST_CONF_DIR "/not-exist-cert.pem";
+ params["tls_key_file"] = TEST_CONF_DIR "/not-exist-key.pem";
+ ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError);
+}
+
+TEST(AuthPluginTest, testOauth2UnknownTokenEndpointAuthMethod) {
+ std::string params = R"({
+ "type": "client_credentials",
+ "tokenEndpointAuthMethod": "client_secret_get",
+ "issuer_url": "https://dev-kt-aa9ne.us.auth0.com",
+ "client_id": "Xd23RHsUnvUlP7wchjNYOaIfazgeHd9x",
+ "client_secret":
"rT7ps7WY8uhdVuBTKWZkttwLdQotmdEliaM5rLfmgNibvqziZ-g07ZH52N_poGAb",
+ "audience": "https://dev-kt-aa9ne.us.auth0.com/api/v2/"})";
+
+ LOG_INFO("PARAMS: " << params);
+ ASSERT_THROW(AuthOauth2::create(params), std::invalid_argument);
+}
+
TEST(AuthPluginTest, testInvalidPlugin) {
Client client("pulsar://localhost:6650",
ClientConfiguration{}.setAuth(AuthFactory::create("invalid")));
Producer producer;