This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 07cc9ae82 [#2460] feat(spark3): Tracking failure of pushing data for
spark UI (#2481)
07cc9ae82 is described below
commit 07cc9ae82852a9c18bcca36ae96a8b56d906368b
Author: Junfan Zhang <[email protected]>
AuthorDate: Tue May 13 17:22:53 2025 +0800
[#2460] feat(spark3): Tracking failure of pushing data for spark UI (#2481)
### What changes were proposed in this pull request?
Tracking failure of pushing data for spark UI
### Why are the changes needed?
More easiler to inspect data writing
### Does this PR introduce _any_ user-facing change?
Yes.

### How was this patch tested?
Internal tests
---
.../spark/shuffle/events/ShuffleWriteMetric.java | 37 +++++++++++++++++-
.../apache/spark/shuffle/writer/DataPusher.java | 2 +-
.../spark/shuffle/writer/WriteBufferManager.java | 2 +-
.../shuffle/manager/ShuffleManagerGrpcService.java | 5 +--
.../spark/shuffle/writer/DataPusherTest.java | 2 +-
.../scala/org/apache/spark/UniffleListener.scala | 8 +++-
.../org/apache/spark/UniffleStatusStore.scala | 6 ++-
.../scala/org/apache/spark/ui/ShufflePage.scala | 28 +++++++++++---
.../client/impl/ShuffleWriteClientImpl.java | 8 +++-
.../client/response/SendShuffleDataResult.java | 2 +-
.../client/common}/ShuffleServerPushCost.java | 33 +++++++++++++++-
.../common}/ShuffleServerPushCostTracker.java | 27 +++++++++++--
.../client/impl/grpc/ShuffleServerGrpcClient.java | 6 +++
.../RssReportShuffleWriteMetricRequest.java | 44 +++++++++++++++++++---
.../client/request/RssSendShuffleDataRequest.java | 12 +++++-
proto/src/main/proto/Rss.proto | 5 +++
16 files changed, 197 insertions(+), 30 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/events/ShuffleWriteMetric.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/events/ShuffleWriteMetric.java
index 4b24c1fa8..c8939491c 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/events/ShuffleWriteMetric.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/events/ShuffleWriteMetric.java
@@ -17,8 +17,43 @@
package org.apache.spark.shuffle.events;
+import org.apache.uniffle.proto.RssProtos;
+
public class ShuffleWriteMetric extends ShuffleMetric {
- public ShuffleWriteMetric(long durationMillis, long byteSize) {
+ private final long requireBufferFailureNumber;
+ private final long pushFailureNumber;
+ private final String lastFailureReason;
+
+ public ShuffleWriteMetric(
+ long durationMillis,
+ long byteSize,
+ long requireBufferFailureNumber,
+ long pushFailureNumber,
+ String lastFailureReason) {
super(durationMillis, byteSize);
+ this.requireBufferFailureNumber = requireBufferFailureNumber;
+ this.pushFailureNumber = pushFailureNumber;
+ this.lastFailureReason = lastFailureReason;
+ }
+
+ public long getRequireBufferFailureNumber() {
+ return requireBufferFailureNumber;
+ }
+
+ public long getPushFailureNumber() {
+ return pushFailureNumber;
+ }
+
+ public String getLastFailureReason() {
+ return lastFailureReason;
+ }
+
+ public static ShuffleWriteMetric from(RssProtos.ShuffleWriteMetric proto) {
+ return new ShuffleWriteMetric(
+ proto.getDurationMillis(),
+ proto.getByteSize(),
+ proto.getRequireBufferFailureNumber(),
+ proto.getPushFailureNumber(),
+ proto.getLastPushFailureReason());
}
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
index f14583654..5b3455179 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -35,8 +35,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
-import org.apache.uniffle.client.impl.ShuffleServerPushCostTracker;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.exception.RssException;
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 169bb20b1..8b0862b81 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -47,7 +47,7 @@ import org.apache.spark.shuffle.RssSparkConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.uniffle.client.impl.ShuffleServerPushCostTracker;
+import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
index 213f1a774..03998889f 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java
@@ -729,10 +729,7 @@ public class ShuffleManagerGrpcService extends
ShuffleManagerImplBase {
request.getMetricsMap().entrySet().stream()
.collect(
Collectors.toMap(
- Map.Entry::getKey,
- x ->
- new ShuffleWriteMetric(
- x.getValue().getDurationMillis(),
x.getValue().getByteSize()))));
+ Map.Entry::getKey, x ->
ShuffleWriteMetric.from(x.getValue()))));
RssSparkShuffleUtils.getActiveSparkContext().listenerBus().post(event);
RssProtos.ReportShuffleWriteMetricResponse reply =
RssProtos.ReportShuffleWriteMetricResponse.newBuilder()
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
index 8ac07664f..4d0d661c3 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -30,9 +30,9 @@ import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.Test;
+import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
-import org.apache.uniffle.client.impl.ShuffleServerPushCostTracker;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.ShuffleBlockInfo;
diff --git
a/client-spark/extension/src/main/scala/org/apache/spark/UniffleListener.scala
b/client-spark/extension/src/main/scala/org/apache/spark/UniffleListener.scala
index 8c84593ea..560007cbb 100644
---
a/client-spark/extension/src/main/scala/org/apache/spark/UniffleListener.scala
+++
b/client-spark/extension/src/main/scala/org/apache/spark/UniffleListener.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import org.apache.commons.lang3.StringUtils
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent,
SparkListenerJobEnd, SparkListenerJobStart, SparkListenerTaskEnd}
import org.apache.spark.shuffle.events.{ShuffleAssignmentInfoEvent,
TaskShuffleReadInfoEvent, TaskShuffleWriteInfoEvent}
@@ -95,9 +96,14 @@ class UniffleListener(conf: SparkConf, kvstore:
ElementTrackingStore)
val metrics = event.getMetrics
for (metric <- metrics.asScala) {
val id = metric._1
- val agg_metric = this.aggregatedShuffleWriteMetric.computeIfAbsent(id, _
=> new AggregatedShuffleWriteMetric(0, 0))
+ val agg_metric = this.aggregatedShuffleWriteMetric.computeIfAbsent(id, _
=> new AggregatedShuffleWriteMetric(0, 0, 0, 0, ""))
agg_metric.byteSize += metric._2.getByteSize
agg_metric.durationMillis += metric._2.getDurationMillis
+ agg_metric.requireBufferFailureNumber +=
metric._2.getRequireBufferFailureNumber
+ agg_metric.pushFailureNumber += metric._2.getPushFailureNumber
+ if (!StringUtils.isEmpty(metric._2.getLastFailureReason)) {
+ agg_metric.lastPushFailureReason = metric._2.getLastFailureReason
+ }
}
}
diff --git
a/client-spark/extension/src/main/scala/org/apache/spark/UniffleStatusStore.scala
b/client-spark/extension/src/main/scala/org/apache/spark/UniffleStatusStore.scala
index 021311981..17063528f 100644
---
a/client-spark/extension/src/main/scala/org/apache/spark/UniffleStatusStore.scala
+++
b/client-spark/extension/src/main/scala/org/apache/spark/UniffleStatusStore.scala
@@ -105,7 +105,11 @@ class AggregatedShuffleWriteMetricsUIData(val metrics:
ConcurrentHashMap[String,
@KVIndex
def id: String = classOf[AggregatedShuffleWriteMetricsUIData].getName()
}
-class AggregatedShuffleWriteMetric(durationMillis: Long, byteSize: Long)
+class AggregatedShuffleWriteMetric(durationMillis: Long,
+ byteSize: Long,
+ var requireBufferFailureNumber: Long,
+ var pushFailureNumber: Long,
+ var lastPushFailureReason: String)
extends AggregatedShuffleMetric(durationMillis, byteSize)
class AggregatedShuffleReadMetricsUIData(val metrics:
ConcurrentHashMap[String, AggregatedShuffleReadMetric]) {
diff --git
a/client-spark/extension/src/main/scala/org/apache/spark/ui/ShufflePage.scala
b/client-spark/extension/src/main/scala/org/apache/spark/ui/ShufflePage.scala
index 32864da90..d3e811243 100644
---
a/client-spark/extension/src/main/scala/org/apache/spark/ui/ShufflePage.scala
+++
b/client-spark/extension/src/main/scala/org/apache/spark/ui/ShufflePage.scala
@@ -39,7 +39,7 @@ class ShufflePage(parent: ShuffleTab) extends WebUIPage("")
with Logging {
</td>
</tr>
- private def allServerRow(kv: (String, String, String, Double, String,
String, Double)) = <tr>
+ private def allServerRow(kv: (String, String, String, Double, Long, Long,
String, String, String, Double)) = <tr>
<td>{kv._1}</td>
<td>{kv._2}</td>
<td>{kv._3}</td>
@@ -47,6 +47,9 @@ class ShufflePage(parent: ShuffleTab) extends WebUIPage("")
with Logging {
<td>{kv._5}</td>
<td>{kv._6}</td>
<td>{kv._7}</td>
+ <td>{kv._8}</td>
+ <td>{kv._9}</td>
+ <td>{kv._10}</td>
</tr>
private def createShuffleMetricsRows(shuffleWriteMetrics: (Seq[Double],
Seq[String]), shuffleReadMetrics: (Seq[Double], Seq[String])):
Seq[scala.xml.Elem] = {
@@ -137,7 +140,18 @@ class ShufflePage(parent: ShuffleTab) extends
WebUIPage("") with Logging {
originReadMetric.metrics
)
val allServersTableUI = UIUtils.listingTable(
- Seq("Shuffle Server ID", "Write Bytes", "Write Duration", "Write Speed
(MB/sec)", "Read Bytes", "Read Duration", "Read Speed (MB/sec)"),
+ Seq(
+ "Shuffle Server ID",
+ "Write Bytes",
+ "Write Duration",
+ "Write Speed (MB/sec)",
+ "Require Buffer Failures",
+ "Push Failures",
+ "Last push failure reason",
+ "Read Bytes",
+ "Read Duration",
+ "Read Speed (MB/sec)"
+ ),
allServerRow,
allServers,
fixedWidth = true
@@ -294,7 +308,7 @@ class ShufflePage(parent: ShuffleTab) extends WebUIPage("")
with Logging {
}
private def unionByServerId(write: ConcurrentHashMap[String,
AggregatedShuffleWriteMetric],
- read: ConcurrentHashMap[String,
AggregatedShuffleReadMetric]): Seq[(String, String, String, Double, String,
String, Double)] = {
+ read: ConcurrentHashMap[String,
AggregatedShuffleReadMetric]): Seq[(String, String, String, Double, Long, Long,
String, String, String, Double)] = {
val writeMetrics = write.asScala
val readMetrics = read.asScala
val allServerIds = writeMetrics.keySet ++ readMetrics.keySet
@@ -303,7 +317,8 @@ class ShufflePage(parent: ShuffleTab) extends WebUIPage("")
with Logging {
writeMetrics
.mapValues {
metrics =>
- (metrics.byteSize, metrics.durationMillis,
(metrics.byteSize.toDouble / metrics.durationMillis) / 1000.00)
+ (metrics.byteSize, metrics.durationMillis,
(metrics.byteSize.toDouble / metrics.durationMillis) / 1000.00,
+ metrics.requireBufferFailureNumber, metrics.pushFailureNumber,
metrics.lastPushFailureReason)
}
.toMap
val readMetricsToMap =
@@ -315,13 +330,16 @@ class ShufflePage(parent: ShuffleTab) extends
WebUIPage("") with Logging {
.toMap
val unionMetrics = allServerIds.toSeq.map { serverId =>
- val writeMetric = writeMetricsToMap.getOrElse(serverId, (0L, 0L, 0.00))
+ val writeMetric = writeMetricsToMap.getOrElse(serverId, (0L, 0L, 0.00,
0L, 0L, ""))
val readMetric = readMetricsToMap.getOrElse(serverId, (0L, 0L, 0.00))
(
serverId,
Utils.bytesToString(writeMetric._1),
UIUtils.formatDuration(writeMetric._2),
roundToTwoDecimals(writeMetric._3),
+ writeMetric._4,
+ writeMetric._5,
+ writeMetric._6,
Utils.bytesToString(readMetric._1),
UIUtils.formatDuration(readMetric._2),
roundToTwoDecimals(readMetric._3)
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index e7791e103..5756f761b 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -48,6 +48,7 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
import org.apache.uniffle.client.api.ShuffleServerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.factory.ShuffleServerClientFactory;
@@ -201,7 +202,8 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
stageAttemptNumber,
retryMax,
retryIntervalMax,
- shuffleIdToBlocks);
+ shuffleIdToBlocks,
+ shuffleServerPushCostTracker);
long s = System.currentTimeMillis();
RssSendShuffleDataResponse response =
getShuffleServerClient(ssi).sendShuffleData(request);
@@ -234,6 +236,8 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
}
LOG.warn(
"{}, it failed wth statusCode[{}]", logMsg,
response.getStatusCode());
+ shuffleServerPushCostTracker.recordPushFailure(
+ ssi.getId(), response.getStatusCode());
return false;
}
@@ -247,6 +251,8 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
.orElse(0);
shuffleServerPushCostTracker.record(ssi.getId(),
sentBytes, pushDuration);
} catch (Exception e) {
+ shuffleServerPushCostTracker.recordPushFailure(
+ ssi.getId(), StatusCode.INTERNAL_ERROR);
recordFailedBlocks(
failedBlockSendTracker, serverToBlocks, ssi,
StatusCode.INTERNAL_ERROR);
if (defectiveServers != null) {
diff --git
a/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java
b/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java
index 31bd76be0..393fb8dfd 100644
---
a/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java
+++
b/client/src/main/java/org/apache/uniffle/client/response/SendShuffleDataResult.java
@@ -19,8 +19,8 @@ package org.apache.uniffle.client.response;
import java.util.Set;
+import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
-import org.apache.uniffle.client.impl.ShuffleServerPushCostTracker;
public class SendShuffleDataResult {
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleServerPushCost.java
b/internal-client/src/main/java/org/apache/uniffle/client/common/ShuffleServerPushCost.java
similarity index 67%
rename from
client/src/main/java/org/apache/uniffle/client/impl/ShuffleServerPushCost.java
rename to
internal-client/src/main/java/org/apache/uniffle/client/common/ShuffleServerPushCost.java
index 92383a1bf..25ea43b8f 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleServerPushCost.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/common/ShuffleServerPushCost.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.uniffle.client.impl;
+package org.apache.uniffle.client.common;
import java.util.concurrent.atomic.AtomicLong;
@@ -24,10 +24,27 @@ public class ShuffleServerPushCost {
private final AtomicLong sentBytes;
private final AtomicLong sentDurationMs;
+ private final AtomicLong requireBufferFailureCounter;
+ private final AtomicLong pushFailureCounter;
+
+ private String lastPushFailureReason;
+
public ShuffleServerPushCost(String shuffleServerId) {
this.shuffleServerId = shuffleServerId;
this.sentBytes = new AtomicLong();
this.sentDurationMs = new AtomicLong();
+ this.requireBufferFailureCounter = new AtomicLong();
+ this.pushFailureCounter = new AtomicLong();
+ this.lastPushFailureReason = null;
+ }
+
+ public void incRequiredBufferFailure(long delta) {
+ this.requireBufferFailureCounter.addAndGet(delta);
+ }
+
+ public void incSentFailure(long delta, String failureReason) {
+ this.pushFailureCounter.addAndGet(delta);
+ this.lastPushFailureReason = failureReason;
}
public void incSentBytes(long bytes) {
@@ -45,6 +62,8 @@ public class ShuffleServerPushCost {
this.incSentBytes(cost.sentBytes.get());
this.incDurationMs(cost.sentDurationMs.get());
+ this.incRequiredBufferFailure(cost.requireBufferFailureCounter.get());
+ this.incSentFailure(cost.pushFailureCounter.get(),
cost.lastPushFailureReason);
}
public long speed() {
@@ -62,6 +81,18 @@ public class ShuffleServerPushCost {
return sentDurationMs.get();
}
+ public long requiredBufferFailureNumber() {
+ return requireBufferFailureCounter.get();
+ }
+
+ public long pushFailureNumber() {
+ return pushFailureCounter.get();
+ }
+
+ public String getLastPushFailureReason() {
+ return lastPushFailureReason;
+ }
+
@Override
public String toString() {
return "ShuffleServerPushCost{"
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleServerPushCostTracker.java
b/internal-client/src/main/java/org/apache/uniffle/client/common/ShuffleServerPushCostTracker.java
similarity index 79%
rename from
client/src/main/java/org/apache/uniffle/client/impl/ShuffleServerPushCostTracker.java
rename to
internal-client/src/main/java/org/apache/uniffle/client/common/ShuffleServerPushCostTracker.java
index d3d434320..bc6ab049a 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleServerPushCostTracker.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/common/ShuffleServerPushCostTracker.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.uniffle.client.impl;
+package org.apache.uniffle.client.common;
import java.util.ArrayList;
import java.util.Collections;
@@ -30,6 +30,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.request.RssReportShuffleWriteMetricRequest;
+import org.apache.uniffle.common.rpc.StatusCode;
/** This class is to track the underlying assigned shuffle servers' data
pushing speed. */
public class ShuffleServerPushCostTracker {
@@ -53,6 +54,18 @@ public class ShuffleServerPushCostTracker {
}
}
+ public void recordRequireBufferFailure(String id) {
+ ShuffleServerPushCost cost =
+ this.tracking.computeIfAbsent(id, key -> new
ShuffleServerPushCost(key));
+ cost.incRequiredBufferFailure(1);
+ }
+
+ public void recordPushFailure(String id, StatusCode failureStatusCode) {
+ ShuffleServerPushCost cost =
+ this.tracking.computeIfAbsent(id, key -> new
ShuffleServerPushCost(key));
+ cost.incSentFailure(1, failureStatusCode.name());
+ }
+
public void record(String id, long sentBytes, long pushDuration) {
ShuffleServerPushCost cost =
this.tracking.computeIfAbsent(id, key -> new
ShuffleServerPushCost(key));
@@ -97,8 +110,14 @@ public class ShuffleServerPushCostTracker {
.collect(
Collectors.toMap(
Map.Entry::getKey,
- x ->
- new
RssReportShuffleWriteMetricRequest.TaskShuffleWriteMetric(
- x.getValue().sentDurationMillis(),
x.getValue().sentBytes())));
+ x -> {
+ ShuffleServerPushCost cost = x.getValue();
+ return new
RssReportShuffleWriteMetricRequest.TaskShuffleWriteMetric(
+ cost.sentDurationMillis(),
+ cost.sentBytes(),
+ cost.requiredBufferFailureNumber(),
+ cost.pushFailureNumber(),
+ cost.getLastPushFailureReason());
+ }));
}
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 19b7a0b94..d2537e839 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -40,6 +40,7 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.api.ClientInfo;
import org.apache.uniffle.client.api.ShuffleServerClient;
+import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
import org.apache.uniffle.client.request.RetryableRequest;
import org.apache.uniffle.client.request.RssAppHeartBeatRequest;
import org.apache.uniffle.client.request.RssFinishShuffleRequest;
@@ -539,6 +540,7 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks =
request.getShuffleIdToBlocks();
int stageAttemptNumber = request.getStageAttemptNumber();
+ ShuffleServerPushCostTracker costTracker = request.getCostTracker();
boolean isSuccessful = true;
AtomicReference<StatusCode> failedStatusCode = new
AtomicReference<>(StatusCode.INTERNAL_ERROR);
@@ -598,6 +600,10 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
long requireId = allocationResult.getLeft();
needSplitPartitionIds.addAll(allocationResult.getRight());
if (requireId == FAILED_REQUIRE_ID) {
+ ClientInfo clientInfo = getClientInfo();
+ if (clientInfo != null && costTracker != null) {
+
costTracker.recordRequireBufferFailure(clientInfo.getShuffleServerInfo().getId());
+ }
throw new RssException(
String.format(
"requirePreAllocation failed! size[%s], host[%s],
port[%s]",
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleWriteMetricRequest.java
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleWriteMetricRequest.java
index 1c0f75235..d07ac9536 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleWriteMetricRequest.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleWriteMetricRequest.java
@@ -18,6 +18,7 @@
package org.apache.uniffle.client.request;
import java.util.Map;
+import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.uniffle.proto.RssProtos;
@@ -49,11 +50,17 @@ public class RssReportShuffleWriteMetricRequest {
.collect(
Collectors.toMap(
Map.Entry::getKey,
- x ->
- RssProtos.ShuffleWriteMetric.newBuilder()
- .setByteSize(x.getValue().getByteSize())
-
.setDurationMillis(x.getValue().getDurationMillis())
- .build())));
+ x -> {
+ TaskShuffleWriteMetric metric = x.getValue();
+ return RssProtos.ShuffleWriteMetric.newBuilder()
+ .setByteSize(metric.getByteSize())
+ .setDurationMillis(metric.getDurationMillis())
+
.setPushFailureNumber(metric.getPushFailureNumber())
+
.setRequireBufferFailureNumber(metric.getRequireBufferFailureNumber())
+ .setLastPushFailureReason(
+
Optional.ofNullable(metric.getLastFailureReason()).orElse(""))
+ .build();
+ })));
return builder.build();
}
@@ -61,9 +68,22 @@ public class RssReportShuffleWriteMetricRequest {
private long durationMillis;
private long byteSize;
- public TaskShuffleWriteMetric(long durationMillis, long byteSize) {
+ private long requireBufferFailureNumber;
+ private long pushFailureNumber;
+
+ private String lastFailureReason;
+
+ public TaskShuffleWriteMetric(
+ long durationMillis,
+ long byteSize,
+ long requireBufferFailureNumber,
+ long pushFailureNumber,
+ String lastFailureReason) {
this.durationMillis = durationMillis;
this.byteSize = byteSize;
+ this.requireBufferFailureNumber = requireBufferFailureNumber;
+ this.pushFailureNumber = pushFailureNumber;
+ this.lastFailureReason = lastFailureReason;
}
public long getDurationMillis() {
@@ -73,5 +93,17 @@ public class RssReportShuffleWriteMetricRequest {
public long getByteSize() {
return byteSize;
}
+
+ public long getRequireBufferFailureNumber() {
+ return requireBufferFailureNumber;
+ }
+
+ public long getPushFailureNumber() {
+ return pushFailureNumber;
+ }
+
+ public String getLastFailureReason() {
+ return lastFailureReason;
+ }
}
}
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
index 1b5fdcff8..2d883b4ef 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java
@@ -20,6 +20,7 @@ package org.apache.uniffle.client.request;
import java.util.List;
import java.util.Map;
+import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
import org.apache.uniffle.common.ShuffleBlockInfo;
public class RssSendShuffleDataRequest {
@@ -29,13 +30,14 @@ public class RssSendShuffleDataRequest {
private int retryMax;
private long retryIntervalMax;
private Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks;
+ private ShuffleServerPushCostTracker costTracker;
public RssSendShuffleDataRequest(
String appId,
int retryMax,
long retryIntervalMax,
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks) {
- this(appId, 0, retryMax, retryIntervalMax, shuffleIdToBlocks);
+ this(appId, 0, retryMax, retryIntervalMax, shuffleIdToBlocks, null);
}
public RssSendShuffleDataRequest(
@@ -43,12 +45,14 @@ public class RssSendShuffleDataRequest {
int stageAttemptNumber,
int retryMax,
long retryIntervalMax,
- Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks) {
+ Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks,
+ ShuffleServerPushCostTracker costTracker) {
this.appId = appId;
this.retryMax = retryMax;
this.retryIntervalMax = retryIntervalMax;
this.shuffleIdToBlocks = shuffleIdToBlocks;
this.stageAttemptNumber = stageAttemptNumber;
+ this.costTracker = costTracker;
}
public String getAppId() {
@@ -67,6 +71,10 @@ public class RssSendShuffleDataRequest {
return stageAttemptNumber;
}
+ public ShuffleServerPushCostTracker getCostTracker() {
+ return costTracker;
+ }
+
public Map<Integer, Map<Integer, List<ShuffleBlockInfo>>>
getShuffleIdToBlocks() {
return shuffleIdToBlocks;
}
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 6d156c5a5..0df3e0ffd 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -595,6 +595,11 @@ message ReportShuffleWriteMetricRequest {
message ShuffleWriteMetric {
int64 durationMillis = 1;
int64 byteSize = 2;
+
+ int64 requireBufferFailureNumber = 3;
+ int64 pushFailureNumber = 4;
+
+ string lastPushFailureReason = 5;
}
message ShuffleReadMetric {