This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new bc9aaaaf [#584] feat(netty): Add transport client pool for netty (#771)
bc9aaaaf is described below

commit bc9aaaafad4a2a614064559e6f46c14191817664
Author: xumanbu <[email protected]>
AuthorDate: Mon Apr 3 11:28:56 2023 +0800

    [#584] feat(netty): Add transport client pool for netty (#771)
    
    ### What changes were proposed in this pull request?
      1. add  for netty rpc client TransportClient
      2. TransportClientFactory for connection pool
      3. TransportContext contains the context to create a 
TransportClientFactory, setup Netty Channel pipelines with a 
TransportResponseHandler
      4. TransportConf for netty transport config create by RssConf
    
    ### Why are the changes needed?
    Fix: #584
    
    ### Does this PR introduce _any_ user-facing change?
    add client configurations and add the ability to reuse netty clients.
    Todo: update the user documentation after the netty feature is completed 
@xumanbu
    
    ### How was this patch tested?
    local test
    
    Co-authored-by: jam.xu <[email protected]>
---
 .../uniffle/common/config/RssClientConf.java       |  52 +++++
 .../uniffle/common/netty/MessageEncoder.java       |   5 +
 .../common/netty/client/RpcResponseCallback.java   |  36 +++
 .../common/netty/client/TransportClient.java       | 144 ++++++++++++
 .../netty/client/TransportClientFactory.java       | 250 +++++++++++++++++++++
 .../uniffle/common/netty/client/TransportConf.java |  64 ++++++
 .../common/netty/client/TransportContext.java      |  60 +++++
 .../netty/handle/TransportResponseHandler.java     |  72 ++++++
 .../common/netty/EncoderAndDecoderTest.java        |   4 +-
 .../netty/client/TransportClientFactoryTest.java   | 107 +++++++++
 .../netty/client/TransportClientTestBase.java      | 119 ++++++++++
 11 files changed, 911 insertions(+), 2 deletions(-)

diff --git 
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java 
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
index 119ab4e6..10f56bd2 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -19,6 +19,7 @@ package org.apache.uniffle.common.config;
 
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.netty.IOMode;
 
 import static org.apache.uniffle.common.compression.Codec.Type.LZ4;
 
@@ -43,4 +44,55 @@ public class RssClientConf {
       .defaultValue(ShuffleDataDistributionType.NORMAL)
       .withDescription("The type of partition shuffle data distribution, 
including normal and local_order. "
           + "The default value is normal. This config is only valid in 
Spark3.x");
+
+  public static final ConfigOption<Integer> NETTY_IO_CONNECT_TIMEOUT_MS = 
ConfigOptions
+      .key("rss.client.netty.io.connect.timeout.ms")
+      .intType()
+      .defaultValue(10 * 1000)
+      .withDescription("netty connect to server time out mills");
+
+  public static final ConfigOption<IOMode> NETTY_IO_MODE = ConfigOptions
+      .key("rss.client.netty.io.mode")
+      .enumType(IOMode.class)
+      .defaultValue(IOMode.NIO)
+      .withDescription("Netty EventLoopGroup backend, available options: NIO, 
EPOLL.");
+
+  public static final ConfigOption<Integer> NETTY_IO_CONNECTION_TIMEOUT_MS = 
ConfigOptions
+      .key("rss.client.netty.client.connection.timeout.ms")
+      .intType()
+      .defaultValue(10 * 60 * 1000)
+      .withDescription("connection active timeout");
+
+  public static final ConfigOption<Integer> NETTY_CLIENT_THREADS = 
ConfigOptions
+      .key("rss.client.netty.client.threads")
+      .intType()
+      .defaultValue(0)
+      .withDescription("Number of threads used in the client thread pool.");
+
+  public static final ConfigOption<Boolean> NETWORK_CLIENT_PREFER_DIRECT_BUFS 
= ConfigOptions
+      .key("rss.client.netty.client.prefer.direct.bufs")
+      .booleanType()
+      .defaultValue(true)
+      .withDescription("If true, we will prefer allocating off-heap byte 
buffers within Netty.");
+
+  public static final ConfigOption<Integer> 
NETTY_CLIENT_NUM_CONNECTIONS_PER_PEER = ConfigOptions
+      .key("rss.client.netty.client.connections.per.peer")
+      .intType()
+      .defaultValue(2)
+      .withDescription("Number of concurrent connections between two nodes.");
+
+  public static final ConfigOption<Integer> NETTY_CLIENT_RECEIVE_BUFFER = 
ConfigOptions
+      .key("rss.client.netty.client.receive.buffer")
+      .intType()
+      .defaultValue(0)
+      .withDescription("Receive buffer size (SO_RCVBUF). Note: the optimal 
size for receive buffer and send buffer "
+          + "should be latency * network_bandwidth. Assuming latency = 1ms, 
network_bandwidth = 10Gbps "
+          + "buffer size should be ~ 1.25MB.");
+
+  public static final ConfigOption<Integer> NETTY_CLIENT_SEND_BUFFER = 
ConfigOptions
+      .key("rss.client.netty.client.send.buffer")
+      .intType()
+      .defaultValue(0)
+      .withDescription("Send buffer size (SO_SNDBUF).");
+
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java 
b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
index 4167e53a..e3537ecd 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
@@ -39,6 +39,11 @@ public class MessageEncoder extends 
ChannelOutboundHandlerAdapter {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(MessageEncoder.class);
 
+  public static final MessageEncoder INSTANCE = new MessageEncoder();
+
+  private MessageEncoder() {
+  }
+
   @Override
   public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise 
promise) {
     // todo: support zero copy
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/client/RpcResponseCallback.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/client/RpcResponseCallback.java
new file mode 100644
index 00000000..6de925c0
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/client/RpcResponseCallback.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+
+public interface RpcResponseCallback {
+  /**
+   * Successful serialized result from server.
+   *
+   * <p>After `onSuccess` returns, `response` will be recycled and its content 
will become invalid.
+   * Please copy the content of `response` if you want to use it after 
`onSuccess` returns.
+   */
+  void onSuccess(RpcResponse rpcResponse);
+
+  /**
+   * Exception either propagated from server or raised on client side.
+   */
+  void onFailure(Throwable e);
+}
+
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
new file mode 100644
index 00000000..34ebb20a
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.SocketAddress;
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
+
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.GenericFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.handle.TransportResponseHandler;
+import org.apache.uniffle.common.netty.protocol.Message;
+import org.apache.uniffle.common.util.NettyUtils;
+
+
+public class TransportClient implements Closeable {
+  private static final Logger logger = 
LoggerFactory.getLogger(TransportClient.class);
+
+  private Channel channel;
+  private TransportResponseHandler handler;
+  private volatile boolean timedOut;
+
+  private static final AtomicLong counter = new AtomicLong();
+
+  public TransportClient(Channel channel, TransportResponseHandler handler) {
+    this.channel = Objects.requireNonNull(channel);
+    this.handler = Objects.requireNonNull(handler);
+    this.timedOut = false;
+  }
+
+  public Channel getChannel() {
+    return channel;
+  }
+
+  public boolean isActive() {
+    return !timedOut && (channel.isOpen() || channel.isActive());
+  }
+
+  public SocketAddress getSocketAddress() {
+    return channel.remoteAddress();
+  }
+
+  public ChannelFuture sendShuffleData(Message message, RpcResponseCallback 
callback) {
+    if (logger.isTraceEnabled()) {
+      logger.trace("Pushing data to {}", NettyUtils.getRemoteAddress(channel));
+    }
+    long requestId = requestId();
+    handler.addResponseCallback(requestId, callback);
+    RpcChannelListener listener = new RpcChannelListener(requestId, callback);
+    return channel.writeAndFlush(message).addListener(listener);
+  }
+
+  public static long requestId() {
+    return counter.getAndIncrement();
+  }
+
+  public class StdChannelListener implements GenericFutureListener<Future<? 
super Void>> {
+    final long startTime;
+    final Object requestId;
+
+    public StdChannelListener(Object requestId) {
+      this.startTime = System.currentTimeMillis();
+      this.requestId = requestId;
+    }
+
+    @Override
+    public void operationComplete(Future<? super Void> future) throws 
Exception {
+      if (future.isSuccess()) {
+        if (logger.isTraceEnabled()) {
+          long timeTaken = System.currentTimeMillis() - startTime;
+          logger.trace(
+              "Sending request {} to {} took {} ms",
+              requestId,
+              NettyUtils.getRemoteAddress(channel),
+              timeTaken);
+        }
+      } else {
+        String errorMsg =
+            String.format(
+                "Failed to send request %s to %s: %s, channel will be closed",
+                requestId, NettyUtils.getRemoteAddress(channel), 
future.cause());
+        logger.warn(errorMsg);
+        channel.close();
+        try {
+          handleFailure(errorMsg, future.cause());
+        } catch (Exception e) {
+          logger.error("Uncaught exception in RPC response callback handler!", 
e);
+        }
+      }
+    }
+
+    protected void handleFailure(String errorMsg, Throwable cause) {
+      logger.error("Error encountered " + errorMsg, cause);
+    }
+  }
+
+  private class RpcChannelListener extends StdChannelListener {
+    final long rpcRequestId;
+    final RpcResponseCallback callback;
+
+    RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) {
+      super("RPC " + rpcRequestId);
+      this.rpcRequestId = rpcRequestId;
+      this.callback = callback;
+    }
+
+    @Override
+    protected void handleFailure(String errorMsg, Throwable cause) {
+      handler.removeRpcRequest(rpcRequestId);
+      callback.onFailure(new IOException(errorMsg, cause));
+    }
+  }
+
+
+  @Override
+  public void close() throws IOException {
+    // close is a local operation and should finish with milliseconds; timeout 
just to be safe
+    channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+  }
+
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
new file mode 100644
index 00000000..c8056151
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.util.Objects;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicReference;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.IOMode;
+import org.apache.uniffle.common.netty.TransportFrameDecoder;
+import org.apache.uniffle.common.netty.handle.TransportResponseHandler;
+import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.common.util.NettyUtils;
+
+public class TransportClientFactory implements Closeable {
+
+  /**
+   * A simple data structure to track the pool of clients between two peer 
nodes.
+   */
+  private static class ClientPool {
+    TransportClient[] clients;
+    Object[] locks;
+
+    ClientPool(int size) {
+      clients = new TransportClient[size];
+      locks = new Object[size];
+      for (int i = 0; i < size; i++) {
+        locks[i] = new Object();
+      }
+    }
+  }
+
+  private static final Logger logger = 
LoggerFactory.getLogger(TransportClientFactory.class);
+
+  private final TransportContext context;
+  private final TransportConf conf;
+  private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
+
+  /**
+   * Random number generator for picking connections between peers.
+   */
+  private final Random rand;
+
+  private final int numConnectionsPerPeer;
+
+  private final Class<? extends Channel> socketChannelClass;
+  private EventLoopGroup workerGroup;
+  private PooledByteBufAllocator pooledAllocator;
+
+  public TransportClientFactory(TransportContext context) {
+    this.context = Objects.requireNonNull(context);
+    this.conf = context.getConf();
+    this.connectionPool = JavaUtils.newConcurrentMap();
+    this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
+    this.rand = new Random();
+
+    IOMode ioMode = conf.ioMode();
+    this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
+    this.workerGroup =
+        NettyUtils.createEventLoop(ioMode, conf.clientThreads(), 
"netty-rpc-client");
+    this.pooledAllocator =
+        NettyUtils.createPooledByteBufAllocator(
+            conf.preferDirectBufs(), false, conf.clientThreads());
+  }
+
+  public TransportClient createClient(String remoteHost, int remotePort, int 
partitionId)
+      throws IOException, InterruptedException {
+    return createClient(remoteHost, remotePort, partitionId, new 
TransportFrameDecoder());
+  }
+
+  public TransportClient createClient(
+      String remoteHost, int remotePort, int partitionId, 
ChannelInboundHandlerAdapter decoder)
+      throws IOException, InterruptedException {
+    // Get connection from the connection pool first.
+    // If it is not found or not active, create a new one.
+    // Use unresolved address here to avoid DNS resolution each time we 
creates a client.
+    final InetSocketAddress unresolvedAddress =
+        InetSocketAddress.createUnresolved(remoteHost, remotePort);
+
+    // Create the ClientPool if we don't have it yet.
+    ClientPool clientPool = connectionPool.computeIfAbsent(unresolvedAddress, x
+        -> new ClientPool(numConnectionsPerPeer));
+
+    int clientIndex =
+        partitionId < 0 ? rand.nextInt(numConnectionsPerPeer) : partitionId % 
numConnectionsPerPeer;
+    TransportClient cachedClient = clientPool.clients[clientIndex];
+
+    if (cachedClient != null && cachedClient.isActive()) {
+      // Make sure that the channel will not timeout by updating the last use 
time of the
+      // handler. Then check that the client is still alive, in case it timed 
out before
+      // this code was able to update things.
+      TransportResponseHandler handler =
+          
cachedClient.getChannel().pipeline().get(TransportResponseHandler.class);
+
+      if (cachedClient.isActive()) {
+        logger.trace(
+            "Returning cached connection to {}: {}", 
cachedClient.getSocketAddress(), cachedClient);
+        return cachedClient;
+      }
+    }
+
+    // If we reach here, we don't have an existing connection open. Let's 
create a new one.
+    // Multiple threads might race here to create new connections. Keep only 
one of them active.
+    final long preResolveHost = System.nanoTime();
+    final InetSocketAddress resolvedAddress = new 
InetSocketAddress(remoteHost, remotePort);
+    final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 
1000000;
+    if (hostResolveTimeMs > 2000) {
+      logger.warn("DNS resolution for {} took {} ms", resolvedAddress, 
hostResolveTimeMs);
+    } else {
+      logger.trace("DNS resolution for {} took {} ms", resolvedAddress, 
hostResolveTimeMs);
+    }
+
+    synchronized (clientPool.locks[clientIndex]) {
+      cachedClient = clientPool.clients[clientIndex];
+
+      if (cachedClient != null) {
+        if (cachedClient.isActive()) {
+          logger.trace("Returning cached connection to {}: {}", 
resolvedAddress, cachedClient);
+          return cachedClient;
+        } else {
+          logger.info("Found inactive connection to {}, creating a new one.", 
resolvedAddress);
+        }
+      }
+      clientPool.clients[clientIndex] = internalCreateClient(resolvedAddress, 
decoder);
+      return clientPool.clients[clientIndex];
+    }
+  }
+
+  public TransportClient createClient(String remoteHost, int remotePort)
+      throws IOException, InterruptedException {
+    return createClient(remoteHost, remotePort, -1);
+  }
+
+  /**
+   * Create a completely new {@link TransportClient} to the given remote host 
/ port. This
+   * connection is not pooled.
+   *
+   * <p>As with {@link #createClient(String, int)}, this method is blocking.
+   */
+  private TransportClient internalCreateClient(
+      InetSocketAddress address, ChannelInboundHandlerAdapter decoder)
+      throws IOException, InterruptedException {
+    Bootstrap bootstrap = new Bootstrap();
+    bootstrap
+        .group(workerGroup)
+        .channel(socketChannelClass)
+        // Disable Nagle's Algorithm since we don't want packets to wait
+        .option(ChannelOption.TCP_NODELAY, true)
+        .option(ChannelOption.SO_KEEPALIVE, true)
+        .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs())
+        .option(ChannelOption.ALLOCATOR, pooledAllocator);
+
+    if (conf.receiveBuf() > 0) {
+      bootstrap.option(ChannelOption.SO_RCVBUF, conf.receiveBuf());
+    }
+
+    if (conf.sendBuf() > 0) {
+      bootstrap.option(ChannelOption.SO_SNDBUF, conf.sendBuf());
+    }
+
+    final AtomicReference<TransportClient> clientRef = new AtomicReference<>();
+    final AtomicReference<Channel> channelRef = new AtomicReference<>();
+
+    bootstrap.handler(
+        new ChannelInitializer<SocketChannel>() {
+          @Override
+          public void initChannel(SocketChannel ch) {
+            TransportResponseHandler transportResponseHandler = 
context.initializePipeline(ch, decoder);
+            TransportClient client = new TransportClient(ch, 
transportResponseHandler);
+            clientRef.set(client);
+            channelRef.set(ch);
+          }
+        });
+
+    // Connect to the remote server
+    ChannelFuture cf = bootstrap.connect(address);
+    if (!cf.await(conf.connectTimeoutMs())) {
+      throw new IOException(
+          String.format("Connecting to %s timed out (%s ms)", address, 
conf.connectTimeoutMs()));
+    } else if (cf.cause() != null) {
+      throw new IOException(String.format("Failed to connect to %s", address), 
cf.cause());
+    }
+
+    TransportClient client = clientRef.get();
+    assert client != null : "Channel future completed successfully with null 
client";
+
+    logger.debug("Connection to {} successful", address);
+
+    return client;
+  }
+
+  /**
+   * Close all connections in the connection pool, and shutdown the worker 
thread pool.
+   */
+  @Override
+  public void close() {
+    // Go through all clients and close them if they are active.
+    for (ClientPool clientPool : connectionPool.values()) {
+      for (int i = 0; i < clientPool.clients.length; i++) {
+        TransportClient client = clientPool.clients[i];
+        if (client != null) {
+          clientPool.clients[i] = null;
+          JavaUtils.closeQuietly(client);
+        }
+      }
+    }
+    connectionPool.clear();
+
+    
+    if (workerGroup != null && !workerGroup.isShuttingDown()) {
+      workerGroup.shutdownGracefully();
+    }
+  }
+
+  public TransportContext getContext() {
+    return context;
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportConf.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportConf.java
new file mode 100644
index 00000000..a664cc19
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportConf.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.netty.IOMode;
+
+public class TransportConf {
+
+  private final RssConf rssConf;
+
+  public TransportConf(RssConf rssConf) {
+    this.rssConf = rssConf;
+  }
+
+  public IOMode ioMode() {
+    return rssConf.get(RssClientConf.NETTY_IO_MODE);
+  }
+
+  public int connectTimeoutMs() {
+    return rssConf.get(RssClientConf.NETTY_IO_CONNECT_TIMEOUT_MS);
+  }
+
+  public int connectionTimeoutMs() {
+    return rssConf.get(RssClientConf.NETTY_IO_CONNECTION_TIMEOUT_MS);
+  }
+
+  public int clientThreads() {
+    return rssConf.get(RssClientConf.NETTY_CLIENT_THREADS);
+  }
+
+  public int numConnectionsPerPeer() {
+    return rssConf.get(RssClientConf.NETTY_CLIENT_NUM_CONNECTIONS_PER_PEER);
+  }
+
+  public boolean preferDirectBufs() {
+    return rssConf.get(RssClientConf.NETWORK_CLIENT_PREFER_DIRECT_BUFS);
+  }
+
+  public int receiveBuf() {
+    return rssConf.get(RssClientConf.NETTY_CLIENT_RECEIVE_BUFFER);
+  }
+
+  public int sendBuf() {
+    return rssConf.get(RssClientConf.NETTY_CLIENT_SEND_BUFFER);
+  }
+
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
new file mode 100644
index 00000000..134b633a
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.timeout.IdleStateHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.MessageEncoder;
+import org.apache.uniffle.common.netty.handle.TransportResponseHandler;
+
+public class TransportContext {
+  private static final Logger logger = 
LoggerFactory.getLogger(TransportContext.class);
+
+  private TransportConf transportConf;
+
+  private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
+
+  public TransportContext(TransportConf transportConf) {
+    this.transportConf = transportConf;
+  }
+
+  public TransportClientFactory createClientFactory() {
+    return new TransportClientFactory(this);
+  }
+
+  public TransportResponseHandler initializePipeline(
+      SocketChannel channel, ChannelInboundHandlerAdapter decoder) {
+    TransportResponseHandler responseHandler = new 
TransportResponseHandler(channel);
+    channel
+        .pipeline()
+        .addLast("encoder", ENCODER) // out
+        .addLast("decoder", decoder) // in
+        .addLast(
+            "idleStateHandler", new IdleStateHandler(0, 0, 
transportConf.connectionTimeoutMs() / 1000))
+        .addLast("responseHandler", responseHandler);
+    return responseHandler;
+  }
+
+  public TransportConf getConf() {
+    return transportConf;
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
new file mode 100644
index 00000000..86dd9953
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.handle;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.netty.client.RpcResponseCallback;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+import org.apache.uniffle.common.util.NettyUtils;
+
+
+public class TransportResponseHandler extends ChannelInboundHandlerAdapter {
+  private static final Logger logger = 
LoggerFactory.getLogger(TransportResponseHandler.class);
+
+  private Map<Long, RpcResponseCallback> outstandingRpcRequests;
+  private Channel channel;
+
+  public TransportResponseHandler(Channel channel) {
+    this.channel = channel;
+    this.outstandingRpcRequests = new ConcurrentHashMap<>();
+  }
+
+  @Override
+  public void channelRead(ChannelHandlerContext ctx, Object msg) throws 
Exception {
+    if (msg instanceof RpcResponse) {
+      RpcResponse responseMessage = (RpcResponse) msg;
+      RpcResponseCallback listener = 
outstandingRpcRequests.get(responseMessage.getRequestId());
+      if (listener == null) {
+        logger.warn("Ignoring response from {} since it is not outstanding",
+            NettyUtils.getRemoteAddress(channel));
+      } else {
+        listener.onSuccess(responseMessage);
+      }
+    } else {
+      throw new RssException("receive unexpected message!");
+    }
+    super.channelRead(ctx, msg);
+  }
+
+  public void addResponseCallback(long requestId, RpcResponseCallback 
callback) {
+    outstandingRpcRequests.put(requestId, callback);
+  }
+
+  public void removeRpcRequest(long requestId) {
+    outstandingRpcRequests.remove(requestId);
+  }
+
+
+}
diff --git 
a/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
 
b/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
index 7adce841..63441b55 100644
--- 
a/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
+++ 
b/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
@@ -107,7 +107,7 @@ public class EncoderAndDecoderTest {
         new ChannelInitializer<SocketChannel>() {
           @Override
           public void initChannel(SocketChannel ch) {
-            ch.pipeline().addLast("ClientEncoder", new MessageEncoder())
+            ch.pipeline().addLast("ClientEncoder", MessageEncoder.INSTANCE)
                 .addLast("ClientDecoder", new TransportFrameDecoder())
                 .addLast("ClientResponseHandler", new MockResponseHandler());
             channelRef.set(ch);
@@ -152,7 +152,7 @@ public class EncoderAndDecoderTest {
     serverBootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
       @Override
       public void initChannel(final SocketChannel ch) {
-        ch.pipeline().addLast("ServerEncoder", new MessageEncoder())
+        ch.pipeline().addLast("ServerEncoder", MessageEncoder.INSTANCE)
             .addLast("ServerDecoder", new TransportFrameDecoder())
             .addLast("ServerResponseHandler", new MockResponseHandler());
       }
diff --git 
a/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientFactoryTest.java
 
b/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientFactoryTest.java
new file mode 100644
index 00000000..cdf8958c
--- /dev/null
+++ 
b/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientFactoryTest.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import java.io.IOException;
+
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.config.RssBaseConf;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class TransportClientFactoryTest extends TransportClientTestBase {
+
+  private static int SERVER_PORT_RANGE_START = 10000;
+  private static int SERVER_PORT_RANGE_END = 10005;
+
+  @BeforeAll
+  public static void setupServer() {
+    for (int i = SERVER_PORT_RANGE_START; i < SERVER_PORT_RANGE_END + 1; i++) {
+      mockServers.add(new MockServer(i));
+    }
+    startMockServer();
+  }
+
+  @Test
+  public void testCreateClient() throws IOException, InterruptedException {
+    RssBaseConf rssBaseConf = new RssBaseConf();
+    rssBaseConf.setInteger("rss.client.netty.client.connections.per.peer", 1);
+    TransportConf transportConf = new TransportConf(rssBaseConf);
+    TransportContext transportContext = new TransportContext(transportConf);
+    TransportClient transportClient1 = transportContext.createClientFactory()
+        .createClient("localhost", SERVER_PORT_RANGE_START, 1);
+    assertTrue(transportClient1.isActive());
+    transportClient1.close();
+
+    TransportClient transportClient2 = transportContext.createClientFactory()
+        .createClient("localhost", SERVER_PORT_RANGE_START, 1);
+    assertNotEquals(transportClient1, transportClient2);
+    assertTrue(transportClient2.isActive());
+  }
+
+  @Test
+  public void testClientReuse() throws IOException, InterruptedException {
+    RssBaseConf rssBaseConf = new RssBaseConf();
+    TransportConf transportConf = new TransportConf(rssBaseConf);
+    TransportContext transportContext = new TransportContext(transportConf);
+    TransportClientFactory transportClientFactory = 
transportContext.createClientFactory();
+    TransportClient client1 = transportClientFactory.createClient("localhost", 
SERVER_PORT_RANGE_START, 1);
+    TransportClient client2 = transportClientFactory.createClient("localhost", 
SERVER_PORT_RANGE_START, 1);
+    assertEquals(client1, client2);
+  }
+
+  @Test
+  public void testClientDiffPartition() throws IOException, 
InterruptedException {
+    RssBaseConf rssBaseConf = new RssBaseConf();
+    rssBaseConf.setInteger("rss.client.netty.client.connections.per.peer", 1);
+    TransportConf transportConf = new TransportConf(rssBaseConf);
+    TransportContext transportContext = new TransportContext(transportConf);
+    TransportClientFactory transportClientFactory = 
transportContext.createClientFactory();
+    TransportClient client1 = transportClientFactory.createClient("localhost", 
SERVER_PORT_RANGE_START, 1);
+    TransportClient client2 = transportClientFactory.createClient("localhost", 
SERVER_PORT_RANGE_START, 2);
+    assertEquals(client1, client2);
+    transportClientFactory.close();
+
+    rssBaseConf.setInteger("rss.client.netty.client.connections.per.peer", 10);
+    transportConf = new TransportConf(rssBaseConf);
+    transportContext = new TransportContext(transportConf);
+    transportClientFactory = transportContext.createClientFactory();
+    client1 = transportClientFactory.createClient("localhost", 
SERVER_PORT_RANGE_START, 1);
+    client2 = transportClientFactory.createClient("localhost", 
SERVER_PORT_RANGE_START, 2);
+    assertNotEquals(client1, client2);
+    transportClientFactory.close();
+  }
+
+  @Test
+  public void testClientDiffServer() throws IOException, InterruptedException {
+    RssBaseConf rssBaseConf = new RssBaseConf();
+    rssBaseConf.setInteger("rss.client.netty.client.connections.per.peer", 1);
+    TransportConf transportConf = new TransportConf(rssBaseConf);
+    TransportContext transportContext = new TransportContext(transportConf);
+    TransportClientFactory transportClientFactory = 
transportContext.createClientFactory();
+    TransportClient client1 = transportClientFactory.createClient("localhost", 
SERVER_PORT_RANGE_START, 1);
+    TransportClient client2 = transportClientFactory.createClient("localhost", 
SERVER_PORT_RANGE_START + 1, 1);
+    assertNotEquals(client1, client2);
+    transportClientFactory.close();
+  }
+
+}
diff --git 
a/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientTestBase.java
 
b/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientTestBase.java
new file mode 100644
index 00000000..dbe7fb41
--- /dev/null
+++ 
b/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientTestBase.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelPipeline;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import org.junit.jupiter.api.AfterAll;
+
+public abstract class TransportClientTestBase {
+
+  protected static List<MockServer> mockServers = Lists.newArrayList();
+
+  protected static void startMockServer() {
+    for (MockServer shuffleServer : mockServers) {
+      try {
+        shuffleServer.start();
+      } catch (IOException e) {
+        throw new RuntimeException(String.format("start mock server on port %s 
failed", shuffleServer.port), e);
+      }
+    }
+  }
+
+
+  @AfterAll
+  public static void shutdownServers() throws Exception {
+    for (MockServer shuffleServer : mockServers) {
+      shuffleServer.stop();
+    }
+    mockServers.clear();
+  }
+
+  public static class MockServer {
+    ServerBootstrap bootstrap;
+    ChannelFuture channelFuture;
+    private EventLoopGroup bossGroup;
+    private EventLoopGroup workerGroup;
+    int port;
+
+    public MockServer(int port) {
+      this.port = port;
+      this.bossGroup = new NioEventLoopGroup(1);
+      this.workerGroup = new NioEventLoopGroup(2);
+    }
+
+    public void start() throws IOException {
+
+      try {
+        bootstrap = new ServerBootstrap();
+        bootstrap.group(bossGroup, workerGroup)
+            .channel(NioServerSocketChannel.class)
+            .childHandler(new ChannelInitializer<SocketChannel>() {
+              @Override
+              public void initChannel(SocketChannel ch) throws Exception {
+                ChannelPipeline p = ch.pipeline();
+                p.addLast(new MockEchoServerHandler());
+              }
+            });
+        channelFuture = bootstrap.bind(port).sync();
+      } catch (InterruptedException e) {
+        stop();
+      }
+    }
+
+    public void stop() {
+      if (channelFuture != null) {
+        channelFuture.channel().close().awaitUninterruptibly(10L, 
TimeUnit.SECONDS);
+        channelFuture = null;
+      }
+      if (bossGroup != null) {
+        bossGroup.shutdownGracefully();
+        workerGroup.shutdownGracefully();
+        bossGroup = null;
+        workerGroup = null;
+      }
+    }
+  }
+
+  static class MockEchoServerHandler extends ChannelInboundHandlerAdapter {
+
+    @Override
+    public void channelRead(ChannelHandlerContext ctx, Object msg) {
+      ctx.writeAndFlush(msg);
+    }
+
+    @Override
+    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+      cause.printStackTrace();
+      ctx.close();
+    }
+  }
+}


Reply via email to