kirktrue commented on a change in pull request #11284:
URL: https://github.com/apache/kafka/pull/11284#discussion_r733887473



##########
File path: 
clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetriever.java
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.security.oauthbearer.secured;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.UnsupportedEncodingException;
+import java.net.HttpURLConnection;
+import java.net.URL;
+import java.net.URLEncoder;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import javax.net.ssl.HttpsURLConnection;
+import javax.net.ssl.SSLSocketFactory;
+import org.apache.kafka.common.config.SaslConfigs;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <code>HttpAccessTokenRetriever</code> is an {@link AccessTokenRetriever} 
that will
+ * communicate with an OAuth/OIDC provider directly via HTTP to post client 
credentials
+ * ({@link OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG}/{@link 
OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG})
+ * to a publicized token endpoint URL
+ * ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URI}).
+ *
+ * @see AccessTokenRetriever
+ * @see OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG
+ * @see OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG
+ * @see OAuthBearerLoginCallbackHandler#SCOPE_CONFIG
+ * @see SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URI
+ */
+
+public class HttpAccessTokenRetriever implements AccessTokenRetriever {
+
+    private static final Logger log = 
LoggerFactory.getLogger(HttpAccessTokenRetriever.class);
+
+    private static final Set<Integer> UNRETRYABLE_HTTP_CODES;
+
+    public static final String AUTHORIZATION_HEADER = "Authorization";
+
+    static {
+        // This does not have to be an exhaustive list. There are other HTTP 
codes that
+        // are defined in different RFCs (e.g. 
https://datatracker.ietf.org/doc/html/rfc6585)
+        // that we won't worry about yet. The worst case if a status code is 
missing from
+        // this set is that the request will be retried.
+        UNRETRYABLE_HTTP_CODES = new HashSet<>();
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_BAD_REQUEST);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_UNAUTHORIZED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PAYMENT_REQUIRED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_FORBIDDEN);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_FOUND);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_BAD_METHOD);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_ACCEPTABLE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PROXY_AUTH);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_CONFLICT);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_GONE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_LENGTH_REQUIRED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PRECON_FAILED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_ENTITY_TOO_LARGE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_REQ_TOO_LONG);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_UNSUPPORTED_TYPE);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_IMPLEMENTED);
+        UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_VERSION);
+    }
+
+    private final String clientId;
+
+    private final String clientSecret;
+
+    private final String scope;
+
+    private final SSLSocketFactory sslSocketFactory;
+
+    private final String tokenEndpointUri;
+
+    private final long loginRetryBackoffMs;
+
+    private final long loginRetryBackoffMaxMs;
+
+    private final Integer loginConnectTimeoutMs;
+
+    private final Integer loginReadTimeoutMs;
+
+    public HttpAccessTokenRetriever(String clientId,
+        String clientSecret,
+        String scope,
+        SSLSocketFactory sslSocketFactory,
+        String tokenEndpointUri,
+        long loginRetryBackoffMs,
+        long loginRetryBackoffMaxMs,
+        Integer loginConnectTimeoutMs,
+        Integer loginReadTimeoutMs) {
+        this.clientId = Objects.requireNonNull(clientId);
+        this.clientSecret = Objects.requireNonNull(clientSecret);
+        this.scope = scope;
+        this.sslSocketFactory = sslSocketFactory;
+        this.tokenEndpointUri = Objects.requireNonNull(tokenEndpointUri);
+        this.loginRetryBackoffMs = loginRetryBackoffMs;
+        this.loginRetryBackoffMaxMs = loginRetryBackoffMaxMs;
+        this.loginConnectTimeoutMs = loginConnectTimeoutMs;
+        this.loginReadTimeoutMs = loginReadTimeoutMs;
+    }
+
+    /**
+     * Retrieves a JWT access token in its serialized three-part form. The 
implementation
+     * is free to determine how it should be retrieved but should not perform 
validation
+     * on the result.
+     *
+     * <b>Note</b>: This is a blocking function and callers should be aware 
that the
+     * implementation communicates over a network. The facility in the
+     * {@link javax.security.auth.spi.LoginModule} from which this is 
ultimately called
+     * does not provide an asynchronous approach.
+     *
+     * @return Non-<code>null</code> JWT access token string
+     *
+     * @throws IOException Thrown on errors related to IO during retrieval
+     */
+
+    @Override
+    public String retrieve() throws IOException {
+        String authorizationHeader = formatAuthorizationHeader(clientId, 
clientSecret);
+        String requestBody = formatRequestBody(scope);
+
+        Retry<String> retry = new Retry<>(Time.SYSTEM,
+            loginRetryBackoffMs,
+            loginRetryBackoffMaxMs);
+
+        Map<String, String> headers = 
Collections.singletonMap(AUTHORIZATION_HEADER, authorizationHeader);
+
+        String responseBody = retry.execute(() -> {
+            HttpURLConnection con = (HttpURLConnection) new 
URL(tokenEndpointUri).openConnection();
+
+            if (sslSocketFactory != null && con instanceof HttpsURLConnection)
+                ((HttpsURLConnection) 
con).setSSLSocketFactory(sslSocketFactory);
+
+            try {
+                return post(con, headers, requestBody, loginConnectTimeoutMs, 
loginReadTimeoutMs);
+            } finally {
+                con.disconnect();
+            }
+        });
+        log.debug("retrieve - responseBody: {}", responseBody);
+
+        return parseAccessToken(responseBody);
+    }
+
+    public static String post(HttpURLConnection con,
+        Map<String, String> headers,
+        String requestBody,
+        Integer connectTimeoutMs,
+        Integer readTimeoutMs)
+        throws IOException, UnretryableException {
+        handleInput(con, headers, requestBody, connectTimeoutMs, 
readTimeoutMs);
+        return handleOutput(con);
+    }
+
+    private static void handleInput(HttpURLConnection con,
+        Map<String, String> headers,
+        String requestBody,
+        Integer connectTimeoutMs,
+        Integer readTimeoutMs)
+        throws IOException, UnretryableException {
+        log.debug("handleInput - starting post for {}", con.getURL());
+        con.setRequestMethod("POST");
+        con.setRequestProperty("Accept", "application/json");
+
+        if (headers != null) {
+            for (Map.Entry<String, String> header : headers.entrySet())
+                con.setRequestProperty(header.getKey(), header.getValue());
+        }
+
+        con.setRequestProperty("Cache-Control", "no-cache");
+
+        if (requestBody != null) {
+            con.setRequestProperty("Content-Length", 
String.valueOf(requestBody.length()));
+            con.setDoOutput(true);
+        }
+
+        con.setUseCaches(false);
+
+        if (connectTimeoutMs != null)
+            con.setConnectTimeout(connectTimeoutMs);
+
+        if (readTimeoutMs != null)
+            con.setReadTimeout(readTimeoutMs);
+
+        log.debug("handleInput - preparing to connect to {}", con.getURL());
+        con.connect();
+
+        if (requestBody != null) {
+            try (OutputStream os = con.getOutputStream()) {
+                ByteArrayInputStream is = new 
ByteArrayInputStream(requestBody.getBytes(
+                    StandardCharsets.UTF_8));
+                log.debug("handleInput - preparing to write request body to 
{}", con.getURL());
+                copy(is, os);
+            }
+        }
+    }
+
+    static String handleOutput(final HttpURLConnection con) throws IOException 
{
+        int responseCode = con.getResponseCode();
+        log.debug("handleOutput - responseCode: {}", responseCode);
+
+        String responseBody = null;
+
+        try (InputStream is = con.getInputStream()) {
+            ByteArrayOutputStream os = new ByteArrayOutputStream();
+            log.debug("handleOutput - preparing to read response body from 
{}", con.getURL());
+            copy(is, os);
+            responseBody = os.toString(StandardCharsets.UTF_8.name());
+        } catch (Exception e) {
+            log.warn("handleOutput - error retrieving data", e);
+        }
+
+        if (responseCode == HttpURLConnection.HTTP_OK || responseCode == 
HttpURLConnection.HTTP_CREATED) {
+            if (responseBody == null || responseBody.isEmpty())
+                throw new IOException(String.format("The token endpoint 
response was unexpectedly empty despite response code %s from %s", 
responseCode, con.getURL()));
+
+            log.debug("handleOutput - responseCode: {}, response: {}", 
responseCode, responseBody);
+
+            return responseBody;
+        } else {
+            log.warn("handleOutput - error response code: {}, error response 
body: {}", responseCode, responseBody);
+
+            if (UNRETRYABLE_HTTP_CODES.contains(responseCode)) {
+                // We know that this is a non-transient error, so let's not 
keep retrying the
+                // request unnecessarily.
+                throw new UnretryableException(new 
IOException(String.format("The response code %s was encountered reading the 
token endpoint response; will not attempt further retries", responseCode)));
+            } else {
+                // We don't know if this is a transient (retryable) error or 
not, so let's assume
+                // it is.
+                throw new IOException(String.format("The unexpected response 
code %s was encountered reading the token endpoint response", responseCode));
+            }
+        }
+    }
+
+    static void copy(InputStream is, OutputStream os) throws IOException {
+        byte[] buf = new byte[4096];
+        int b;
+
+        while ((b = is.read(buf)) != -1)
+            os.write(buf, 0, b);
+    }
+
+    static String parseAccessToken(String responseBody) throws IOException {
+        ObjectMapper mapper = new ObjectMapper();
+        JsonNode rootNode = mapper.readTree(responseBody);
+        JsonNode accessTokenNode = rootNode.at("/access_token");
+
+        if (accessTokenNode == null)
+            throw new IOException("The token endpoint response did not contain 
an access_token value");
+
+        return sanitizeString("The token endpoint response access_token", 
accessTokenNode.textValue());
+    }
+
+    static String formatAuthorizationHeader(String clientId, String 
clientSecret) throws IOException {
+        clientId = sanitizeString("The token endpoint request clientId", 
clientId);
+        clientSecret = sanitizeString("The token endpoint request clientId", 
clientSecret);
+
+        String s = String.format("%s:%s", clientId, clientSecret);
+        String encoded = Base64.getUrlEncoder().encodeToString(Utils.utf8(s));
+        return String.format("Basic %s", encoded);
+    }
+
+    static String formatRequestBody(String scope) throws IOException {
+        try {
+            StringBuilder requestParameters = new StringBuilder();
+            requestParameters.append("grant_type=client_credentials");
+
+            if (scope != null && !scope.trim().isEmpty()) {
+                scope = scope.trim();
+                String encodedScope = URLEncoder.encode(scope, 
StandardCharsets.UTF_8.name());
+                requestParameters.append("&scope=").append(encodedScope);
+            }
+
+            return requestParameters.toString();
+        } catch (UnsupportedEncodingException e) {
+            // The world has gone crazy!
+            throw new IOException(String.format("Encoding %s not supported", 
StandardCharsets.UTF_8.name()));
+        }
+    }
+
+    private static String sanitizeString(String name, String value) throws 
IOException {
+        if (value == null)
+            throw new IOException(String.format("%s value must be non-null", 
name));
+
+        if (value.isEmpty())
+            throw new IOException(String.format("%s value must be non-empty", 
name));

Review comment:
       It is correct as written, though obviously confusing.
   
   I changed the format string to:
   
   ```
   The value for %s must be non-null
   ```
   
   e.g.:
   
   ```
   The value for the token endpoint request client ID parameter must be non-null
   The value for the token endpoint response's access_token JSON attribute must 
not contain only whitespace
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to