This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new e10aefd04 [CELEBORN-1857] Support LocalPartitionReader read partition 
by chunkOffsets when enable optimize skew partition read
e10aefd04 is described below

commit e10aefd0465b0885e4356213721db103f49c3f0f
Author: wangshengjie3 <[email protected]>
AuthorDate: Wed Feb 26 23:11:35 2025 +0800

    [CELEBORN-1857] Support LocalPartitionReader read partition by chunkOffsets 
when enable optimize skew partition read
    
    ### What changes were proposed in this pull request?
    Support LocalPartitionReader read partition by chunkOffsets when enable 
optimize skew partition read
    
    ### Why are the changes needed?
    In [CELEBORN-1319](https://issues.apache.org/jira/browse/CELEBORN-1319), we 
have already implemented the skew partition read optimization based on chunk 
offsets, but we haven't implemented the local partition read. This pull request 
aims to implement the local partition read.
    
    ### Does this PR introduce _any_ user-facing change?
    When `celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled` set to 
true, when read local skew partition files, will  read data by chunk offsets.
    
    ### How was this patch tested?
    Current uts and cluster test.
    
    Closes #3111 from wangshengjie123/support-local-partition-reader.
    
    Authored-by: wangshengjie3 <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../celeborn/client/read/CelebornInputStream.java  |   4 +-
 .../celeborn/client/read/LocalPartitionReader.java |  29 +++-
 .../cluster/LocalReadByChunkOffsetsTest.scala      | 193 +++++++++++++++++++++
 3 files changed, 216 insertions(+), 10 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java 
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
index ea3df88d9..ab2afcf73 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
@@ -563,7 +563,9 @@ public abstract class CelebornInputStream extends 
InputStream {
                 clientFactory,
                 startMapIndex,
                 endMapIndex,
-                callback);
+                callback,
+                startChunkIndex,
+                endChunkIndex);
           } else {
             return new WorkerPartitionReader(
                 conf,
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
 
b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
index 722bab100..1de83eb45 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
@@ -57,7 +57,6 @@ public class LocalPartitionReader implements PartitionReader {
   private final int fetchMaxReqsInFlight;
   private final PartitionLocation location;
   private volatile boolean closed = false;
-  private final int numChunks;
   private int returnedChunks = 0;
   private int chunkIndex = 0;
   private String fullPath;
@@ -68,6 +67,8 @@ public class LocalPartitionReader implements PartitionReader {
   private PbStreamHandler streamHandler;
   private TransportClient client;
   private MetricsCallback metricsCallback;
+  private int startChunkIndex;
+  private int endChunkIndex;
 
   @SuppressWarnings("StaticAssignmentInConstructor")
   public LocalPartitionReader(
@@ -78,7 +79,9 @@ public class LocalPartitionReader implements PartitionReader {
       TransportClientFactory clientFactory,
       int startMapIndex,
       int endMapIndex,
-      MetricsCallback metricsCallback)
+      MetricsCallback metricsCallback,
+      int startChunkIndex,
+      int endChunkIndex)
       throws IOException {
     if (readLocalShufflePool == null) {
       synchronized (LocalPartitionReader.class) {
@@ -113,6 +116,12 @@ public class LocalPartitionReader implements 
PartitionReader {
       } else {
         this.streamHandler = pbStreamHandler;
       }
+      this.startChunkIndex = startChunkIndex == -1 ? 0 : startChunkIndex;
+      this.endChunkIndex =
+          endChunkIndex == -1
+              ? streamHandler.getNumChunks() - 1
+              : Math.min(streamHandler.getNumChunks() - 1, endChunkIndex);
+      this.chunkIndex = this.startChunkIndex;
     } catch (IOException | InterruptedException e) {
       throw new IOException(
           "Read shuffle file from local file failed, partition location: "
@@ -123,7 +132,6 @@ public class LocalPartitionReader implements 
PartitionReader {
     }
 
     chunkOffsets = new ArrayList<>(streamHandler.getChunkOffsetsList());
-    numChunks = streamHandler.getNumChunks();
     fullPath = streamHandler.getFullPath();
     mapRangeRead = endMapIndex != Integer.MAX_VALUE;
 
@@ -140,7 +148,7 @@ public class LocalPartitionReader implements 
PartitionReader {
       if (shuffleChannel == null) {
         shuffleChannel = FileChannelUtils.openReadableFileChannel(fullPath);
         if (mapRangeRead) {
-          shuffleChannel.position(chunkOffsets.get(0));
+          shuffleChannel.position(chunkOffsets.get(chunkIndex));
         }
       }
       for (int i = 0; i < toFetch; i++) {
@@ -179,9 +187,9 @@ public class LocalPartitionReader implements 
PartitionReader {
   }
 
   private void fetchChunks() {
-    int inFlight = chunkIndex - returnedChunks;
+    int inFlight = chunkIndex - startChunkIndex - returnedChunks;
     if (inFlight < fetchMaxReqsInFlight) {
-      int toFetch = Math.min(fetchMaxReqsInFlight - inFlight + 1, numChunks - 
chunkIndex);
+      int toFetch = Math.min(fetchMaxReqsInFlight - inFlight + 1, 
endChunkIndex - chunkIndex + 1);
       if (pendingFetchTask.compareAndSet(false, true)) {
         logger.debug(
             "Trigger local reader fetch chunk with {} and fetch {} chunks", 
chunkIndex, toFetch);
@@ -194,14 +202,17 @@ public class LocalPartitionReader implements 
PartitionReader {
 
   @Override
   public boolean hasNext() {
-    logger.debug("Check has next current index: {} chunks {}", returnedChunks, 
numChunks);
-    return returnedChunks < numChunks;
+    logger.debug(
+        "Check has next current index: {} chunks {}",
+        returnedChunks,
+        endChunkIndex - startChunkIndex + 1);
+    return returnedChunks < endChunkIndex - startChunkIndex + 1;
   }
 
   @Override
   public ByteBuf next() throws IOException, InterruptedException {
     checkException();
-    if (chunkIndex < numChunks) {
+    if (chunkIndex <= endChunkIndex) {
       fetchChunks();
     }
     ByteBuf chunk = null;
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala
new file mode 100644
index 000000000..33dcc12da
--- /dev/null
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.cluster
+
+import java.io.ByteArrayOutputStream
+import java.nio.charset.StandardCharsets
+import java.util.{Collections, HashMap => JHashMap}
+
+import scala.collection.mutable
+
+import org.apache.commons.lang3.RandomStringUtils
+import org.apache.commons.lang3.tuple.Pair
+import org.junit.Assert
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
+import org.apache.celeborn.client.read.MetricsCallback
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.identity.UserIdentifier
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.protocol.CompressionCodec
+import org.apache.celeborn.service.deploy.MiniClusterFeature
+
+class LocalReadByChunkOffsetsTest extends AnyFunSuite
+  with Logging with MiniClusterFeature with BeforeAndAfterAll {
+  var masterPort = 0
+
+  override def beforeAll(): Unit = {
+    logInfo("test initialized , setup Celeborn mini cluster")
+    val workerConfig = Map(
+      CelebornConf.SHUFFLE_CHUNK_SIZE.key -> "8k",
+      CelebornConf.WORKER_FLUSHER_BUFFER_SIZE.key -> "8k")
+    val (m, _) = setupMiniClusterWithRandomPorts(workerConf = workerConfig)
+    masterPort = m.conf.masterPort
+  }
+
+  override def afterAll(): Unit = {
+    logInfo("all test complete , stop Celeborn mini cluster")
+    shutdownMiniCluster()
+  }
+
+  test("CELEBORN-1857: test LocalPartitionReader read partition by 
chunkOffsets when enable optimize skew partition read") {
+    val APP = "CELEBORN-1857"
+
+    val clientConf = new CelebornConf()
+      .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$masterPort")
+      .set(CelebornConf.SHUFFLE_COMPRESSION_CODEC.key, 
CompressionCodec.NONE.name)
+      .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED, false)
+      .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "1k")
+      
.set(CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED, true)
+      .set(CelebornConf.READ_LOCAL_SHUFFLE_FILE, true)
+      .set("celeborn.data.io.numConnectionsPerPeer", "1")
+    val lifecycleManager = new LifecycleManager(APP, clientConf)
+    val shuffleClient = new ShuffleClientImpl(APP, clientConf, 
UserIdentifier("mock", "mock"))
+    shuffleClient.setupLifecycleManagerRef(lifecycleManager.self)
+
+    val dataPrefix = Array("000000", "111111", "222222", "333333", "444444", 
"555555")
+    val dataPrefixMap = new mutable.HashMap[String, String]
+    val STR1 = dataPrefix(0) + RandomStringUtils.random(4 * 1024)
+    dataPrefixMap.put(dataPrefix(0), STR1)
+    val DATA1 = STR1.getBytes(StandardCharsets.UTF_8)
+    val OFFSET1 = 0
+    val LENGTH1 = DATA1.length
+    shuffleClient.pushData(1, 0, 0, 0, DATA1, OFFSET1, LENGTH1, 1, 1)
+
+    val STR2 = dataPrefix(1) + RandomStringUtils.random(3 * 1024)
+    dataPrefixMap.put(dataPrefix(1), STR2)
+    val DATA2 = STR2.getBytes(StandardCharsets.UTF_8)
+    val OFFSET2 = 0
+    val LENGTH2 = DATA2.length
+    shuffleClient.pushData(1, 0, 0, 0, DATA2, OFFSET2, LENGTH2, 1, 1)
+    Thread.sleep(1000)
+
+    val STR3 = dataPrefix(2) + RandomStringUtils.random(4 * 1024)
+    dataPrefixMap.put(dataPrefix(2), STR3)
+    val DATA3 = STR3.getBytes(StandardCharsets.UTF_8)
+    val LENGTH3 = DATA3.length
+    shuffleClient.pushData(1, 0, 0, 0, DATA3, 0, LENGTH3, 1, 1)
+    Thread.sleep(1000)
+
+    val STR4 = dataPrefix(3) + RandomStringUtils.random(2 * 1024)
+    dataPrefixMap.put(dataPrefix(3), STR4)
+    val DATA4 = STR4.getBytes(StandardCharsets.UTF_8)
+    val LENGTH4 = DATA4.length
+    shuffleClient.pushData(1, 0, 0, 0, DATA4, 0, LENGTH4, 1, 1)
+    Thread.sleep(1000)
+
+    val STR5 = dataPrefix(4) + RandomStringUtils.random(2 * 1024)
+    dataPrefixMap.put(dataPrefix(4), STR5)
+    val DATA5 = STR5.getBytes(StandardCharsets.UTF_8)
+    val LENGTH5 = DATA5.length
+    shuffleClient.pushData(1, 0, 0, 0, DATA5, 0, LENGTH5, 1, 1)
+    Thread.sleep(1000)
+
+    val STR6 = dataPrefix(5) + RandomStringUtils.random(6 * 1024)
+    dataPrefixMap.put(dataPrefix(5), STR6)
+    val DATA6 = STR6.getBytes(StandardCharsets.UTF_8)
+    val LENGTH6 = DATA6.length
+    shuffleClient.pushData(1, 0, 0, 0, DATA6, 0, LENGTH6, 1, 1)
+    shuffleClient.pushMergedData(1, 0, 0)
+    Thread.sleep(1000)
+
+    shuffleClient.mapperEnd(1, 0, 0, 1)
+
+    val metricsCallback = new MetricsCallback {
+      override def incBytesRead(bytesWritten: Long): Unit = {}
+      override def incReadTime(time: Long): Unit = {}
+    }
+
+    // chunkOffset is [0, 9404, 25913, 35393, 49576]
+    // chunk0 -> DATA1, chunk1 -> DATA2+DATA3, chunk2 -> DATA4+DATA5, chunk3 
-> DATA6
+    val subMap = new JHashMap[String, Pair[Integer, Integer]]()
+    // pair of (1, 2) means read chunk1 and chunk 2
+    // why not test pair of (0, 1), because we want to test 
fileChannel.position of chunk index not 0
+    subMap.put("0-0", Pair.of(1, 2))
+
+    val inputStream = shuffleClient.readPartition(
+      1,
+      1,
+      0,
+      0,
+      0,
+      3, // startMapId > endMapId, means sub-partition size
+      1, // sub-partition index
+      null,
+      null,
+      null,
+      Collections.emptyMap(), // failed batch could not be null
+      subMap, // sub-partition chunk range
+      null,
+      metricsCallback)
+    val outputStream = new ByteArrayOutputStream()
+
+    var b = inputStream.read()
+    while (b != -1) {
+      outputStream.write(b)
+      b = inputStream.read()
+    }
+
+    val readBytes = outputStream.toByteArray
+    val dataPrefix1 = Array("111111", "222222", "333333", "444444")
+    val readStringMap = getReadStringMap(readBytes, dataPrefix1, dataPrefixMap)
+
+    Assert.assertEquals(LENGTH2 + LENGTH3 + LENGTH4 + LENGTH5, 
readBytes.length)
+    for ((prefix, data) <- readStringMap) {
+      Assert.assertEquals(dataPrefixMap(prefix), data)
+    }
+
+    Thread.sleep(5000L)
+    shuffleClient.shutdown()
+    lifecycleManager.rpcEnv.shutdown()
+  }
+
+  def getReadStringMap(
+      readBytes: Array[Byte],
+      dataPrefix: Array[String],
+      dataPrefixMap: mutable.HashMap[String, String]): mutable.HashMap[String, 
String] = {
+    val readString = new String(readBytes, StandardCharsets.UTF_8)
+    val prefixStringMap = new mutable.HashMap[String, String]
+
+    var remainingString = readString
+    while (remainingString.nonEmpty) {
+      dataPrefix.find(prefix => remainingString.startsWith(prefix)) match {
+        case Some(prefix) =>
+          val expectedLength = dataPrefixMap.get(prefix).get.length
+          val subString = remainingString.substring(0, expectedLength)
+          prefixStringMap.put(prefix, subString)
+          remainingString = remainingString.substring(expectedLength)
+        case None =>
+          remainingString = ""
+      }
+    }
+
+    prefixStringMap
+  }
+
+}

Reply via email to