This is an automated email from the ASF dual-hosted git repository.

exceptionfactory pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi.git


The following commit(s) were added to refs/heads/main by this push:
     new 9098c013f2 NIFI-12300 Add OAuth2 Support to RestLookupService (#8462)
9098c013f2 is described below

commit 9098c013f2dbce7a73b5562598a7da0cda808489
Author: Greg Foreman <gfore...@spinnerconsulting.com>
AuthorDate: Fri Apr 26 00:50:42 2024 -0400

    NIFI-12300 Add OAuth2 Support to RestLookupService (#8462)
    
    Signed-off-by: David Handermann <exceptionfact...@apache.org>
---
 .../nifi-lookup-services/pom.xml                   |  4 ++
 .../apache/nifi/lookup/AuthenticationStrategy.java | 51 +++++++++++++++++
 .../org/apache/nifi/lookup/RestLookupService.java  | 65 +++++++++++++++++++---
 .../apache/nifi/lookup/TestRestLookupService.java  | 32 +++++++++++
 4 files changed, 143 insertions(+), 9 deletions(-)

diff --git 
a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/pom.xml
 
b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/pom.xml
index 3926ae32a6..f4d43f0621 100644
--- 
a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/pom.xml
+++ 
b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/pom.xml
@@ -148,6 +148,10 @@
             <artifactId>nifi-schema-registry-service-api</artifactId>
             <scope>test</scope>
         </dependency>
+        <dependency>
+            <groupId>org.apache.nifi</groupId>
+            <artifactId>nifi-oauth2-provider-api</artifactId>
+        </dependency>
     </dependencies>
     <build>
         <plugins>
diff --git 
a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/AuthenticationStrategy.java
 
b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/AuthenticationStrategy.java
new file mode 100644
index 0000000000..854c47dd2c
--- /dev/null
+++ 
b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/AuthenticationStrategy.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nifi.lookup;
+
+import org.apache.nifi.components.DescribedValue;
+
+public enum AuthenticationStrategy implements DescribedValue {
+
+    NONE("None","No Authentication"),
+    BASIC("Basic", "Basic Authentication"),
+    OAUTH2("OAuth2", "OAuth2 Authentication");
+
+    private final String displayName;
+    private final String description;
+
+    AuthenticationStrategy(final String displayName, final String description) 
{
+        this.displayName = displayName;
+        this.description = description;
+    }
+
+    @Override
+    public String getValue() {
+        return name();
+    }
+
+    @Override
+    public String getDisplayName() {
+        return displayName;
+    }
+
+    @Override
+    public String getDescription() {
+        return description;
+    }
+
+}
diff --git 
a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/RestLookupService.java
 
b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/RestLookupService.java
index aa3d0864b2..a319b1c761 100644
--- 
a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/RestLookupService.java
+++ 
b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/RestLookupService.java
@@ -41,6 +41,8 @@ import org.apache.nifi.components.Validator;
 import org.apache.nifi.controller.AbstractControllerService;
 import org.apache.nifi.controller.ConfigurationContext;
 import org.apache.nifi.expression.ExpressionLanguageScope;
+import org.apache.nifi.migration.PropertyConfiguration;
+import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
 import org.apache.nifi.processor.util.StandardValidators;
 import org.apache.nifi.proxy.ProxyConfiguration;
 import org.apache.nifi.proxy.ProxyConfigurationService;
@@ -125,11 +127,30 @@ public class RestLookupService extends 
AbstractControllerService implements Reco
         .identifiesControllerService(SSLContextService.class)
         .build();
 
+    public static final PropertyDescriptor AUTHENTICATION_STRATEGY = new 
PropertyDescriptor.Builder()
+            .name("rest-lookup-authentication-strategy")
+            .displayName("Authentication Strategy")
+            .description("Authentication strategy to use with REST service.")
+            .required(true)
+            .allowableValues(AuthenticationStrategy.class)
+            .defaultValue(AuthenticationStrategy.NONE)
+            .build();
+
+    public static final PropertyDescriptor OAUTH2_ACCESS_TOKEN_PROVIDER = new 
PropertyDescriptor.Builder()
+        .name("rest-lookup-oauth2-access-token-provider")
+        .displayName("OAuth2 Access Token Provider")
+        .description("Enables managed retrieval of OAuth2 Bearer Token applied 
to HTTP requests using the Authorization Header.")
+        .identifiesControllerService(OAuth2AccessTokenProvider.class)
+        .required(true)
+        .dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.OAUTH2)
+        .build();
+
     public static final PropertyDescriptor PROP_BASIC_AUTH_USERNAME = new 
PropertyDescriptor.Builder()
         .name("rest-lookup-basic-auth-username")
         .displayName("Basic Authentication Username")
         .description("The username to be used by the client to authenticate 
against the Remote URL.  Cannot include control characters (0-31), ':', or DEL 
(127).")
         .required(false)
+        .dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.BASIC)
         .expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT)
         
.addValidator(StandardValidators.createRegexMatchingValidator(Pattern.compile("^[\\x20-\\x39\\x3b-\\x7e\\x80-\\xff]+$")))
         .build();
@@ -139,6 +160,7 @@ public class RestLookupService extends 
AbstractControllerService implements Reco
         .displayName("Basic Authentication Password")
         .description("The password to be used by the client to authenticate 
against the Remote URL.")
         .required(false)
+        .dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.BASIC)
         .sensitive(true)
         .expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT)
         
.addValidator(StandardValidators.createRegexMatchingValidator(Pattern.compile("^[\\x20-\\x7e\\x80-\\xff]+$")))
@@ -150,6 +172,7 @@ public class RestLookupService extends 
AbstractControllerService implements Reco
         .description("Whether to communicate with the website using Digest 
Authentication. 'Basic Authentication Username' and 'Basic Authentication 
Password' are used "
                 + "for authentication.")
         .required(false)
+        .dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.BASIC)
         .defaultValue("false")
         .allowableValues("true", "false")
         .build();
@@ -201,6 +224,8 @@ public class RestLookupService extends 
AbstractControllerService implements Reco
             RECORD_PATH,
             RESPONSE_HANDLING_STRATEGY,
             SSL_CONTEXT_SERVICE,
+            AUTHENTICATION_STRATEGY,
+            OAUTH2_ACCESS_TOKEN_PROVIDER,
             PROXY_CONFIGURATION_SERVICE,
             PROP_BASIC_AUTH_USERNAME,
             PROP_BASIC_AUTH_PASSWORD,
@@ -225,6 +250,7 @@ public class RestLookupService extends 
AbstractControllerService implements Reco
     private volatile String basicPass;
     private volatile boolean isDigest;
     private volatile ResponseHandlingStrategy responseHandlingStrategy;
+    private volatile Optional<OAuth2AccessTokenProvider> 
oauth2AccessTokenProviderOptional;
 
     @OnEnabled
     public void onEnabled(final ConfigurationContext context) {
@@ -232,6 +258,14 @@ public class RestLookupService extends 
AbstractControllerService implements Reco
         proxyConfigurationService = 
context.getProperty(PROXY_CONFIGURATION_SERVICE)
                 .asControllerService(ProxyConfigurationService.class);
 
+        if (context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet()) {
+            OAuth2AccessTokenProvider oauth2AccessTokenProvider = 
context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).asControllerService(OAuth2AccessTokenProvider.class);
+            oauth2AccessTokenProvider.getAccessDetails();
+            oauth2AccessTokenProviderOptional = 
Optional.of(oauth2AccessTokenProvider);
+        } else {
+            oauth2AccessTokenProviderOptional = Optional.empty();
+        }
+
         OkHttpClient.Builder builder = new OkHttpClient.Builder();
 
         setAuthenticator(builder, context);
@@ -363,6 +397,13 @@ public class RestLookupService extends 
AbstractControllerService implements Reco
         }
     }
 
+    @Override
+    public void migrateProperties(final PropertyConfiguration config) {
+        if (config.isPropertySet(PROP_BASIC_AUTH_USERNAME)) {
+            config.setProperty(AUTHENTICATION_STRATEGY, 
AuthenticationStrategy.BASIC.getValue());
+        }
+    }
+
     protected void validateVerb(String method) throws LookupFailureException {
         if (!VALID_VERBS.contains(method)) {
             throw new LookupFailureException(String.format("%s is not a 
supported HTTP verb.", method));
@@ -444,32 +485,38 @@ public class RestLookupService extends 
AbstractControllerService implements Reco
             final MediaType mt = MediaType.parse(mimeType);
             requestBody = RequestBody.create(body, mt);
         }
-        Request.Builder request = new Request.Builder()
+        final Request.Builder request = new Request.Builder()
                 .url(endpoint);
         switch (method) {
             case "delete":
-                request = body != null ? request.delete(requestBody) : 
request.delete();
+                if (body != null) request.delete(requestBody); else 
request.delete();
                 break;
             case "get":
-                request = request.get();
+                request.get();
                 break;
             case "post":
-                request = request.post(requestBody);
+                request.post(requestBody);
                 break;
             case "put":
-                request = request.put(requestBody);
+                request.put(requestBody);
                 break;
         }
 
         if (headers != null) {
             for (Map.Entry<String, PropertyValue> header : headers.entrySet()) 
{
-                request = request.addHeader(header.getKey(), 
header.getValue().evaluateAttributeExpressions(context).getValue());
+                request.addHeader(header.getKey(), 
header.getValue().evaluateAttributeExpressions(context).getValue());
             }
         }
 
-        if (!basicUser.isEmpty() && !isDigest) {
-            String credential = Credentials.basic(basicUser, basicPass);
-            request = request.header("Authorization", credential);
+        if (!isDigest) {
+            if (!basicUser.isEmpty()) {
+                String credential = Credentials.basic(basicUser, basicPass);
+                request.header("Authorization", credential);
+            } else {
+                
oauth2AccessTokenProviderOptional.ifPresent(oauth2AccessTokenProvider ->
+                    request.header("Authorization", "Bearer " + 
oauth2AccessTokenProvider.getAccessDetails().getAccessToken())
+                );
+            }
         }
 
         return request.build();
diff --git 
a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/test/java/org/apache/nifi/lookup/TestRestLookupService.java
 
b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/test/java/org/apache/nifi/lookup/TestRestLookupService.java
index 724f5b8df0..04901cc123 100644
--- 
a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/test/java/org/apache/nifi/lookup/TestRestLookupService.java
+++ 
b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/test/java/org/apache/nifi/lookup/TestRestLookupService.java
@@ -19,6 +19,7 @@ package org.apache.nifi.lookup;
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
 import okhttp3.mockwebserver.RecordedRequest;
+import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
 import org.apache.nifi.reporting.InitializationException;
 import org.apache.nifi.serialization.RecordReader;
 import org.apache.nifi.serialization.RecordReaderFactory;
@@ -31,6 +32,7 @@ import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Timeout;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Answers;
 import org.mockito.Mock;
 import org.mockito.junit.jupiter.MockitoExtension;
 
@@ -51,6 +53,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
 @Timeout(10)
@@ -197,6 +200,35 @@ class TestRestLookupService {
         assertInstanceOf(IOException.class, exception.getCause());
     }
 
+    @Test
+    void testOAuth2AuthorizationHeader() throws Exception {
+        String accessToken = "access_token";
+        String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
+
+        OAuth2AccessTokenProvider oauth2AccessTokenProvider = 
mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
+        
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
+        
when(oauth2AccessTokenProvider.getAccessDetails().getAccessToken()).thenReturn(accessToken);
+        runner.addControllerService(oauth2AccessTokenProviderId, 
oauth2AccessTokenProvider);
+        runner.enableControllerService(oauth2AccessTokenProvider);
+
+        runner.setProperty(RestLookupService.AUTHENTICATION_STRATEGY, 
AuthenticationStrategy.OAUTH2);
+        runner.setProperty(restLookupService, 
RestLookupService.OAUTH2_ACCESS_TOKEN_PROVIDER, 
oauth2AccessTokenProvider.getIdentifier());
+        runner.enableControllerService(restLookupService);
+
+        when(recordReaderFactory.createRecordReader(any(), any(), anyLong(), 
any())).thenReturn(recordReader);
+        when(recordReader.nextRecord()).thenReturn(record);
+        mockWebServer.enqueue(new MockResponse());
+
+        final Optional<Record> recordFound = 
restLookupService.lookup(Collections.emptyMap());
+        assertTrue(recordFound.isPresent());
+
+        RecordedRequest recordedRequest = mockWebServer.takeRequest();
+
+        String actualAuthorizationHeader = 
recordedRequest.getHeader("Authorization");
+        assertEquals("Bearer " + accessToken, actualAuthorizationHeader);
+
+    }
+
     private void assertRecordedRequestFound() throws InterruptedException {
         final RecordedRequest request = mockWebServer.takeRequest();
 

Reply via email to