kirktrue commented on a change in pull request #11284:
URL: https://github.com/apache/kafka/pull/11284#discussion_r733882914



##########
File path: 
clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetriever.java
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.security.oauthbearer.secured;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.UnsupportedEncodingException;
+import java.net.HttpURLConnection;
+import java.net.URL;
+import java.net.URLEncoder;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import javax.net.ssl.HttpsURLConnection;
+import javax.net.ssl.SSLSocketFactory;
+import org.apache.kafka.common.config.SaslConfigs;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <code>HttpAccessTokenRetriever</code> is an {@link AccessTokenRetriever} 
that will
+ * communicate with an OAuth/OIDC provider directly via HTTP to post client 
credentials
+ * ({@link OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG}/{@link 
OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG})
+ * to a publicized token endpoint URL
+ * ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URI}).
+ *
+ * @see AccessTokenRetriever
+ * @see OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG
+ * @see OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG
+ * @see OAuthBearerLoginCallbackHandler#SCOPE_CONFIG
+ * @see SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URI
+ */
+
+public class HttpAccessTokenRetriever implements AccessTokenRetriever {
+
+    private static final Logger log = 
LoggerFactory.getLogger(HttpAccessTokenRetriever.class);
+
+    private static final Set<Integer> UNRETRYABLE_HTTP_CODES;
+
+    public static final String AUTHORIZATION_HEADER = "Authorization";
+
+    static {
+        // This does not have to be an exhaustive list. There are other HTTP 
codes that
+        // are defined in different RFCs (e.g. 
https://datatracker.ietf.org/doc/html/rfc6585)
+        // that we won't worry about yet. The worst case if a status code is 
missing from
+        // this set is that the request will be retried.
+        UNRETRYABLE_HTTP_CODES = new HashSet<>();
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_BAD_REQUEST);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_UNAUTHORIZED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PAYMENT_REQUIRED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_FORBIDDEN);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_FOUND);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_BAD_METHOD);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_ACCEPTABLE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PROXY_AUTH);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_CONFLICT);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_GONE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_LENGTH_REQUIRED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PRECON_FAILED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_ENTITY_TOO_LARGE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_REQ_TOO_LONG);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_UNSUPPORTED_TYPE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_IMPLEMENTED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_VERSION);
+    }
+
+    private final String clientId;
+
+    private final String clientSecret;
+
+    private final String scope;
+
+    private final SSLSocketFactory sslSocketFactory;
+
+    private final String tokenEndpointUri;
+
+    private final long loginRetryBackoffMs;
+
+    private final long loginRetryBackoffMaxMs;
+
+    private final Integer loginConnectTimeoutMs;
+
+    private final Integer loginReadTimeoutMs;
+
+    public HttpAccessTokenRetriever(String clientId,
+        String clientSecret,
+        String scope,
+        SSLSocketFactory sslSocketFactory,
+        String tokenEndpointUri,
+        long loginRetryBackoffMs,
+        long loginRetryBackoffMaxMs,
+        Integer loginConnectTimeoutMs,
+        Integer loginReadTimeoutMs) {
+        this.clientId = Objects.requireNonNull(clientId);
+        this.clientSecret = Objects.requireNonNull(clientSecret);
+        this.scope = scope;
+        this.sslSocketFactory = sslSocketFactory;
+        this.tokenEndpointUri = Objects.requireNonNull(tokenEndpointUri);
+        this.loginRetryBackoffMs = loginRetryBackoffMs;
+        this.loginRetryBackoffMaxMs = loginRetryBackoffMaxMs;
+        this.loginConnectTimeoutMs = loginConnectTimeoutMs;
+        this.loginReadTimeoutMs = loginReadTimeoutMs;
+    }
+
+    /**
+     * Retrieves a JWT access token in its serialized three-part form. The 
implementation
+     * is free to determine how it should be retrieved but should not perform 
validation
+     * on the result.
+     *
+     * <b>Note</b>: This is a blocking function and callers should be aware 
that the
+     * implementation communicates over a network. The facility in the
+     * {@link javax.security.auth.spi.LoginModule} from which this is 
ultimately called
+     * does not provide an asynchronous approach.
+     *
+     * @return Non-<code>null</code> JWT access token string
+     *
+     * @throws IOException Thrown on errors related to IO during retrieval
+     */
+
+    @Override
+    public String retrieve() throws IOException {
+        String authorizationHeader = formatAuthorizationHeader(clientId, 
clientSecret);
+        String requestBody = formatRequestBody(scope);
+
+        Retry<String> retry = new Retry<>(Time.SYSTEM,
+            loginRetryBackoffMs,
+            loginRetryBackoffMaxMs);
+
+        Map<String, String> headers = 
Collections.singletonMap(AUTHORIZATION_HEADER, authorizationHeader);
+
+        String responseBody = retry.execute(() -> {
+            HttpURLConnection con = (HttpURLConnection) new 
URL(tokenEndpointUri).openConnection();
+
+            if (sslSocketFactory != null && con instanceof HttpsURLConnection)
+                ((HttpsURLConnection) 
con).setSSLSocketFactory(sslSocketFactory);
+
+            try {
+                return post(con, headers, requestBody, loginConnectTimeoutMs, 
loginReadTimeoutMs);
+            } finally {
+                con.disconnect();
+            }
+        });
+        log.debug("retrieve - responseBody: {}", responseBody);
+
+        return parseAccessToken(responseBody);
+    }
+
+    public static String post(HttpURLConnection con,
+        Map<String, String> headers,
+        String requestBody,
+        Integer connectTimeoutMs,
+        Integer readTimeoutMs)
+        throws IOException, UnretryableException {
+        handleInput(con, headers, requestBody, connectTimeoutMs, 
readTimeoutMs);
+        return handleOutput(con);
+    }
+
+    private static void handleInput(HttpURLConnection con,
+        Map<String, String> headers,
+        String requestBody,
+        Integer connectTimeoutMs,
+        Integer readTimeoutMs)
+        throws IOException, UnretryableException {
+        log.debug("handleInput - starting post for {}", con.getURL());
+        con.setRequestMethod("POST");
+        con.setRequestProperty("Accept", "application/json");
+
+        if (headers != null) {
+            for (Map.Entry<String, String> header : headers.entrySet())
+                con.setRequestProperty(header.getKey(), header.getValue());
+        }
+
+        con.setRequestProperty("Cache-Control", "no-cache");
+
+        if (requestBody != null) {
+            con.setRequestProperty("Content-Length", 
String.valueOf(requestBody.length()));
+            con.setDoOutput(true);
+        }
+
+        con.setUseCaches(false);
+
+        if (connectTimeoutMs != null)
+            con.setConnectTimeout(connectTimeoutMs);
+
+        if (readTimeoutMs != null)
+            con.setReadTimeout(readTimeoutMs);
+
+        log.debug("handleInput - preparing to connect to {}", con.getURL());
+        con.connect();
+
+        if (requestBody != null) {
+            try (OutputStream os = con.getOutputStream()) {
+                ByteArrayInputStream is = new 
ByteArrayInputStream(requestBody.getBytes(
+                    StandardCharsets.UTF_8));
+                log.debug("handleInput - preparing to write request body to 
{}", con.getURL());
+                copy(is, os);
+            }
+        }
+    }
+
+    static String handleOutput(final HttpURLConnection con) throws IOException 
{
+        int responseCode = con.getResponseCode();
+        log.debug("handleOutput - responseCode: {}", responseCode);
+
+        String responseBody = null;
+
+        try (InputStream is = con.getInputStream()) {
+            ByteArrayOutputStream os = new ByteArrayOutputStream();
+            log.debug("handleOutput - preparing to read response body from 
{}", con.getURL());
+            copy(is, os);
+            responseBody = os.toString(StandardCharsets.UTF_8.name());
+        } catch (Exception e) {
+            log.warn("handleOutput - error retrieving data", e);
+        }
+
+        if (responseCode == HttpURLConnection.HTTP_OK || responseCode == 
HttpURLConnection.HTTP_CREATED) {
+            if (responseBody == null || responseBody.isEmpty())
+                throw new IOException(String.format("The token endpoint 
response was unexpectedly empty despite response code %s from %s", 
responseCode, con.getURL()));
+
+            log.debug("handleOutput - responseCode: {}, response: {}", 
responseCode, responseBody);
+
+            return responseBody;
+        } else {
+            log.warn("handleOutput - error response code: {}, error response 
body: {}", responseCode, responseBody);
+
+            if (UNRETRYABLE_HTTP_CODES.contains(responseCode)) {
+                // We know that this is a non-transient error, so let's not 
keep retrying the
+                // request unnecessarily.
+                throw new UnretryableException(new 
IOException(String.format("The response code %s was encountered reading the 
token endpoint response; will not attempt further retries", responseCode)));
+            } else {
+                // We don't know if this is a transient (retryable) error or 
not, so let's assume
+                // it is.
+                throw new IOException(String.format("The unexpected response 
code %s was encountered reading the token endpoint response", responseCode));
+            }
+        }
+    }
+
+    static void copy(InputStream is, OutputStream os) throws IOException {
+        byte[] buf = new byte[4096];
+        int b;
+
+        while ((b = is.read(buf)) != -1)
+            os.write(buf, 0, b);
+    }
+
+    static String parseAccessToken(String responseBody) throws IOException {
+        ObjectMapper mapper = new ObjectMapper();
+        JsonNode rootNode = mapper.readTree(responseBody);
+        JsonNode accessTokenNode = rootNode.at("/access_token");
+
+        if (accessTokenNode == null)
+            throw new IOException("The token endpoint response did not contain 
an access_token value");

Review comment:
       I was/am a _little_ leery of including the response verbatim, only 
because the response could be huge, especially if the URL provided in the 
configuration is incorrect.
   
   The code is now logging the entire response, but arbitrarily limiting the 
length to 1000 characters for the exception message.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to