This is an automated email from the ASF dual-hosted git repository.

BewareMyPower pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pulsar-client-cpp.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ea3ff1  [improve][client] Implement tls_client_auth for AuthOauth2  
(#575)
3ea3ff1 is described below

commit 3ea3ff19ca3c57b0cf185e30e8d115c48437da91
Author: Hideaki Oguni <[email protected]>
AuthorDate: Tue May 26 21:25:52 2026 +0900

    [improve][client] Implement tls_client_auth for AuthOauth2  (#575)
---
 include/pulsar/Authentication.h |  20 ++-
 lib/auth/AuthOauth2.cc          | 352 +++++++++++++++++++++++++-----------
 lib/auth/AuthOauth2.h           |  30 ++++
 tests/AuthPluginTest.cc         | 382 +++++++++++++++++++++++++++++++++++++++-
 4 files changed, 673 insertions(+), 111 deletions(-)

diff --git a/include/pulsar/Authentication.h b/include/pulsar/Authentication.h
index 34b70cd..a6a02b8 100644
--- a/include/pulsar/Authentication.h
+++ b/include/pulsar/Authentication.h
@@ -515,11 +515,22 @@ typedef std::shared_ptr<CachedToken> CachedTokenPtr;
  * Passed in parameter would be like:
  * ```
  *   "type": "client_credentials",
+ *   "tokenEndpointAuthMethod": "client_secret_post",
  *   "issuer_url": "https://accounts.google.com";,
  *   "client_id": "d9ZyX97q1ef8Cr81WHVC4hFQ64vSlDK3",
  *   "client_secret": "on1uJ...k6F6R",
  *   "audience": "https://broker.example.com";
  *  ```
+ *
+ * For `tokenEndpointAuthMethod = "tls_client_auth"`:
+ * ```
+ *   "type": "client_credentials",
+ *   "tokenEndpointAuthMethod": "tls_client_auth",
+ *   "issuer_url": "https://accounts.google.com";,
+ *   "client_id": "d9ZyX97q1ef8Cr81WHVC4hFQ64vSlDK3",
+ *   "tls_cert_file": "/path/to/cert.pem",
+ *   "tls_key_file": "/path/to/key.pem"
+ * ```
  *  If passed in as std::string, it should be in Json format.
  */
 class PULSAR_PUBLIC AuthOauth2 : public Authentication {
@@ -530,7 +541,14 @@ class PULSAR_PUBLIC AuthOauth2 : public Authentication {
     /**
      * Create an AuthOauth2 with a ParamMap
      *
-     * The required parameter keys are “issuer_url”, “private_key”, and 
“audience”
+     * For `tokenEndpointAuthMethod = "client_secret_post"` (default), the 
required parameter
+     * keys are “issuer_url”, “private_key”, and “audience”.
+     * Optional keys: `scope`, `tls_cert_file`, `tls_key_file`.
+     *
+     * For `tokenEndpointAuthMethod = "tls_client_auth"`, the required 
parameter keys are
+     * `issuer_url`, `tls_cert_file`, and `tls_key_file`.
+     * Optional keys: `client_id`, `audience`, `scope`. If `client_id` is 
omitted, the client
+     * uses `pulsar-client`.
      *
      * @param parameters the key-value to create OAuth 2.0 client credentials
      * @see 
http://pulsar.apache.org/docs/en/security-oauth2/#client-credentials
diff --git a/lib/auth/AuthOauth2.cc b/lib/auth/AuthOauth2.cc
index 9573496..94d9fc6 100644
--- a/lib/auth/AuthOauth2.cc
+++ b/lib/auth/AuthOauth2.cc
@@ -20,6 +20,7 @@
 
 #include <boost/property_tree/json_parser.hpp>
 #include <boost/property_tree/ptree.hpp>
+#include <cstdint>
 #include <sstream>
 #include <stdexcept>
 
@@ -31,6 +32,36 @@ DECLARE_LOG_OBJECT()
 
 namespace pulsar {
 
+const std::string TlsClientAuthFlow::DEFAULT_CLIENT_ID = "pulsar-client";
+namespace {
+enum class OAuth2TokenEndpointAuthMethod : std::uint8_t
+{
+    ClientSecretPost,
+    TlsClientAuth,
+    Unknown,
+};
+
+OAuth2TokenEndpointAuthMethod parseTokenEndpointAuthMethod(const std::string& 
authMethod) {
+    if (authMethod == "tls_client_auth") {
+        return OAuth2TokenEndpointAuthMethod::TlsClientAuth;
+    }
+    if (authMethod == "client_secret_post") {
+        return OAuth2TokenEndpointAuthMethod::ClientSecretPost;
+    }
+    return OAuth2TokenEndpointAuthMethod::Unknown;
+}
+
+std::string toFlowName(OAuth2TokenEndpointAuthMethod authMethod) {
+    switch (authMethod) {
+        case OAuth2TokenEndpointAuthMethod::TlsClientAuth:
+            return "TlsClientAuthFlow";
+        case OAuth2TokenEndpointAuthMethod::ClientSecretPost:
+        default:
+            return "ClientCredentialFlow";
+    }
+}
+}  // namespace
+
 // AuthDataOauth2
 
 AuthDataOauth2::AuthDataOauth2(const std::string& accessToken) { accessToken_ 
= accessToken; }
@@ -111,6 +142,8 @@ bool Oauth2CachedToken::isExpired() { return expiresAt_ < 
Clock::now(); }
 Oauth2Flow::Oauth2Flow() {}
 Oauth2Flow::~Oauth2Flow() {}
 
+static std::string buildClientCredentialsBody(CurlWrapper& curl, const 
ParamMap& params);
+
 KeyFile KeyFile::fromParamMap(ParamMap& params) {
     const auto it = params.find("private_key");
     if (it == params.cend()) {
@@ -199,80 +232,186 @@ KeyFile KeyFile::fromBase64(const std::string& encoded) {
     }
 }
 
-ClientCredentialFlow::ClientCredentialFlow(ParamMap& params)
-    : issuerUrl_(params["issuer_url"]),
-      keyFile_(KeyFile::fromParamMap(params)),
-      audience_(params["audience"]),
-      scope_(params["scope"]) {}
+static std::string getWellKnownUrl(const std::string& issuerUrl) {
+    std::string wellKnownUrl = issuerUrl;
+    if (!wellKnownUrl.empty() && wellKnownUrl.back() == '/') {
+        wellKnownUrl.pop_back();
+    }
+    wellKnownUrl.append("/.well-known/openid-configuration");
+    return wellKnownUrl;
+}
 
-std::string ClientCredentialFlow::getTokenEndPoint() const { return 
tokenEndPoint_; }
+static std::unique_ptr<CurlWrapper::TlsContext> createTlsContext(const 
std::string& tlsTrustCertsFilePath,
+                                                                 const 
std::string& tlsCertFilePath,
+                                                                 const 
std::string& tlsKeyFilePath) {
+    const bool hasTrustCerts = !tlsTrustCertsFilePath.empty();
+    const bool hasClientCertPair = !tlsCertFilePath.empty() && 
!tlsKeyFilePath.empty();
 
-void ClientCredentialFlow::initialize() {
-    if (issuerUrl_.empty()) {
-        LOG_ERROR("Failed to initialize ClientCredentialFlow: issuer_url is 
not set");
-        return;
+    if (!tlsCertFilePath.empty() != !tlsKeyFilePath.empty()) {
+        LOG_WARN("Ignore incomplete mTLS settings: both tls_cert_file and 
tls_key_file are required");
     }
-    if (!keyFile_.isValid()) {
-        return;
+    if (!hasTrustCerts && !hasClientCertPair) {
+        return nullptr;
     }
 
-    // set URL: well-know endpoint
-    std::string wellKnownUrl = issuerUrl_;
-    if (wellKnownUrl.back() == '/') {
-        wellKnownUrl.pop_back();
+    auto tlsContext = std::unique_ptr<CurlWrapper::TlsContext>(new 
CurlWrapper::TlsContext);
+    if (hasTrustCerts) {
+        tlsContext->trustCertsFilePath = tlsTrustCertsFilePath;
     }
-    wellKnownUrl.append("/.well-known/openid-configuration");
+    if (hasClientCertPair) {
+        tlsContext->certPath = tlsCertFilePath;
+        tlsContext->keyPath = tlsKeyFilePath;
+    }
+    return tlsContext;
+}
 
+static std::string fetchTokenEndpoint(const std::string& issuerUrl,
+                                      const CurlWrapper::TlsContext* 
tlsContext) {
+    const auto wellKnownUrl = getWellKnownUrl(issuerUrl);
     CurlWrapper curl;
     if (!curl.init()) {
         LOG_ERROR("Failed to initialize curl");
-        return;
-    }
-    std::unique_ptr<CurlWrapper::TlsContext> tlsContext;
-    if (!tlsTrustCertsFilePath_.empty()) {
-        tlsContext.reset(new CurlWrapper::TlsContext);
-        tlsContext->trustCertsFilePath = tlsTrustCertsFilePath_;
+        return "";
     }
 
-    auto result = curl.get(wellKnownUrl, "Accept: application/json", {}, 
tlsContext.get());
+    auto result = curl.get(wellKnownUrl, "Accept: application/json", {}, 
tlsContext);
     if (!result.error.empty()) {
-        LOG_ERROR("Failed to get the well-known configuration " << issuerUrl_ 
<< ": " << result.error);
-        return;
+        LOG_ERROR("Failed to get the well-known configuration " << issuerUrl 
<< ": " << result.error);
+        return "";
     }
 
     const auto res = result.code;
-    const auto response_code = result.responseCode;
+    const auto responseCode = result.responseCode;
     const auto& responseData = result.responseData;
     const auto& errorBuffer = result.serverError;
 
     switch (res) {
         case CURLE_OK:
-            LOG_DEBUG("Received well-known configuration data " << issuerUrl_ 
<< " code " << response_code);
-            if (response_code == 200) {
+            LOG_DEBUG("Received well-known configuration data " << issuerUrl 
<< " code " << responseCode);
+            if (responseCode == 200) {
                 boost::property_tree::ptree root;
                 std::stringstream stream;
                 stream << responseData;
                 try {
                     boost::property_tree::read_json(stream, root);
+                    return root.get<std::string>("token_endpoint");
                 } catch (boost::property_tree::json_parser_error& e) {
                     LOG_ERROR("Failed to parse well-known configuration data 
response: "
                               << e.what() << "\nInput Json = " << 
responseData);
+                    return "";
+                }
+            } else {
+                LOG_ERROR("Response failed for getting the well-known 
configuration "
+                          << issuerUrl << ". response Code " << responseCode);
+            }
+            break;
+        default:
+            LOG_ERROR("Response failed for getting the well-known 
configuration "
+                      << issuerUrl << ". Error Code " << res << ": " << 
errorBuffer);
+            break;
+    }
+    return "";
+}
+
+static Oauth2TokenResultPtr fetchOauth2Token(const std::string& tokenEndpoint, 
const ParamMap& params,
+                                             const CurlWrapper::TlsContext* 
tlsContext,
+                                             OAuth2TokenEndpointAuthMethod 
authMethod) {
+    Oauth2TokenResultPtr resultPtr = Oauth2TokenResultPtr(new 
Oauth2TokenResult());
+    if (tokenEndpoint.empty()) {
+        return resultPtr;
+    }
+
+    CurlWrapper curl;
+    if (!curl.init()) {
+        LOG_ERROR("Failed to initialize curl");
+        return resultPtr;
+    }
+
+    auto postData = buildClientCredentialsBody(curl, params);
+    if (postData.empty()) {
+        return resultPtr;
+    }
+    LOG_DEBUG("Generate URL encoded body for " << toFlowName(authMethod) << ": 
" << postData);
+
+    CurlWrapper::Options options;
+    options.postFields = std::move(postData);
+    auto result =
+        curl.get(tokenEndpoint, "Content-Type: 
application/x-www-form-urlencoded", options, tlsContext);
+    if (!result.error.empty()) {
+        LOG_ERROR("Failed to fetch OAuth2 token from " << tokenEndpoint << ": 
" << result.error);
+        return resultPtr;
+    }
+
+    const auto res = result.code;
+    const auto responseCode = result.responseCode;
+    const auto& responseData = result.responseData;
+    const auto& errorBuffer = result.serverError;
+
+    switch (res) {
+        case CURLE_OK:
+            LOG_DEBUG("Response received for token endpoint " << tokenEndpoint 
<< " code " << responseCode);
+            if (responseCode == 200) {
+                boost::property_tree::ptree root;
+                std::stringstream stream;
+                stream << responseData;
+                try {
+                    boost::property_tree::read_json(stream, root);
+                } catch (boost::property_tree::json_parser_error& e) {
+                    LOG_ERROR("Failed to parse json of Oauth2 response: " << 
e.what() << "\nInput Json = "
+                                                                          << 
responseData);
                     break;
                 }
 
-                this->tokenEndPoint_ = root.get<std::string>("token_endpoint");
+                
resultPtr->setAccessToken(root.get<std::string>("access_token", ""));
+                resultPtr->setExpiresIn(
+                    root.get<uint32_t>("expires_in", 
Oauth2TokenResult::undefined_expiration));
+                
resultPtr->setRefreshToken(root.get<std::string>("refresh_token", ""));
+                resultPtr->setIdToken(root.get<std::string>("id_token", ""));
 
-                LOG_DEBUG("Get token endpoint: " << this->tokenEndPoint_);
+                if (!resultPtr->getAccessToken().empty()) {
+                    LOG_DEBUG("access_token: " << resultPtr->getAccessToken()
+                                               << " expires_in: " << 
resultPtr->getExpiresIn());
+                } else {
+                    LOG_ERROR("Response doesn't contain access_token, the 
response is: " << responseData);
+                }
             } else {
-                LOG_ERROR("Response failed for getting the well-known 
configuration "
-                          << issuerUrl_ << ". response Code " << 
response_code);
+                LOG_ERROR("Response failed for token endpoint " << 
tokenEndpoint << ". response Code "
+                                                                << 
responseCode);
             }
             break;
         default:
-            LOG_ERROR("Response failed for getting the well-known 
configuration "
-                      << issuerUrl_ << ". Error Code " << res << ": " << 
errorBuffer);
+            LOG_ERROR("Response failed for token endpoint " << tokenEndpoint 
<< ". ErrorCode " << res << ": "
+                                                            << errorBuffer);
             break;
     }
+
+    return resultPtr;
+}
+
+ClientCredentialFlow::ClientCredentialFlow(ParamMap& params)
+    : issuerUrl_(params["issuer_url"]),
+      keyFile_(KeyFile::fromParamMap(params)),
+      audience_(params["audience"]),
+      scope_(params["scope"]),
+      tlsCertFilePath_(params["tls_cert_file"]),
+      tlsKeyFilePath_(params["tls_key_file"]) {}
+
+std::string ClientCredentialFlow::getTokenEndPoint() const { return 
tokenEndPoint_; }
+
+void ClientCredentialFlow::initialize() {
+    if (issuerUrl_.empty()) {
+        LOG_ERROR("Failed to initialize ClientCredentialFlow: issuer_url is 
not set");
+        return;
+    }
+    if (!keyFile_.isValid()) {
+        return;
+    }
+
+    const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_, 
tlsCertFilePath_, tlsKeyFilePath_);
+    this->tokenEndPoint_ = fetchTokenEndpoint(issuerUrl_, tlsContext.get());
+    if (!this->tokenEndPoint_.empty()) {
+        LOG_DEBUG("Get token endpoint: " << this->tokenEndPoint_);
+    }
 }
 void ClientCredentialFlow::close() {}
 
@@ -324,85 +463,90 @@ static std::string 
buildClientCredentialsBody(CurlWrapper& curl, const ParamMap&
 
 Oauth2TokenResultPtr ClientCredentialFlow::authenticate() {
     std::call_once(initializeOnce_, &ClientCredentialFlow::initialize, this);
-    Oauth2TokenResultPtr resultPtr = Oauth2TokenResultPtr(new 
Oauth2TokenResult());
-    if (tokenEndPoint_.empty()) {
-        return resultPtr;
+    const auto params = generateParamMap();
+    const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_, 
tlsCertFilePath_, tlsKeyFilePath_);
+    return fetchOauth2Token(tokenEndPoint_, params, tlsContext.get(),
+                            OAuth2TokenEndpointAuthMethod::ClientSecretPost);
+}
+
+TlsClientAuthFlow::TlsClientAuthFlow(ParamMap& params)
+    : issuerUrl_(params["issuer_url"]),
+      clientId_(params["client_id"].empty() ? DEFAULT_CLIENT_ID : 
params["client_id"]),
+      audience_(params["audience"]),
+      scope_(params["scope"]),
+      tlsCertFilePath_(params["tls_cert_file"]),
+      tlsKeyFilePath_(params["tls_key_file"]) {}
+
+std::string TlsClientAuthFlow::getTokenEndPoint() const { return 
tokenEndPoint_; }
+
+void TlsClientAuthFlow::initialize() {
+    if (issuerUrl_.empty()) {
+        LOG_ERROR("Failed to initialize TlsClientAuthFlow: issuer_url is not 
set");
+        return;
+    }
+    if (tlsCertFilePath_.empty() || tlsKeyFilePath_.empty()) {
+        LOG_ERROR("Failed to initialize TlsClientAuthFlow: tls_cert_file or 
tls_key_file is not set");
+        return;
     }
 
-    CurlWrapper curl;
-    if (!curl.init()) {
-        LOG_ERROR("Failed to initialize curl");
-        return resultPtr;
+    const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_, 
tlsCertFilePath_, tlsKeyFilePath_);
+    if (!tlsContext || tlsContext->certPath.empty() || 
tlsContext->keyPath.empty()) {
+        LOG_ERROR("Failed to initialize TlsClientAuthFlow: tls_cert_file or 
tls_key_file is not set");
+        return;
     }
-    auto postData = buildClientCredentialsBody(curl, generateParamMap());
-    if (postData.empty()) {
-        return resultPtr;
+    this->tokenEndPoint_ = fetchTokenEndpoint(issuerUrl_, tlsContext.get());
+    if (!this->tokenEndPoint_.empty()) {
+        LOG_DEBUG("Get token endpoint: " << this->tokenEndPoint_);
     }
-    LOG_DEBUG("Generate URL encoded body for ClientCredentialFlow: " << 
postData);
+}
+void TlsClientAuthFlow::close() {}
 
-    CurlWrapper::Options options;
-    options.postFields = std::move(postData);
-    std::unique_ptr<CurlWrapper::TlsContext> tlsContext;
-    if (!tlsTrustCertsFilePath_.empty()) {
-        tlsContext.reset(new CurlWrapper::TlsContext);
-        tlsContext->trustCertsFilePath = tlsTrustCertsFilePath_;
+ParamMap TlsClientAuthFlow::generateParamMap() const {
+    ParamMap params;
+    params.emplace("grant_type", "client_credentials");
+    params.emplace("client_id", clientId_);
+    if (!audience_.empty()) {
+        params.emplace("audience", audience_);
     }
-    auto result = curl.get(tokenEndPoint_, "Content-Type: 
application/x-www-form-urlencoded", options,
-                           tlsContext.get());
-    if (!result.error.empty()) {
-        LOG_ERROR("Failed to get the well-known configuration " << issuerUrl_ 
<< ": " << result.error);
-        return resultPtr;
+    if (!scope_.empty()) {
+        params.emplace("scope", scope_);
     }
-    const auto res = result.code;
-    const auto response_code = result.responseCode;
-    const auto& responseData = result.responseData;
-    const auto& errorBuffer = result.serverError;
+    return params;
+}
 
-    switch (res) {
-        case CURLE_OK:
-            LOG_DEBUG("Response received for issuerurl " << issuerUrl_ << " 
code " << response_code);
-            if (response_code == 200) {
-                boost::property_tree::ptree root;
-                std::stringstream stream;
-                stream << responseData;
-                try {
-                    boost::property_tree::read_json(stream, root);
-                } catch (boost::property_tree::json_parser_error& e) {
-                    LOG_ERROR("Failed to parse json of Oauth2 response: "
-                              << e.what() << "\nInput Json = " << responseData 
<< " passedin: " << postData);
-                    break;
-                }
+Oauth2TokenResultPtr TlsClientAuthFlow::authenticate() {
+    std::call_once(initializeOnce_, &TlsClientAuthFlow::initialize, this);
+    const auto params = generateParamMap();
+    const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_, 
tlsCertFilePath_, tlsKeyFilePath_);
+    if (!tlsContext || tlsContext->certPath.empty() || 
tlsContext->keyPath.empty()) {
+        Oauth2TokenResultPtr resultPtr = Oauth2TokenResultPtr(new 
Oauth2TokenResult());
+        return resultPtr;
+    }
+    return fetchOauth2Token(tokenEndPoint_, params, tlsContext.get(),
+                            OAuth2TokenEndpointAuthMethod::TlsClientAuth);
+}
 
-                
resultPtr->setAccessToken(root.get<std::string>("access_token", ""));
-                resultPtr->setExpiresIn(
-                    root.get<uint32_t>("expires_in", 
Oauth2TokenResult::undefined_expiration));
-                
resultPtr->setRefreshToken(root.get<std::string>("refresh_token", ""));
-                resultPtr->setIdToken(root.get<std::string>("id_token", ""));
+// AuthOauth2
 
-                if (!resultPtr->getAccessToken().empty()) {
-                    LOG_DEBUG("access_token: " << resultPtr->getAccessToken()
-                                               << " expires_in: " << 
resultPtr->getExpiresIn());
-                } else {
-                    LOG_ERROR("Response doesn't contain access_token, the 
response is: " << responseData);
-                }
-            } else {
-                LOG_ERROR("Response failed for issuerurl " << issuerUrl_ << ". 
response Code "
-                                                           << response_code << 
" passedin: " << postData);
-            }
+AuthOauth2::AuthOauth2(ParamMap& params) {
+    std::string tokenEndpointAuthMethodName = 
params["tokenEndpointAuthMethod"];
+    if (tokenEndpointAuthMethodName.empty()) {
+        tokenEndpointAuthMethodName = "client_secret_post";
+    }
+    const auto tokenEndpointAuthMethod = 
parseTokenEndpointAuthMethod(tokenEndpointAuthMethodName);
+    switch (tokenEndpointAuthMethod) {
+        case OAuth2TokenEndpointAuthMethod::TlsClientAuth:
+            flowPtr_ = FlowPtr(new TlsClientAuthFlow(params));
             break;
-        default:
-            LOG_ERROR("Response failed for issuerurl " << issuerUrl_ << ". 
ErrorCode " << res << ": "
-                                                       << errorBuffer << " 
passedin: " << postData);
+        case OAuth2TokenEndpointAuthMethod::ClientSecretPost:
+            flowPtr_ = FlowPtr(new ClientCredentialFlow(params));
             break;
+        case OAuth2TokenEndpointAuthMethod::Unknown:
+        default:
+            throw std::invalid_argument("Unknown tokenEndpointAuthMethod: " + 
tokenEndpointAuthMethodName);
     }
-
-    return resultPtr;
 }
 
-// AuthOauth2
-
-AuthOauth2::AuthOauth2(ParamMap& params) : flowPtr_(new 
ClientCredentialFlow(params)) {}
-
 AuthOauth2::~AuthOauth2() {}
 
 ParamMap parseJsonAuthParamsString(const std::string& authParamsString) {
@@ -436,11 +580,13 @@ const std::string AuthOauth2::getAuthMethodName() const { 
return "token"; }
 Result AuthOauth2::getAuthData(AuthenticationDataPtr& authDataContent) {
     auto initialAuthData = 
std::dynamic_pointer_cast<InitialAuthData>(authDataContent);
     if (initialAuthData) {
-        auto flowPtr = 
std::dynamic_pointer_cast<ClientCredentialFlow>(flowPtr_);
-        if (!flowPtr_) {
-            throw std::invalid_argument("AuthOauth2::flowPtr_ is not a 
ClientCredentialFlow");
+        if (auto clientCredentialFlow = 
std::dynamic_pointer_cast<ClientCredentialFlow>(flowPtr_)) {
+            
clientCredentialFlow->setTlsTrustCertsFilePath(initialAuthData->tlsTrustCertsFilePath_);
+        } else if (auto tlsClientAuthFlow = 
std::dynamic_pointer_cast<TlsClientAuthFlow>(flowPtr_)) {
+            
tlsClientAuthFlow->setTlsTrustCertsFilePath(initialAuthData->tlsTrustCertsFilePath_);
+        } else {
+            throw std::invalid_argument("AuthOauth2::flowPtr_ is not an OAuth2 
flow implementation");
         }
-        
flowPtr->setTlsTrustCertsFilePath(initialAuthData->tlsTrustCertsFilePath_);
     }
 
     if (cachedTokenPtr_ == nullptr || cachedTokenPtr_->isExpired()) {
diff --git a/lib/auth/AuthOauth2.h b/lib/auth/AuthOauth2.h
index 035ad08..b402f37 100644
--- a/lib/auth/AuthOauth2.h
+++ b/lib/auth/AuthOauth2.h
@@ -71,6 +71,36 @@ class ClientCredentialFlow : public Oauth2Flow {
     const KeyFile keyFile_;
     const std::string audience_;
     const std::string scope_;
+    const std::string tlsCertFilePath_;
+    const std::string tlsKeyFilePath_;
+    std::string tlsTrustCertsFilePath_;
+    std::once_flag initializeOnce_;
+};
+
+class TlsClientAuthFlow : public Oauth2Flow {
+   public:
+    static const std::string DEFAULT_CLIENT_ID;
+
+    TlsClientAuthFlow(ParamMap& params);
+    void initialize();
+    Oauth2TokenResultPtr authenticate();
+    void close();
+
+    ParamMap generateParamMap() const;
+    std::string getTokenEndPoint() const;
+
+    void setTlsTrustCertsFilePath(const std::string& tlsTrustCertsFilePath) {
+        tlsTrustCertsFilePath_ = tlsTrustCertsFilePath;
+    }
+
+   private:
+    std::string tokenEndPoint_;
+    const std::string issuerUrl_;
+    const std::string clientId_;
+    const std::string audience_;
+    const std::string scope_;
+    const std::string tlsCertFilePath_;
+    const std::string tlsKeyFilePath_;
     std::string tlsTrustCertsFilePath_;
     std::once_flag initializeOnce_;
 };
diff --git a/tests/AuthPluginTest.cc b/tests/AuthPluginTest.cc
index 6c6b898..7ab151c 100644
--- a/tests/AuthPluginTest.cc
+++ b/tests/AuthPluginTest.cc
@@ -22,11 +22,17 @@
 
 #include <array>
 #include <boost/algorithm/string.hpp>
+#include <chrono>
+#include <future>
+#include <mutex>
 #include <sstream>
+#include <stdexcept>
 #ifdef USE_ASIO
 #include <asio.hpp>
+#include <asio/ssl.hpp>
 #else
 #include <boost/asio.hpp>
+#include <boost/asio/ssl.hpp>
 #endif
 #include <thread>
 
@@ -36,6 +42,7 @@
 #include "lib/LogUtils.h"
 #include "lib/Utils.h"
 #include "lib/auth/AuthOauth2.h"
+#include "lib/auth/InitialAuthData.h"
 DECLARE_LOG_OBJECT()
 
 using namespace pulsar;
@@ -59,6 +66,8 @@ static const std::string mimServiceUrlTls = 
"pulsar+ssl://localhost:6653";
 static const std::string mimServiceUrlHttps = "https://localhost:8444";;
 
 static const std::string mimCaPath = TEST_CONF_DIR 
"/hn-verification/cacert.pem";
+static const std::string brokerPublicKeyPath = TEST_CONF_DIR 
"/broker-cert.pem";
+static const std::string brokerPrivateKeyPath = TEST_CONF_DIR 
"/broker-key.pem";
 
 static void sendCallBackTls(Result r, const MessageId& msgId) {
     ASSERT_EQ(r, ResultOk);
@@ -324,14 +333,12 @@ static std::vector<std::string> split(const std::string& 
s, char separator) {
     return tokens;
 }
 
-namespace testAthenz {
-std::string principalToken;
-
 // ASIO::ip::tcp::iostream could call a virtual function during destruction, 
so the clang-tidy will fail by
 // clang-analyzer-optin.cplusplus.VirtualCall. Here we write a simple stream 
to read lines from socket.
+template <typename Stream>
 class SocketStream {
    public:
-    SocketStream(ASIO::ip::tcp::socket& socket) : socket_(socket) {}
+    explicit SocketStream(Stream& stream) : stream_(stream) {}
 
     bool getline(std::string& line) {
         auto pos = buffer_.find('\n', bufferPos_);
@@ -343,7 +350,7 @@ class SocketStream {
 
         std::array<char, 1024> buffer;
         ASIO_ERROR error;
-        auto length = socket_.read_some(ASIO::buffer(buffer.data(), 
buffer.size()), error);
+        auto length = stream_.read_some(ASIO::buffer(buffer.data(), 
buffer.size()), error);
         if (error == ASIO::error::eof) {
             return false;
         } else if (error) {
@@ -362,12 +369,29 @@ class SocketStream {
         return true;
     }
 
+    bool readBytes(size_t size, std::string& out) {
+        while (buffer_.size() - bufferPos_ < size) {
+            std::array<char, 1024> buffer;
+            ASIO_ERROR error;
+            auto length = stream_.read_some(ASIO::buffer(buffer.data(), 
buffer.size()), error);
+            if (error == ASIO::error::eof) return false;
+            if (error) return false;
+            buffer_.append(buffer.data(), length);
+        }
+        out.assign(buffer_.data() + bufferPos_, size);
+        bufferPos_ += size;
+        return true;
+    }
+
    private:
-    ASIO::ip::tcp::socket& socket_;
+    Stream& stream_;
     std::string buffer_;
     size_t bufferPos_{0};
 };
 
+namespace testAthenz {
+std::string principalToken;
+
 void mockZTS(Latch& latch, int port) {
     LOG_INFO("-- MockZTS started");
     ASIO::io_context io;
@@ -380,7 +404,7 @@ void mockZTS(Latch& latch, int port) {
     LOG_INFO("-- MockZTS got connection");
 
     std::string headerLine;
-    SocketStream stream(socket);
+    SocketStream<ASIO::ip::tcp::socket> stream(socket);
     while (stream.getline(headerLine)) {
         if (headerLine.empty()) {
             continue;
@@ -518,6 +542,142 @@ TEST(AuthPluginTest, testAuthFactoryAthenz) {
     }
 }
 
+namespace testOauth2Tls {
+static const auto mockServerTimeout = std::chrono::seconds(10);
+
+class MockOauth2Server {
+   public:
+    MockOauth2Server(const std::string& responseBody, const std::string& 
responseContentType, int listenPort,
+                     bool requireClientCert = true)
+        : responseBody_(responseBody),
+          responseContentType_(responseContentType),
+          acceptor_(io_, ASIO::ip::tcp::endpoint(ASIO::ip::tcp::v4(), 
static_cast<uint16_t>(listenPort))),
+          sslCtx_(ASIO::ssl::context::sslv23) {
+        sslCtx_.set_options(ASIO::ssl::context::default_workarounds | 
ASIO::ssl::context::no_sslv2 |
+                            ASIO::ssl::context::no_sslv3);
+        sslCtx_.use_certificate_chain_file(brokerPublicKeyPath);
+        sslCtx_.use_private_key_file(brokerPrivateKeyPath, 
ASIO::ssl::context::pem);
+        sslCtx_.load_verify_file(caPath);
+        sslCtx_.set_verify_mode(requireClientCert
+                                    ? (ASIO::ssl::verify_peer | 
ASIO::ssl::verify_fail_if_no_peer_cert)
+                                    : ASIO::ssl::verify_none);
+    }
+
+    const std::string& request() const { return request_; }
+
+    bool mockServe() {
+        ASIO_ERROR error;
+        auto socket = std::make_shared<ASIO::ip::tcp::socket>(io_);
+        {
+            std::lock_guard<std::mutex> lock(mutex_);
+            activeSocket_ = socket;
+        }
+
+        acceptor_.accept(*socket, error);
+        if (error) {
+            clearActiveSocket();
+            return false;
+        }
+
+        ASIO::ssl::stream<ASIO::ip::tcp::socket&> sslStream(*socket, sslCtx_);
+        sslStream.handshake(ASIO::ssl::stream_base::server, error);
+        if (error || !readRequest(sslStream)) {
+            clearActiveSocket();
+            return false;
+        }
+
+        const std::string response = "HTTP/1.1 200 OK\r\nContent-Type: " + 
responseContentType_ +
+                                     "\r\nContent-Length: " + 
std::to_string(responseBody_.size()) +
+                                     "\r\nConnection: close\r\n\r\n" + 
responseBody_;
+        ASIO::write(sslStream, ASIO::buffer(response.data(), response.size()), 
error);
+        clearActiveSocket();
+        if (error) return false;
+        return true;
+    }
+
+    void stop() {
+        ASIO_ERROR error;
+        {
+            std::lock_guard<std::mutex> lock(mutex_);
+            if (acceptor_.is_open()) {
+                acceptor_.close(error);
+            }
+            if (activeSocket_ && activeSocket_->is_open()) {
+                activeSocket_->cancel(error);
+                activeSocket_->shutdown(ASIO::ip::tcp::socket::shutdown_both, 
error);
+                activeSocket_->close(error);
+            }
+        }
+        io_.stop();
+    }
+
+   private:
+    void clearActiveSocket() {
+        std::lock_guard<std::mutex> lock(mutex_);
+        activeSocket_.reset();
+    }
+
+    bool readRequest(ASIO::ssl::stream<ASIO::ip::tcp::socket&>& sslStream) {
+        SocketStream<ASIO::ssl::stream<ASIO::ip::tcp::socket&>> 
stream(sslStream);
+        request_.clear();
+        int contentLength = 0;
+        const std::string prefix = "Content-Length:";
+        std::string headerLine;
+        while (stream.getline(headerLine)) {
+            if (headerLine.empty()) {
+                continue;
+            }
+            request_.append(headerLine).append("\n");
+            if (headerLine.rfind(prefix, 0) == 0) {
+                contentLength = std::stoi(headerLine.substr(prefix.size()));
+            }
+            if (headerLine == "\r") {
+                break;
+            }
+        }
+        if (headerLine != "\r") return false;
+
+        if (contentLength > 0) {
+            std::string body;
+            if (!stream.readBytes(static_cast<size_t>(contentLength), body)) 
return false;
+            request_ += body;
+        }
+        return true;
+    }
+
+    const std::string responseBody_;
+    const std::string responseContentType_;
+    std::string request_;
+
+    ASIO::io_context io_;
+    ASIO::ip::tcp::acceptor acceptor_;
+    ASIO::ssl::context sslCtx_;
+    std::shared_ptr<ASIO::ip::tcp::socket> activeSocket_;
+    std::mutex mutex_;
+};
+
+static bool awaitMockServeResult(std::future<bool>& future, MockOauth2Server& 
server, std::thread& thread,
+                                 const char* serverName) {
+    if (future.wait_for(mockServerTimeout) != std::future_status::ready) {
+        server.stop();
+        if (thread.joinable()) {
+            thread.join();
+        }
+        ADD_FAILURE() << serverName << " did not complete within "
+                      << 
std::chrono::duration_cast<std::chrono::seconds>(mockServerTimeout).count()
+                      << " seconds";
+        return false;
+    }
+
+    const bool result = future.get();
+    if (thread.joinable()) {
+        thread.join();
+    }
+    return result;
+}
+
+}  // namespace testOauth2Tls
+
 TEST(AuthPluginTest, testOauth2) {
     // test success get token from oauth2 server.
     pulsar::AuthenticationDataPtr data;
@@ -584,11 +744,15 @@ TEST(AuthPluginTest, testOauth2RequestBody) {
     params["client_id"] = "Xd23RHsUnvUlP7wchjNYOaIfazgeHd9x";
     params["client_secret"] = 
"rT7ps7WY8uhdVuBTKWZkttwLdQotmdEliaM5rLfmgNibvqziZ-g07ZH52N_poGAb";
     params["audience"] = "https://dev-kt-aa9ne.us.auth0.com/api/v2/";;
+    params["tls_cert_file"] = "/path/to/cert.pem";
+    params["tls_key_file"] = "/path/to/key.pem";
 
     auto createExpectedResult = [&] {
         auto paramsCopy = params;
         paramsCopy.emplace("grant_type", "client_credentials");
         paramsCopy.erase("issuer_url");
+        paramsCopy.erase("tls_cert_file");
+        paramsCopy.erase("tls_key_file");
         return paramsCopy;
     };
 
@@ -668,6 +832,210 @@ TEST(AuthPluginTest, testOauth2Failure) {
     client5.close();
 }
 
+TEST(AuthPluginTest, testOauth2TlsClientAuth) {
+    const int tokenServerPort = 58081;
+    const int wellKnownServerPort = 58082;
+    const std::string tokenBody = 
R"({"access_token":"mockToken","expires_in":3600,"token_type":"Bearer"})";
+    std::unique_ptr<testOauth2Tls::MockOauth2Server> tokenServer;
+    try {
+        tokenServer =
+            std::make_unique<testOauth2Tls::MockOauth2Server>(tokenBody, 
"application/json", tokenServerPort);
+    } catch (const std::exception& e) {
+        FAIL() << "Failed to bind local mock token server: " << e.what();
+    }
+
+    std::promise<bool> tokenPromise;
+    auto tokenFuture = tokenPromise.get_future();
+    std::thread tokenThread(
+        [&tokenServer, &tokenPromise]() { 
tokenPromise.set_value(tokenServer->mockServe()); });
+
+    std::ostringstream wellKnownBody;
+    wellKnownBody << R"({"token_endpoint":"https://localhost:)" << 
tokenServerPort << R"(/oauth/token"})";
+    std::unique_ptr<testOauth2Tls::MockOauth2Server> wellKnownServer;
+    try {
+        wellKnownServer = std::make_unique<testOauth2Tls::MockOauth2Server>(
+            wellKnownBody.str(), "application/json", wellKnownServerPort, 
false);
+    } catch (const std::exception& e) {
+        tokenThread.join();
+        FAIL() << "Failed to bind local mock well-known server: " << e.what();
+    }
+
+    std::promise<bool> wellKnownPromise;
+    auto wellKnownFuture = wellKnownPromise.get_future();
+    std::thread wellKnownThread([&wellKnownServer, &wellKnownPromise]() {
+        wellKnownPromise.set_value(wellKnownServer->mockServe());
+    });
+
+    ParamMap params;
+    params["tokenEndpointAuthMethod"] = "tls_client_auth";
+    params["issuer_url"] = "https://localhost:"; + 
std::to_string(wellKnownServerPort);
+    params["client_id"] = "test-client";
+    params["tls_cert_file"] = clientPublicKeyPath;
+    params["tls_key_file"] = clientPrivateKeyPath;
+
+    AuthenticationDataPtr data =
+        
std::static_pointer_cast<AuthenticationDataProvider>(std::make_shared<InitialAuthData>(caPath));
+    AuthenticationPtr auth = AuthOauth2::create(params);
+    ASSERT_EQ(auth->getAuthData(data), ResultOk);
+    ASSERT_TRUE(data->hasDataFromCommand());
+    ASSERT_EQ(data->getCommandData(), "mockToken");
+
+    ASSERT_TRUE(testOauth2Tls::awaitMockServeResult(wellKnownFuture, 
*wellKnownServer, wellKnownThread,
+                                                    "Well-known mock server"));
+    ASSERT_TRUE(
+        testOauth2Tls::awaitMockServeResult(tokenFuture, *tokenServer, 
tokenThread, "Token mock server"));
+    ASSERT_NE(wellKnownServer->request().find("GET 
/.well-known/openid-configuration "), std::string::npos);
+    ASSERT_NE(tokenServer->request().find("POST /oauth/token "), 
std::string::npos);
+    ASSERT_NE(tokenServer->request().find("grant_type=client_credentials"), 
std::string::npos);
+}
+
+TEST(AuthPluginTest, testOauth2TlsClientAuthWrongCert) {
+    const int tokenServerPort = 58083;
+    const int wellKnownServerPort = 58084;
+    const std::string tokenBody = 
R"({"access_token":"mockToken","expires_in":3600,"token_type":"Bearer"})";
+
+    std::unique_ptr<testOauth2Tls::MockOauth2Server> tokenServer;
+    try {
+        tokenServer =
+            std::make_unique<testOauth2Tls::MockOauth2Server>(tokenBody, 
"application/json", tokenServerPort);
+    } catch (const std::exception& e) {
+        FAIL() << "Failed to bind local mock token server: " << e.what();
+    }
+
+    std::promise<bool> tokenPromise;
+    auto tokenFuture = tokenPromise.get_future();
+    std::thread tokenThread(
+        [&tokenServer, &tokenPromise]() { 
tokenPromise.set_value(tokenServer->mockServe()); });
+
+    std::ostringstream wellKnownBody;
+    wellKnownBody << R"({"token_endpoint":"https://localhost:)" << 
tokenServerPort << R"(/oauth/token"})";
+    std::unique_ptr<testOauth2Tls::MockOauth2Server> wellKnownServer;
+    try {
+        wellKnownServer = std::make_unique<testOauth2Tls::MockOauth2Server>(
+            wellKnownBody.str(), "application/json", wellKnownServerPort, 
false);
+    } catch (const std::exception& e) {
+        tokenThread.join();
+        FAIL() << "Failed to bind local mock well-known server: " << e.what();
+    }
+
+    std::promise<bool> wellKnownPromise;
+    auto wellKnownFuture = wellKnownPromise.get_future();
+    std::thread wellKnownThread([&wellKnownServer, &wellKnownPromise]() {
+        wellKnownPromise.set_value(wellKnownServer->mockServe());
+    });
+
+    ParamMap params;
+    params["tokenEndpointAuthMethod"] = "tls_client_auth";
+    params["issuer_url"] = "https://localhost:"; + 
std::to_string(wellKnownServerPort);
+    params["client_id"] = "test-client";
+    // set wrong cert and key
+    params["tls_cert_file"] = TEST_CONF_DIR "/hn-verification/broker-cert.pem";
+    params["tls_key_file"] = TEST_CONF_DIR "/hn-verification/broker-key.pem";
+
+    AuthenticationDataPtr data =
+        
std::static_pointer_cast<AuthenticationDataProvider>(std::make_shared<InitialAuthData>(caPath));
+    AuthenticationPtr auth = AuthOauth2::create(params);
+    ASSERT_EQ(auth->getAuthData(data), ResultAuthenticationError);
+
+    ASSERT_TRUE(testOauth2Tls::awaitMockServeResult(wellKnownFuture, 
*wellKnownServer, wellKnownThread,
+                                                    "Well-known mock server"));
+    ASSERT_FALSE(
+        testOauth2Tls::awaitMockServeResult(tokenFuture, *tokenServer, 
tokenThread, "Token mock server"));
+    ASSERT_NE(wellKnownServer->request().find("GET 
/.well-known/openid-configuration "), std::string::npos);
+}
+
+TEST(AuthPluginTest, testOauth2TlsClientAuthRequestBody) {
+    ParamMap params;
+    params["tokenEndpointAuthMethod"] = "tls_client_auth";
+    params["issuer_url"] = "https://dev-kt-aa9ne.us.auth0.com";;
+    params["client_id"] = "Xd23RHsUnvUlP7wchjNYOaIfazgeHd9x";
+    params["audience"] = "https://dev-kt-aa9ne.us.auth0.com/api/v2/";;
+    params["tls_cert_file"] = "/path/to/cert.pem";
+    params["tls_key_file"] = "/path/to/key.pem";
+
+    auto createExpectedResult = [&] {
+        auto paramsCopy = params;
+        paramsCopy.emplace("grant_type", "client_credentials");
+        paramsCopy.erase("tokenEndpointAuthMethod");
+        paramsCopy.erase("issuer_url");
+        paramsCopy.erase("tls_cert_file");
+        paramsCopy.erase("tls_key_file");
+        return paramsCopy;
+    };
+
+    const auto expectedResult1 = createExpectedResult();
+    TlsClientAuthFlow flow1(params);
+    ASSERT_EQ(flow1.generateParamMap(), expectedResult1);
+
+    params["scope"] = "test-scope";
+    const auto expectedResult2 = createExpectedResult();
+    TlsClientAuthFlow flow2(params);
+    ASSERT_EQ(flow2.generateParamMap(), expectedResult2);
+
+    params.erase("client_id");
+    auto expectedResult3 = expectedResult2;
+    expectedResult3["client_id"] = TlsClientAuthFlow::DEFAULT_CLIENT_ID;
+    TlsClientAuthFlow flow3(params);
+    ASSERT_EQ(flow3.generateParamMap(), expectedResult3);
+
+    params.erase("audience");
+    auto expectedResult4 = expectedResult3;
+    expectedResult4.erase("audience");
+    TlsClientAuthFlow flow4(params);
+    ASSERT_EQ(flow4.generateParamMap(), expectedResult4);
+}
+
+TEST(AuthPluginTest, testOauth2TlsClientAuthFailure) {
+    ParamMap params;
+    auto getAuthDataResult = [&]() -> Result {
+        AuthenticationDataPtr data =
+            
std::static_pointer_cast<AuthenticationDataProvider>(std::make_shared<InitialAuthData>(caPath));
+        AuthenticationPtr auth = AuthOauth2::create(params);
+        return auth->getAuthData(data);
+    };
+
+    params["tokenEndpointAuthMethod"] = "tls_client_auth";
+    params["tls_cert_file"] = clientPublicKeyPath;
+    params["tls_key_file"] = clientPrivateKeyPath;
+
+    // No issuer_url
+    params.erase("issuer_url");
+    ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError);
+
+    // Invalid issuer_url
+    params["issuer_url"] = "hello";
+    ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError);
+
+    // No cert and key
+    params["issuer_url"] = "https://localhost:58086";;
+    params.erase("tls_cert_file");
+    params.erase("tls_key_file");
+    ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError);
+
+    // Only cert
+    params["tls_cert_file"] = clientPublicKeyPath;
+    params.erase("tls_key_file");
+    ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError);
+
+    // Invalid cert and key
+    params["tls_cert_file"] = TEST_CONF_DIR "/not-exist-cert.pem";
+    params["tls_key_file"] = TEST_CONF_DIR "/not-exist-key.pem";
+    ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError);
+}
+
+TEST(AuthPluginTest, testOauth2UnknownTokenEndpointAuthMethod) {
+    std::string params = R"({
+        "type": "client_credentials",
+        "tokenEndpointAuthMethod": "client_secret_get",
+        "issuer_url": "https://dev-kt-aa9ne.us.auth0.com";,
+        "client_id": "Xd23RHsUnvUlP7wchjNYOaIfazgeHd9x",
+        "client_secret": 
"rT7ps7WY8uhdVuBTKWZkttwLdQotmdEliaM5rLfmgNibvqziZ-g07ZH52N_poGAb",
+        "audience": "https://dev-kt-aa9ne.us.auth0.com/api/v2/"})";
+
+    LOG_INFO("PARAMS: " << params);
+    ASSERT_THROW(AuthOauth2::create(params), std::invalid_argument);
+}
+
 TEST(AuthPluginTest, testInvalidPlugin) {
     Client client("pulsar://localhost:6650", 
ClientConfiguration{}.setAuth(AuthFactory::create("invalid")));
     Producer producer;


Reply via email to