This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new d37bb4c68 [CELEBORN-1131] Add Client/Server bootstrap framework to
transport layer
d37bb4c68 is described below
commit d37bb4c6824a170a62cbac20321910e8b57d186c
Author: Chandni Singh <[email protected]>
AuthorDate: Wed Nov 29 21:06:23 2023 +0800
[CELEBORN-1131] Add Client/Server bootstrap framework to transport layer
### What changes were proposed in this pull request?
This adds the client/server bootstrap framework to transport layer in
Celeborn. This is copied from Spark.
It is part of the epic: https://issues.apache.org/jira/browse/CELEBORN-1011.
### Why are the changes needed?
The changes are needed for adding authentication to Celeborn. See
[CELEBORN-1011](https://issues.apache.org/jira/browse/CELEBORN-1011).
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Part of a larger change which has tests
Closes #2120 from otterc/CELEBORN-1131-PR1.
Lead-authored-by: Chandni Singh <[email protected]>
Co-authored-by: otterc <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../flink/network/FlinkTransportClientFactory.java | 3 +-
.../celeborn/common/network/TransportContext.java | 37 +++++++++++++++-----
.../network/client/TransportClientBootstrap.java | 40 ++++++++++++++++++++++
.../network/client/TransportClientFactory.java | 33 ++++++++++++++++--
.../common/network/server/TransportServer.java | 24 +++++++++++--
.../network/server/TransportServerBootstrap.java | 36 +++++++++++++++++++
6 files changed, 159 insertions(+), 14 deletions(-)
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
index e9e716ef3..68e418a54 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java
@@ -18,6 +18,7 @@
package org.apache.celeborn.plugin.flink.network;
import java.io.IOException;
+import java.util.Collections;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
@@ -39,7 +40,7 @@ public class FlinkTransportClientFactory extends
TransportClientFactory {
private final int fetchMaxRetries;
public FlinkTransportClientFactory(TransportContext context, int
fetchMaxRetries) {
- super(context);
+ super(context, Collections.emptyList());
bufferSuppliers = JavaUtils.newConcurrentHashMap();
this.fetchMaxRetries = fetchMaxRetries;
this.pooledAllocator = new UnpooledByteBufAllocator(true);
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
index 50bb5cac1..c8796ea32 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
@@ -17,6 +17,9 @@
package org.apache.celeborn.common.network;
+import java.util.Collections;
+import java.util.List;
+
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;
@@ -27,6 +30,7 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.metrics.source.AbstractSource;
import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.client.TransportClientBootstrap;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.client.TransportResponseHandler;
import org.apache.celeborn.common.network.protocol.MessageEncoder;
@@ -93,35 +97,52 @@ public class TransportContext {
this(conf, msgHandler, false, false, null);
}
+ public TransportClientFactory
createClientFactory(List<TransportClientBootstrap> bootstraps) {
+ return new TransportClientFactory(this, bootstraps);
+ }
+
public TransportClientFactory createClientFactory() {
- return new TransportClientFactory(this);
+ return createClientFactory(Collections.emptyList());
}
/** Create a server which will attempt to bind to a specific host and port.
*/
public TransportServer createServer(String host, int port) {
- return new TransportServer(this, host, port, source);
+ return new TransportServer(this, host, port, source, msgHandler,
Collections.emptyList());
+ }
+
+ public TransportServer createServer(
+ String host, int port, List<TransportServerBootstrap> bootstraps) {
+ return new TransportServer(this, host, port, source, msgHandler,
bootstraps);
}
public TransportServer createServer(int port) {
- return createServer(null, port);
+ return createServer(null, port, Collections.emptyList());
}
/** For Suite only */
public TransportServer createServer() {
- return createServer(null, 0);
+ return createServer(null, 0, Collections.emptyList());
}
- public TransportChannelHandler initializePipeline(SocketChannel channel) {
- return initializePipeline(channel, new TransportFrameDecoder());
+ public TransportChannelHandler initializePipeline(
+ SocketChannel channel, ChannelInboundHandlerAdapter decoder) {
+ return initializePipeline(channel, decoder, msgHandler);
}
public TransportChannelHandler initializePipeline(
- SocketChannel channel, ChannelInboundHandlerAdapter decoder) {
+ SocketChannel channel, BaseMessageHandler resolvedMsgHandler) {
+ return initializePipeline(channel, new TransportFrameDecoder(),
resolvedMsgHandler);
+ }
+
+ public TransportChannelHandler initializePipeline(
+ SocketChannel channel,
+ ChannelInboundHandlerAdapter decoder,
+ BaseMessageHandler resolvedMsgHandler) {
try {
if (channelsLimiter != null) {
channel.pipeline().addLast("limiter", channelsLimiter);
}
- TransportChannelHandler channelHandler = createChannelHandler(channel,
msgHandler);
+ TransportChannelHandler channelHandler = createChannelHandler(channel,
resolvedMsgHandler);
channel
.pipeline()
.addLast("encoder", ENCODER)
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientBootstrap.java
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientBootstrap.java
new file mode 100644
index 000000000..bdad119fc
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientBootstrap.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.client;
+
+import io.netty.channel.Channel;
+
+/**
+ * A bootstrap which is executed on a TransportClient before it is returned to
the user. This
+ * enables an initial exchange of information (e.g., SASL authentication
tokens) on a once-per-
+ * connection basis.
+ *
+ * <p>Since connections (and TransportClients) are reused as much as possible,
it is generally
+ * reasonable to perform an expensive bootstrapping operation, as they often
share a lifespan with
+ * the JVM itself.
+ */
+public interface TransportClientBootstrap {
+ /**
+ * Performs the bootstrapping operation, throwing an exception on failure.
+ *
+ * @param client the transport client to bootstrap
+ * @param channel the associated channel with the transport client
+ * @throws RuntimeException
+ */
+ void doBootstrap(TransportClient client, Channel channel) throws
RuntimeException;
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
index ba1a0dde9..a4d4d7515 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
@@ -21,12 +21,15 @@ import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
+import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Lists;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.*;
@@ -39,6 +42,7 @@ import org.apache.celeborn.common.network.TransportContext;
import org.apache.celeborn.common.network.server.TransportChannelHandler;
import org.apache.celeborn.common.network.util.*;
import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.common.util.Utils;
/**
* Factory for creating {@link TransportClient}s by using createClient.
@@ -68,6 +72,7 @@ public class TransportClientFactory implements Closeable {
private static final Logger logger =
LoggerFactory.getLogger(TransportClientFactory.class);
private final TransportContext context;
+ private final List<TransportClientBootstrap> clientBootstraps;
private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
/** Random number generator for picking connections between peers. */
@@ -84,9 +89,11 @@ public class TransportClientFactory implements Closeable {
private EventLoopGroup workerGroup;
protected ByteBufAllocator pooledAllocator;
- public TransportClientFactory(TransportContext context) {
+ public TransportClientFactory(
+ TransportContext context, List<TransportClientBootstrap>
clientBootstraps) {
this.context = Preconditions.checkNotNull(context);
TransportConf conf = context.getConf();
+ this.clientBootstraps =
Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
this.connectionPool = JavaUtils.newConcurrentHashMap();
this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
this.connectTimeoutMs = conf.connectTimeoutMs();
@@ -241,6 +248,7 @@ public class TransportClientFactory implements Closeable {
});
// Connect to the remote server
+ long preConnect = System.nanoTime();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.await(connectTimeoutMs)) {
throw new CelebornIOException(
@@ -250,10 +258,31 @@ public class TransportClientFactory implements Closeable {
}
TransportClient client = clientRef.get();
+ Channel channel = channelRef.get();
assert client != null : "Channel future completed successfully with null
client";
+ // Execute any client bootstraps synchronously before marking the Client
as successful.
+ long preBootstrap = System.nanoTime();
+ logger.debug("Running bootstraps for {} ...", address);
+ try {
+ for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
+ clientBootstrap.doBootstrap(client, channel);
+ }
+ } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap
may be written in Scala
+ long bootstrapTime = System.nanoTime() - preBootstrap;
+ logger.error(
+ "Exception while bootstrapping client after {}",
+ Utils.nanoDurationToString(bootstrapTime),
+ e);
+ client.close();
+ throw Throwables.propagate(e);
+ }
+ long postBootstrap = System.nanoTime();
logger.debug(
- "Connection from {} to {} successful",
client.getChannel().localAddress(), address);
+ "Successfully created connection to {} after {} ({} spent in
bootstraps)",
+ address,
+ Utils.nanoDurationToString(postBootstrap - preConnect),
+ Utils.nanoDurationToString(postBootstrap - preBootstrap));
return client;
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
index 73bb38bac..53f7b7169 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
@@ -19,8 +19,11 @@ package org.apache.celeborn.common.network.server;
import java.io.Closeable;
import java.net.InetSocketAddress;
+import java.util.List;
import java.util.concurrent.TimeUnit;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.ChannelFuture;
@@ -44,17 +47,25 @@ public class TransportServer implements Closeable {
private final TransportContext context;
private final TransportConf conf;
-
+ private final BaseMessageHandler appMessageHandler;
+ private final List<TransportServerBootstrap> bootstraps;
private ServerBootstrap bootstrap;
private ChannelFuture channelFuture;
private AbstractSource source;
private int port = -1;
public TransportServer(
- TransportContext context, String hostToBind, int portToBind,
AbstractSource source) {
+ TransportContext context,
+ String hostToBind,
+ int portToBind,
+ AbstractSource source,
+ BaseMessageHandler appMessageHandler,
+ List<TransportServerBootstrap> bootstraps) {
this.context = context;
this.conf = context.getConf();
this.source = source;
+ this.appMessageHandler = Preconditions.checkNotNull(appMessageHandler);
+ this.bootstraps =
Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
boolean shouldClose = true;
try {
@@ -124,7 +135,14 @@ public class TransportServer implements Closeable {
new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
- context.initializePipeline(ch);
+ BaseMessageHandler baseHandler = appMessageHandler;
+ logger.debug("number of bootstraps {}", bootstraps.size());
+ for (TransportServerBootstrap bootstrap : bootstraps) {
+ logger.debug(
+ "Adding bootstrap to TransportServer {}.",
bootstrap.getClass().getName());
+ baseHandler = bootstrap.doBootstrap(ch, baseHandler);
+ }
+ context.initializePipeline(ch, baseHandler);
}
});
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServerBootstrap.java
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServerBootstrap.java
new file mode 100644
index 000000000..cc367a05f
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServerBootstrap.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.server;
+
+import io.netty.channel.Channel;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a
client connects to the
+ * server. This allows customizing the client channel to allow for things such
as SASL
+ * authentication.
+ */
+public interface TransportServerBootstrap {
+ /**
+ * Customizes the channel to include new features, if needed.
+ *
+ * @param channel The connected channel opened by the client.
+ * @param baseMessageHandler The RPC handler for the server.
+ * @return The base message handler to use for the channel.
+ */
+ BaseMessageHandler doBootstrap(Channel channel, BaseMessageHandler
baseMessageHandler);
+}