This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new d37bb4c68 [CELEBORN-1131] Add Client/Server bootstrap framework to 
transport layer
d37bb4c68 is described below

commit d37bb4c6824a170a62cbac20321910e8b57d186c
Author: Chandni Singh <[email protected]>
AuthorDate: Wed Nov 29 21:06:23 2023 +0800

    [CELEBORN-1131] Add Client/Server bootstrap framework to transport layer
    
    ### What changes were proposed in this pull request?
    This adds the client/server bootstrap framework to transport layer in 
Celeborn. This is copied from Spark.
    It is part of the epic: https://issues.apache.org/jira/browse/CELEBORN-1011.
    
    ### Why are the changes needed?
    The changes are needed for adding authentication to Celeborn. See 
[CELEBORN-1011](https://issues.apache.org/jira/browse/CELEBORN-1011).
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Part of a larger change which has tests
    
    Closes #2120 from otterc/CELEBORN-1131-PR1.
    
    Lead-authored-by: Chandni Singh <[email protected]>
    Co-authored-by: otterc <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../flink/network/FlinkTransportClientFactory.java |  3 +-
 .../celeborn/common/network/TransportContext.java  | 37 +++++++++++++++-----
 .../network/client/TransportClientBootstrap.java   | 40 ++++++++++++++++++++++
 .../network/client/TransportClientFactory.java     | 33 ++++++++++++++++--
 .../common/network/server/TransportServer.java     | 24 +++++++++++--
 .../network/server/TransportServerBootstrap.java   | 36 +++++++++++++++++++
 6 files changed, 159 insertions(+), 14 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
index e9e716ef3..68e418a54 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
@@ -18,6 +18,7 @@
 package org.apache.celeborn.plugin.flink.network;
 
 import java.io.IOException;
+import java.util.Collections;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Supplier;
 
@@ -39,7 +40,7 @@ public class FlinkTransportClientFactory extends 
TransportClientFactory {
   private final int fetchMaxRetries;
 
   public FlinkTransportClientFactory(TransportContext context, int 
fetchMaxRetries) {
-    super(context);
+    super(context, Collections.emptyList());
     bufferSuppliers = JavaUtils.newConcurrentHashMap();
     this.fetchMaxRetries = fetchMaxRetries;
     this.pooledAllocator = new UnpooledByteBufAllocator(true);
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java 
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
index 50bb5cac1..c8796ea32 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
@@ -17,6 +17,9 @@
 
 package org.apache.celeborn.common.network;
 
+import java.util.Collections;
+import java.util.List;
+
 import io.netty.channel.Channel;
 import io.netty.channel.ChannelDuplexHandler;
 import io.netty.channel.ChannelInboundHandlerAdapter;
@@ -27,6 +30,7 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.metrics.source.AbstractSource;
 import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.client.TransportClientBootstrap;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
 import org.apache.celeborn.common.network.client.TransportResponseHandler;
 import org.apache.celeborn.common.network.protocol.MessageEncoder;
@@ -93,35 +97,52 @@ public class TransportContext {
     this(conf, msgHandler, false, false, null);
   }
 
+  public TransportClientFactory 
createClientFactory(List<TransportClientBootstrap> bootstraps) {
+    return new TransportClientFactory(this, bootstraps);
+  }
+
   public TransportClientFactory createClientFactory() {
-    return new TransportClientFactory(this);
+    return createClientFactory(Collections.emptyList());
   }
 
   /** Create a server which will attempt to bind to a specific host and port. 
*/
   public TransportServer createServer(String host, int port) {
-    return new TransportServer(this, host, port, source);
+    return new TransportServer(this, host, port, source, msgHandler, 
Collections.emptyList());
+  }
+
+  public TransportServer createServer(
+      String host, int port, List<TransportServerBootstrap> bootstraps) {
+    return new TransportServer(this, host, port, source, msgHandler, 
bootstraps);
   }
 
   public TransportServer createServer(int port) {
-    return createServer(null, port);
+    return createServer(null, port, Collections.emptyList());
   }
 
   /** For Suite only */
   public TransportServer createServer() {
-    return createServer(null, 0);
+    return createServer(null, 0, Collections.emptyList());
   }
 
-  public TransportChannelHandler initializePipeline(SocketChannel channel) {
-    return initializePipeline(channel, new TransportFrameDecoder());
+  public TransportChannelHandler initializePipeline(
+      SocketChannel channel, ChannelInboundHandlerAdapter decoder) {
+    return initializePipeline(channel, decoder, msgHandler);
   }
 
   public TransportChannelHandler initializePipeline(
-      SocketChannel channel, ChannelInboundHandlerAdapter decoder) {
+      SocketChannel channel, BaseMessageHandler resolvedMsgHandler) {
+    return initializePipeline(channel, new TransportFrameDecoder(), 
resolvedMsgHandler);
+  }
+
+  public TransportChannelHandler initializePipeline(
+      SocketChannel channel,
+      ChannelInboundHandlerAdapter decoder,
+      BaseMessageHandler resolvedMsgHandler) {
     try {
       if (channelsLimiter != null) {
         channel.pipeline().addLast("limiter", channelsLimiter);
       }
-      TransportChannelHandler channelHandler = createChannelHandler(channel, 
msgHandler);
+      TransportChannelHandler channelHandler = createChannelHandler(channel, 
resolvedMsgHandler);
       channel
           .pipeline()
           .addLast("encoder", ENCODER)
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientBootstrap.java
 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientBootstrap.java
new file mode 100644
index 000000000..bdad119fc
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientBootstrap.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.client;
+
+import io.netty.channel.Channel;
+
+/**
+ * A bootstrap which is executed on a TransportClient before it is returned to 
the user. This
+ * enables an initial exchange of information (e.g., SASL authentication 
tokens) on a once-per-
+ * connection basis.
+ *
+ * <p>Since connections (and TransportClients) are reused as much as possible, 
it is generally
+ * reasonable to perform an expensive bootstrapping operation, as they often 
share a lifespan with
+ * the JVM itself.
+ */
+public interface TransportClientBootstrap {
+  /**
+   * Performs the bootstrapping operation, throwing an exception on failure.
+   *
+   * @param client the transport client to bootstrap
+   * @param channel the associated channel with the transport client
+   * @throws RuntimeException
+   */
+  void doBootstrap(TransportClient client, Channel channel) throws 
RuntimeException;
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
index ba1a0dde9..a4d4d7515 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
@@ -21,12 +21,15 @@ import java.io.Closeable;
 import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
+import java.util.List;
 import java.util.Random;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 
 import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Lists;
 import io.netty.bootstrap.Bootstrap;
 import io.netty.buffer.ByteBufAllocator;
 import io.netty.channel.*;
@@ -39,6 +42,7 @@ import org.apache.celeborn.common.network.TransportContext;
 import org.apache.celeborn.common.network.server.TransportChannelHandler;
 import org.apache.celeborn.common.network.util.*;
 import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.common.util.Utils;
 
 /**
  * Factory for creating {@link TransportClient}s by using createClient.
@@ -68,6 +72,7 @@ public class TransportClientFactory implements Closeable {
   private static final Logger logger = 
LoggerFactory.getLogger(TransportClientFactory.class);
 
   private final TransportContext context;
+  private final List<TransportClientBootstrap> clientBootstraps;
   private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
 
   /** Random number generator for picking connections between peers. */
@@ -84,9 +89,11 @@ public class TransportClientFactory implements Closeable {
   private EventLoopGroup workerGroup;
   protected ByteBufAllocator pooledAllocator;
 
-  public TransportClientFactory(TransportContext context) {
+  public TransportClientFactory(
+      TransportContext context, List<TransportClientBootstrap> 
clientBootstraps) {
     this.context = Preconditions.checkNotNull(context);
     TransportConf conf = context.getConf();
+    this.clientBootstraps = 
Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
     this.connectionPool = JavaUtils.newConcurrentHashMap();
     this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
     this.connectTimeoutMs = conf.connectTimeoutMs();
@@ -241,6 +248,7 @@ public class TransportClientFactory implements Closeable {
         });
 
     // Connect to the remote server
+    long preConnect = System.nanoTime();
     ChannelFuture cf = bootstrap.connect(address);
     if (!cf.await(connectTimeoutMs)) {
       throw new CelebornIOException(
@@ -250,10 +258,31 @@ public class TransportClientFactory implements Closeable {
     }
 
     TransportClient client = clientRef.get();
+    Channel channel = channelRef.get();
     assert client != null : "Channel future completed successfully with null 
client";
 
+    // Execute any client bootstraps synchronously before marking the Client 
as successful.
+    long preBootstrap = System.nanoTime();
+    logger.debug("Running bootstraps for {} ...", address);
+    try {
+      for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
+        clientBootstrap.doBootstrap(client, channel);
+      }
+    } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap 
may be written in Scala
+      long bootstrapTime = System.nanoTime() - preBootstrap;
+      logger.error(
+          "Exception while bootstrapping client after {}",
+          Utils.nanoDurationToString(bootstrapTime),
+          e);
+      client.close();
+      throw Throwables.propagate(e);
+    }
+    long postBootstrap = System.nanoTime();
     logger.debug(
-        "Connection from {} to {} successful", 
client.getChannel().localAddress(), address);
+        "Successfully created connection to {} after {} ({} spent in 
bootstraps)",
+        address,
+        Utils.nanoDurationToString(postBootstrap - preConnect),
+        Utils.nanoDurationToString(postBootstrap - preBootstrap));
 
     return client;
   }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
 
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
index 73bb38bac..53f7b7169 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
@@ -19,8 +19,11 @@ package org.apache.celeborn.common.network.server;
 
 import java.io.Closeable;
 import java.net.InetSocketAddress;
+import java.util.List;
 import java.util.concurrent.TimeUnit;
 
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
 import io.netty.bootstrap.ServerBootstrap;
 import io.netty.buffer.PooledByteBufAllocator;
 import io.netty.channel.ChannelFuture;
@@ -44,17 +47,25 @@ public class TransportServer implements Closeable {
 
   private final TransportContext context;
   private final TransportConf conf;
-
+  private final BaseMessageHandler appMessageHandler;
+  private final List<TransportServerBootstrap> bootstraps;
   private ServerBootstrap bootstrap;
   private ChannelFuture channelFuture;
   private AbstractSource source;
   private int port = -1;
 
   public TransportServer(
-      TransportContext context, String hostToBind, int portToBind, 
AbstractSource source) {
+      TransportContext context,
+      String hostToBind,
+      int portToBind,
+      AbstractSource source,
+      BaseMessageHandler appMessageHandler,
+      List<TransportServerBootstrap> bootstraps) {
     this.context = context;
     this.conf = context.getConf();
     this.source = source;
+    this.appMessageHandler = Preconditions.checkNotNull(appMessageHandler);
+    this.bootstraps = 
Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
 
     boolean shouldClose = true;
     try {
@@ -124,7 +135,14 @@ public class TransportServer implements Closeable {
         new ChannelInitializer<SocketChannel>() {
           @Override
           protected void initChannel(SocketChannel ch) {
-            context.initializePipeline(ch);
+            BaseMessageHandler baseHandler = appMessageHandler;
+            logger.debug("number of bootstraps {}", bootstraps.size());
+            for (TransportServerBootstrap bootstrap : bootstraps) {
+              logger.debug(
+                  "Adding bootstrap to TransportServer {}.", 
bootstrap.getClass().getName());
+              baseHandler = bootstrap.doBootstrap(ch, baseHandler);
+            }
+            context.initializePipeline(ch, baseHandler);
           }
         });
   }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServerBootstrap.java
 
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServerBootstrap.java
new file mode 100644
index 000000000..cc367a05f
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServerBootstrap.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.server;
+
+import io.netty.channel.Channel;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a 
client connects to the
+ * server. This allows customizing the client channel to allow for things such 
as SASL
+ * authentication.
+ */
+public interface TransportServerBootstrap {
+  /**
+   * Customizes the channel to include new features, if needed.
+   *
+   * @param channel The connected channel opened by the client.
+   * @param baseMessageHandler The RPC handler for the server.
+   * @return The base message handler to use for the channel.
+   */
+  BaseMessageHandler doBootstrap(Channel channel, BaseMessageHandler 
baseMessageHandler);
+}

Reply via email to