This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new feb3ed90c [CELEBORN-2262] Prepare S3 directory only once and cache s3 
client for MultiPartUploader
feb3ed90c is described below

commit feb3ed90c36d924dbfab8e2bec1972c4ef162486
Author: Enrico Olivelli <[email protected]>
AuthorDate: Wed Feb 25 09:46:54 2026 +0800

    [CELEBORN-2262] Prepare S3 directory only once and cache s3 client for 
MultiPartUploader
    
    ### What changes were proposed in this pull request?
    
    - Create only one S3 client for all MultiPartUploaders
    - Create S3 worker directory only once and not for every slot
    
    ### Why are the changes needed?
    - Because on S3 AWS creating connections is slow (due to credentials 
handshaking and TLS handshaking)
    - Because "mkdirs" in S3 AWS is very slow (and it needs several S3 calls)
    
    Sample CPU flamegraph about the need of Connection pooling:
    <img width="2248" height="1275" alt="image" 
src="https://github.com/user-attachments/assets/5fb46d8f-5a1e-41a0-a8ca-01c92a2a3eb0";
 />
    
    Sample CPU flamegraph about the need of pooling the client due to 
AssumeRoleWithWebIdentity
    <img width="2248" height="1275" alt="image" 
src="https://github.com/user-attachments/assets/e9efbadd-ef68-40d3-8fb5-d8fe43f56752";
 />
    
    ### Does this PR resolve a correctness bug?
    
    No
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Manual testing.
    
    There is one end-to-end integration test with S3 that exercise this code
    
    Closes #3604 from eolivelli/improve-s3-apache.
    
    Authored-by: Enrico Olivelli <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../apache/celeborn/S3MultipartUploadHandler.java  | 162 +++++++++++----------
 tests/spark-it/src/test/resources/log4j2-test.xml  |   2 +-
 .../spark/s3/BasicEndToEndTieredStorageTest.scala  |   3 +
 .../deploy/worker/storage/TierWriterHelper.java    |  26 ++--
 .../service/deploy/worker/Controller.scala         |   5 +
 .../deploy/worker/storage/StorageManager.scala     |  49 ++++++-
 .../service/deploy/worker/storage/TierWriter.scala |   8 +-
 7 files changed, 157 insertions(+), 98 deletions(-)

diff --git 
a/multipart-uploader/multipart-uploader-s3/src/main/java/org/apache/celeborn/S3MultipartUploadHandler.java
 
b/multipart-uploader/multipart-uploader-s3/src/main/java/org/apache/celeborn/S3MultipartUploadHandler.java
index 555f0695c..cf3a0adee 100644
--- 
a/multipart-uploader/multipart-uploader-s3/src/main/java/org/apache/celeborn/S3MultipartUploadHandler.java
+++ 
b/multipart-uploader/multipart-uploader-s3/src/main/java/org/apache/celeborn/S3MultipartUploadHandler.java
@@ -34,6 +34,7 @@ import com.amazonaws.retry.RetryPolicy;
 import com.amazonaws.services.s3.AmazonS3;
 import com.amazonaws.services.s3.AmazonS3ClientBuilder;
 import com.amazonaws.services.s3.model.*;
+import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.s3a.AWSCredentialProviderList;
@@ -49,69 +50,81 @@ public class S3MultipartUploadHandler implements 
MultipartUploadHandler {
   private static final Logger logger = 
LoggerFactory.getLogger(S3MultipartUploadHandler.class);
 
   private String uploadId;
+  private final String key;
 
-  private final AmazonS3 s3Client;
+  private final S3MultipartUploadHandlerSharedState sharedState;
 
-  private final String key;
+  public static class S3MultipartUploadHandlerSharedState implements 
AutoCloseable {
+
+    private final AmazonS3 s3Client;
+    private final String bucketName;
+    private final int s3MultiplePartUploadMaxRetries;
+    private final int baseDelay;
+    private final int maxBackoff;
+
+    public S3MultipartUploadHandlerSharedState(
+        FileSystem hadoopFs,
+        String bucketName,
+        Integer s3MultiplePartUploadMaxRetries,
+        Integer baseDelay,
+        Integer maxBackoff)
+        throws IOException, URISyntaxException {
+      this.bucketName = bucketName;
+      this.s3MultiplePartUploadMaxRetries = s3MultiplePartUploadMaxRetries;
+      this.baseDelay = baseDelay;
+      this.maxBackoff = maxBackoff;
+      Configuration conf = hadoopFs.getConf();
+      URI binding = new URI(String.format("s3a://%s", bucketName));
 
-  private final String bucketName;
-
-  private final Integer s3MultiplePartUploadMaxRetries;
-  private final Integer baseDelay;
-  private final Integer maxBackoff;
-
-  public S3MultipartUploadHandler(
-      FileSystem hadoopFs,
-      String bucketName,
-      String key,
-      Integer s3MultiplePartUploadMaxRetries,
-      Integer baseDelay,
-      Integer maxBackoff)
-      throws IOException, URISyntaxException {
-    this.bucketName = bucketName;
-    this.s3MultiplePartUploadMaxRetries = s3MultiplePartUploadMaxRetries;
-    this.baseDelay = baseDelay;
-    this.maxBackoff = maxBackoff;
-
-    Configuration conf = hadoopFs.getConf();
-    URI binding = new URI(String.format("s3a://%s", bucketName));
-
-    RetryPolicy retryPolicy =
-        new RetryPolicy(
-            PredefinedRetryPolicies.DEFAULT_RETRY_CONDITION,
-            new PredefinedBackoffStrategies.SDKDefaultBackoffStrategy(
-                baseDelay, baseDelay, maxBackoff),
-            s3MultiplePartUploadMaxRetries,
-            false);
-    ClientConfiguration clientConfig =
-        new ClientConfiguration()
-            .withRetryPolicy(retryPolicy)
-            .withMaxErrorRetry(s3MultiplePartUploadMaxRetries);
-    AmazonS3ClientBuilder builder =
-        AmazonS3ClientBuilder.standard()
-            .withCredentials(getCredentialsProvider(binding, conf))
-            .withClientConfiguration(clientConfig);
-    // for MinIO
-    String endpoint = conf.get("fs.s3a.endpoint");
-    if (endpoint != null && !endpoint.isEmpty()) {
-      builder =
-          builder
-              .withEndpointConfiguration(
-                  new AwsClientBuilder.EndpointConfiguration(
-                      endpoint, conf.get(Constants.AWS_REGION)))
-              
.withPathStyleAccessEnabled(conf.getBoolean("fs.s3a.path.style.access", false));
-    } else {
-      builder = builder.withRegion(conf.get(Constants.AWS_REGION));
+      RetryPolicy retryPolicy =
+          new RetryPolicy(
+              PredefinedRetryPolicies.DEFAULT_RETRY_CONDITION,
+              new PredefinedBackoffStrategies.SDKDefaultBackoffStrategy(
+                  baseDelay, baseDelay, maxBackoff),
+              s3MultiplePartUploadMaxRetries,
+              false);
+      ClientConfiguration clientConfig =
+          new ClientConfiguration()
+              .withRetryPolicy(retryPolicy)
+              .withMaxErrorRetry(s3MultiplePartUploadMaxRetries);
+      AmazonS3ClientBuilder builder =
+          AmazonS3ClientBuilder.standard()
+              .withCredentials(getCredentialsProvider(binding, conf))
+              .withClientConfiguration(clientConfig);
+      // for MinIO
+      String endpoint = conf.get(Constants.ENDPOINT);
+      if (!StringUtils.isEmpty(endpoint)) {
+        builder =
+            builder
+                .withEndpointConfiguration(
+                    new AwsClientBuilder.EndpointConfiguration(
+                        endpoint, conf.get(Constants.AWS_REGION)))
+                
.withPathStyleAccessEnabled(conf.getBoolean(Constants.PATH_STYLE_ACCESS, 
false));
+      } else {
+        builder = builder.withRegion(conf.get(Constants.AWS_REGION));
+      }
+      this.s3Client = builder.build();
     }
-    this.s3Client = builder.build();
+
+    @Override
+    public void close() {
+      if (s3Client != null) {
+        s3Client.shutdown();
+      }
+    }
+  }
+
+  public S3MultipartUploadHandler(AutoCloseable sharedState, String key) {
+    this.sharedState = (S3MultipartUploadHandlerSharedState) sharedState;
     this.key = key;
   }
 
   @Override
   public void startUpload() {
     InitiateMultipartUploadRequest initRequest =
-        new InitiateMultipartUploadRequest(bucketName, key);
-    InitiateMultipartUploadResult initResponse = 
s3Client.initiateMultipartUpload(initRequest);
+        new InitiateMultipartUploadRequest(sharedState.bucketName, key);
+    InitiateMultipartUploadResult initResponse =
+        sharedState.s3Client.initiateMultipartUpload(initRequest);
     this.uploadId = initResponse.getUploadId();
   }
 
@@ -131,14 +144,14 @@ public class S3MultipartUploadHandler implements 
MultipartUploadHandler {
       }
       UploadPartRequest uploadRequest =
           new UploadPartRequest()
-              .withBucketName(bucketName)
+              .withBucketName(sharedState.bucketName)
               .withKey(key)
               .withUploadId(uploadId)
               .withPartNumber(partNumber)
               .withInputStream(inStream)
               .withPartSize(partSize)
               .withLastPart(finalFlush);
-      s3Client.uploadPart(uploadRequest);
+      sharedState.s3Client.uploadPart(uploadRequest);
       logger.debug(
           "key {} uploadId {} part number {} uploaded with size {} finalFlush 
{}",
           key,
@@ -155,10 +168,10 @@ public class S3MultipartUploadHandler implements 
MultipartUploadHandler {
   @Override
   public void complete() {
     List<PartETag> partETags = new ArrayList<>();
-    ListPartsRequest listPartsRequest = new ListPartsRequest(bucketName, key, 
uploadId);
+    ListPartsRequest listPartsRequest = new 
ListPartsRequest(sharedState.bucketName, key, uploadId);
     PartListing partListing;
     do {
-      partListing = s3Client.listParts(listPartsRequest);
+      partListing = sharedState.s3Client.listParts(listPartsRequest);
       for (PartSummary part : partListing.getParts()) {
         partETags.add(new PartETag(part.getPartNumber(), part.getETag()));
       }
@@ -167,11 +180,12 @@ public class S3MultipartUploadHandler implements 
MultipartUploadHandler {
     if (partETags.size() == 0) {
       logger.debug(
           "bucket {} key {} uploadId {} has no parts uploaded, aborting 
upload",
-          bucketName,
+          sharedState.bucketName,
           key,
           uploadId);
       abort();
-      logger.debug("bucket {} key {} upload completed with size {}", 
bucketName, key, 0);
+      logger.debug(
+          "bucket {} key {} upload completed with size {}", 
sharedState.bucketName, key, 0);
       return;
     }
     ProgressListener progressListener =
@@ -185,34 +199,36 @@ public class S3MultipartUploadHandler implements 
MultipartUploadHandler {
         };
 
     CompleteMultipartUploadRequest compRequest =
-        new CompleteMultipartUploadRequest(bucketName, key, uploadId, 
partETags)
+        new CompleteMultipartUploadRequest(sharedState.bucketName, key, 
uploadId, partETags)
             .withGeneralProgressListener(progressListener);
     CompleteMultipartUploadResult compResult = null;
-    for (int attempt = 1; attempt <= this.s3MultiplePartUploadMaxRetries; 
attempt++) {
+    for (int attempt = 1; attempt <= 
sharedState.s3MultiplePartUploadMaxRetries; attempt++) {
       try {
-        compResult = s3Client.completeMultipartUpload(compRequest);
+        compResult = sharedState.s3Client.completeMultipartUpload(compRequest);
         break;
       } catch (AmazonClientException e) {
-        if (attempt == this.s3MultiplePartUploadMaxRetries
+        if (attempt == sharedState.s3MultiplePartUploadMaxRetries
             || 
!PredefinedRetryPolicies.DEFAULT_RETRY_CONDITION.shouldRetry(null, e, attempt)) 
{
           logger.error(
               "bucket {} key {} uploadId {} upload failed to complete, will 
not retry",
-              bucketName,
+              sharedState.bucketName,
               key,
               uploadId,
               e);
           throw e;
         }
 
-        long backoffTime = Math.min(maxBackoff, baseDelay * (long) Math.pow(2, 
attempt - 1));
+        long backoffTime =
+            Math.min(
+                sharedState.maxBackoff, sharedState.baseDelay * (long) 
Math.pow(2, attempt - 1));
         try {
           logger.warn(
               "bucket {} key {} uploadId {} upload failed to complete, will 
retry ({}/{})",
-              bucketName,
+              sharedState.bucketName,
               key,
               uploadId,
               attempt,
-              this.s3MultiplePartUploadMaxRetries,
+              sharedState.s3MultiplePartUploadMaxRetries,
               e);
           Thread.sleep(backoffTime);
         } catch (InterruptedException ex) {
@@ -222,7 +238,7 @@ public class S3MultipartUploadHandler implements 
MultipartUploadHandler {
     }
     logger.debug(
         "bucket {} key {} uploadId {} upload completed location is in {} ",
-        bucketName,
+        sharedState.bucketName,
         key,
         uploadId,
         compResult.getLocation());
@@ -231,16 +247,12 @@ public class S3MultipartUploadHandler implements 
MultipartUploadHandler {
   @Override
   public void abort() {
     AbortMultipartUploadRequest abortMultipartUploadRequest =
-        new AbortMultipartUploadRequest(bucketName, key, uploadId);
-    s3Client.abortMultipartUpload(abortMultipartUploadRequest);
+        new AbortMultipartUploadRequest(sharedState.bucketName, key, uploadId);
+    sharedState.s3Client.abortMultipartUpload(abortMultipartUploadRequest);
   }
 
   @Override
-  public void close() {
-    if (s3Client != null) {
-      s3Client.shutdown();
-    }
-  }
+  public void close() {}
 
   static AWSCredentialProviderList getCredentialsProvider(URI binding, 
Configuration conf)
       throws IOException {
diff --git a/tests/spark-it/src/test/resources/log4j2-test.xml 
b/tests/spark-it/src/test/resources/log4j2-test.xml
index 9adcdccfd..0b21a8eb3 100644
--- a/tests/spark-it/src/test/resources/log4j2-test.xml
+++ b/tests/spark-it/src/test/resources/log4j2-test.xml
@@ -21,7 +21,7 @@
         <Console name="stdout" target="SYSTEM_OUT">
             <PatternLayout pattern="%d{yy/MM/dd HH:mm:ss,SSS} %p [%t] %c{1}: 
%m%n%ex"/>
             <Filters>
-                <ThresholdFilter level="ERROR"/>
+                <ThresholdFilter level="INFO"/>
             </Filters>
         </Console>
         <File name="file" fileName="target/unit-tests.log">
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/s3/BasicEndToEndTieredStorageTest.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/s3/BasicEndToEndTieredStorageTest.scala
index f65b00c4f..0b2ea3646 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/s3/BasicEndToEndTieredStorageTest.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/s3/BasicEndToEndTieredStorageTest.scala
@@ -125,6 +125,9 @@ class BasicEndToEndTieredStorageTest extends AnyFunSuite
     val celebornSparkSession = SparkSession.builder()
       .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
       .getOrCreate()
+
+    // execute multiple operations that reserve slots
+    repartition(celebornSparkSession)
     groupBy(celebornSparkSession)
 
     celebornSparkSession.stop()
diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/TierWriterHelper.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/TierWriterHelper.java
index 8c0103912..33c18aab4 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/TierWriterHelper.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/TierWriterHelper.java
@@ -23,25 +23,29 @@ import org.apache.celeborn.reflect.DynConstructors;
 import org.apache.celeborn.server.common.service.mpu.MultipartUploadHandler;
 
 public class TierWriterHelper {
-  public static MultipartUploadHandler getS3MultipartUploadHandler(
-      FileSystem hadoopFs,
-      String bucketName,
-      String key,
-      int maxRetryies,
-      int baseDelay,
-      int maxBackoff) {
-    return (MultipartUploadHandler)
+
+  public static AutoCloseable getS3MultipartUploadHandlerSharedState(
+      FileSystem hadoopFs, String bucketName, int maxRetryies, int baseDelay, 
int maxBackoff) {
+    return (AutoCloseable)
         DynConstructors.builder()
             .impl(
-                "org.apache.celeborn.S3MultipartUploadHandler",
+                
"org.apache.celeborn.S3MultipartUploadHandler$S3MultipartUploadHandlerSharedState",
                 FileSystem.class,
                 String.class,
-                String.class,
                 Integer.class,
                 Integer.class,
                 Integer.class)
             .build()
-            .newInstance(hadoopFs, bucketName, key, maxRetryies, baseDelay, 
maxBackoff);
+            .newInstance(hadoopFs, bucketName, maxRetryies, baseDelay, 
maxBackoff);
+  }
+
+  public static MultipartUploadHandler getS3MultipartUploadHandler(
+      AutoCloseable sharedState, String key) {
+    return (MultipartUploadHandler)
+        DynConstructors.builder()
+            .impl("org.apache.celeborn.S3MultipartUploadHandler", 
AutoCloseable.class, String.class)
+            .build()
+            .newInstance(sharedState, key);
   }
 
   public static MultipartUploadHandler getOssMultipartUploadHandler(
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
index 97ce0e2c8..ee959e4d6 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
@@ -196,6 +196,11 @@ private[deploy] class Controller(
       context.reply(ReserveSlotsResponse(StatusCode.NO_AVAILABLE_WORKING_DIR, 
msg))
       return
     }
+
+    // do this once, and not for each location
+    if (conf.hasS3Storage)
+      storageManager.ensureS3DirectoryForShuffleKey(applicationId, shuffleId)
+
     val primaryLocs = createWriters(
       shuffleKey,
       applicationId,
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala
index 17b725b2d..58cd4c8ab 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala
@@ -79,6 +79,7 @@ final private[worker] class StorageManager(conf: 
CelebornConf, workerSource: Abs
 
   val diskReserveSize = conf.workerDiskReserveSize
   val diskReserveRatio = conf.workerDiskReserveRatio
+  var s3MultipartUploadHandlerSharedState: AutoCloseable = _
 
   // (deviceName -> deviceInfo) and (mount point -> diskInfo)
   val (deviceInfos, diskInfos) = {
@@ -436,6 +437,45 @@ final private[worker] class StorageManager(conf: 
CelebornConf, workerSource: Abs
       isSegmentGranularityVisible = false)
   }
 
+  def ensureS3MultipartUploaderSharedState(): Unit = this.synchronized {
+    if (s3MultipartUploadHandlerSharedState != null)
+      return
+
+    val s3HadoopFs: FileSystem = hadoopFs.get(StorageInfo.Type.S3)
+    if (s3HadoopFs == null)
+      throw new IllegalStateException("S3 is not configured")
+
+    val uri = s3HadoopFs.getUri
+    val bucketName = uri.getHost
+    logInfo(s"Creating S3 client for $uri, bucketName is $bucketName")
+    s3MultipartUploadHandlerSharedState = 
TierWriterHelper.getS3MultipartUploadHandlerSharedState(
+      s3HadoopFs,
+      bucketName,
+      conf.s3MultiplePartUploadMaxRetries,
+      conf.s3MultiplePartUploadBaseDelay,
+      conf.s3MultiplePartUploadMaxBackoff)
+  }
+
+  /**
+   * Ensure that the directory for the Shuffle exists.
+   * This method is not synchronized because from the protocol it is not 
expected
+   * to have more than one reserveSlots running for a given shuffleId.
+   * Also running the method concurrently should not make any harm, the worst 
case
+   * is that we send a little bit more requests to S3
+   */
+  def ensureS3DirectoryForShuffleKey(appId: String, shuffleId: Int): Unit = {
+
+    ensureS3MultipartUploaderSharedState()
+
+    val shuffleDir =
+      new Path(new Path(s3Dir, conf.workerWorkingDir), s"$appId/$shuffleId")
+    logDebug(s"Creating S3 directory at $shuffleDir");
+    FileSystem.mkdirs(
+      StorageManager.hadoopFs.get(StorageInfo.Type.S3),
+      shuffleDir,
+      hdfsPermission)
+  }
+
   @throws[IOException]
   def createPartitionDataWriter(
       appId: String,
@@ -850,6 +890,9 @@ final private[worker] class StorageManager(conf: 
CelebornConf, workerSource: Abs
     if (null != deviceMonitor) {
       deviceMonitor.close()
     }
+
+    if (s3MultipartUploadHandlerSharedState != null)
+      s3MultipartUploadHandlerSharedState.close()
   }
 
   private def flushFileWriters(): Unit = {
@@ -1125,11 +1168,7 @@ final private[worker] class StorageManager(conf: 
CelebornConf, workerSource: Abs
       } else if (storageType == Type.S3 && 
location.getStorageInfo.S3Available()) {
         val shuffleDir =
           new Path(new Path(s3Dir, conf.workerWorkingDir), 
s"$appId/$shuffleId")
-        logDebug(s"trying to create S3 file at $shuffleDir");
-        FileSystem.mkdirs(
-          StorageManager.hadoopFs.get(StorageInfo.Type.S3),
-          shuffleDir,
-          hdfsPermission)
+        // directory has been prepared by ensureS3DirectoryForShuffleKey
         val s3FilePath = new Path(shuffleDir, fileName).toString
         val s3FileInfo = new DiskFileInfo(
           userIdentifier,
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala
index 90924d0ad..526fb5d92 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala
@@ -560,12 +560,8 @@ class DfsTierWriter(
       val key = dfsFileInfo.getFilePath.substring(index + bucketName.length + 
1)
 
       this.s3MultipartUploadHandler = 
TierWriterHelper.getS3MultipartUploadHandler(
-        hadoopFs,
-        bucketName,
-        key,
-        conf.s3MultiplePartUploadMaxRetries,
-        conf.s3MultiplePartUploadBaseDelay,
-        conf.s3MultiplePartUploadMaxBackoff)
+        storageManager.s3MultipartUploadHandlerSharedState,
+        key)
       s3MultipartUploadHandler.startUpload()
     } else if (dfsFileInfo.isOSS) {
       val configuration = hadoopFs.getConf

Reply via email to