This is an automated email from the ASF dual-hosted git repository.

xianjingfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 7665a64ef [#1947] feat(server): Support authentication for 
decommission interface (#1948)
7665a64ef is described below

commit 7665a64ef1476b7b88bec58f12375f62a9f9db7e
Author: xianjingfeng <xianjingfeng...@gmail.com>
AuthorDate: Fri Jul 26 11:16:13 2024 +0800

    [#1947] feat(server): Support authentication for decommission interface 
(#1948)
    
    ### What changes were proposed in this pull request?
    Support authentication for decommission interface.
    
    ### Why are the changes needed?
    For more safety.
    Fix: #1947
    
    ### Does this PR introduce any user-facing change?
    Set rss.rest.authorization.credentials to a not empty string.
    
    ### How was this patch tested?
    UT
---
 .../apache/uniffle/common/config/RssBaseConf.java  | 10 ++++
 .../uniffle/common/web/resource/Authorization.java | 30 ++++++++++++
 .../web/resource/AuthorizationRequestFilter.java   | 55 +++++++++++++++++++++
 .../apache/uniffle/common/metrics/TestUtils.java   | 56 ++++++++++++---------
 .../uniffle/coordinator/CoordinatorServer.java     |  4 ++
 .../coordinator/web/resource/ServerResource.java   |  5 ++
 docs/coordinator_guide.md                          | 10 ++++
 .../java/org/apache/uniffle/test/ServletTest.java  | 57 +++++++++++++++++++---
 8 files changed, 196 insertions(+), 31 deletions(-)

diff --git 
a/common/src/main/java/org/apache/uniffle/common/config/RssBaseConf.java 
b/common/src/main/java/org/apache/uniffle/common/config/RssBaseConf.java
index f6b99269c..0ada7bfb6 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssBaseConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssBaseConf.java
@@ -286,6 +286,16 @@ public class RssBaseConf extends RssConf {
           .defaultValue(WritableSerializer.class.getName())
           .withDescription("Serializations are used for creative Serializers 
and Deserializers");
 
+  public static final ConfigOption<String> REST_AUTHORIZATION_CREDENTIALS =
+      ConfigOptions.key("rss.http.basic.authorizationCredentials")
+          .stringType()
+          .noDefaultValue()
+          .withDescription(
+              "Authorization credentials for the rest interface. "
+                  + "For Basic authentication the credentials are constructed 
by"
+                  + " first combining the username and the password with a 
colon (uniffle:uniffle123)"
+                  + ", and then by encoding the resulting string in base64 
(dW5pZmZsZTp1bmlmZmxlMTIz).");
+
   public boolean loadConfFromFile(String fileName, List<ConfigOption<Object>> 
configOptions) {
     Map<String, String> properties = RssUtils.getPropertiesFromFile(fileName);
     if (properties == null) {
diff --git 
a/common/src/main/java/org/apache/uniffle/common/web/resource/Authorization.java
 
b/common/src/main/java/org/apache/uniffle/common/web/resource/Authorization.java
new file mode 100644
index 000000000..314174eb2
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/web/resource/Authorization.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.web.resource;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+import org.apache.hbase.thirdparty.javax.ws.rs.NameBinding;
+
+@NameBinding
+@Target({ElementType.METHOD, ElementType.TYPE})
+@Retention(RetentionPolicy.RUNTIME)
+public @interface Authorization {}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/web/resource/AuthorizationRequestFilter.java
 
b/common/src/main/java/org/apache/uniffle/common/web/resource/AuthorizationRequestFilter.java
new file mode 100644
index 000000000..1a927d674
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/web/resource/AuthorizationRequestFilter.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.web.resource;
+
+import java.io.IOException;
+import javax.servlet.ServletContext;
+
+import 
org.apache.hbase.thirdparty.javax.ws.rs.container.ContainerRequestContext;
+import 
org.apache.hbase.thirdparty.javax.ws.rs.container.ContainerRequestFilter;
+import org.apache.hbase.thirdparty.javax.ws.rs.core.Context;
+import org.apache.hbase.thirdparty.javax.ws.rs.core.MediaType;
+import org.apache.hbase.thirdparty.javax.ws.rs.core.Response;
+import org.apache.hbase.thirdparty.javax.ws.rs.ext.Provider;
+
+import org.apache.uniffle.common.config.RssBaseConf;
+
+@Provider
+@Authorization
+public class AuthorizationRequestFilter implements ContainerRequestFilter {
+  @Context protected ServletContext servletContext;
+
+  @Override
+  public void filter(ContainerRequestContext requestContext) throws 
IOException {
+    Object credentials =
+        
servletContext.getAttribute(RssBaseConf.REST_AUTHORIZATION_CREDENTIALS.key());
+    if (credentials == null) {
+      return;
+    }
+    String authorization = requestContext.getHeaderString("Authorization");
+    if (authorization == null
+        || !authorization.startsWith("Basic ")
+        || !authorization.substring(6).equals(credentials)) {
+      requestContext.abortWith(
+          Response.status(Response.Status.UNAUTHORIZED)
+              .entity("Authentication Failed")
+              .type(MediaType.TEXT_PLAIN)
+              .build());
+    }
+  }
+}
diff --git 
a/common/src/test/java/org/apache/uniffle/common/metrics/TestUtils.java 
b/common/src/test/java/org/apache/uniffle/common/metrics/TestUtils.java
index c08d09e65..ffe7068d4 100644
--- a/common/src/test/java/org/apache/uniffle/common/metrics/TestUtils.java
+++ b/common/src/test/java/org/apache/uniffle/common/metrics/TestUtils.java
@@ -19,10 +19,12 @@ package org.apache.uniffle.common.metrics;
 
 import java.io.BufferedReader;
 import java.io.IOException;
+import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.io.OutputStream;
 import java.net.HttpURLConnection;
 import java.net.URL;
+import java.util.Map;
 
 public class TestUtils {
 
@@ -30,16 +32,9 @@ public class TestUtils {
 
   public static String httpGet(String urlString) throws IOException {
     URL url = new URL(urlString);
-    HttpURLConnection con = (HttpURLConnection) url.openConnection();
-    con.setRequestMethod("GET");
-    StringBuilder content = new StringBuilder();
-    try (BufferedReader in = new BufferedReader(new 
InputStreamReader(con.getInputStream())); ) {
-      String inputLine;
-      while ((inputLine = in.readLine()) != null) {
-        content.append(inputLine);
-      }
-    }
-    return content.toString();
+    HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+    conn.setRequestMethod("GET");
+    return getResponseStr(conn);
   }
 
   public static String httpPost(String urlString) throws IOException {
@@ -47,24 +42,39 @@ public class TestUtils {
   }
 
   public static String httpPost(String urlString, String postData) throws 
IOException {
+    return httpPost(urlString, postData, null);
+  }
+
+  public static String httpPost(String urlString, String postData, Map<String, 
String> headers)
+      throws IOException {
     URL url = new URL(urlString);
-    HttpURLConnection con = (HttpURLConnection) url.openConnection();
-    con.setDoOutput(true);
-    con.setRequestMethod("POST");
-    con.setRequestProperty("Content-type", "application/json");
-    StringBuilder content = new StringBuilder();
-    try (OutputStream outputStream = con.getOutputStream()) {
+    HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+    conn.setDoOutput(true);
+    conn.setRequestMethod("POST");
+    conn.setRequestProperty("Content-type", "application/json");
+    if (headers != null) {
+      for (Map.Entry<String, String> entry : headers.entrySet()) {
+        conn.setRequestProperty(entry.getKey(), entry.getValue());
+      }
+    }
+    try (OutputStream outputStream = conn.getOutputStream()) {
       if (postData != null) {
         outputStream.write(postData.getBytes());
       }
-      try (BufferedReader in = new BufferedReader(new 
InputStreamReader(con.getInputStream())); ) {
-        String inputLine;
-        while ((inputLine = in.readLine()) != null) {
-          content.append(inputLine);
-        }
-      }
+      return getResponseStr(conn);
     }
+  }
 
-    return content.toString();
+  private static String getResponseStr(HttpURLConnection conn) throws 
IOException {
+    StringBuilder responseContent = new StringBuilder();
+    InputStream inputStream =
+        conn.getResponseCode() == 200 ? conn.getInputStream() : 
conn.getErrorStream();
+    try (BufferedReader in = new BufferedReader(new 
InputStreamReader(inputStream))) {
+      String inputLine;
+      while ((inputLine = in.readLine()) != null) {
+        responseContent.append(inputLine);
+      }
+    }
+    return responseContent.toString();
   }
 }
diff --git 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorServer.java
 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorServer.java
index cd3a09283..baa42dabc 100644
--- 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorServer.java
+++ 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorServer.java
@@ -27,6 +27,7 @@ import picocli.CommandLine;
 
 import org.apache.uniffle.common.Arguments;
 import org.apache.uniffle.common.ReconfigurableConfManager;
+import org.apache.uniffle.common.config.RssBaseConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.metrics.GRPCMetrics;
 import org.apache.uniffle.common.metrics.JvmMetrics;
@@ -195,6 +196,9 @@ public class CoordinatorServer {
     CoordinatorFactory coordinatorFactory = new CoordinatorFactory(this);
     server = coordinatorFactory.getServer();
     jettyServer = new JettyServer(coordinatorConf);
+    jettyServer.registerInstance(
+        RssBaseConf.REST_AUTHORIZATION_CREDENTIALS.key(),
+        coordinatorConf.getString(RssBaseConf.REST_AUTHORIZATION_CREDENTIALS));
     // register packages and instances for jersey
     jettyServer.addResourcePackages(
         "org.apache.uniffle.coordinator.web.resource", 
"org.apache.uniffle.common.web.resource");
diff --git 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/web/resource/ServerResource.java
 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/web/resource/ServerResource.java
index c9741f3ce..c176f065b 100644
--- 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/web/resource/ServerResource.java
+++ 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/web/resource/ServerResource.java
@@ -41,6 +41,7 @@ import org.apache.hbase.thirdparty.javax.ws.rs.core.MediaType;
 import org.apache.uniffle.common.Application;
 import org.apache.uniffle.common.ServerStatus;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.web.resource.Authorization;
 import org.apache.uniffle.common.web.resource.BaseResource;
 import org.apache.uniffle.common.web.resource.Response;
 import org.apache.uniffle.coordinator.ApplicationManager;
@@ -97,6 +98,7 @@ public class ServerResource extends BaseResource {
     return Response.success(serverList);
   }
 
+  @Authorization
   @POST
   @Path("/cancelDecommission")
   public Response<Object> cancelDecommission(CancelDecommissionRequest params) 
{
@@ -110,6 +112,7 @@ public class ServerResource extends BaseResource {
         });
   }
 
+  @Authorization
   @POST
   @Path("/{id}/cancelDecommission")
   public Response<Object> cancelDecommission(@PathParam("id") String serverId) 
{
@@ -120,6 +123,7 @@ public class ServerResource extends BaseResource {
         });
   }
 
+  @Authorization
   @POST
   @Path("/decommission")
   public Response<Object> decommission(DecommissionRequest params) {
@@ -133,6 +137,7 @@ public class ServerResource extends BaseResource {
         });
   }
 
+  @Authorization
   @POST
   @Path("/{id}/decommission")
   @Produces({MediaType.APPLICATION_JSON})
diff --git a/docs/coordinator_guide.md b/docs/coordinator_guide.md
index 66a19cc09..b91dcc325 100644
--- a/docs/coordinator_guide.md
+++ b/docs/coordinator_guide.md
@@ -130,6 +130,16 @@ AccessQuotaChecker is a checker when the number of 
concurrent tasks submitted by
 
 ## RESTful API
 
+### Enable Authorization
+The RESTful API supports Basic authorization. we can enable it by setting 
`rss.http.basic.authorizationCredentials` to a not empty string.
+After enabling Basic authorization, you need to add credentials to the header 
when requesting some of the interfaces, such as the decommissioning interface.
+For Basic authentication the credentials are constructed by first combining 
the username and the password with
+a colon (uniffle:uniffle123) , and then by encoding the resulting string in 
base64 (dW5pZmZsZTp1bmlmZmxlMTIz).
+#### Example cURL:
+>```bash
+>curl -X POST -H 'Authorization: Basic dW5pZmZsZTp1bmlmZmxlMTIz' -H 
"Content-Type: application/json" http://localhost:19998/api/server/decommission 
 -d '{"serverIds": ["127.0.0.1-19999"]}'
+>```
+
 ### Fetch single shuffle server
 
 <details>
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ServletTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ServletTest.java
index 56c02d63a..2c10e6105 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ServletTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ServletTest.java
@@ -28,6 +28,7 @@ import java.util.concurrent.TimeUnit;
 
 import com.fasterxml.jackson.core.type.TypeReference;
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 import org.apache.commons.collections4.CollectionUtils;
@@ -71,6 +72,9 @@ public class ServletTest extends IntegrationTestBase {
   private static final String DECOMMISSION_SINGLENODE_URL = URL_PREFIX + 
"server/%s/decommission";
   private static final String CANCEL_DECOMMISSION_SINGLENODE_URL =
       URL_PREFIX + "server/%s/cancelDecommission";
+  private static final String AUTHORIZATION_CREDENTIALS = 
"dW5pZmZsZTp1bmlmZmxlMTIz";
+  private static final Map<String, String> authorizationHeader =
+      ImmutableMap.of("Authorization", "Basic " + AUTHORIZATION_CREDENTIALS);
   private static CoordinatorServer coordinatorServer;
   private ObjectMapper objectMapper = new ObjectMapper();
 
@@ -85,6 +89,7 @@ public class ServletTest extends IntegrationTestBase {
     coordinatorConf.set(RssBaseConf.JETTY_HTTP_PORT, 12345);
     coordinatorConf.set(RssBaseConf.JETTY_CORE_POOL_SIZE, 128);
     coordinatorConf.set(RssBaseConf.RPC_SERVER_PORT, 12346);
+    coordinatorConf.set(RssBaseConf.REST_AUTHORIZATION_CREDENTIALS, 
AUTHORIZATION_CREDENTIALS);
     createCoordinatorServer(coordinatorConf);
 
     ShuffleServerConf shuffleServerConf = 
getShuffleServerConf(ServerType.GRPC);
@@ -249,7 +254,9 @@ public class ServletTest extends IntegrationTestBase {
     decommissionRequest.setServerIds(Sets.newHashSet("not_exist_serverId"));
     String content =
         TestUtils.httpPost(
-            CANCEL_DECOMMISSION_URL, 
objectMapper.writeValueAsString(decommissionRequest));
+            CANCEL_DECOMMISSION_URL,
+            objectMapper.writeValueAsString(decommissionRequest),
+            authorizationHeader);
     Response<?> response = objectMapper.readValue(content, Response.class);
     assertEquals(-1, response.getCode());
     assertNotNull(response.getErrMsg());
@@ -257,7 +264,9 @@ public class ServletTest extends IntegrationTestBase {
     
cancelDecommissionRequest.setServerIds(Sets.newHashSet(shuffleServer.getId()));
     content =
         TestUtils.httpPost(
-            CANCEL_DECOMMISSION_URL, 
objectMapper.writeValueAsString(cancelDecommissionRequest));
+            CANCEL_DECOMMISSION_URL,
+            objectMapper.writeValueAsString(cancelDecommissionRequest),
+            authorizationHeader);
     response = objectMapper.readValue(content, Response.class);
     assertEquals(0, response.getCode());
 
@@ -268,7 +277,10 @@ public class ServletTest extends IntegrationTestBase {
             "testDecommissionServlet_appId", 0, Lists.newArrayList(new 
PartitionRange(0, 1)), ""));
     decommissionRequest.setServerIds(Sets.newHashSet(shuffleServer.getId()));
     content =
-        TestUtils.httpPost(DECOMMISSION_URL, 
objectMapper.writeValueAsString(decommissionRequest));
+        TestUtils.httpPost(
+            DECOMMISSION_URL,
+            objectMapper.writeValueAsString(decommissionRequest),
+            authorizationHeader);
     response = objectMapper.readValue(content, Response.class);
     assertEquals(0, response.getCode());
     assertEquals(ServerStatus.DECOMMISSIONING, 
shuffleServer.getServerStatus());
@@ -286,7 +298,9 @@ public class ServletTest extends IntegrationTestBase {
     // Cancel decommission.
     content =
         TestUtils.httpPost(
-            CANCEL_DECOMMISSION_URL, 
objectMapper.writeValueAsString(cancelDecommissionRequest));
+            CANCEL_DECOMMISSION_URL,
+            objectMapper.writeValueAsString(cancelDecommissionRequest),
+            authorizationHeader);
     response = objectMapper.readValue(content, Response.class);
     assertEquals(0, response.getCode());
     assertEquals(ServerStatus.ACTIVE, shuffleServer.getServerStatus());
@@ -297,13 +311,18 @@ public class ServletTest extends IntegrationTestBase {
     ShuffleServer shuffleServer = grpcShuffleServers.get(0);
     assertEquals(ServerStatus.ACTIVE, shuffleServer.getServerStatus());
     String content =
-        TestUtils.httpPost(String.format(CANCEL_DECOMMISSION_SINGLENODE_URL, 
"not_exist_serverId"));
+        TestUtils.httpPost(
+            String.format(CANCEL_DECOMMISSION_SINGLENODE_URL, 
"not_exist_serverId"),
+            null,
+            authorizationHeader);
     Response<?> response = objectMapper.readValue(content, Response.class);
     assertEquals(-1, response.getCode());
     assertNotNull(response.getErrMsg());
     content =
         TestUtils.httpPost(
-            String.format(CANCEL_DECOMMISSION_SINGLENODE_URL, 
shuffleServer.getId()));
+            String.format(CANCEL_DECOMMISSION_SINGLENODE_URL, 
shuffleServer.getId()),
+            null,
+            authorizationHeader);
     response = objectMapper.readValue(content, Response.class);
     assertEquals(0, response.getCode());
 
@@ -312,7 +331,11 @@ public class ServletTest extends IntegrationTestBase {
     shuffleServerClient.registerShuffle(
         new RssRegisterShuffleRequest(
             "testDecommissionServlet_appId", 0, Lists.newArrayList(new 
PartitionRange(0, 1)), ""));
-    content = TestUtils.httpPost(String.format(DECOMMISSION_SINGLENODE_URL, 
shuffleServer.getId()));
+    content =
+        TestUtils.httpPost(
+            String.format(DECOMMISSION_SINGLENODE_URL, shuffleServer.getId()),
+            null,
+            authorizationHeader);
     response = objectMapper.readValue(content, Response.class);
     assertEquals(0, response.getCode());
     assertEquals(ServerStatus.DECOMMISSIONING, 
shuffleServer.getServerStatus());
@@ -330,9 +353,27 @@ public class ServletTest extends IntegrationTestBase {
     // Cancel decommission.
     content =
         TestUtils.httpPost(
-            String.format(CANCEL_DECOMMISSION_SINGLENODE_URL, 
shuffleServer.getId()));
+            String.format(CANCEL_DECOMMISSION_SINGLENODE_URL, 
shuffleServer.getId()),
+            null,
+            authorizationHeader);
     response = objectMapper.readValue(content, Response.class);
     assertEquals(0, response.getCode());
     assertEquals(ServerStatus.ACTIVE, shuffleServer.getServerStatus());
   }
+
+  @Test
+  public void testRequestWithWrongCredentials() throws Exception {
+    DecommissionRequest decommissionRequest = new DecommissionRequest();
+    decommissionRequest.setServerIds(Sets.newHashSet("not_exist_serverId"));
+    String wrongCredentials = "dW5pZmZsZTp1bmlmZmxlMTIz1";
+    String content =
+        TestUtils.httpPost(
+            CANCEL_DECOMMISSION_URL,
+            objectMapper.writeValueAsString(decommissionRequest),
+            ImmutableMap.of("Authorization", "Basic " + wrongCredentials));
+    for (int i = 0; i < 1000; i++) {
+      Thread.sleep(1000);
+    }
+    assertEquals("Authentication Failed", content);
+  }
 }

Reply via email to