This is an automated email from the ASF dual-hosted git repository.
binjieyang pushed a commit to branch CELEBORN-1768
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/CELEBORN-1768 by this push:
new a650218e8 update
a650218e8 is described below
commit a650218e80ce4a613057ce71bdc748f2dfbbafa7
Author: binjie yang <[email protected]>
AuthorDate: Tue Jun 3 20:20:15 2025 +0800
update
---
.../celeborn/ColumnarHashBasedShuffleWriter.java | 2 +-
.../spark/shuffle/celeborn/BasedShuffleWriter.java | 242 ++++++++++++++++++++
.../shuffle/celeborn/HashBasedShuffleWriter.java | 229 +++----------------
.../shuffle/celeborn/SortBasedShuffleWriter.java | 249 ++-------------------
.../shuffle/celeborn/SparkShuffleManager.java | 3 +-
.../celeborn/SortBasedShuffleWriterSuiteJ.java | 35 ++-
6 files changed, 315 insertions(+), 445 deletions(-)
diff --git
a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
index b09b1306c..d28673911 100644
---
a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
+++
b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
@@ -132,7 +132,7 @@ public class ColumnarHashBasedShuffleWriter<K, V, C>
extends HashBasedShuffleWri
}
@Override
- protected void closeWrite() throws IOException {
+ protected void closeWrite() throws IOException, InterruptedException {
if (canUseFastWrite() && isColumnarShuffle) {
closeColumnarWrite();
} else {
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java
new file mode 100644
index 000000000..907bea1d1
--- /dev/null
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/BasedShuffleWriter.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+import java.util.concurrent.atomic.LongAdder;
+
+import scala.Option;
+import scala.Product2;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.spark.Partitioner;
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.sql.execution.UnsafeRowSerializer;
+import org.apache.spark.storage.BlockManagerId;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public abstract class BasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
+
+ protected static final ClassTag<Object> OBJECT_CLASS_TAG =
ClassTag$.MODULE$.Object();
+ protected static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
+
+ protected final int PUSH_BUFFER_INIT_SIZE;
+ protected final int PUSH_BUFFER_MAX_SIZE;
+ protected final ShuffleDependency<K, V, C> dep;
+ protected final Partitioner partitioner;
+ protected final ShuffleWriteMetricsReporter writeMetrics;
+ protected final int shuffleId;
+ protected final int mapId;
+ protected final int encodedAttemptId;
+ protected final TaskContext taskContext;
+ protected final ShuffleClient shuffleClient;
+ protected final int numMappers;
+ protected final int numPartitions;
+ protected final OpenByteArrayOutputStream serBuffer;
+ protected final SerializationStream serOutputStream;
+ private final boolean unsafeRowFastWrite;
+
+ protected final LongAdder[] mapStatusLengths;
+
+ /**
+ * Are we in the process of stopping? Because map tasks can call stop() with
success = true and
+ * then call stop() with success = false if they get an exception, we want
to make sure we don't
+ * try deleting files, etc. twice.
+ */
+ private volatile boolean stopping = false;
+
+ protected long peakMemoryUsedBytes = 0;
+ protected long tmpRecordsWritten = 0;
+
+ public BasedShuffleWriter(
+ int shuffleId,
+ CelebornShuffleHandle<K, V, C> handle,
+ TaskContext taskContext,
+ CelebornConf conf,
+ ShuffleClient client,
+ ShuffleWriteMetricsReporter metrics) {
+ PUSH_BUFFER_INIT_SIZE = conf.clientPushBufferInitialSize();
+ PUSH_BUFFER_MAX_SIZE = conf.clientPushBufferMaxSize();
+ this.dep = handle.dependency();
+ this.partitioner = dep.partitioner();
+ this.writeMetrics = metrics;
+ this.shuffleId = shuffleId;
+ this.mapId = taskContext.partitionId();
+ // [CELEBORN-1496] using the encoded attempt number instead of task
attempt number
+ this.encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
+ this.taskContext = taskContext;
+ this.shuffleClient = client;
+ this.numMappers = handle.numMappers();
+ this.numPartitions = dep.partitioner().numPartitions();
+ SerializerInstance serializer = dep.serializer().newInstance();
+ serBuffer = new OpenByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
+ serOutputStream = serializer.serializeStream(serBuffer);
+ unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite();
+
+ mapStatusLengths = new LongAdder[numPartitions];
+ for (int i = 0; i < numPartitions; i++) {
+ mapStatusLengths[i] = new LongAdder();
+ }
+ }
+
+ protected void doWrite(scala.collection.Iterator<Product2<K, V>> records)
+ throws IOException, InterruptedException {
+ if (canUseFastWrite()) {
+ fastWrite0(records);
+ } else if (dep.mapSideCombine()) {
+ if (dep.aggregator().isEmpty()) {
+ throw new UnsupportedOperationException(
+ "When using map side combine, an aggregator must be specified.");
+ }
+ write0(dep.aggregator().get().combineValuesByKey(records, taskContext));
+ } else {
+ write0(records);
+ }
+ }
+
+ @Override
+ public void write(Iterator<Product2<K, V>> records) throws IOException {
+ boolean needCleanupPusher = true;
+ try {
+ doWrite(records);
+ close();
+ needCleanupPusher = false;
+ } catch (InterruptedException e) {
+ TaskInterruptedHelper.throwTaskKillException();
+ } finally {
+ if (needCleanupPusher) {
+ cleanupPusher();
+ }
+ }
+ }
+
+ @Override
+ public Option<MapStatus> stop(boolean success) {
+ try {
+
taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes());
+
+ if (stopping) {
+ return Option.empty();
+ } else {
+ stopping = true;
+ if (success) {
+ BlockManagerId bmId =
SparkEnv.get().blockManager().shuffleServerId();
+ MapStatus mapStatus =
+ SparkUtils.createMapStatus(
+ bmId, SparkUtils.unwrap(mapStatusLengths),
taskContext.taskAttemptId());
+ if (mapStatus == null) {
+ throw new IllegalStateException("Cannot call stop(true) without
having called write()");
+ }
+ return Option.apply(mapStatus);
+ } else {
+ return Option.empty();
+ }
+ }
+ } finally {
+ shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
+ }
+ }
+
+ // Added in SPARK-32917, for Spark 3.2 and above
+ @SuppressWarnings("MissingOverride")
+ public long[] getPartitionLengths() {
+ throw new UnsupportedOperationException(
+ "Celeborn is not compatible with Spark push mode, please set
spark.shuffle.push.enabled to false");
+ }
+
+ abstract void fastWrite0(scala.collection.Iterator iterator)
+ throws IOException, InterruptedException;
+
+ abstract void write0(scala.collection.Iterator iterator) throws IOException,
InterruptedException;
+
+ abstract void updatePeakMemoryUsed();
+
+ abstract void cleanupPusher() throws IOException;
+
+ abstract void closeWrite() throws IOException, InterruptedException;
+
+ @VisibleForTesting
+ boolean canUseFastWrite() {
+ boolean keyIsPartitionId = false;
+ if (unsafeRowFastWrite && dep.serializer() instanceof UnsafeRowSerializer)
{
+ // SPARK-39391 renames PartitionIdPassthrough's package
+ String partitionerClassName = partitioner.getClass().getSimpleName();
+ keyIsPartitionId = "PartitionIdPassthrough".equals(partitionerClassName);
+ }
+ return keyIsPartitionId;
+ }
+
+ /** Return the peak memory used so far, in bytes. */
+ public long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
+ }
+
+ protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes)
throws IOException {
+ long start = System.nanoTime();
+ int bytesWritten =
+ shuffleClient.pushData(
+ shuffleId,
+ mapId,
+ encodedAttemptId,
+ partitionId,
+ buffer,
+ 0,
+ numBytes,
+ numMappers,
+ numPartitions);
+ long delta = System.nanoTime() - start;
+ mapStatusLengths[partitionId].add(bytesWritten);
+ writeMetrics.incBytesWritten(bytesWritten);
+ writeMetrics.incWriteTime(delta);
+ }
+
+ /**
+ * This method will push the remaining data and close these pushers. It's
important, will send
+ * Mapper End RPC to LifecycleManager to update the attempt of the
corresponding task. We should
+ * only call this method when the task is successfully completed.
+ */
+ protected void close() throws IOException, InterruptedException {
+ long pushMergedDataTime = System.nanoTime();
+ closeWrite();
+ shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
+ writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
+ updateRecordsWrittenMetrics();
+
+ long waitStartTime = System.nanoTime();
+ shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
+ writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
+ }
+
+ protected void updateRecordsWrittenMetrics() {
+ writeMetrics.incRecordsWritten(tmpRecordsWritten);
+ tmpRecordsWritten = 0;
+ }
+}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index c423a97ce..e490fe4c2 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -19,85 +19,31 @@ package org.apache.spark.shuffle.celeborn;
import java.io.IOException;
import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.atomic.LongAdder;
-
-import javax.annotation.Nullable;
-
-import scala.Option;
-import scala.Product2;
-import scala.reflect.ClassTag;
-import scala.reflect.ClassTag$;
-
-import com.google.common.annotations.VisibleForTesting;
-import org.apache.spark.Partitioner;
-import org.apache.spark.ShuffleDependency;
-import org.apache.spark.SparkEnv;
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.client.write.DataPusher;
+import org.apache.celeborn.client.write.PushTask;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.util.Utils;
import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
-import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.serializer.SerializationStream;
-import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
-import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
import org.apache.spark.sql.execution.metric.SQLMetric;
-import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.unsafe.Platform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-
-import org.apache.celeborn.client.ShuffleClient;
-import org.apache.celeborn.client.write.DataPusher;
-import org.apache.celeborn.client.write.PushTask;
-import org.apache.celeborn.common.CelebornConf;
-import org.apache.celeborn.common.util.Utils;
+import scala.Product2;
@Private
-public class HashBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
+public class HashBasedShuffleWriter<K, V, C> extends BasedShuffleWriter<K, V,
C> {
private static final Logger logger =
LoggerFactory.getLogger(HashBasedShuffleWriter.class);
- private static final ClassTag<Object> OBJECT_CLASS_TAG =
ClassTag$.MODULE$.Object();
- private static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
-
- private final int PUSH_BUFFER_INIT_SIZE;
- private final int PUSH_BUFFER_MAX_SIZE;
- private final ShuffleDependency<K, V, C> dep;
- private final Partitioner partitioner;
- private final ShuffleWriteMetricsReporter writeMetrics;
- private final int shuffleId;
- private final int mapId;
- private final int encodedAttemptId;
- private final TaskContext taskContext;
- private final ShuffleClient shuffleClient;
- private final int numMappers;
- private final int numPartitions;
-
- @Nullable private MapStatus mapStatus;
- private long peakMemoryUsedBytes = 0;
-
- private final OpenByteArrayOutputStream serBuffer;
- private final SerializationStream serOutputStream;
-
private byte[][] sendBuffers;
private int[] sendOffsets;
-
- private final LongAdder[] mapStatusLengths;
- protected long tmpRecordsWritten = 0;
-
- private final SendBufferPool sendBufferPool;
-
- /**
- * Are we in the process of stopping? Because map tasks can call stop() with
success = true and
- * then call stop() with success = false if they get an exception, we want
to make sure we don't
- * try deleting files, etc. twice.
- */
- private volatile boolean stopping = false;
-
private DataPusher dataPusher;
-
- private final boolean unsafeRowFastWrite;
+ private final SendBufferPool sendBufferPool;
// In order to facilitate the writing of unit test code, ShuffleClient needs
to be passed in as
// parameters. By the way, simplify the passed parameters.
@@ -110,31 +56,9 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
ShuffleWriteMetricsReporter metrics,
SendBufferPool sendBufferPool)
throws IOException {
- this.mapId = taskContext.partitionId();
- this.dep = handle.dependency();
- this.shuffleId = shuffleId;
- this.encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
- SerializerInstance serializer = dep.serializer().newInstance();
- this.partitioner = dep.partitioner();
- this.writeMetrics = metrics;
- this.taskContext = taskContext;
- this.numMappers = handle.numMappers();
- this.numPartitions = dep.partitioner().numPartitions();
- this.shuffleClient = client;
-
- unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite();
- serBuffer = new OpenByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
- serOutputStream = serializer.serializeStream(serBuffer);
-
- mapStatusLengths = new LongAdder[numPartitions];
- for (int i = 0; i < numPartitions; i++) {
- mapStatusLengths[i] = new LongAdder();
- }
-
- PUSH_BUFFER_INIT_SIZE = conf.clientPushBufferInitialSize();
- PUSH_BUFFER_MAX_SIZE = conf.clientPushBufferMaxSize();
-
+ super(shuffleId, handle, taskContext, conf, client, metrics);
this.sendBufferPool = sendBufferPool;
+
sendBuffers = sendBufferPool.acquireBuffer(numPartitions);
sendOffsets = new int[numPartitions];
@@ -159,42 +83,6 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
@Override
- public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
- boolean needCleanupPusher = true;
- try {
- if (canUseFastWrite()) {
- fastWrite0(records);
- } else if (dep.mapSideCombine()) {
- if (dep.aggregator().isEmpty()) {
- throw new UnsupportedOperationException(
- "When using map side combine, an aggregator must be specified.");
- }
- write0(dep.aggregator().get().combineValuesByKey(records,
taskContext));
- } else {
- write0(records);
- }
- close();
- needCleanupPusher = false;
- } catch (InterruptedException e) {
- TaskInterruptedHelper.throwTaskKillException();
- } finally {
- if (needCleanupPusher) {
- cleanupPusher();
- }
- }
- }
-
- @VisibleForTesting
- boolean canUseFastWrite() {
- boolean keyIsPartitionId = false;
- if (unsafeRowFastWrite && dep.serializer() instanceof UnsafeRowSerializer)
{
- // SPARK-39391 renames PartitionIdPassthrough's package
- String partitionerClassName = partitioner.getClass().getSimpleName();
- keyIsPartitionId = "PartitionIdPassthrough".equals(partitionerClassName);
- }
- return keyIsPartitionId;
- }
-
protected void fastWrite0(scala.collection.Iterator iterator)
throws IOException, InterruptedException {
final scala.collection.Iterator<Product2<Integer, UnsafeRow>> records =
iterator;
@@ -238,7 +126,9 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
- private void write0(scala.collection.Iterator iterator) throws IOException,
InterruptedException {
+ @Override
+ protected void write0(scala.collection.Iterator iterator)
+ throws IOException, InterruptedException {
final scala.collection.Iterator<Product2<K, ?>> records = iterator;
while (records.hasNext()) {
@@ -265,6 +155,11 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
+ @Override
+ void updatePeakMemoryUsed() {
+ // do nothing, hash shuffle writer always update this used peak memory
+ }
+
private byte[] getOrCreateBuffer(int partitionId) {
byte[] buffer = sendBuffers[partitionId];
if (buffer == null) {
@@ -275,26 +170,6 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
return buffer;
}
- protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes)
throws IOException {
- logger.debug("Push giant record, size {}.", numBytes);
- long start = System.nanoTime();
- int bytesWritten =
- shuffleClient.pushData(
- shuffleId,
- mapId,
- encodedAttemptId,
- partitionId,
- buffer,
- 0,
- numBytes,
- numMappers,
- numPartitions);
- long delta = System.nanoTime() - start;
- mapStatusLengths[partitionId].add(bytesWritten);
- writeMetrics.incBytesWritten(bytesWritten);
- writeMetrics.incWriteTime(delta);
- }
-
private int getOrUpdateOffset(int partitionId, int serializedRecordSize)
throws IOException, InterruptedException {
int offset = sendOffsets[partitionId];
@@ -325,7 +200,12 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incWriteTime(System.nanoTime() - start);
}
- protected void closeWrite() throws IOException {
+ @Override
+ protected void closeWrite() throws IOException, InterruptedException {
+ // here we wait for all the in-flight batches to return which sent by
dataPusher thread
+ dataPusher.waitOnTermination();
+ sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());
+ shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
// merge and push residual data to reduce network traffic
// NB: since dataPusher thread have no in-flight data at this point,
// we now push merged data by task thread will not introduce any
contention
@@ -333,6 +213,8 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
final int size = sendOffsets[i];
if (size > 0) {
mergeData(i, sendBuffers[i], 0, size);
+ // free buffer
+ sendBuffers[i] = null;
}
}
sendBufferPool.returnBuffer(sendBuffers);
@@ -357,7 +239,8 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incBytesWritten(bytesWritten);
}
- private void cleanupPusher() throws IOException {
+ @Override
+ protected void cleanupPusher() throws IOException {
try {
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());
@@ -365,60 +248,4 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
TaskInterruptedHelper.throwTaskKillException();
}
}
-
- private void close() throws IOException, InterruptedException {
- // here we wait for all the in-flight batches to return which sent by
dataPusher thread
- long pushMergedDataTime = System.nanoTime();
- dataPusher.waitOnTermination();
- sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());
- shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
- closeWrite();
- shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
- writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
- updateRecordsWrittenMetrics();
-
- long waitStartTime = System.nanoTime();
- shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
- writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
-
- BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
- mapStatus =
- SparkUtils.createMapStatus(
- bmId, SparkUtils.unwrap(mapStatusLengths),
taskContext.taskAttemptId());
- }
-
- private void updateRecordsWrittenMetrics() {
- writeMetrics.incRecordsWritten(tmpRecordsWritten);
- tmpRecordsWritten = 0;
- }
-
- @Override
- public Option<MapStatus> stop(boolean success) {
- try {
- taskContext.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes);
-
- if (stopping) {
- return Option.empty();
- } else {
- stopping = true;
- if (success) {
- if (mapStatus == null) {
- throw new IllegalStateException("Cannot call stop(true) without
having called write()");
- }
- return Option.apply(mapStatus);
- } else {
- return Option.empty();
- }
- }
- } finally {
- shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
- }
- }
-
- // Added in SPARK-32917, for Spark 3.2 and above
- @SuppressWarnings("MissingOverride")
- public long[] getPartitionLengths() {
- throw new UnsupportedOperationException(
- "Celeborn is not compatible with Spark push mode, please set
spark.shuffle.push.enabled to false");
- }
}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 3346deb2a..b6fd2f407 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -18,28 +18,15 @@
package org.apache.spark.shuffle.celeborn;
import java.io.IOException;
-import java.util.concurrent.atomic.LongAdder;
-import scala.Option;
import scala.Product2;
-import scala.reflect.ClassTag;
-import scala.reflect.ClassTag$;
-import com.google.common.annotations.VisibleForTesting;
-import org.apache.spark.Partitioner;
-import org.apache.spark.ShuffleDependency;
-import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
-import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.serializer.SerializationStream;
-import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
-import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
import org.apache.spark.sql.execution.metric.SQLMetric;
-import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.unsafe.Platform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -47,96 +34,40 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
-import org.apache.celeborn.common.util.Utils;
@Private
-public class SortBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
+public class SortBasedShuffleWriter<K, V, C> extends BasedShuffleWriter<K, V,
C> {
private static final Logger logger =
LoggerFactory.getLogger(SortBasedShuffleWriter.class);
-
- private static final ClassTag<Object> OBJECT_CLASS_TAG =
ClassTag$.MODULE$.Object();
- private static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
-
- private final ShuffleDependency<K, V, C> dep;
- private final Partitioner partitioner;
- private final ShuffleWriteMetricsReporter writeMetrics;
- private final int shuffleId;
- private final int mapId;
- private final int encodedAttemptId;
- private final TaskContext taskContext;
- private final ShuffleClient shuffleClient;
- private final int numMappers;
- private final int numPartitions;
-
- private final long pushBufferMaxSize;
+ private final SendBufferPool sendBufferPool;
private final SortBasedPusher pusher;
- private long peakMemoryUsedBytes = 0;
-
- private final OpenByteArrayOutputStream serBuffer;
- private final SerializationStream serOutputStream;
-
- private final LongAdder[] mapStatusLengths;
- private long tmpRecordsWritten = 0;
-
- /**
- * Are we in the process of stopping? Because map tasks can call stop() with
success = true and
- * then call stop() with success = false if they get an exception, we want
to make sure we don't
- * try deleting files, etc. twice.
- */
- private volatile boolean stopping = false;
-
- private final boolean unsafeRowFastWrite;
public SortBasedShuffleWriter(
int shuffleId,
- ShuffleDependency<K, V, C> dep,
- int numMappers,
+ CelebornShuffleHandle<K, V, C> handle,
TaskContext taskContext,
CelebornConf conf,
ShuffleClient client,
ShuffleWriteMetricsReporter metrics,
SendBufferPool sendBufferPool)
throws IOException {
- this(shuffleId, dep, numMappers, taskContext, conf, client, metrics,
sendBufferPool, null);
+ this(shuffleId, handle, taskContext, conf, client, metrics,
sendBufferPool, null);
}
// In order to facilitate the writing of unit test code, ShuffleClient needs
to be passed in as
// parameters. By the way, simplify the passed parameters.
public SortBasedShuffleWriter(
int shuffleId,
- ShuffleDependency<K, V, C> dep,
- int numMappers,
+ CelebornShuffleHandle<K, V, C> handle,
TaskContext taskContext,
CelebornConf conf,
ShuffleClient client,
ShuffleWriteMetricsReporter metrics,
SendBufferPool sendBufferPool,
- SortBasedPusher pusher)
- throws IOException {
- this.mapId = taskContext.partitionId();
- this.dep = dep;
- this.shuffleId = shuffleId;
- this.encodedAttemptId =
SparkCommonUtils.getEncodedAttemptNumber(taskContext);
- SerializerInstance serializer = dep.serializer().newInstance();
- this.partitioner = dep.partitioner();
- this.writeMetrics = metrics;
- this.taskContext = taskContext;
- this.numMappers = numMappers;
- this.numPartitions = dep.partitioner().numPartitions();
- this.shuffleClient = client;
- unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite();
-
- serBuffer = new OpenByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
- serOutputStream = serializer.serializeStream(serBuffer);
-
- this.mapStatusLengths = new LongAdder[numPartitions];
- for (int i = 0; i < numPartitions; i++) {
- this.mapStatusLengths[i] = new LongAdder();
- }
-
- pushBufferMaxSize = conf.clientPushBufferMaxSize();
-
+ SortBasedPusher pusher) {
+ super(shuffleId, handle, taskContext, conf, client, metrics);
+ this.sendBufferPool = sendBufferPool;
if (pusher == null) {
this.pusher =
new SortBasedPusher(
@@ -159,99 +90,16 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
- public SortBasedShuffleWriter(
- CelebornShuffleHandle<K, V, C> handle,
- TaskContext taskContext,
- CelebornConf conf,
- ShuffleClient client,
- ShuffleWriteMetricsReporter metrics,
- SendBufferPool sendBufferPool)
- throws IOException {
- this(
- SparkUtils.celebornShuffleId(client, handle, taskContext, true),
- handle.dependency(),
- handle.numMappers(),
- taskContext,
- conf,
- client,
- metrics,
- sendBufferPool);
- }
-
- public SortBasedShuffleWriter(
- CelebornShuffleHandle<K, V, C> handle,
- TaskContext taskContext,
- CelebornConf conf,
- ShuffleClient client,
- ShuffleWriteMetricsReporter metrics,
- SendBufferPool sendBufferPool,
- SortBasedPusher pusher)
- throws IOException {
- this(
- SparkUtils.celebornShuffleId(client, handle, taskContext, true),
- handle.dependency(),
- handle.numMappers(),
- taskContext,
- conf,
- client,
- metrics,
- sendBufferPool,
- pusher);
- }
-
- private void updatePeakMemoryUsed() {
+ @Override
+ protected void updatePeakMemoryUsed() {
long mem = pusher.getPeakMemoryUsedBytes();
if (mem > peakMemoryUsedBytes) {
peakMemoryUsedBytes = mem;
}
}
- /** Return the peak memory used so far, in bytes. */
- public long getPeakMemoryUsedBytes() {
- updatePeakMemoryUsed();
- return peakMemoryUsedBytes;
- }
-
- void doWrite(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
- if (canUseFastWrite()) {
- fastWrite0(records);
- } else if (dep.mapSideCombine()) {
- if (dep.aggregator().isEmpty()) {
- throw new UnsupportedOperationException(
- "When using map side combine, an aggregator must be specified.");
- }
- write0(dep.aggregator().get().combineValuesByKey(records, taskContext));
- } else {
- write0(records);
- }
- }
-
@Override
- public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
- boolean needCleanupPusher = true;
- try {
- doWrite(records);
- close();
- needCleanupPusher = false;
- } finally {
- if (needCleanupPusher) {
- cleanupPusher();
- }
- }
- }
-
- @VisibleForTesting
- boolean canUseFastWrite() {
- boolean keyIsPartitionId = false;
- if (unsafeRowFastWrite && dep.serializer() instanceof UnsafeRowSerializer)
{
- // SPARK-39391 renames PartitionIdPassthrough's package
- String partitionerClassName = partitioner.getClass().getSimpleName();
- keyIsPartitionId = "PartitionIdPassthrough".equals(partitionerClassName);
- }
- return keyIsPartitionId;
- }
-
- private void fastWrite0(scala.collection.Iterator iterator) throws
IOException {
+ protected void fastWrite0(scala.collection.Iterator iterator) throws
IOException {
final scala.collection.Iterator<Product2<Integer, UnsafeRow>> records =
iterator;
SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer)
dep.serializer());
@@ -267,7 +115,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
dataSize.add(serializedRecordSize);
}
- if (serializedRecordSize > pushBufferMaxSize) {
+ if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) {
byte[] giantBuffer = new byte[serializedRecordSize];
Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET,
Integer.reverseBytes(rowSize));
Platform.copyMemory(
@@ -301,7 +149,8 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incWriteTime(System.nanoTime() - start);
}
- private void write0(scala.collection.Iterator iterator) throws IOException {
+ @Override
+ protected void write0(scala.collection.Iterator iterator) throws IOException
{
final scala.collection.Iterator<Product2<K, ?>> records = iterator;
while (records.hasNext()) {
@@ -316,7 +165,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
final int serializedRecordSize = serBuffer.size();
assert (serializedRecordSize > 0);
- if (serializedRecordSize > pushBufferMaxSize) {
+ if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) {
pushGiantRecord(partitionId, serBuffer.getBuf(), serializedRecordSize);
} else {
boolean success =
@@ -344,78 +193,18 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
- private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes)
throws IOException {
- logger.debug("Push giant record, size {}.", Utils.bytesToString(numBytes));
- long start = System.nanoTime();
- int bytesWritten =
- shuffleClient.pushData(
- shuffleId,
- mapId,
- encodedAttemptId,
- partitionId,
- buffer,
- 0,
- numBytes,
- numMappers,
- numPartitions);
- long delta = System.nanoTime() - start;
- mapStatusLengths[partitionId].add(bytesWritten);
- writeMetrics.incBytesWritten(bytesWritten);
- writeMetrics.incWriteTime(delta);
- }
-
- private void cleanupPusher() throws IOException {
+ @Override
+ protected void cleanupPusher() throws IOException {
if (pusher != null) {
pusher.close(false);
}
}
- private void close() throws IOException {
- logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
+ @Override
+ protected void closeWrite() throws IOException, InterruptedException {
long pushStartTime = System.nanoTime();
pusher.pushData(false);
pusher.close(true);
-
- shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
- writeMetrics.incRecordsWritten(tmpRecordsWritten);
-
- long waitStartTime = System.nanoTime();
- shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
- writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
- }
-
- @Override
- public Option<MapStatus> stop(boolean success) {
- try {
-
taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes());
-
- if (stopping) {
- return Option.empty();
- } else {
- stopping = true;
- if (success) {
- BlockManagerId bmId =
SparkEnv.get().blockManager().shuffleServerId();
- MapStatus mapStatus =
- SparkUtils.createMapStatus(
- bmId, SparkUtils.unwrap(mapStatusLengths),
taskContext.taskAttemptId());
- if (mapStatus == null) {
- throw new IllegalStateException("Cannot call stop(true) without
having called write()");
- }
- return Option.apply(mapStatus);
- } else {
- return Option.empty();
- }
- }
- } finally {
- shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
- }
- }
-
- // Added in SPARK-32917, for Spark 3.2 and above
- @SuppressWarnings("MissingOverride")
- public long[] getPartitionLengths() {
- throw new UnsupportedOperationException(
- "Celeborn is not compatible with push-based shuffle, please set
spark.shuffle.push.enabled to false");
}
}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 80ea5c256..43b8d8635 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -320,8 +320,7 @@ public class SparkShuffleManager implements ShuffleManager {
if (ShuffleMode.SORT.equals(shuffleMode)) {
return new SortBasedShuffleWriter<>(
shuffleId,
- h.dependency(),
- h.numMappers(),
+ h,
context,
celebornConf,
shuffleClient,
diff --git
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
index 0963737c0..41fe873a1 100644
---
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
+++
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
@@ -64,7 +64,13 @@ public class SortBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase
ShuffleWriteMetricsReporter metrics)
throws IOException {
return new SortBasedShuffleWriter<Integer, String, String>(
- handle, context, conf, client, metrics, SendBufferPool.get(4, 30, 60));
+ SparkUtils.celebornShuffleId(client, handle, taskContext, true),
+ handle,
+ context,
+ conf,
+ client,
+ metrics,
+ SendBufferPool.get(4, 30, 60));
}
private SortBasedShuffleWriter<Integer, String, String>
createShuffleWriterWithPusher(
@@ -76,7 +82,14 @@ public class SortBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase
SortBasedPusher pusher)
throws Exception {
return new SortBasedShuffleWriter<Integer, String, String>(
- handle, context, conf, client, metrics, SendBufferPool.get(4, 30, 60),
pusher);
+ SparkUtils.celebornShuffleId(client, handle, taskContext, true),
+ handle,
+ context,
+ conf,
+ client,
+ metrics,
+ SendBufferPool.get(4, 30, 60),
+ pusher);
}
private SortBasedPusher createSortBasedPusher(
@@ -98,18 +111,18 @@ public class SortBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase
SortBasedPusher pusher =
new SortBasedPusher(
taskMemoryManager,
- /*shuffleClient=*/ client,
- /*taskContext=*/ taskContext,
- /*shuffleId=*/ 0,
- /*mapId=*/ 0,
- /*attemptNumber=*/ 0,
- /*taskAttemptId=*/ 0,
- /*numMappers=*/ 0,
- /*numPartitions=*/ numPartitions,
+ /* shuffleClient= */ client,
+ /* taskContext= */ taskContext,
+ /* shuffleId= */ 0,
+ /* mapId= */ 0,
+ /* attemptNumber= */ 0,
+ /* taskAttemptId= */ 0,
+ /* numMappers= */ 0,
+ /* numPartitions= */ numPartitions,
conf,
metricsReporter::incBytesWritten,
mapStatusLengths,
- /*pushSortMemoryThreshold=*/ Utils.byteStringAsBytes("32K"),
+ /* pushSortMemoryThreshold= */ Utils.byteStringAsBytes("32K"),
SendBufferPool.get(4, 30, 60));
return pusher;
}