This is an automated email from the ASF dual-hosted git repository.

zhengchenyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 70495697f [#1398] fix(mr)(tez): Make attempId computable and move it 
to taskAttemptId in BlockId layout. (#1418)
70495697f is described below

commit 70495697fd0be6292989c0a69f5464d2ca86ce36
Author: QI Jiale <qijial...@foxmail.com>
AuthorDate: Fri Jul 5 15:19:27 2024 +0800

    [#1398] fix(mr)(tez): Make attempId computable and move it to taskAttemptId 
in BlockId layout. (#1418)
    
    ### What changes were proposed in this pull request?
    
    Before this PR, in MR and TEZ engine:
    1. attemptId is in sequenceNo of BlockId instead of taskAttemptId.
    2. attempId is fixed 6 bit.
    
    After this PR:
    1. attemptId is in taskAttemptId. This is more reasonable.
    2. attempId is calculated from max num of allowed failures and whether 
speculative execution is enabled.
    
    ### Why are the changes needed?
    
    Fix: #1398
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing UT and integrated tests.
---
 .../hadoop/mapred/RssMapOutputCollector.java       |  4 +-
 .../org/apache/hadoop/mapreduce/RssMRUtils.java    | 87 +++++++++++----------
 .../mapreduce/task/reduce/RssEventFetcher.java     |  2 +-
 .../hadoop/mapred/SortWriteBufferManagerTest.java  | 12 +--
 .../apache/hadoop/mapreduce/RssMRUtilsTest.java    | 17 ++--
 .../mapreduce/task/reduce/EventFetcherTest.java    | 28 +++----
 .../hadoop/mapreduce/task/reduce/FetcherTest.java  |  5 +-
 .../shuffle/manager/RssShuffleManagerBase.java     | 29 ++-----
 .../shuffle/manager/RssShuffleManagerBaseTest.java | 40 ----------
 .../java/org/apache/tez/common/RssTezUtils.java    | 91 ++++++++++++----------
 .../common/shuffle/impl/RssShuffleManager.java     |  6 +-
 .../common/shuffle/impl/RssTezFetcherTask.java     |  8 +-
 .../orderedgrouped/RssShuffleScheduler.java        |  6 +-
 .../library/common/sort/impl/RssSorter.java        |  4 +-
 .../library/common/sort/impl/RssUnSorter.java      |  4 +-
 .../output/RssOrderedPartitionedKVOutput.java      |  4 +-
 .../library/output/RssUnorderedKVOutput.java       |  4 +-
 .../output/RssUnorderedPartitionedKVOutput.java    |  4 +-
 .../org/apache/tez/common/RssTezUtilsTest.java     | 10 +--
 .../library/common/sort/impl/RssSorterTest.java    |  5 +-
 .../library/common/sort/impl/RssUnSorterTest.java  |  5 +-
 .../apache/uniffle/client/util/ClientUtils.java    | 17 ++++
 .../org/apache/uniffle/client/ClientUtilsTest.java | 42 ++++++++++
 .../org/apache/uniffle/common/util/BlockId.java    |  4 +-
 .../uniffle/test/TezWordCountWithFailuresTest.java |  2 +-
 .../uniffle/server/buffer/BufferTestBase.java      |  2 +-
 .../handler/impl/HadoopShuffleReadHandlerTest.java |  2 +-
 27 files changed, 246 insertions(+), 198 deletions(-)

diff --git 
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
 
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
index 3acf0b417..30f98c52e 100644
--- 
a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
+++ 
b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
@@ -100,8 +100,8 @@ public class RssMapOutputCollector<K extends Object, V 
extends Object>
     ApplicationAttemptId applicationAttemptId = 
RssMRUtils.getApplicationAttemptId();
     String appId = applicationAttemptId.toString();
     long taskAttemptId =
-        RssMRUtils.convertTaskAttemptIdToLong(
-            mapTask.getTaskID(), applicationAttemptId.getAttemptId());
+        RssMRUtils.createRssTaskAttemptId(
+            mapTask.getTaskID(), applicationAttemptId.getAttemptId(), 
mrJobConf);
     double sendThreshold =
         RssMRUtils.getDouble(
             rssJobConf,
diff --git 
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java 
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
index f1220486b..9012e618e 100644
--- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
+++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java
@@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.util.BlockIdLayout;
@@ -44,37 +45,61 @@ public class RssMRUtils {
 
   private static final Logger LOG = LoggerFactory.getLogger(RssMRUtils.class);
   private static final BlockIdLayout LAYOUT = BlockIdLayout.DEFAULT;
-  private static final int MAX_ATTEMPT_LENGTH = 6;
-  private static final int MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1;
-  private static final int MAX_SEQUENCE_NO =
-      (1 << (LAYOUT.sequenceNoBits - MAX_ATTEMPT_LENGTH)) - 1;
-
-  // Class TaskAttemptId have two field id and mapId, rss taskAttemptID have 
21 bits,
-  // mapId is 19 bits, id is 2 bits. MR have a trick logic, taskAttemptId will 
increase
-  // 1000 * (appAttemptId - 1), so we will decrease it.
-  public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID, 
int appAttemptId) {
-    int lowBytes = taskAttemptID.getTaskID().getId();
-    if (lowBytes > LAYOUT.maxTaskAttemptId) {
-      throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + 
lowBytes + " exceed");
-    }
+
+  // Class TaskAttemptId have two field id and mapId. MR have a trick logic, 
taskAttemptId will
+  // increase 1000 * (appAttemptId - 1), so we will decrease it.
+  public static int createRssTaskAttemptId(
+      TaskAttemptID taskAttemptID, int appAttemptId, int maxAttemptNo) {
+    int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
+
     if (appAttemptId < 1) {
       throw new RssException("appAttemptId " + appAttemptId + " is wrong");
     }
-    int highBytes = taskAttemptID.getId() - (appAttemptId - 1) * 1000;
-    if (highBytes > MAX_ATTEMPT_ID || highBytes < 0) {
+    int attemptId = taskAttemptID.getId() - (appAttemptId - 1) * 1000;
+    if (attemptId > maxAttemptNo || attemptId < 0) {
       throw new RssException(
-          "TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " 
exceed");
+          "TaskAttempt " + taskAttemptID + " attemptId " + attemptId + " 
exceed " + maxAttemptNo);
     }
-    return LAYOUT.getBlockId(highBytes, 0, lowBytes);
+    int taskId = taskAttemptID.getTaskID().getId();
+
+    int mapIndexBits = ClientUtils.getNumberOfSignificantBits(taskId);
+    if (mapIndexBits + attemptBits > LAYOUT.taskAttemptIdBits) {
+      throw new RssException(
+          "Observing taskId["
+              + taskId
+              + "] that would produce a taskAttemptId with "
+              + (mapIndexBits + attemptBits)
+              + " bits which is larger than the allowed "
+              + LAYOUT.taskAttemptIdBits
+              + "]). Please consider providing more bits for taskAttemptIds.");
+    }
+
+    return (taskId << attemptBits) | attemptId;
+  }
+
+  public static int createRssTaskAttemptId(
+      TaskAttemptID taskAttemptID, int appAttemptId, int maxFailures, boolean 
speculation) {
+    int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation);
+    return createRssTaskAttemptId(taskAttemptID, appAttemptId, maxAttemptNo);
+  }
+
+  public static int createRssTaskAttemptId(
+      TaskAttemptID taskAttemptID, int appAttemptId, Configuration conf) {
+    int maxFailures = conf.getInt(MRJobConfig.MAP_MAX_ATTEMPTS, 4);
+    boolean speculation = conf.getBoolean(MRJobConfig.MAP_SPECULATIVE, true);
+    return createRssTaskAttemptId(taskAttemptID, appAttemptId, maxFailures, 
speculation);
   }
 
   public static TaskAttemptID createMRTaskAttemptId(
-      JobID jobID, TaskType taskType, long rssTaskAttemptId, int appAttemptId) 
{
+      JobID jobID, TaskType taskType, long rssTaskAttemptId, int appAttemptId, 
int maxAttemptNo) {
+    int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
     if (appAttemptId < 1) {
       throw new RssException("appAttemptId " + appAttemptId + " is wrong");
     }
-    TaskID taskID = new TaskID(jobID, taskType, 
LAYOUT.getTaskAttemptId(rssTaskAttemptId));
-    int id = LAYOUT.getSequenceNo(rssTaskAttemptId) + 1000 * (appAttemptId - 
1);
+    int task = (int) rssTaskAttemptId >> attemptBits;
+    int attempt = (int) rssTaskAttemptId & ((1 << attemptBits) - 1);
+    TaskID taskID = new TaskID(jobID, taskType, task);
+    int id = attempt + 1000 * (appAttemptId - 1);
     return new TaskAttemptID(taskID, id);
   }
 
@@ -228,27 +253,11 @@ public class RssMRUtils {
   }
 
   public static long getBlockId(int partitionId, long taskAttemptId, int 
nextSeqNo) {
-    long attemptId = taskAttemptId >> (LAYOUT.partitionIdBits + 
LAYOUT.taskAttemptIdBits);
-    if (attemptId < 0 || attemptId > MAX_ATTEMPT_ID) {
-      throw new RssException(
-          "Can't support attemptId [" + attemptId + "], the max value should 
be " + MAX_ATTEMPT_ID);
-    }
-    if (nextSeqNo < 0 || nextSeqNo > MAX_SEQUENCE_NO) {
-      throw new RssException(
-          "Can't support sequence [" + nextSeqNo + "], the max value should be 
" + MAX_SEQUENCE_NO);
-    }
-
-    int atomicInt = (int) ((nextSeqNo << MAX_ATTEMPT_LENGTH) + attemptId);
-    long taskId =
-        taskAttemptId - (attemptId << (LAYOUT.partitionIdBits + 
LAYOUT.taskAttemptIdBits));
-
-    return LAYOUT.getBlockId(atomicInt, partitionId, taskId);
+    return LAYOUT.getBlockId(nextSeqNo, partitionId, taskAttemptId);
   }
 
-  public static long getTaskAttemptId(long blockId) {
-    int mapId = LAYOUT.getTaskAttemptId(blockId);
-    int attemptId = LAYOUT.getSequenceNo(blockId) & MAX_ATTEMPT_ID;
-    return LAYOUT.getBlockId(attemptId, 0, mapId);
+  public static int getTaskAttemptId(long blockId) {
+    return LAYOUT.getTaskAttemptId(blockId);
   }
 
   public static int estimateTaskConcurrency(JobConf jobConf) {
diff --git 
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java
 
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java
index 9b86cabee..ece22be74 100644
--- 
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java
+++ 
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java
@@ -75,7 +75,7 @@ public class RssEventFetcher<K, V> {
     String errMsg = "TaskAttemptIDs are inconsistent with map tasks";
     for (TaskAttemptID taskAttemptID : successMaps) {
       if (!obsoleteMaps.contains(taskAttemptID)) {
-        long rssTaskId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 
appAttemptId);
+        long rssTaskId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, 
appAttemptId, jobConf);
         int mapIndex = taskAttemptID.getTaskID().getId();
         // There can be multiple successful attempts on same map task.
         // So we only need to accept one of them.
diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 3cb017239..62f1f86f8 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -76,7 +76,7 @@ public class SortWriteBufferManagerTest {
     manager =
         new SortWriteBufferManager<BytesWritable, BytesWritable>(
             10240,
-            1L,
+            1,
             10,
             serializationFactory.getSerializer(BytesWritable.class),
             serializationFactory.getSerializer(BytesWritable.class),
@@ -140,7 +140,7 @@ public class SortWriteBufferManagerTest {
     manager =
         new SortWriteBufferManager<BytesWritable, BytesWritable>(
             100,
-            1L,
+            1,
             10,
             serializationFactory.getSerializer(BytesWritable.class),
             serializationFactory.getSerializer(BytesWritable.class),
@@ -192,7 +192,7 @@ public class SortWriteBufferManagerTest {
     manager =
         new SortWriteBufferManager<BytesWritable, BytesWritable>(
             10240,
-            1L,
+            1,
             10,
             serializationFactory.getSerializer(BytesWritable.class),
             serializationFactory.getSerializer(BytesWritable.class),
@@ -244,7 +244,7 @@ public class SortWriteBufferManagerTest {
     manager =
         new SortWriteBufferManager<BytesWritable, BytesWritable>(
             10240,
-            1L,
+            1,
             10,
             serializationFactory.getSerializer(BytesWritable.class),
             serializationFactory.getSerializer(BytesWritable.class),
@@ -311,7 +311,7 @@ public class SortWriteBufferManagerTest {
     manager =
         new SortWriteBufferManager<BytesWritable, BytesWritable>(
             10240,
-            1L,
+            1,
             10,
             serializationFactory.getSerializer(BytesWritable.class),
             serializationFactory.getSerializer(BytesWritable.class),
@@ -390,7 +390,7 @@ public class SortWriteBufferManagerTest {
     SortWriteBufferManager<Text, IntWritable> manager =
         new SortWriteBufferManager<Text, IntWritable>(
             10240,
-            1L,
+            1,
             10,
             keySerializer,
             valueSerializer,
diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java
index 2d3710ca1..31e3590d2 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java
@@ -45,20 +45,21 @@ public class RssMRUtilsTest {
     TaskAttemptID mrTaskAttemptId = new TaskAttemptID(taskId, 3);
     boolean isException = false;
     try {
-      RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
+      RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
     } catch (RssException e) {
       isException = true;
     }
     assertTrue(isException);
-    taskAttemptId = (1 << 20) + 0x123;
-    mrTaskAttemptId = RssMRUtils.createMRTaskAttemptId(new JobID(), 
TaskType.MAP, taskAttemptId, 1);
-    long testId = RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
+    taskAttemptId = (0x123 << 3) + 1;
+    mrTaskAttemptId =
+        RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, 
taskAttemptId, 1, 4);
+    int testId = RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
     assertEquals(taskAttemptId, testId);
-    TaskID taskID = new TaskID(new org.apache.hadoop.mapred.JobID(), 
TaskType.MAP, (int) (1 << 21));
+    TaskID taskID = new TaskID(new org.apache.hadoop.mapred.JobID(), 
TaskType.MAP, 1 << 21);
     mrTaskAttemptId = new TaskAttemptID(taskID, 2);
     isException = false;
     try {
-      RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1);
+      RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4);
     } catch (RssException e) {
       isException = true;
     }
@@ -70,7 +71,7 @@ public class RssMRUtilsTest {
     JobID jobID = new JobID();
     TaskID taskId = new TaskID(jobID, TaskType.MAP, 233);
     TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1);
-    long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 
1);
+    long taskAttemptId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, 1, 
4);
     long blockId = RssMRUtils.getBlockId(1, taskAttemptId, 0);
     long newTaskAttemptId = RssMRUtils.getTaskAttemptId(blockId);
     assertEquals(taskAttemptId, newTaskAttemptId);
@@ -85,7 +86,7 @@ public class RssMRUtilsTest {
     JobID jobID = new JobID();
     TaskID taskId = new TaskID(jobID, TaskType.MAP, 233);
     TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1);
-    long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 
1);
+    long taskAttemptId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, 1, 
4);
     long mask = (1L << layout.partitionIdBits) - 1;
     for (int partitionId = 0; partitionId <= 3000; partitionId++) {
       for (int seqNo = 0; seqNo <= 10; seqNo++) {
diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java
index 8e8ab8f4e..6dfb81f45 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java
@@ -59,8 +59,8 @@ public class EventFetcherTest {
     Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
     for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
       long rssTaskId =
-          RssMRUtils.convertTaskAttemptIdToLong(
-              new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+          RssMRUtils.createRssTaskAttemptId(
+              new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
       expected.addLong(rssTaskId);
     }
 
@@ -89,8 +89,8 @@ public class EventFetcherTest {
     Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
     for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
       long rssTaskId =
-          RssMRUtils.convertTaskAttemptIdToLong(
-              new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+          RssMRUtils.createRssTaskAttemptId(
+              new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
       expected.addLong(rssTaskId);
     }
 
@@ -121,8 +121,8 @@ public class EventFetcherTest {
     Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
     for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
       long rssTaskId =
-          RssMRUtils.convertTaskAttemptIdToLong(
-              new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+          RssMRUtils.createRssTaskAttemptId(
+              new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
       expected.addLong(rssTaskId);
     }
     Roaring64NavigableMap taskIdBitmap = ef.fetchAllRssTaskIds();
@@ -146,8 +146,8 @@ public class EventFetcherTest {
     Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
     for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
       long rssTaskId =
-          RssMRUtils.convertTaskAttemptIdToLong(
-              new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+          RssMRUtils.createRssTaskAttemptId(
+              new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
       expected.addLong(rssTaskId);
     }
     IllegalStateException ex =
@@ -172,8 +172,8 @@ public class EventFetcherTest {
     Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf();
     for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
       long rssTaskId =
-          RssMRUtils.convertTaskAttemptIdToLong(
-              new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+          RssMRUtils.createRssTaskAttemptId(
+              new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4);
       expected.addLong(rssTaskId);
     }
     IllegalStateException ex =
@@ -205,14 +205,14 @@ public class EventFetcherTest {
     for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) {
       if (!tipFailed.contains(mapIndex) && !obsoleted.contains(mapIndex)) {
         long rssTaskId =
-            RssMRUtils.convertTaskAttemptIdToLong(
-                new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1);
+            RssMRUtils.createRssTaskAttemptId(
+                new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 
4);
         expected.addLong(rssTaskId);
       }
       if (obsoleted.contains(mapIndex)) {
         long rssTaskId =
-            RssMRUtils.convertTaskAttemptIdToLong(
-                new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 1), 1);
+            RssMRUtils.createRssTaskAttemptId(
+                new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 1), 1, 
4);
         expected.addLong(rssTaskId);
       }
     }
diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index 6c5f6776d..9ca680cb3 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -291,7 +291,8 @@ public class FetcherTest {
             null,
             new Progress(),
             new MROutputFiles());
-    TaskAttemptID taskAttemptID = RssMRUtils.createMRTaskAttemptId(new 
JobID(), TaskType.MAP, 1, 1);
+    TaskAttemptID taskAttemptID =
+        RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, 1, 1, 4);
     byte[] buffer = new byte[10];
     MapOutput mapOutput1 = merger.reserve(taskAttemptID, 10, 1);
     RssBypassWriter.write(mapOutput1, buffer);
@@ -350,7 +351,7 @@ public class FetcherTest {
     SortWriteBufferManager<Text, Text> manager =
         new SortWriteBufferManager(
             10240,
-            1L,
+            1,
             10,
             serializationFactory.getSerializer(Text.class),
             serializationFactory.getSerializer(Text.class),
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index bbeec90dd..6a281db2e 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -194,8 +194,10 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
               + maxPartitions);
     }
 
-    int attemptIdBits = getAttemptIdBits(getMaxAttemptNo(maxFailures, 
speculation));
-    int partitionIdBits = 32 - Integer.numberOfLeadingZeros(maxPartitions - 
1); // [1..31]
+    int attemptIdBits =
+        ClientUtils.getNumberOfSignificantBits(
+            ClientUtils.getMaxAttemptNo(maxFailures, speculation));
+    int partitionIdBits = ClientUtils.getNumberOfSignificantBits(maxPartitions 
- 1); // [1..31]
     int taskAttemptIdBits = partitionIdBits + attemptIdBits; // 
[1+attemptIdBits..31+attemptIdBits]
     int sequenceNoBits = 63 - partitionIdBits - taskAttemptIdBits; // 
[1-attemptIdBits..61]
 
@@ -334,23 +336,6 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     }
   }
 
-  protected static int getMaxAttemptNo(int maxFailures, boolean speculation) {
-    // attempt number is zero based: 0, 1, …, maxFailures-1
-    // max maxFailures < 1 is not allowed but for safety, we interpret that as 
maxFailures == 1
-    int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
-
-    // with speculative execution enabled we could observe +1 attempts
-    if (speculation) {
-      maxAttemptNo++;
-    }
-
-    return maxAttemptNo;
-  }
-
-  protected static int getAttemptIdBits(int maxAttemptNo) {
-    return 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
-  }
-
   /** See static overload of this method. */
   public abstract long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo);
 
@@ -369,8 +354,8 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
    */
   protected static long getTaskAttemptIdForBlockId(
       int mapIndex, int attemptNo, int maxFailures, boolean speculation, int 
maxTaskAttemptIdBits) {
-    int maxAttemptNo = getMaxAttemptNo(maxFailures, speculation);
-    int attemptBits = getAttemptIdBits(maxAttemptNo);
+    int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation);
+    int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
 
     if (attemptNo > maxAttemptNo) {
       // this should never happen, if it does, our assumptions are wrong,
@@ -384,7 +369,7 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
               + ".");
     }
 
-    int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
+    int mapIndexBits = ClientUtils.getNumberOfSignificantBits(mapIndex);
     if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
       throw new RssException(
           "Observing mapIndex["
diff --git 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
index fffb7af3f..610b42c8c 100644
--- 
a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java
@@ -370,46 +370,6 @@ public class RssShuffleManagerBaseTest {
     assertTrue(e.getMessage().startsWith("All block id bit config keys must be 
provided "));
   }
 
-  @Test
-  public void testGetMaxAttemptNo() {
-    // without speculation
-    assertEquals(0, RssShuffleManagerBase.getMaxAttemptNo(-1, false));
-    assertEquals(0, RssShuffleManagerBase.getMaxAttemptNo(0, false));
-    assertEquals(0, RssShuffleManagerBase.getMaxAttemptNo(1, false));
-    assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(2, false));
-    assertEquals(2, RssShuffleManagerBase.getMaxAttemptNo(3, false));
-    assertEquals(3, RssShuffleManagerBase.getMaxAttemptNo(4, false));
-    assertEquals(4, RssShuffleManagerBase.getMaxAttemptNo(5, false));
-    assertEquals(1023, RssShuffleManagerBase.getMaxAttemptNo(1024, false));
-
-    // with speculation
-    assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(-1, true));
-    assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(0, true));
-    assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(1, true));
-    assertEquals(2, RssShuffleManagerBase.getMaxAttemptNo(2, true));
-    assertEquals(3, RssShuffleManagerBase.getMaxAttemptNo(3, true));
-    assertEquals(4, RssShuffleManagerBase.getMaxAttemptNo(4, true));
-    assertEquals(5, RssShuffleManagerBase.getMaxAttemptNo(5, true));
-    assertEquals(1024, RssShuffleManagerBase.getMaxAttemptNo(1024, true));
-  }
-
-  @Test
-  public void testGetAttemptIdBits() {
-    assertEquals(0, RssShuffleManagerBase.getAttemptIdBits(0));
-    assertEquals(1, RssShuffleManagerBase.getAttemptIdBits(1));
-    assertEquals(2, RssShuffleManagerBase.getAttemptIdBits(2));
-    assertEquals(2, RssShuffleManagerBase.getAttemptIdBits(3));
-    assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(4));
-    assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(5));
-    assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(6));
-    assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(7));
-    assertEquals(4, RssShuffleManagerBase.getAttemptIdBits(8));
-    assertEquals(4, RssShuffleManagerBase.getAttemptIdBits(9));
-    assertEquals(10, RssShuffleManagerBase.getAttemptIdBits(1023));
-    assertEquals(11, RssShuffleManagerBase.getAttemptIdBits(1024));
-    assertEquals(11, RssShuffleManagerBase.getAttemptIdBits(1025));
-  }
-
   private long bits(String string) {
     return Long.parseLong(string.replaceAll("[|]", ""), 2);
   }
diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java 
b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
index e168ec0d3..3df582bde 100644
--- a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java
@@ -24,6 +24,7 @@ import java.util.Set;
 import com.google.common.base.Preconditions;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
 import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
@@ -51,6 +52,7 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
+import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.util.BlockIdLayout;
@@ -60,10 +62,6 @@ public class RssTezUtils {
 
   private static final Logger LOG = LoggerFactory.getLogger(RssTezUtils.class);
   private static final BlockIdLayout LAYOUT = BlockIdLayout.DEFAULT;
-  private static final int MAX_ATTEMPT_LENGTH = 6;
-  private static final int MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1;
-  private static final int MAX_SEQUENCE_NO =
-      (1 << (LAYOUT.sequenceNoBits - MAX_ATTEMPT_LENGTH)) - 1;
 
   public static final String HOST_NAME = "hostname";
 
@@ -159,32 +157,11 @@ public class RssTezUtils {
   }
 
   public static long getBlockId(int partitionId, long taskAttemptId, int 
nextSeqNo) {
-    LOG.info(
-        "GetBlockId, partitionId:{}, taskAttemptId:{}, nextSeqNo:{}",
-        partitionId,
-        taskAttemptId,
-        nextSeqNo);
-    long attemptId = taskAttemptId >> (LAYOUT.partitionIdBits + 
LAYOUT.taskAttemptIdBits);
-    if (attemptId < 0 || attemptId > MAX_ATTEMPT_ID) {
-      throw new RssException(
-          "Can't support attemptId [" + attemptId + "], the max value should 
be " + MAX_ATTEMPT_ID);
-    }
-    if (nextSeqNo < 0 || nextSeqNo > MAX_SEQUENCE_NO) {
-      throw new RssException(
-          "Can't support sequence [" + nextSeqNo + "], the max value should be 
" + MAX_SEQUENCE_NO);
-    }
-
-    int atomicInt = (int) ((nextSeqNo << MAX_ATTEMPT_LENGTH) + attemptId);
-    long taskId =
-        taskAttemptId - (attemptId << (LAYOUT.partitionIdBits + 
LAYOUT.taskAttemptIdBits));
-
-    return LAYOUT.getBlockId(atomicInt, partitionId, taskId);
+    return LAYOUT.getBlockId(nextSeqNo, partitionId, taskAttemptId);
   }
 
-  public static long getTaskAttemptId(long blockId) {
-    int mapId = LAYOUT.getTaskAttemptId(blockId);
-    int attemptId = LAYOUT.getSequenceNo(blockId) & MAX_ATTEMPT_ID;
-    return LAYOUT.getBlockId(attemptId, 0, mapId);
+  public static int getTaskAttemptId(long blockId) {
+    return LAYOUT.getTaskAttemptId(blockId);
   }
 
   public static int estimateTaskConcurrency(Configuration jobConf, int mapNum, 
int reduceNum) {
@@ -276,23 +253,55 @@ public class RssTezUtils {
     }
   }
 
-  public static long convertTaskAttemptIdToLong(TezTaskAttemptID 
taskAttemptID) {
-    int lowBytes = taskAttemptID.getTaskID().getId();
-    if (lowBytes > LAYOUT.maxTaskAttemptId) {
-      throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + 
lowBytes + " exceed");
+  public static int createRssTaskAttemptId(TezTaskAttemptID taskAttemptId, int 
maxAttemptNo) {
+    int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
+
+    int attemptId = taskAttemptId.getId();
+    if (attemptId > maxAttemptNo || attemptId < 0) {
+      throw new RssException(
+          "TaskAttempt " + taskAttemptId + " attemptId " + attemptId + " 
exceed");
     }
-    int highBytes = taskAttemptID.getId();
-    if (highBytes > MAX_ATTEMPT_ID || highBytes < 0) {
+    int taskId = taskAttemptId.getTaskID().getId();
+
+    int mapIndexBits = ClientUtils.getNumberOfSignificantBits(taskId);
+    if (mapIndexBits + attemptBits > LAYOUT.taskAttemptIdBits) {
       throw new RssException(
-          "TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " 
exceed.");
+          "Observing taskId["
+              + taskId
+              + "] that would produce a taskAttemptId with "
+              + (mapIndexBits + attemptBits)
+              + " bits which is larger than the allowed "
+              + LAYOUT.taskAttemptIdBits
+              + "]). Please consider providing more bits for taskAttemptIds.");
     }
-    long id = LAYOUT.getBlockId(highBytes, 0, lowBytes);
-    LOG.info("ConvertTaskAttemptIdToLong taskAttemptID:{}, id is {}, .", 
taskAttemptID, id);
+
+    int id = (taskId << attemptBits) | attemptId;
+    LOG.info("createRssTaskAttemptId taskAttemptId:{}, id is {}, .", 
taskAttemptId, id);
     return id;
   }
 
+  public static int createRssTaskAttemptId(TezTaskAttemptID taskAttemptId, 
Configuration conf) {
+    int maxAttemptNo = getMaxAttemptNo(conf);
+    return createRssTaskAttemptId(taskAttemptId, maxAttemptNo);
+  }
+
+  public static int getMaxAttemptNo(Configuration conf) {
+    int maxFailures =
+        conf.getInt(
+            TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS,
+            TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS_DEFAULT);
+    boolean speculation =
+        conf.getBoolean(
+            TezConfiguration.TEZ_AM_SPECULATION_ENABLED,
+            TezConfiguration.TEZ_AM_SPECULATION_ENABLED_DEFAULT);
+    return ClientUtils.getMaxAttemptNo(maxFailures, speculation);
+  }
+
   public static Roaring64NavigableMap fetchAllRssTaskIds(
-      Set<InputAttemptIdentifier> successMapTaskAttempts, int totalMapsCount, 
int appAttemptId) {
+      Set<InputAttemptIdentifier> successMapTaskAttempts,
+      int totalMapsCount,
+      int appAttemptId,
+      int maxAttemptNo) {
     String errMsg = "TaskAttemptIDs are inconsistent with map tasks";
     Roaring64NavigableMap rssTaskIdBitmap = Roaring64NavigableMap.bitmapOf();
     Roaring64NavigableMap mapTaskIdBitmap = Roaring64NavigableMap.bitmapOf();
@@ -301,9 +310,9 @@ public class RssTezUtils {
 
     for (InputAttemptIdentifier inputAttemptIdentifier : 
successMapTaskAttempts) {
       String pathComponent = inputAttemptIdentifier.getPathComponent();
-      TezTaskAttemptID mapTaskAttemptID = 
IdUtils.convertTezTaskAttemptID(pathComponent);
-      long rssTaskId = 
RssTezUtils.convertTaskAttemptIdToLong(mapTaskAttemptID);
-      long mapTaskId = mapTaskAttemptID.getTaskID().getId();
+      TezTaskAttemptID mapTaskAttemptId = 
IdUtils.convertTezTaskAttemptID(pathComponent);
+      long rssTaskId = RssTezUtils.createRssTaskAttemptId(mapTaskAttemptId, 
maxAttemptNo);
+      long mapTaskId = mapTaskAttemptId.getTaskID().getId();
 
       LOG.info(
           "FetchAllRssTaskIds, pathComponent: {}, mapTaskId:{}, rssTaskId:{}, 
is contains:{}",
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
index 4fa7a7b98..dfdc6f840 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java
@@ -67,6 +67,7 @@ import org.apache.hadoop.util.Time;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.tez.common.CallableWithNdc;
 import org.apache.tez.common.InputContextUtils;
+import org.apache.tez.common.RssTezUtils;
 import org.apache.tez.common.TezRuntimeFrameworkConfigs;
 import org.apache.tez.common.TezUtilsInternal;
 import org.apache.tez.common.UmbilicalUtils;
@@ -590,6 +591,8 @@ public class RssShuffleManager extends ShuffleManager {
                     partitionToServers.get(partition),
                     partitionToServers);
 
+                int maxAttemptNo = RssTezUtils.getMaxAttemptNo(conf);
+
                 RssTezFetcherTask fetcher =
                     new RssTezFetcherTask(
                         RssShuffleManager.this,
@@ -604,7 +607,8 @@ public class RssShuffleManager extends ShuffleManager {
                         rssAllBlockIdBitmapMap,
                         rssSuccessBlockIdBitmapMap,
                         numInputs,
-                        partitionToServers.size());
+                        partitionToServers.size(),
+                        maxAttemptNo);
                 rssRunningFetchers.add(fetcher);
                 if (isShutdown.get()) {
                   LOG.info(
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
index 63922eb64..3052c5dbc 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java
@@ -76,6 +76,7 @@ public class RssTezFetcherTask extends 
CallableWithNdc<FetchResult> {
   private final int partitionNum;
   private final int shuffleId;
   private final ApplicationAttemptId applicationAttemptId;
+  private final int maxAttemptNo;
 
   public RssTezFetcherTask(
       FetcherCallback fetcherCallback,
@@ -90,7 +91,8 @@ public class RssTezFetcherTask extends 
CallableWithNdc<FetchResult> {
       Map<Integer, Roaring64NavigableMap> rssAllBlockIdBitmapMap,
       Map<Integer, Roaring64NavigableMap> rssSuccessBlockIdBitmapMap,
       int numPhysicalInputs,
-      int partitionNum) {
+      int partitionNum,
+      int maxAttemptNo) {
     if (CollectionUtils.isEmpty(inputs)) {
       throw new RssException("inputs should not be empty");
     }
@@ -135,6 +137,7 @@ public class RssTezFetcherTask extends 
CallableWithNdc<FetchResult> {
         conf.getInt(
             RssTezConfig.RSS_PARTITION_NUM_PER_RANGE,
             RssTezConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE);
+    this.maxAttemptNo = maxAttemptNo;
     LOG.info(
         "RssTezFetcherTask fetch partition:{}, with inputs:{}, 
readBufferSize:{}, partitionNumPerRange:{}.",
         this.partition,
@@ -163,7 +166,8 @@ public class RssTezFetcherTask extends 
CallableWithNdc<FetchResult> {
     // final RssEventFetcher eventFetcher = new RssEventFetcher(inputs, 
numPhysicalInputs);
     int appAttemptId = applicationAttemptId.getAttemptId();
     Roaring64NavigableMap taskIdBitmap =
-        RssTezUtils.fetchAllRssTaskIds(new HashSet<>(inputs), 
numPhysicalInputs, appAttemptId);
+        RssTezUtils.fetchAllRssTaskIds(
+            new HashSet<>(inputs), numPhysicalInputs, appAttemptId, 
this.maxAttemptNo);
     LOG.info(
         "Inputs:{}, num input:{}, appAttemptId:{}, taskIdBitmap:{}",
         inputs,
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
index 8fb873649..72a9e1763 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java
@@ -282,6 +282,8 @@ class RssShuffleScheduler extends ShuffleScheduler {
   private RemoteStorageInfo remoteStorageInfo;
   private int indexReadLimit;
 
+  private final int maxAttemptNo;
+
   RssShuffleScheduler(
       InputContext inputContext,
       Configuration conf,
@@ -538,6 +540,7 @@ class RssShuffleScheduler extends ShuffleScheduler {
     this.basePath = this.conf.get(RssTezConfig.RSS_REMOTE_STORAGE_PATH);
     String remoteStorageConf = 
this.conf.get(RssTezConfig.RSS_REMOTE_STORAGE_CONF);
     this.remoteStorageInfo = new RemoteStorageInfo(basePath, 
remoteStorageConf);
+    this.maxAttemptNo = RssTezUtils.getMaxAttemptNo(conf);
 
     LOG.info(
         "RSSShuffleScheduler running for sourceVertex: "
@@ -1834,7 +1837,8 @@ class RssShuffleScheduler extends ShuffleScheduler {
         RssTezUtils.fetchAllRssTaskIds(
             partitionIdToSuccessMapTaskAttempts.get(mapHost.getPartitionId()),
             this.numInputs,
-            appAttemptId);
+            appAttemptId,
+            maxAttemptNo);
 
     LOG.info(
         "In reduce: {}, RSS Tez client has fetched blockIds and taskIds 
successfully, partitionId:{}.",
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
index 94c9aa90d..fe4f11e13 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java
@@ -61,7 +61,8 @@ public class RssSorter extends ExternalSorter {
       long initialMemoryAvailable,
       int shuffleId,
       ApplicationAttemptId applicationAttemptId,
-      Map<Integer, List<ShuffleServerInfo>> partitionToServers)
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+      long taskAttemptId)
       throws IOException {
     super(outputContext, conf, numOutputs, initialMemoryAvailable);
     this.partitionToServers = partitionToServers;
@@ -81,7 +82,6 @@ public class RssSorter extends ExternalSorter {
         conf.getDouble(
             RssTezConfig.RSS_CLIENT_SORT_MEMORY_USE_THRESHOLD,
             RssTezConfig.RSS_CLIENT_DEFAULT_SORT_MEMORY_USE_THRESHOLD);
-    long taskAttemptId = 
RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptID);
 
     long maxSegmentSize =
         conf.getLong(
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
index 0248bb8a2..94e87a40e 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java
@@ -60,7 +60,8 @@ public class RssUnSorter extends ExternalSorter {
       long initialMemoryAvailable,
       int shuffleId,
       ApplicationAttemptId applicationAttemptId,
-      Map<Integer, List<ShuffleServerInfo>> partitionToServers)
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+      long taskAttemptId)
       throws IOException {
     super(outputContext, conf, numOutputs, initialMemoryAvailable);
     this.partitionToServers = partitionToServers;
@@ -80,7 +81,6 @@ public class RssUnSorter extends ExternalSorter {
         conf.getDouble(
             RssTezConfig.RSS_CLIENT_SORT_MEMORY_USE_THRESHOLD,
             RssTezConfig.RSS_CLIENT_DEFAULT_SORT_MEMORY_USE_THRESHOLD);
-    long taskAttemptId = 
RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptID);
     long maxSegmentSize =
         conf.getLong(
             RssTezConfig.RSS_CLIENT_MAX_BUFFER_SIZE,
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
index 6bf563c9a..997dcdfa6 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
@@ -217,6 +217,7 @@ public class RssOrderedPartitionedKVOutput extends 
AbstractLogicalOutput {
   public void start() throws Exception {
     if (!isStarted.get()) {
       memoryUpdateCallbackHandler.validateUpdateReceived();
+      long rssTaskAttemptId = 
RssTezUtils.createRssTaskAttemptId(taskAttemptId, conf);
       sorter =
           new RssSorter(
               taskAttemptId,
@@ -227,7 +228,8 @@ public class RssOrderedPartitionedKVOutput extends 
AbstractLogicalOutput {
               memoryUpdateCallbackHandler.getMemoryAssigned(),
               shuffleId,
               applicationAttemptId,
-              partitionToServers);
+              partitionToServers,
+              rssTaskAttemptId);
       LOG.info("Initialized RssSorter.");
       isStarted.set(true);
     }
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
index 6aa32b327..3e08e7944 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
@@ -222,6 +222,7 @@ public class RssUnorderedKVOutput extends 
AbstractLogicalOutput {
   public void start() throws Exception {
     if (!isStarted.get()) {
       memoryUpdateCallbackHandler.validateUpdateReceived();
+      long rssTaskAttemptId = 
RssTezUtils.createRssTaskAttemptId(taskAttemptId, conf);
       sorter =
           new RssUnSorter(
               taskAttemptId,
@@ -232,7 +233,8 @@ public class RssUnorderedKVOutput extends 
AbstractLogicalOutput {
               memoryUpdateCallbackHandler.getMemoryAssigned(),
               shuffleId,
               applicationAttemptId,
-              partitionToServers);
+              partitionToServers,
+              rssTaskAttemptId);
       LOG.info("Initialized RssUnSorter.");
       isStarted.set(true);
     }
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
index bc78a2f67..7edb5a1f3 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
@@ -220,6 +220,7 @@ public class RssUnorderedPartitionedKVOutput extends 
AbstractLogicalOutput {
   public void start() throws Exception {
     if (!isStarted.get()) {
       memoryUpdateCallbackHandler.validateUpdateReceived();
+      long rssTaskAttemptId = 
RssTezUtils.createRssTaskAttemptId(taskAttemptId, conf);
       sorter =
           new RssUnSorter(
               taskAttemptId,
@@ -230,7 +231,8 @@ public class RssUnorderedPartitionedKVOutput extends 
AbstractLogicalOutput {
               memoryUpdateCallbackHandler.getMemoryAssigned(),
               shuffleId,
               applicationAttemptId,
-              partitionToServers);
+              partitionToServers,
+              rssTaskAttemptId);
       LOG.info("Initialized RssUnSorter.");
       isStarted.set(true);
     }
diff --git 
a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java 
b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
index c71dbefcf..12404a14f 100644
--- a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
+++ b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java
@@ -56,17 +56,17 @@ public class RssTezUtilsTest {
 
     boolean isException = false;
     try {
-      RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId);
+      RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId, 3);
     } catch (RssException e) {
       isException = true;
     }
     assertTrue(isException);
 
-    taskId = TezTaskID.getInstance(vId, (int) (1 << 21));
+    taskId = TezTaskID.getInstance(vId, 1 << 21);
     tezTaskAttemptId = TezTaskAttemptID.getInstance(taskId, 2);
     isException = false;
     try {
-      RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId);
+      RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId, 3);
     } catch (RssException e) {
       isException = true;
     }
@@ -80,7 +80,7 @@ public class RssTezUtilsTest {
     TezVertexID vId = TezVertexID.getInstance(dagId, 35);
     TezTaskID tId = TezTaskID.getInstance(vId, 389);
     TezTaskAttemptID tezTaskAttemptId = TezTaskAttemptID.getInstance(tId, 2);
-    long taskAttemptId = 
RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId);
+    long taskAttemptId = RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId, 
3);
     long blockId = RssTezUtils.getBlockId(1, taskAttemptId, 0);
     long newTaskAttemptId = RssTezUtils.getTaskAttemptId(blockId);
     assertEquals(taskAttemptId, newTaskAttemptId);
@@ -97,7 +97,7 @@ public class RssTezUtilsTest {
     TezVertexID vId = TezVertexID.getInstance(dagId, 35);
     TezTaskID tId = TezTaskID.getInstance(vId, 389);
     TezTaskAttemptID tezTaskAttemptId = TezTaskAttemptID.getInstance(tId, 2);
-    long taskAttemptId = 
RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId);
+    long taskAttemptId = RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId, 
3);
     long mask = (1L << layout.partitionIdBits) - 1;
     for (int partitionId = 0; partitionId <= 3000; partitionId++) {
       for (int seqNo = 0; seqNo <= 10; seqNo++) {
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
index 588abf32d..db9f9cea5 100644
--- 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java
@@ -31,6 +31,7 @@ import org.apache.hadoop.io.Text;
 import org.apache.hadoop.yarn.api.ApplicationConstants;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.RssTezUtils;
 import org.apache.tez.common.TezRuntimeFrameworkConfigs;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.runtime.api.OutputContext;
@@ -87,6 +88,7 @@ public class RssSorterTest {
 
     long initialMemoryAvailable = 10240000;
     int shuffleId = 1001;
+    long rssTaskAttemptId = 
RssTezUtils.createRssTaskAttemptId(tezTaskAttemptID, 3);
 
     RssSorter rssSorter =
         new RssSorter(
@@ -98,7 +100,8 @@ public class RssSorterTest {
             initialMemoryAvailable,
             shuffleId,
             applicationAttemptId,
-            partitionToServers);
+            partitionToServers,
+            rssTaskAttemptId);
 
     rssSorter.collect(new Text("0"), new Text("0"), 0);
     rssSorter.collect(new Text("0"), new Text("1"), 0);
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java
index e54a37fd7..57f3f7626 100644
--- 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java
@@ -27,6 +27,7 @@ import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.tez.common.RssTezUtils;
 import org.apache.tez.common.TezRuntimeFrameworkConfigs;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.runtime.api.OutputContext;
@@ -82,6 +83,7 @@ public class RssUnSorterTest {
 
     long initialMemoryAvailable = 10240000;
     int shuffleId = 1001;
+    long rssTaskAttemptId = 
RssTezUtils.createRssTaskAttemptId(tezTaskAttemptID, 3);
 
     RssUnSorter rssSorter =
         new RssUnSorter(
@@ -93,7 +95,8 @@ public class RssUnSorterTest {
             initialMemoryAvailable,
             shuffleId,
             APPATTEMPT_ID,
-            partitionToServers);
+            partitionToServers,
+            rssTaskAttemptId);
 
     rssSorter.collect(new Text("0"), new Text("0"), 0);
     rssSorter.collect(new Text("0"), new Text("1"), 0);
diff --git 
a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java 
b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
index b3d40dcde..29fc4b241 100644
--- a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
+++ b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
@@ -112,4 +112,21 @@ public class ClientUtils {
           String.format("The value of %s should be one of %s", clientType, 
types));
     }
   }
+
+  public static int getMaxAttemptNo(int maxFailures, boolean speculation) {
+    // attempt number is zero based: 0, 1, …, maxFailures-1
+    // max maxFailures < 1 is not allowed but for safety, we interpret that as 
maxFailures == 1
+    int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
+
+    // with speculative execution enabled we could observe +1 attempts
+    if (speculation) {
+      maxAttemptNo++;
+    }
+
+    return maxAttemptNo;
+  }
+
+  public static int getNumberOfSignificantBits(int number) {
+    return 32 - Integer.numberOfLeadingZeros(number);
+  }
 }
diff --git 
a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java 
b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
index dd9ef62b1..611a46b44 100644
--- a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
+++ b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
@@ -36,6 +36,8 @@ import org.apache.uniffle.client.util.DefaultIdHelper;
 import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.RssUtils;
 
+import static org.apache.uniffle.client.util.ClientUtils.getMaxAttemptNo;
+import static 
org.apache.uniffle.client.util.ClientUtils.getNumberOfSignificantBits;
 import static org.apache.uniffle.client.util.ClientUtils.waitUntilDoneOrFail;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.fail;
@@ -134,4 +136,44 @@ public class ClientUtilsTest {
       // Ignore
     }
   }
+
+  @Test
+  public void testGetMaxAttemptNo() {
+    // without speculation
+    assertEquals(0, getMaxAttemptNo(-1, false));
+    assertEquals(0, getMaxAttemptNo(0, false));
+    assertEquals(0, getMaxAttemptNo(1, false));
+    assertEquals(1, getMaxAttemptNo(2, false));
+    assertEquals(2, getMaxAttemptNo(3, false));
+    assertEquals(3, getMaxAttemptNo(4, false));
+    assertEquals(4, getMaxAttemptNo(5, false));
+    assertEquals(1023, getMaxAttemptNo(1024, false));
+
+    // with speculation
+    assertEquals(1, getMaxAttemptNo(-1, true));
+    assertEquals(1, getMaxAttemptNo(0, true));
+    assertEquals(1, getMaxAttemptNo(1, true));
+    assertEquals(2, getMaxAttemptNo(2, true));
+    assertEquals(3, getMaxAttemptNo(3, true));
+    assertEquals(4, getMaxAttemptNo(4, true));
+    assertEquals(5, getMaxAttemptNo(5, true));
+    assertEquals(1024, getMaxAttemptNo(1024, true));
+  }
+
+  @Test
+  public void testGetNumberOfSignificantBits() {
+    assertEquals(0, getNumberOfSignificantBits(0));
+    assertEquals(1, getNumberOfSignificantBits(1));
+    assertEquals(2, getNumberOfSignificantBits(2));
+    assertEquals(2, getNumberOfSignificantBits(3));
+    assertEquals(3, getNumberOfSignificantBits(4));
+    assertEquals(3, getNumberOfSignificantBits(5));
+    assertEquals(3, getNumberOfSignificantBits(6));
+    assertEquals(3, getNumberOfSignificantBits(7));
+    assertEquals(4, getNumberOfSignificantBits(8));
+    assertEquals(4, getNumberOfSignificantBits(9));
+    assertEquals(10, getNumberOfSignificantBits(1023));
+    assertEquals(11, getNumberOfSignificantBits(1024));
+    assertEquals(11, getNumberOfSignificantBits(1025));
+  }
 }
diff --git a/common/src/main/java/org/apache/uniffle/common/util/BlockId.java 
b/common/src/main/java/org/apache/uniffle/common/util/BlockId.java
index 36025f66b..6c93e6345 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/BlockId.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/BlockId.java
@@ -32,10 +32,10 @@ public class BlockId {
   public final BlockIdLayout layout;
   public final int sequenceNo;
   public final int partitionId;
-  public final int taskAttemptId;
+  public final long taskAttemptId;
 
   protected BlockId(
-      long blockId, BlockIdLayout layout, int sequenceNo, int partitionId, int 
taskAttemptId) {
+      long blockId, BlockIdLayout layout, int sequenceNo, int partitionId, 
long taskAttemptId) {
     this.blockId = blockId;
     this.layout = layout;
     this.sequenceNo = sequenceNo;
diff --git 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
index 438bbdf39..d31d842ab 100644
--- 
a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
+++ 
b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java
@@ -362,7 +362,7 @@ public class TezWordCountWithFailuresTest extends 
IntegrationTestBase {
         // verifyMode is 0: avoid recompute succeeded task is true
         Assertions.assertEquals(0, 
progressMap.get("Tokenizer").getKilledTaskAttemptCount());
       } else if (verifyMode == 1) {
-        // verifyMode is 1: avoid recompute succeeded task is true
+        // verifyMode is 1: avoid recompute succeeded task is false
         
Assertions.assertTrue(progressMap.get("Tokenizer").getKilledTaskAttemptCount() 
> 0);
       }
       return 0;
diff --git 
a/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java 
b/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java
index 2314a1b6c..7aef68c15 100644
--- a/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java
+++ b/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java
@@ -51,7 +51,7 @@ public abstract class BufferTestBase {
     return createData(partitionId, 0, len);
   }
 
-  protected ShufflePartitionedData createData(int partitionId, int 
taskAttemptId, int len) {
+  protected ShufflePartitionedData createData(int partitionId, long 
taskAttemptId, int len) {
     byte[] buf = new byte[len];
     new Random().nextBytes(buf);
     long blockId =
diff --git 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
index c11fc27a7..d1b663f1f 100644
--- 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
+++ 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java
@@ -108,7 +108,7 @@ public class HadoopShuffleReadHandlerTest extends 
HadoopTestBase {
     int totalBlockNum = 0;
     int expectTotalBlockNum = 6;
     int blockSize = 7;
-    int taskAttemptId = 0;
+    long taskAttemptId = 0;
 
     // write expectTotalBlockNum - 1 complete block
     HadoopShuffleHandlerTestBase.writeTestData(

Reply via email to