This is an automated email from the ASF dual-hosted git repository. zhengchenyu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push: new 70495697f [#1398] fix(mr)(tez): Make attempId computable and move it to taskAttemptId in BlockId layout. (#1418) 70495697f is described below commit 70495697fd0be6292989c0a69f5464d2ca86ce36 Author: QI Jiale <qijial...@foxmail.com> AuthorDate: Fri Jul 5 15:19:27 2024 +0800 [#1398] fix(mr)(tez): Make attempId computable and move it to taskAttemptId in BlockId layout. (#1418) ### What changes were proposed in this pull request? Before this PR, in MR and TEZ engine: 1. attemptId is in sequenceNo of BlockId instead of taskAttemptId. 2. attempId is fixed 6 bit. After this PR: 1. attemptId is in taskAttemptId. This is more reasonable. 2. attempId is calculated from max num of allowed failures and whether speculative execution is enabled. ### Why are the changes needed? Fix: #1398 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing UT and integrated tests. --- .../hadoop/mapred/RssMapOutputCollector.java | 4 +- .../org/apache/hadoop/mapreduce/RssMRUtils.java | 87 +++++++++++---------- .../mapreduce/task/reduce/RssEventFetcher.java | 2 +- .../hadoop/mapred/SortWriteBufferManagerTest.java | 12 +-- .../apache/hadoop/mapreduce/RssMRUtilsTest.java | 17 ++-- .../mapreduce/task/reduce/EventFetcherTest.java | 28 +++---- .../hadoop/mapreduce/task/reduce/FetcherTest.java | 5 +- .../shuffle/manager/RssShuffleManagerBase.java | 29 ++----- .../shuffle/manager/RssShuffleManagerBaseTest.java | 40 ---------- .../java/org/apache/tez/common/RssTezUtils.java | 91 ++++++++++++---------- .../common/shuffle/impl/RssShuffleManager.java | 6 +- .../common/shuffle/impl/RssTezFetcherTask.java | 8 +- .../orderedgrouped/RssShuffleScheduler.java | 6 +- .../library/common/sort/impl/RssSorter.java | 4 +- .../library/common/sort/impl/RssUnSorter.java | 4 +- .../output/RssOrderedPartitionedKVOutput.java | 4 +- .../library/output/RssUnorderedKVOutput.java | 4 +- .../output/RssUnorderedPartitionedKVOutput.java | 4 +- .../org/apache/tez/common/RssTezUtilsTest.java | 10 +-- .../library/common/sort/impl/RssSorterTest.java | 5 +- .../library/common/sort/impl/RssUnSorterTest.java | 5 +- .../apache/uniffle/client/util/ClientUtils.java | 17 ++++ .../org/apache/uniffle/client/ClientUtilsTest.java | 42 ++++++++++ .../org/apache/uniffle/common/util/BlockId.java | 4 +- .../uniffle/test/TezWordCountWithFailuresTest.java | 2 +- .../uniffle/server/buffer/BufferTestBase.java | 2 +- .../handler/impl/HadoopShuffleReadHandlerTest.java | 2 +- 27 files changed, 246 insertions(+), 198 deletions(-) diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java index 3acf0b417..30f98c52e 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java @@ -100,8 +100,8 @@ public class RssMapOutputCollector<K extends Object, V extends Object> ApplicationAttemptId applicationAttemptId = RssMRUtils.getApplicationAttemptId(); String appId = applicationAttemptId.toString(); long taskAttemptId = - RssMRUtils.convertTaskAttemptIdToLong( - mapTask.getTaskID(), applicationAttemptId.getAttemptId()); + RssMRUtils.createRssTaskAttemptId( + mapTask.getTaskID(), applicationAttemptId.getAttemptId(), mrJobConf); double sendThreshold = RssMRUtils.getDouble( rssJobConf, diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java index f1220486b..9012e618e 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java @@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; +import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.BlockIdLayout; @@ -44,37 +45,61 @@ public class RssMRUtils { private static final Logger LOG = LoggerFactory.getLogger(RssMRUtils.class); private static final BlockIdLayout LAYOUT = BlockIdLayout.DEFAULT; - private static final int MAX_ATTEMPT_LENGTH = 6; - private static final int MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1; - private static final int MAX_SEQUENCE_NO = - (1 << (LAYOUT.sequenceNoBits - MAX_ATTEMPT_LENGTH)) - 1; - - // Class TaskAttemptId have two field id and mapId, rss taskAttemptID have 21 bits, - // mapId is 19 bits, id is 2 bits. MR have a trick logic, taskAttemptId will increase - // 1000 * (appAttemptId - 1), so we will decrease it. - public static long convertTaskAttemptIdToLong(TaskAttemptID taskAttemptID, int appAttemptId) { - int lowBytes = taskAttemptID.getTaskID().getId(); - if (lowBytes > LAYOUT.maxTaskAttemptId) { - throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed"); - } + + // Class TaskAttemptId have two field id and mapId. MR have a trick logic, taskAttemptId will + // increase 1000 * (appAttemptId - 1), so we will decrease it. + public static int createRssTaskAttemptId( + TaskAttemptID taskAttemptID, int appAttemptId, int maxAttemptNo) { + int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo); + if (appAttemptId < 1) { throw new RssException("appAttemptId " + appAttemptId + " is wrong"); } - int highBytes = taskAttemptID.getId() - (appAttemptId - 1) * 1000; - if (highBytes > MAX_ATTEMPT_ID || highBytes < 0) { + int attemptId = taskAttemptID.getId() - (appAttemptId - 1) * 1000; + if (attemptId > maxAttemptNo || attemptId < 0) { throw new RssException( - "TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed"); + "TaskAttempt " + taskAttemptID + " attemptId " + attemptId + " exceed " + maxAttemptNo); } - return LAYOUT.getBlockId(highBytes, 0, lowBytes); + int taskId = taskAttemptID.getTaskID().getId(); + + int mapIndexBits = ClientUtils.getNumberOfSignificantBits(taskId); + if (mapIndexBits + attemptBits > LAYOUT.taskAttemptIdBits) { + throw new RssException( + "Observing taskId[" + + taskId + + "] that would produce a taskAttemptId with " + + (mapIndexBits + attemptBits) + + " bits which is larger than the allowed " + + LAYOUT.taskAttemptIdBits + + "]). Please consider providing more bits for taskAttemptIds."); + } + + return (taskId << attemptBits) | attemptId; + } + + public static int createRssTaskAttemptId( + TaskAttemptID taskAttemptID, int appAttemptId, int maxFailures, boolean speculation) { + int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation); + return createRssTaskAttemptId(taskAttemptID, appAttemptId, maxAttemptNo); + } + + public static int createRssTaskAttemptId( + TaskAttemptID taskAttemptID, int appAttemptId, Configuration conf) { + int maxFailures = conf.getInt(MRJobConfig.MAP_MAX_ATTEMPTS, 4); + boolean speculation = conf.getBoolean(MRJobConfig.MAP_SPECULATIVE, true); + return createRssTaskAttemptId(taskAttemptID, appAttemptId, maxFailures, speculation); } public static TaskAttemptID createMRTaskAttemptId( - JobID jobID, TaskType taskType, long rssTaskAttemptId, int appAttemptId) { + JobID jobID, TaskType taskType, long rssTaskAttemptId, int appAttemptId, int maxAttemptNo) { + int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo); if (appAttemptId < 1) { throw new RssException("appAttemptId " + appAttemptId + " is wrong"); } - TaskID taskID = new TaskID(jobID, taskType, LAYOUT.getTaskAttemptId(rssTaskAttemptId)); - int id = LAYOUT.getSequenceNo(rssTaskAttemptId) + 1000 * (appAttemptId - 1); + int task = (int) rssTaskAttemptId >> attemptBits; + int attempt = (int) rssTaskAttemptId & ((1 << attemptBits) - 1); + TaskID taskID = new TaskID(jobID, taskType, task); + int id = attempt + 1000 * (appAttemptId - 1); return new TaskAttemptID(taskID, id); } @@ -228,27 +253,11 @@ public class RssMRUtils { } public static long getBlockId(int partitionId, long taskAttemptId, int nextSeqNo) { - long attemptId = taskAttemptId >> (LAYOUT.partitionIdBits + LAYOUT.taskAttemptIdBits); - if (attemptId < 0 || attemptId > MAX_ATTEMPT_ID) { - throw new RssException( - "Can't support attemptId [" + attemptId + "], the max value should be " + MAX_ATTEMPT_ID); - } - if (nextSeqNo < 0 || nextSeqNo > MAX_SEQUENCE_NO) { - throw new RssException( - "Can't support sequence [" + nextSeqNo + "], the max value should be " + MAX_SEQUENCE_NO); - } - - int atomicInt = (int) ((nextSeqNo << MAX_ATTEMPT_LENGTH) + attemptId); - long taskId = - taskAttemptId - (attemptId << (LAYOUT.partitionIdBits + LAYOUT.taskAttemptIdBits)); - - return LAYOUT.getBlockId(atomicInt, partitionId, taskId); + return LAYOUT.getBlockId(nextSeqNo, partitionId, taskAttemptId); } - public static long getTaskAttemptId(long blockId) { - int mapId = LAYOUT.getTaskAttemptId(blockId); - int attemptId = LAYOUT.getSequenceNo(blockId) & MAX_ATTEMPT_ID; - return LAYOUT.getBlockId(attemptId, 0, mapId); + public static int getTaskAttemptId(long blockId) { + return LAYOUT.getTaskAttemptId(blockId); } public static int estimateTaskConcurrency(JobConf jobConf) { diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java index 9b86cabee..ece22be74 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssEventFetcher.java @@ -75,7 +75,7 @@ public class RssEventFetcher<K, V> { String errMsg = "TaskAttemptIDs are inconsistent with map tasks"; for (TaskAttemptID taskAttemptID : successMaps) { if (!obsoleteMaps.contains(taskAttemptID)) { - long rssTaskId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, appAttemptId); + long rssTaskId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, appAttemptId, jobConf); int mapIndex = taskAttemptID.getTaskID().getId(); // There can be multiple successful attempts on same map task. // So we only need to accept one of them. diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java index 3cb017239..62f1f86f8 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java @@ -76,7 +76,7 @@ public class SortWriteBufferManagerTest { manager = new SortWriteBufferManager<BytesWritable, BytesWritable>( 10240, - 1L, + 1, 10, serializationFactory.getSerializer(BytesWritable.class), serializationFactory.getSerializer(BytesWritable.class), @@ -140,7 +140,7 @@ public class SortWriteBufferManagerTest { manager = new SortWriteBufferManager<BytesWritable, BytesWritable>( 100, - 1L, + 1, 10, serializationFactory.getSerializer(BytesWritable.class), serializationFactory.getSerializer(BytesWritable.class), @@ -192,7 +192,7 @@ public class SortWriteBufferManagerTest { manager = new SortWriteBufferManager<BytesWritable, BytesWritable>( 10240, - 1L, + 1, 10, serializationFactory.getSerializer(BytesWritable.class), serializationFactory.getSerializer(BytesWritable.class), @@ -244,7 +244,7 @@ public class SortWriteBufferManagerTest { manager = new SortWriteBufferManager<BytesWritable, BytesWritable>( 10240, - 1L, + 1, 10, serializationFactory.getSerializer(BytesWritable.class), serializationFactory.getSerializer(BytesWritable.class), @@ -311,7 +311,7 @@ public class SortWriteBufferManagerTest { manager = new SortWriteBufferManager<BytesWritable, BytesWritable>( 10240, - 1L, + 1, 10, serializationFactory.getSerializer(BytesWritable.class), serializationFactory.getSerializer(BytesWritable.class), @@ -390,7 +390,7 @@ public class SortWriteBufferManagerTest { SortWriteBufferManager<Text, IntWritable> manager = new SortWriteBufferManager<Text, IntWritable>( 10240, - 1L, + 1, 10, keySerializer, valueSerializer, diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java index 2d3710ca1..31e3590d2 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/RssMRUtilsTest.java @@ -45,20 +45,21 @@ public class RssMRUtilsTest { TaskAttemptID mrTaskAttemptId = new TaskAttemptID(taskId, 3); boolean isException = false; try { - RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1); + RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4); } catch (RssException e) { isException = true; } assertTrue(isException); - taskAttemptId = (1 << 20) + 0x123; - mrTaskAttemptId = RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, taskAttemptId, 1); - long testId = RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1); + taskAttemptId = (0x123 << 3) + 1; + mrTaskAttemptId = + RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, taskAttemptId, 1, 4); + int testId = RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4); assertEquals(taskAttemptId, testId); - TaskID taskID = new TaskID(new org.apache.hadoop.mapred.JobID(), TaskType.MAP, (int) (1 << 21)); + TaskID taskID = new TaskID(new org.apache.hadoop.mapred.JobID(), TaskType.MAP, 1 << 21); mrTaskAttemptId = new TaskAttemptID(taskID, 2); isException = false; try { - RssMRUtils.convertTaskAttemptIdToLong(mrTaskAttemptId, 1); + RssMRUtils.createRssTaskAttemptId(mrTaskAttemptId, 1, 4); } catch (RssException e) { isException = true; } @@ -70,7 +71,7 @@ public class RssMRUtilsTest { JobID jobID = new JobID(); TaskID taskId = new TaskID(jobID, TaskType.MAP, 233); TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1); - long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 1); + long taskAttemptId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, 1, 4); long blockId = RssMRUtils.getBlockId(1, taskAttemptId, 0); long newTaskAttemptId = RssMRUtils.getTaskAttemptId(blockId); assertEquals(taskAttemptId, newTaskAttemptId); @@ -85,7 +86,7 @@ public class RssMRUtilsTest { JobID jobID = new JobID(); TaskID taskId = new TaskID(jobID, TaskType.MAP, 233); TaskAttemptID taskAttemptID = new TaskAttemptID(taskId, 1); - long taskAttemptId = RssMRUtils.convertTaskAttemptIdToLong(taskAttemptID, 1); + long taskAttemptId = RssMRUtils.createRssTaskAttemptId(taskAttemptID, 1, 4); long mask = (1L << layout.partitionIdBits) - 1; for (int partitionId = 0; partitionId <= 3000; partitionId++) { for (int seqNo = 0; seqNo <= 10; seqNo++) { diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java index 8e8ab8f4e..6dfb81f45 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/EventFetcherTest.java @@ -59,8 +59,8 @@ public class EventFetcherTest { Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf(); for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) { long rssTaskId = - RssMRUtils.convertTaskAttemptIdToLong( - new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1); + RssMRUtils.createRssTaskAttemptId( + new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4); expected.addLong(rssTaskId); } @@ -89,8 +89,8 @@ public class EventFetcherTest { Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf(); for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) { long rssTaskId = - RssMRUtils.convertTaskAttemptIdToLong( - new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1); + RssMRUtils.createRssTaskAttemptId( + new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4); expected.addLong(rssTaskId); } @@ -121,8 +121,8 @@ public class EventFetcherTest { Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf(); for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) { long rssTaskId = - RssMRUtils.convertTaskAttemptIdToLong( - new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1); + RssMRUtils.createRssTaskAttemptId( + new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4); expected.addLong(rssTaskId); } Roaring64NavigableMap taskIdBitmap = ef.fetchAllRssTaskIds(); @@ -146,8 +146,8 @@ public class EventFetcherTest { Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf(); for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) { long rssTaskId = - RssMRUtils.convertTaskAttemptIdToLong( - new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1); + RssMRUtils.createRssTaskAttemptId( + new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4); expected.addLong(rssTaskId); } IllegalStateException ex = @@ -172,8 +172,8 @@ public class EventFetcherTest { Roaring64NavigableMap expected = Roaring64NavigableMap.bitmapOf(); for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) { long rssTaskId = - RssMRUtils.convertTaskAttemptIdToLong( - new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1); + RssMRUtils.createRssTaskAttemptId( + new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4); expected.addLong(rssTaskId); } IllegalStateException ex = @@ -205,14 +205,14 @@ public class EventFetcherTest { for (int mapIndex = 0; mapIndex < mapTaskNum; mapIndex++) { if (!tipFailed.contains(mapIndex) && !obsoleted.contains(mapIndex)) { long rssTaskId = - RssMRUtils.convertTaskAttemptIdToLong( - new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1); + RssMRUtils.createRssTaskAttemptId( + new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 0), 1, 4); expected.addLong(rssTaskId); } if (obsoleted.contains(mapIndex)) { long rssTaskId = - RssMRUtils.convertTaskAttemptIdToLong( - new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 1), 1); + RssMRUtils.createRssTaskAttemptId( + new TaskAttemptID("12345", 1, TaskType.MAP, mapIndex, 1), 1, 4); expected.addLong(rssTaskId); } } diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java index 6c5f6776d..9ca680cb3 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java @@ -291,7 +291,8 @@ public class FetcherTest { null, new Progress(), new MROutputFiles()); - TaskAttemptID taskAttemptID = RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, 1, 1); + TaskAttemptID taskAttemptID = + RssMRUtils.createMRTaskAttemptId(new JobID(), TaskType.MAP, 1, 1, 4); byte[] buffer = new byte[10]; MapOutput mapOutput1 = merger.reserve(taskAttemptID, 10, 1); RssBypassWriter.write(mapOutput1, buffer); @@ -350,7 +351,7 @@ public class FetcherTest { SortWriteBufferManager<Text, Text> manager = new SortWriteBufferManager( 10240, - 1L, + 1, 10, serializationFactory.getSerializer(Text.class), serializationFactory.getSerializer(Text.class), diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index bbeec90dd..6a281db2e 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -194,8 +194,10 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac + maxPartitions); } - int attemptIdBits = getAttemptIdBits(getMaxAttemptNo(maxFailures, speculation)); - int partitionIdBits = 32 - Integer.numberOfLeadingZeros(maxPartitions - 1); // [1..31] + int attemptIdBits = + ClientUtils.getNumberOfSignificantBits( + ClientUtils.getMaxAttemptNo(maxFailures, speculation)); + int partitionIdBits = ClientUtils.getNumberOfSignificantBits(maxPartitions - 1); // [1..31] int taskAttemptIdBits = partitionIdBits + attemptIdBits; // [1+attemptIdBits..31+attemptIdBits] int sequenceNoBits = 63 - partitionIdBits - taskAttemptIdBits; // [1-attemptIdBits..61] @@ -334,23 +336,6 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac } } - protected static int getMaxAttemptNo(int maxFailures, boolean speculation) { - // attempt number is zero based: 0, 1, …, maxFailures-1 - // max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1 - int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1; - - // with speculative execution enabled we could observe +1 attempts - if (speculation) { - maxAttemptNo++; - } - - return maxAttemptNo; - } - - protected static int getAttemptIdBits(int maxAttemptNo) { - return 32 - Integer.numberOfLeadingZeros(maxAttemptNo); - } - /** See static overload of this method. */ public abstract long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo); @@ -369,8 +354,8 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac */ protected static long getTaskAttemptIdForBlockId( int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) { - int maxAttemptNo = getMaxAttemptNo(maxFailures, speculation); - int attemptBits = getAttemptIdBits(maxAttemptNo); + int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation); + int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo); if (attemptNo > maxAttemptNo) { // this should never happen, if it does, our assumptions are wrong, @@ -384,7 +369,7 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac + "."); } - int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex); + int mapIndexBits = ClientUtils.getNumberOfSignificantBits(mapIndex); if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) { throw new RssException( "Observing mapIndex[" diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java index fffb7af3f..610b42c8c 100644 --- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java +++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBaseTest.java @@ -370,46 +370,6 @@ public class RssShuffleManagerBaseTest { assertTrue(e.getMessage().startsWith("All block id bit config keys must be provided ")); } - @Test - public void testGetMaxAttemptNo() { - // without speculation - assertEquals(0, RssShuffleManagerBase.getMaxAttemptNo(-1, false)); - assertEquals(0, RssShuffleManagerBase.getMaxAttemptNo(0, false)); - assertEquals(0, RssShuffleManagerBase.getMaxAttemptNo(1, false)); - assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(2, false)); - assertEquals(2, RssShuffleManagerBase.getMaxAttemptNo(3, false)); - assertEquals(3, RssShuffleManagerBase.getMaxAttemptNo(4, false)); - assertEquals(4, RssShuffleManagerBase.getMaxAttemptNo(5, false)); - assertEquals(1023, RssShuffleManagerBase.getMaxAttemptNo(1024, false)); - - // with speculation - assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(-1, true)); - assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(0, true)); - assertEquals(1, RssShuffleManagerBase.getMaxAttemptNo(1, true)); - assertEquals(2, RssShuffleManagerBase.getMaxAttemptNo(2, true)); - assertEquals(3, RssShuffleManagerBase.getMaxAttemptNo(3, true)); - assertEquals(4, RssShuffleManagerBase.getMaxAttemptNo(4, true)); - assertEquals(5, RssShuffleManagerBase.getMaxAttemptNo(5, true)); - assertEquals(1024, RssShuffleManagerBase.getMaxAttemptNo(1024, true)); - } - - @Test - public void testGetAttemptIdBits() { - assertEquals(0, RssShuffleManagerBase.getAttemptIdBits(0)); - assertEquals(1, RssShuffleManagerBase.getAttemptIdBits(1)); - assertEquals(2, RssShuffleManagerBase.getAttemptIdBits(2)); - assertEquals(2, RssShuffleManagerBase.getAttemptIdBits(3)); - assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(4)); - assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(5)); - assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(6)); - assertEquals(3, RssShuffleManagerBase.getAttemptIdBits(7)); - assertEquals(4, RssShuffleManagerBase.getAttemptIdBits(8)); - assertEquals(4, RssShuffleManagerBase.getAttemptIdBits(9)); - assertEquals(10, RssShuffleManagerBase.getAttemptIdBits(1023)); - assertEquals(11, RssShuffleManagerBase.getAttemptIdBits(1024)); - assertEquals(11, RssShuffleManagerBase.getAttemptIdBits(1025)); - } - private long bits(String string) { return Long.parseLong(string.replaceAll("[|]", ""), 2); } diff --git a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java index e168ec0d3..3df582bde 100644 --- a/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java +++ b/client-tez/src/main/java/org/apache/tez/common/RssTezUtils.java @@ -24,6 +24,7 @@ import java.util.Set; import com.google.common.base.Preconditions; import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.conf.Configuration; +import org.apache.tez.dag.api.TezConfiguration; import org.apache.tez.dag.records.TezTaskAttemptID; import org.apache.tez.runtime.library.api.TezRuntimeConfiguration; import org.apache.tez.runtime.library.common.InputAttemptIdentifier; @@ -51,6 +52,7 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; +import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.BlockIdLayout; @@ -60,10 +62,6 @@ public class RssTezUtils { private static final Logger LOG = LoggerFactory.getLogger(RssTezUtils.class); private static final BlockIdLayout LAYOUT = BlockIdLayout.DEFAULT; - private static final int MAX_ATTEMPT_LENGTH = 6; - private static final int MAX_ATTEMPT_ID = (1 << MAX_ATTEMPT_LENGTH) - 1; - private static final int MAX_SEQUENCE_NO = - (1 << (LAYOUT.sequenceNoBits - MAX_ATTEMPT_LENGTH)) - 1; public static final String HOST_NAME = "hostname"; @@ -159,32 +157,11 @@ public class RssTezUtils { } public static long getBlockId(int partitionId, long taskAttemptId, int nextSeqNo) { - LOG.info( - "GetBlockId, partitionId:{}, taskAttemptId:{}, nextSeqNo:{}", - partitionId, - taskAttemptId, - nextSeqNo); - long attemptId = taskAttemptId >> (LAYOUT.partitionIdBits + LAYOUT.taskAttemptIdBits); - if (attemptId < 0 || attemptId > MAX_ATTEMPT_ID) { - throw new RssException( - "Can't support attemptId [" + attemptId + "], the max value should be " + MAX_ATTEMPT_ID); - } - if (nextSeqNo < 0 || nextSeqNo > MAX_SEQUENCE_NO) { - throw new RssException( - "Can't support sequence [" + nextSeqNo + "], the max value should be " + MAX_SEQUENCE_NO); - } - - int atomicInt = (int) ((nextSeqNo << MAX_ATTEMPT_LENGTH) + attemptId); - long taskId = - taskAttemptId - (attemptId << (LAYOUT.partitionIdBits + LAYOUT.taskAttemptIdBits)); - - return LAYOUT.getBlockId(atomicInt, partitionId, taskId); + return LAYOUT.getBlockId(nextSeqNo, partitionId, taskAttemptId); } - public static long getTaskAttemptId(long blockId) { - int mapId = LAYOUT.getTaskAttemptId(blockId); - int attemptId = LAYOUT.getSequenceNo(blockId) & MAX_ATTEMPT_ID; - return LAYOUT.getBlockId(attemptId, 0, mapId); + public static int getTaskAttemptId(long blockId) { + return LAYOUT.getTaskAttemptId(blockId); } public static int estimateTaskConcurrency(Configuration jobConf, int mapNum, int reduceNum) { @@ -276,23 +253,55 @@ public class RssTezUtils { } } - public static long convertTaskAttemptIdToLong(TezTaskAttemptID taskAttemptID) { - int lowBytes = taskAttemptID.getTaskID().getId(); - if (lowBytes > LAYOUT.maxTaskAttemptId) { - throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed"); + public static int createRssTaskAttemptId(TezTaskAttemptID taskAttemptId, int maxAttemptNo) { + int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo); + + int attemptId = taskAttemptId.getId(); + if (attemptId > maxAttemptNo || attemptId < 0) { + throw new RssException( + "TaskAttempt " + taskAttemptId + " attemptId " + attemptId + " exceed"); } - int highBytes = taskAttemptID.getId(); - if (highBytes > MAX_ATTEMPT_ID || highBytes < 0) { + int taskId = taskAttemptId.getTaskID().getId(); + + int mapIndexBits = ClientUtils.getNumberOfSignificantBits(taskId); + if (mapIndexBits + attemptBits > LAYOUT.taskAttemptIdBits) { throw new RssException( - "TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed."); + "Observing taskId[" + + taskId + + "] that would produce a taskAttemptId with " + + (mapIndexBits + attemptBits) + + " bits which is larger than the allowed " + + LAYOUT.taskAttemptIdBits + + "]). Please consider providing more bits for taskAttemptIds."); } - long id = LAYOUT.getBlockId(highBytes, 0, lowBytes); - LOG.info("ConvertTaskAttemptIdToLong taskAttemptID:{}, id is {}, .", taskAttemptID, id); + + int id = (taskId << attemptBits) | attemptId; + LOG.info("createRssTaskAttemptId taskAttemptId:{}, id is {}, .", taskAttemptId, id); return id; } + public static int createRssTaskAttemptId(TezTaskAttemptID taskAttemptId, Configuration conf) { + int maxAttemptNo = getMaxAttemptNo(conf); + return createRssTaskAttemptId(taskAttemptId, maxAttemptNo); + } + + public static int getMaxAttemptNo(Configuration conf) { + int maxFailures = + conf.getInt( + TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS, + TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS_DEFAULT); + boolean speculation = + conf.getBoolean( + TezConfiguration.TEZ_AM_SPECULATION_ENABLED, + TezConfiguration.TEZ_AM_SPECULATION_ENABLED_DEFAULT); + return ClientUtils.getMaxAttemptNo(maxFailures, speculation); + } + public static Roaring64NavigableMap fetchAllRssTaskIds( - Set<InputAttemptIdentifier> successMapTaskAttempts, int totalMapsCount, int appAttemptId) { + Set<InputAttemptIdentifier> successMapTaskAttempts, + int totalMapsCount, + int appAttemptId, + int maxAttemptNo) { String errMsg = "TaskAttemptIDs are inconsistent with map tasks"; Roaring64NavigableMap rssTaskIdBitmap = Roaring64NavigableMap.bitmapOf(); Roaring64NavigableMap mapTaskIdBitmap = Roaring64NavigableMap.bitmapOf(); @@ -301,9 +310,9 @@ public class RssTezUtils { for (InputAttemptIdentifier inputAttemptIdentifier : successMapTaskAttempts) { String pathComponent = inputAttemptIdentifier.getPathComponent(); - TezTaskAttemptID mapTaskAttemptID = IdUtils.convertTezTaskAttemptID(pathComponent); - long rssTaskId = RssTezUtils.convertTaskAttemptIdToLong(mapTaskAttemptID); - long mapTaskId = mapTaskAttemptID.getTaskID().getId(); + TezTaskAttemptID mapTaskAttemptId = IdUtils.convertTezTaskAttemptID(pathComponent); + long rssTaskId = RssTezUtils.createRssTaskAttemptId(mapTaskAttemptId, maxAttemptNo); + long mapTaskId = mapTaskAttemptId.getTaskID().getId(); LOG.info( "FetchAllRssTaskIds, pathComponent: {}, mapTaskId:{}, rssTaskId:{}, is contains:{}", diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java index 4fa7a7b98..dfdc6f840 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssShuffleManager.java @@ -67,6 +67,7 @@ import org.apache.hadoop.util.Time; import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; import org.apache.tez.common.CallableWithNdc; import org.apache.tez.common.InputContextUtils; +import org.apache.tez.common.RssTezUtils; import org.apache.tez.common.TezRuntimeFrameworkConfigs; import org.apache.tez.common.TezUtilsInternal; import org.apache.tez.common.UmbilicalUtils; @@ -590,6 +591,8 @@ public class RssShuffleManager extends ShuffleManager { partitionToServers.get(partition), partitionToServers); + int maxAttemptNo = RssTezUtils.getMaxAttemptNo(conf); + RssTezFetcherTask fetcher = new RssTezFetcherTask( RssShuffleManager.this, @@ -604,7 +607,8 @@ public class RssShuffleManager extends ShuffleManager { rssAllBlockIdBitmapMap, rssSuccessBlockIdBitmapMap, numInputs, - partitionToServers.size()); + partitionToServers.size(), + maxAttemptNo); rssRunningFetchers.add(fetcher); if (isShutdown.get()) { LOG.info( diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java index 63922eb64..3052c5dbc 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/RssTezFetcherTask.java @@ -76,6 +76,7 @@ public class RssTezFetcherTask extends CallableWithNdc<FetchResult> { private final int partitionNum; private final int shuffleId; private final ApplicationAttemptId applicationAttemptId; + private final int maxAttemptNo; public RssTezFetcherTask( FetcherCallback fetcherCallback, @@ -90,7 +91,8 @@ public class RssTezFetcherTask extends CallableWithNdc<FetchResult> { Map<Integer, Roaring64NavigableMap> rssAllBlockIdBitmapMap, Map<Integer, Roaring64NavigableMap> rssSuccessBlockIdBitmapMap, int numPhysicalInputs, - int partitionNum) { + int partitionNum, + int maxAttemptNo) { if (CollectionUtils.isEmpty(inputs)) { throw new RssException("inputs should not be empty"); } @@ -135,6 +137,7 @@ public class RssTezFetcherTask extends CallableWithNdc<FetchResult> { conf.getInt( RssTezConfig.RSS_PARTITION_NUM_PER_RANGE, RssTezConfig.RSS_PARTITION_NUM_PER_RANGE_DEFAULT_VALUE); + this.maxAttemptNo = maxAttemptNo; LOG.info( "RssTezFetcherTask fetch partition:{}, with inputs:{}, readBufferSize:{}, partitionNumPerRange:{}.", this.partition, @@ -163,7 +166,8 @@ public class RssTezFetcherTask extends CallableWithNdc<FetchResult> { // final RssEventFetcher eventFetcher = new RssEventFetcher(inputs, numPhysicalInputs); int appAttemptId = applicationAttemptId.getAttemptId(); Roaring64NavigableMap taskIdBitmap = - RssTezUtils.fetchAllRssTaskIds(new HashSet<>(inputs), numPhysicalInputs, appAttemptId); + RssTezUtils.fetchAllRssTaskIds( + new HashSet<>(inputs), numPhysicalInputs, appAttemptId, this.maxAttemptNo); LOG.info( "Inputs:{}, num input:{}, appAttemptId:{}, taskIdBitmap:{}", inputs, diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java index 8fb873649..72a9e1763 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssShuffleScheduler.java @@ -282,6 +282,8 @@ class RssShuffleScheduler extends ShuffleScheduler { private RemoteStorageInfo remoteStorageInfo; private int indexReadLimit; + private final int maxAttemptNo; + RssShuffleScheduler( InputContext inputContext, Configuration conf, @@ -538,6 +540,7 @@ class RssShuffleScheduler extends ShuffleScheduler { this.basePath = this.conf.get(RssTezConfig.RSS_REMOTE_STORAGE_PATH); String remoteStorageConf = this.conf.get(RssTezConfig.RSS_REMOTE_STORAGE_CONF); this.remoteStorageInfo = new RemoteStorageInfo(basePath, remoteStorageConf); + this.maxAttemptNo = RssTezUtils.getMaxAttemptNo(conf); LOG.info( "RSSShuffleScheduler running for sourceVertex: " @@ -1834,7 +1837,8 @@ class RssShuffleScheduler extends ShuffleScheduler { RssTezUtils.fetchAllRssTaskIds( partitionIdToSuccessMapTaskAttempts.get(mapHost.getPartitionId()), this.numInputs, - appAttemptId); + appAttemptId, + maxAttemptNo); LOG.info( "In reduce: {}, RSS Tez client has fetched blockIds and taskIds successfully, partitionId:{}.", diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java index 94c9aa90d..fe4f11e13 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssSorter.java @@ -61,7 +61,8 @@ public class RssSorter extends ExternalSorter { long initialMemoryAvailable, int shuffleId, ApplicationAttemptId applicationAttemptId, - Map<Integer, List<ShuffleServerInfo>> partitionToServers) + Map<Integer, List<ShuffleServerInfo>> partitionToServers, + long taskAttemptId) throws IOException { super(outputContext, conf, numOutputs, initialMemoryAvailable); this.partitionToServers = partitionToServers; @@ -81,7 +82,6 @@ public class RssSorter extends ExternalSorter { conf.getDouble( RssTezConfig.RSS_CLIENT_SORT_MEMORY_USE_THRESHOLD, RssTezConfig.RSS_CLIENT_DEFAULT_SORT_MEMORY_USE_THRESHOLD); - long taskAttemptId = RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptID); long maxSegmentSize = conf.getLong( diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java index 0248bb8a2..94e87a40e 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorter.java @@ -60,7 +60,8 @@ public class RssUnSorter extends ExternalSorter { long initialMemoryAvailable, int shuffleId, ApplicationAttemptId applicationAttemptId, - Map<Integer, List<ShuffleServerInfo>> partitionToServers) + Map<Integer, List<ShuffleServerInfo>> partitionToServers, + long taskAttemptId) throws IOException { super(outputContext, conf, numOutputs, initialMemoryAvailable); this.partitionToServers = partitionToServers; @@ -80,7 +81,6 @@ public class RssUnSorter extends ExternalSorter { conf.getDouble( RssTezConfig.RSS_CLIENT_SORT_MEMORY_USE_THRESHOLD, RssTezConfig.RSS_CLIENT_DEFAULT_SORT_MEMORY_USE_THRESHOLD); - long taskAttemptId = RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptID); long maxSegmentSize = conf.getLong( RssTezConfig.RSS_CLIENT_MAX_BUFFER_SIZE, diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java index 6bf563c9a..997dcdfa6 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java @@ -217,6 +217,7 @@ public class RssOrderedPartitionedKVOutput extends AbstractLogicalOutput { public void start() throws Exception { if (!isStarted.get()) { memoryUpdateCallbackHandler.validateUpdateReceived(); + long rssTaskAttemptId = RssTezUtils.createRssTaskAttemptId(taskAttemptId, conf); sorter = new RssSorter( taskAttemptId, @@ -227,7 +228,8 @@ public class RssOrderedPartitionedKVOutput extends AbstractLogicalOutput { memoryUpdateCallbackHandler.getMemoryAssigned(), shuffleId, applicationAttemptId, - partitionToServers); + partitionToServers, + rssTaskAttemptId); LOG.info("Initialized RssSorter."); isStarted.set(true); } diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java index 6aa32b327..3e08e7944 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java @@ -222,6 +222,7 @@ public class RssUnorderedKVOutput extends AbstractLogicalOutput { public void start() throws Exception { if (!isStarted.get()) { memoryUpdateCallbackHandler.validateUpdateReceived(); + long rssTaskAttemptId = RssTezUtils.createRssTaskAttemptId(taskAttemptId, conf); sorter = new RssUnSorter( taskAttemptId, @@ -232,7 +233,8 @@ public class RssUnorderedKVOutput extends AbstractLogicalOutput { memoryUpdateCallbackHandler.getMemoryAssigned(), shuffleId, applicationAttemptId, - partitionToServers); + partitionToServers, + rssTaskAttemptId); LOG.info("Initialized RssUnSorter."); isStarted.set(true); } diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java index bc78a2f67..7edb5a1f3 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java @@ -220,6 +220,7 @@ public class RssUnorderedPartitionedKVOutput extends AbstractLogicalOutput { public void start() throws Exception { if (!isStarted.get()) { memoryUpdateCallbackHandler.validateUpdateReceived(); + long rssTaskAttemptId = RssTezUtils.createRssTaskAttemptId(taskAttemptId, conf); sorter = new RssUnSorter( taskAttemptId, @@ -230,7 +231,8 @@ public class RssUnorderedPartitionedKVOutput extends AbstractLogicalOutput { memoryUpdateCallbackHandler.getMemoryAssigned(), shuffleId, applicationAttemptId, - partitionToServers); + partitionToServers, + rssTaskAttemptId); LOG.info("Initialized RssUnSorter."); isStarted.set(true); } diff --git a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java index c71dbefcf..12404a14f 100644 --- a/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java +++ b/client-tez/src/test/java/org/apache/tez/common/RssTezUtilsTest.java @@ -56,17 +56,17 @@ public class RssTezUtilsTest { boolean isException = false; try { - RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId); + RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId, 3); } catch (RssException e) { isException = true; } assertTrue(isException); - taskId = TezTaskID.getInstance(vId, (int) (1 << 21)); + taskId = TezTaskID.getInstance(vId, 1 << 21); tezTaskAttemptId = TezTaskAttemptID.getInstance(taskId, 2); isException = false; try { - RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId); + RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId, 3); } catch (RssException e) { isException = true; } @@ -80,7 +80,7 @@ public class RssTezUtilsTest { TezVertexID vId = TezVertexID.getInstance(dagId, 35); TezTaskID tId = TezTaskID.getInstance(vId, 389); TezTaskAttemptID tezTaskAttemptId = TezTaskAttemptID.getInstance(tId, 2); - long taskAttemptId = RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId); + long taskAttemptId = RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId, 3); long blockId = RssTezUtils.getBlockId(1, taskAttemptId, 0); long newTaskAttemptId = RssTezUtils.getTaskAttemptId(blockId); assertEquals(taskAttemptId, newTaskAttemptId); @@ -97,7 +97,7 @@ public class RssTezUtilsTest { TezVertexID vId = TezVertexID.getInstance(dagId, 35); TezTaskID tId = TezTaskID.getInstance(vId, 389); TezTaskAttemptID tezTaskAttemptId = TezTaskAttemptID.getInstance(tId, 2); - long taskAttemptId = RssTezUtils.convertTaskAttemptIdToLong(tezTaskAttemptId); + long taskAttemptId = RssTezUtils.createRssTaskAttemptId(tezTaskAttemptId, 3); long mask = (1L << layout.partitionIdBits) - 1; for (int partitionId = 0; partitionId <= 3000; partitionId++) { for (int seqNo = 0; seqNo <= 10; seqNo++) { diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java index 588abf32d..db9f9cea5 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssSorterTest.java @@ -31,6 +31,7 @@ import org.apache.hadoop.io.Text; import org.apache.hadoop.yarn.api.ApplicationConstants; import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.tez.common.RssTezUtils; import org.apache.tez.common.TezRuntimeFrameworkConfigs; import org.apache.tez.dag.records.TezTaskAttemptID; import org.apache.tez.runtime.api.OutputContext; @@ -87,6 +88,7 @@ public class RssSorterTest { long initialMemoryAvailable = 10240000; int shuffleId = 1001; + long rssTaskAttemptId = RssTezUtils.createRssTaskAttemptId(tezTaskAttemptID, 3); RssSorter rssSorter = new RssSorter( @@ -98,7 +100,8 @@ public class RssSorterTest { initialMemoryAvailable, shuffleId, applicationAttemptId, - partitionToServers); + partitionToServers, + rssTaskAttemptId); rssSorter.collect(new Text("0"), new Text("0"), 0); rssSorter.collect(new Text("0"), new Text("1"), 0); diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java index e54a37fd7..57f3f7626 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/impl/RssUnSorterTest.java @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.Text; import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.tez.common.RssTezUtils; import org.apache.tez.common.TezRuntimeFrameworkConfigs; import org.apache.tez.dag.records.TezTaskAttemptID; import org.apache.tez.runtime.api.OutputContext; @@ -82,6 +83,7 @@ public class RssUnSorterTest { long initialMemoryAvailable = 10240000; int shuffleId = 1001; + long rssTaskAttemptId = RssTezUtils.createRssTaskAttemptId(tezTaskAttemptID, 3); RssUnSorter rssSorter = new RssUnSorter( @@ -93,7 +95,8 @@ public class RssUnSorterTest { initialMemoryAvailable, shuffleId, APPATTEMPT_ID, - partitionToServers); + partitionToServers, + rssTaskAttemptId); rssSorter.collect(new Text("0"), new Text("0"), 0); rssSorter.collect(new Text("0"), new Text("1"), 0); diff --git a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java index b3d40dcde..29fc4b241 100644 --- a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java +++ b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java @@ -112,4 +112,21 @@ public class ClientUtils { String.format("The value of %s should be one of %s", clientType, types)); } } + + public static int getMaxAttemptNo(int maxFailures, boolean speculation) { + // attempt number is zero based: 0, 1, …, maxFailures-1 + // max maxFailures < 1 is not allowed but for safety, we interpret that as maxFailures == 1 + int maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1; + + // with speculative execution enabled we could observe +1 attempts + if (speculation) { + maxAttemptNo++; + } + + return maxAttemptNo; + } + + public static int getNumberOfSignificantBits(int number) { + return 32 - Integer.numberOfLeadingZeros(number); + } } diff --git a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java index dd9ef62b1..611a46b44 100644 --- a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java +++ b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java @@ -36,6 +36,8 @@ import org.apache.uniffle.client.util.DefaultIdHelper; import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.RssUtils; +import static org.apache.uniffle.client.util.ClientUtils.getMaxAttemptNo; +import static org.apache.uniffle.client.util.ClientUtils.getNumberOfSignificantBits; import static org.apache.uniffle.client.util.ClientUtils.waitUntilDoneOrFail; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; @@ -134,4 +136,44 @@ public class ClientUtilsTest { // Ignore } } + + @Test + public void testGetMaxAttemptNo() { + // without speculation + assertEquals(0, getMaxAttemptNo(-1, false)); + assertEquals(0, getMaxAttemptNo(0, false)); + assertEquals(0, getMaxAttemptNo(1, false)); + assertEquals(1, getMaxAttemptNo(2, false)); + assertEquals(2, getMaxAttemptNo(3, false)); + assertEquals(3, getMaxAttemptNo(4, false)); + assertEquals(4, getMaxAttemptNo(5, false)); + assertEquals(1023, getMaxAttemptNo(1024, false)); + + // with speculation + assertEquals(1, getMaxAttemptNo(-1, true)); + assertEquals(1, getMaxAttemptNo(0, true)); + assertEquals(1, getMaxAttemptNo(1, true)); + assertEquals(2, getMaxAttemptNo(2, true)); + assertEquals(3, getMaxAttemptNo(3, true)); + assertEquals(4, getMaxAttemptNo(4, true)); + assertEquals(5, getMaxAttemptNo(5, true)); + assertEquals(1024, getMaxAttemptNo(1024, true)); + } + + @Test + public void testGetNumberOfSignificantBits() { + assertEquals(0, getNumberOfSignificantBits(0)); + assertEquals(1, getNumberOfSignificantBits(1)); + assertEquals(2, getNumberOfSignificantBits(2)); + assertEquals(2, getNumberOfSignificantBits(3)); + assertEquals(3, getNumberOfSignificantBits(4)); + assertEquals(3, getNumberOfSignificantBits(5)); + assertEquals(3, getNumberOfSignificantBits(6)); + assertEquals(3, getNumberOfSignificantBits(7)); + assertEquals(4, getNumberOfSignificantBits(8)); + assertEquals(4, getNumberOfSignificantBits(9)); + assertEquals(10, getNumberOfSignificantBits(1023)); + assertEquals(11, getNumberOfSignificantBits(1024)); + assertEquals(11, getNumberOfSignificantBits(1025)); + } } diff --git a/common/src/main/java/org/apache/uniffle/common/util/BlockId.java b/common/src/main/java/org/apache/uniffle/common/util/BlockId.java index 36025f66b..6c93e6345 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/BlockId.java +++ b/common/src/main/java/org/apache/uniffle/common/util/BlockId.java @@ -32,10 +32,10 @@ public class BlockId { public final BlockIdLayout layout; public final int sequenceNo; public final int partitionId; - public final int taskAttemptId; + public final long taskAttemptId; protected BlockId( - long blockId, BlockIdLayout layout, int sequenceNo, int partitionId, int taskAttemptId) { + long blockId, BlockIdLayout layout, int sequenceNo, int partitionId, long taskAttemptId) { this.blockId = blockId; this.layout = layout; this.sequenceNo = sequenceNo; diff --git a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java index 438bbdf39..d31d842ab 100644 --- a/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java +++ b/integration-test/tez/src/test/java/org/apache/uniffle/test/TezWordCountWithFailuresTest.java @@ -362,7 +362,7 @@ public class TezWordCountWithFailuresTest extends IntegrationTestBase { // verifyMode is 0: avoid recompute succeeded task is true Assertions.assertEquals(0, progressMap.get("Tokenizer").getKilledTaskAttemptCount()); } else if (verifyMode == 1) { - // verifyMode is 1: avoid recompute succeeded task is true + // verifyMode is 1: avoid recompute succeeded task is false Assertions.assertTrue(progressMap.get("Tokenizer").getKilledTaskAttemptCount() > 0); } return 0; diff --git a/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java b/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java index 2314a1b6c..7aef68c15 100644 --- a/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java +++ b/server/src/test/java/org/apache/uniffle/server/buffer/BufferTestBase.java @@ -51,7 +51,7 @@ public abstract class BufferTestBase { return createData(partitionId, 0, len); } - protected ShufflePartitionedData createData(int partitionId, int taskAttemptId, int len) { + protected ShufflePartitionedData createData(int partitionId, long taskAttemptId, int len) { byte[] buf = new byte[len]; new Random().nextBytes(buf); long blockId = diff --git a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java index c11fc27a7..d1b663f1f 100644 --- a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java +++ b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandlerTest.java @@ -108,7 +108,7 @@ public class HadoopShuffleReadHandlerTest extends HadoopTestBase { int totalBlockNum = 0; int expectTotalBlockNum = 6; int blockSize = 7; - int taskAttemptId = 0; + long taskAttemptId = 0; // write expectTotalBlockNum - 1 complete block HadoopShuffleHandlerTestBase.writeTestData(