This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 2df345460d22 [SPARK-56674][SS][RTM][STREAMINGSHUFFLE] Add streaming
shuffle wire protocol
2df345460d22 is described below
commit 2df345460d224a8cdf395672f264749f09c3b618
Author: Boyang Jerry Peng <[email protected]>
AuthorDate: Fri May 22 10:38:33 2026 -0700
[SPARK-56674][SS][RTM][STREAMINGSHUFFLE] Add streaming shuffle wire protocol
### What changes were proposed in this pull request?
### Context
This is the first PR in a stack that contributes a new **streaming
shuffle** implementation to Apache Spark. Streaming shuffle is a push-based
shuffle designed for low-latency,
continuously-running queries (e.g., Real-Time mode in Structured
Streaming) where the map and reduce stages must run concurrently rather than
sequentially. Each map task hosts a Netty
server that pushes records to reduce tasks as they are produced; reduce
tasks open clients to those servers and consume records as a stream — no
on-disk materialization, no map-stage
barrier.
Because the full implementation spans the network protocol, driver-side
coordination, plugin layer, executor-side writer/reader, engine integration,
and tests, it is split into nine
independently reviewable PRs:
1. **Wire protocol (this PR)** — binary message types in
`network-common`, pure Java.
2. **Output tracker** — driver-side coordination service mapping shuffle
IDs to writer task locations.
3. **Shuffle manager + Netty handlers + logging mixin** — the
`ShuffleManager` plugin entry point, the bidirectional Netty handlers, and a
`TaskContext`-aware logging trait.
4. **SparkEnv + DAGScheduler integration** — wires the tracker into
`SparkEnv` and registers shuffles from `DAGScheduler`.
5. **Writer** — `StreamingShuffleWriter` (server-side push).
6. **Reader** — `StreamingShuffleReader` (client-side pull).
7. **MultiShuffleManager** — routes per-shuffle to streaming or sort
shuffle based on a task-local property, so one cluster can host both.
8. **Tests** — end-to-end suite for the streaming shuffle plugin.
9. **Documentation** — design and configuration reference.
Each PR compiles standalone. The plugin only becomes usable end-to-end
after PRs 1–7.
### Changes in this PR
This PR introduces the binary wire protocol used by the streaming
shuffle, contained in a new package
`org.apache.spark.network.shuffle.streaming` under `common/network-common`. All
classes are pure Java and depend only on Netty (already in
`network-common`) and `java.util.zip.CRC32C` (Java 9+). No new Maven
dependencies are introduced.
The protocol consists of four message types sharing a 12-byte common
header (4-byte message-type id + 8-byte sequence number):
- **`StreamingShuffleMessage`** — abstract sealed base class. Owns a
`ByteBuf`, exposes `setSeqNum`/`getSeqNum`, and dispatches `decode()` to the
concrete subclass based on the
message-type id read from the wire. Documents the refcount/ownership
rules concrete subclasses must follow.
- **`StreamingShuffleMessageType`** — enum of the four message IDs
(`DATA_MESSAGE_UNSAFE_ROW=1`, `CREDIT_CONTROL_MESSAGE=2`,
`TERMINATION_CONTROL_MESSAGE=3`, `TERMINATION_ACK_MESSAGE=4`).
Each enum constant carries an explicit `id()` value that is used on the
wire — not the JVM enum ordinal — so the encoding is stable against changes to
declaration order.
`StreamingShuffleMessageType.decode(int)` throws
`IllegalArgumentException` when the wire id does not match any known message
type.
- **`DataMessage`** — writer → reader. Carries serialized records along
with `(shuffleWriterId, shuffleReaderId, dataSize, CRC32C checksum)` in the
header. The constructor validates
`dataSize == data.readableBytes()`; `encode()` writes exactly `dataSize`
bytes of payload (via `data.retainedSlice(data.readerIndex(), dataSize)`);
`decode()` validates that the
post-header readable bytes equal `dataSize` exactly, rejecting both
truncated and trailing-junk frames at the wire boundary.
- **`CreditControlMessage`** — reader → writer. Sent on connection
establishment as the initial handshake; reserved for finer-grained credit-based
flow control in future revisions.
- **`TerminationControlMessage`** — writer → reader. End-of-stream signal.
- **`TerminationAckMessage`** — reader → writer. Echoes the last sequence
number the reader observed, so the writer can verify no messages were lost or
reordered.
Supporting class:
- **`ShuffleChecksum`** — stateful CRC32C helper that supports both heap
and direct `ByteBuf`s, used by the writer to stamp checksums into outgoing
`DataMessage`s and by the reader to
verify them. Range arguments are validated against the buffer's
`writerIndex()` so the checksum cannot accidentally cover bytes in
`[writerIndex, capacity)` that have not been written.
Tests in `StreamingShuffleMessageSuite` cover:
- Encode/decode round-trips for every message type
- CRC32C determinism and equivalence across heap and direct buffers;
sub-range checksums with non-zero `startIndex` compared against a reference
computation; out-of-bounds,
beyond-`writerIndex`, and negative-argument rejection
- `DataMessage` framing: rejection of `dataSize` larger/smaller than
`data.readableBytes()`, negative `dataSize`, and decoded frames with trailing
junk or truncation
- Unknown wire message-type ids → `IllegalArgumentException`
The protocol is designed to be reviewable on its own — the message
layout, sequence-numbering scheme, memory-ownership rules, framing invariants,
and checksum approach can all be
evaluated without reading any of the executor- or driver-side code that
comes in the follow-up PRs.
### Why are the changes needed?
**This is one part of creating the Streaming Shuffle needed for Real-time
Mode (RTM).**
Spark's existing `SortShuffleManager` materializes map outputs to local
disk and gates the reduce stage on map-stage completion. This design is
well-suited to batch jobs but introduces
two latency floors that are unacceptable for continuously-running,
low-latency queries:
1. **Map-stage barrier.** Reducers cannot start until every mapper has
finished writing its output files. End-to-end latency is bound by the slowest
mapper.
2. **Disk materialization.** Every record is serialized to a local file
and read back, even when the consumer is ready immediately. The fsync/read path
adds tens of milliseconds of
overhead per batch.
Real-Time mode in Structured Streaming, and other long-running query
workloads where map and reduce tasks coexist for the lifetime of the query,
require a shuffle that:
- Streams records from mappers to reducers as they are produced.
- Lets the reduce stage run concurrently with the map stage.
- Avoids on-disk intermediate state entirely.
- Provides backpressure so a slow reader can throttle the upstream
iterator without dropping data.
- Detects message loss, reordering, or in-flight corruption (since there
is no on-disk re-read to fall back on).
Streaming shuffle is the new `ShuffleManager` implementation that meets
these requirements. The full implementation involves driver-side coordination,
executor-side push/pull components,
engine integration, and supporting infrastructure — together too large to
land in a single PR.
This PR contributes the **wire protocol** layer, which is the natural
foundation. Every other piece in the stack (the tracker, writer, reader,
handlers, tests) depends on the message
types, sequence-numbering rules, and checksum scheme defined here, so it
must land first. The protocol layer is also the most contract-heavy part of the
design: getting the message
layout, ownership rules, and integrity check right up front avoids costly
churn in later PRs. By making this an independent, standalone-reviewable PR,
reviewers can scrutinize the wire
format, refcount discipline, and CRC32C approach without being distracted
by Spark-internal coordination logic that has not yet been introduced.
The final product can be viewed here for context:
https://github.com/apache/spark/compare/master...jerrypeng:spark:stack/streaming-shuffle-pr9-docs?expand=1
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
Unit tests cover encode/decode round-trips for all message types,
determinism of the checksum, heap/direct buffer equivalence, and the
unknown-type error path.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #55620 from jerrypeng/stack/streaming-shuffle-pr1-network-protocol.
Authored-by: Boyang Jerry Peng <[email protected]>
Signed-off-by: Liang-Chi Hsieh <[email protected]>
(cherry picked from commit 7179629dc642aa4fcaef0704ede4bb34ebd3f8f3)
Signed-off-by: Liang-Chi Hsieh <[email protected]>
---
.../shuffle/streaming/CreditControlMessage.java | 91 +++++
.../network/shuffle/streaming/DataMessage.java | 131 +++++++
.../network/shuffle/streaming/ShuffleChecksum.java | 71 ++++
.../shuffle/streaming/StreamingShuffleMessage.java | 146 ++++++++
.../streaming/StreamingShuffleMessageType.java | 45 +++
.../shuffle/streaming/TerminationAckMessage.java | 68 ++++
.../streaming/TerminationControlMessage.java | 68 ++++
.../streaming/StreamingShuffleMessageSuite.java | 399 +++++++++++++++++++++
8 files changed, 1019 insertions(+)
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/CreditControlMessage.java
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/CreditControlMessage.java
new file mode 100644
index 000000000000..4ba17b4a011a
--- /dev/null
+++
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/CreditControlMessage.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.streaming;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+
+/**
+ * Reader → writer control message.
+ *
+ * Current function: serves as a connection-establishment and per-consumption
"ready"
+ * signal. The writer uses receipt of any CreditControlMessage as the trigger
that the
+ * reader is ready to receive, and does not act on the numeric value of {@link
+ * #numMessages}. Backpressure today is handled at the TCP layer via channel
autoRead,
+ * not via this message.
+ *
+ * Future function: this message is reserved as the carrier for a future
credit-based
+ * flow-control extension. When that extension lands, {@link #numMessages}
will carry
+ * the number of additional DataMessages the writer may send beyond any
+ * previously-granted credit (i.e., a credit-grant delta).
+ */
+public final class CreditControlMessage extends StreamingShuffleMessage {
+ public final int shuffleWriterId;
+ public final int shuffleReaderId;
+
+ /**
+ * In the current protocol revision the writer ignores this value and treats
any
+ * CreditControlMessage as a "reader is ready" signal; senders should pass 1.
+ *
+ * Reserved for the future credit-based flow-control extension, in which
this field
+ * will carry the number of additional DataMessages the writer may send
beyond any
+ * previously-granted credit.
+ */
+ public final int numMessages;
+
+ public CreditControlMessage(int shuffleWriterId, int shuffleReaderId, int
numMessages) {
+ this.shuffleWriterId = shuffleWriterId;
+ this.shuffleReaderId = shuffleReaderId;
+ this.numMessages = numMessages;
+ }
+
+ @Override
+ public StreamingShuffleMessageType messageType() {
+ return StreamingShuffleMessageType.CREDIT_CONTROL_MESSAGE;
+ }
+
+ @Override
+ public int headerLength() {
+ // 4 bytes for the shuffle writer ID, 4 bytes for the shuffle reader ID,
+ // 4 bytes for the number of messages
+ return super.headerLength() + 12;
+ }
+
+ @Override
+ public void encode(CompositeByteBuf buf) {
+ super.encode(buf);
+
+ // Write the shuffle writer ID
+ buf.writeInt(shuffleWriterId);
+ // Write the shuffle reader ID
+ buf.writeInt(shuffleReaderId);
+ // Write the number of messages
+ buf.writeInt(numMessages);
+ }
+
+ public static CreditControlMessage decode(ByteBuf buf) {
+ // Read the shuffle writer ID
+ int shuffleWriterId = buf.readInt();
+ // Read the shuffle reader ID
+ int shuffleReaderId = buf.readInt();
+ // Read the number of messages
+ int numMessages = buf.readInt();
+
+ return new CreditControlMessage(shuffleWriterId, shuffleReaderId,
numMessages);
+ }
+}
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/DataMessage.java
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/DataMessage.java
new file mode 100644
index 000000000000..efcca96cce56
--- /dev/null
+++
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/DataMessage.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.streaming;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+
+/**
+ * Writer → reader data message. Carries a contiguous payload of serialized
shuffle
+ * records together with the routing metadata (shuffleWriterId,
shuffleReaderId,
+ * dataSize) and a CRC32C {@link #checksum} over the payload. Each DataMessage
+ * corresponds to one network frame on the streaming-shuffle wire.
+ *
+ * Memory ownership: the {@code data} ByteBuf passed to the constructor is
retained
+ * once and stored as {@link #ownedBuf}, so callers must release their own
reference
+ * separately. On {@link #release()} the retained reference is dropped. See the
+ * memory-ownership rules in {@link StreamingShuffleMessage}.
+ */
+public final class DataMessage extends StreamingShuffleMessage {
+
+ public final ByteBuf data;
+ public final int shuffleWriterId;
+ public final int shuffleReaderId;
+ public final int dataSize;
+ public final long checksum;
+
+ public DataMessage(int shuffleWriterId, int shuffleReaderId, int dataSize,
ByteBuf data,
+ long checksum) {
+ if (dataSize < 0) {
+ throw new IllegalArgumentException(
+ "dataSize must be non-negative: " + dataSize);
+ }
+ if (dataSize != data.readableBytes()) {
+ throw new IllegalArgumentException(
+ "dataSize must equal data.readableBytes(): " +
+ dataSize + " != " + data.readableBytes());
+ }
+ this.shuffleWriterId = shuffleWriterId;
+ this.shuffleReaderId = shuffleReaderId;
+ this.dataSize = dataSize;
+ this.data = data;
+ this.ownedBuf = data.retain();
+ this.checksum = checksum;
+ }
+
+ @Override
+ public StreamingShuffleMessageType messageType() {
+ return StreamingShuffleMessageType.DATA_MESSAGE_UNSAFE_ROW;
+ }
+
+ @Override
+ public int headerLength() {
+ // 4 bytes EACH for shuffle writer ID, shuffle reader ID, data size
+ // 8 bytes for checksum
+ return super.headerLength() + 20;
+ }
+
+ @Override
+ public void encode(CompositeByteBuf buf) {
+ super.encode(buf);
+ buf.writeInt(shuffleWriterId);
+ buf.writeInt(shuffleReaderId);
+ buf.writeInt(dataSize);
+ buf.writeLong(checksum);
+
+ // Only encode exactly `dataSize` bytes of `data`, so the wire payload
matches the
+ // header's dataSize even if the underlying ByteBuf has additional
readable bytes.
+ // retainedSlice() also bumps the refcount, so this DataMessage's ownedBuf
reference
+ // remains valid for later release().
+ buf.addComponent(true, data.retainedSlice(data.readerIndex(), dataSize));
+ }
+
+ /**
+ * Decodes a {@link DataMessage} from {@code message}. The 12-byte common
header
+ * (message-type id + sequence number) has already been consumed by
+ * {@link StreamingShuffleMessage#decode(ByteBuf)}; this method reads the
remaining
+ * 20-byte DataMessage header and treats the rest of {@code message} as the
payload.
+ *
+ * On success, the returned DataMessage retains a reference to {@code
message} via
+ * its {@code ownedBuf} field (refcount is incremented in the constructor),
and the
+ * caller transfers ownership.
+ *
+ * On failure (an {@link IllegalArgumentException} is thrown — see the
constructor's
+ * validation), the caller still owns {@code message} and is responsible for
+ * releasing it, since the constructor's {@code data.retain()} never ran.
+ */
+ public static DataMessage decode(ByteBuf message) {
+ int shuffleWriterId = message.readInt();
+ int shuffleReaderId = message.readInt();
+ int dataSize = message.readInt();
+ long checksum = message.readLong();
+ // Each streaming-shuffle frame carries exactly one DataMessage, so after
reading the
+ // header the remaining readable bytes must equal `dataSize`. Strict
equality catches
+ // both undersized frames (truncation) and oversized frames (trailing
garbage) at the
+ // wire boundary rather than letting them propagate to getRecordData()
later.
+ if (dataSize < 0 || dataSize != message.readableBytes()) {
+ throw new IllegalArgumentException(
+ "Invalid DataMessage dataSize=" + dataSize +
+ ", readable bytes after header=" + message.readableBytes());
+ }
+ return new DataMessage(shuffleWriterId, shuffleReaderId, dataSize,
message, checksum);
+ }
+
+ /**
+ * Returns a slice of {@link #data} containing exactly the serialized records
+ * (i.e., {@code dataSize} bytes starting at the current reader index).
+ *
+ * Uses {@code slice} (not {@code readSlice}), so this method DOES NOT
advance the
+ * reader index of {@link #data}. The returned slice shares the underlying
storage;
+ * no data is copied. Callers may freely consume the returned slice without
+ * affecting subsequent calls to this method.
+ */
+ public ByteBuf getRecordData() {
+ return data.slice(data.readerIndex(), dataSize);
+ }
+}
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/ShuffleChecksum.java
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/ShuffleChecksum.java
new file mode 100644
index 000000000000..2126c52c02d5
--- /dev/null
+++
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/ShuffleChecksum.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.streaming;
+
+import io.netty.buffer.ByteBuf;
+import java.util.zip.CRC32C;
+import javax.annotation.concurrent.NotThreadSafe;
+
+/**
+ * Helper class for streaming shuffle checksum calculations.
+ */
+@NotThreadSafe
+public final class ShuffleChecksum {
+ private final CRC32C crc = new CRC32C();
+
+ /**
+ * Updates checksum for a specified portion of a ByteBuf message.
+ *
+ * @param message The ByteBuf to calculate checksum for
+ * @param startIndex The index of the first byte to calculate checksum for
+ * @param dataLength The length of the data to calculate checksum for
+ */
+ public void updateChecksum(ByteBuf message, int startIndex, int dataLength) {
+ if (startIndex < 0) {
+ throw new IllegalArgumentException(
+ "startIndex must be non-negative: " + startIndex);
+ }
+ if (dataLength < 0) {
+ throw new IllegalArgumentException(
+ "dataLength must be non-negative: " + dataLength);
+ }
+ // Bound the range against writerIndex() rather than capacity(): the
checksum must
+ // cover actual written data only, never bytes in [writerIndex, capacity)
which
+ // may be uninitialized.
+ if (startIndex + dataLength > message.writerIndex()) {
+ throw new IllegalArgumentException(
+ "startIndex + dataLength exceeds writerIndex: " +
+ startIndex + " + " + dataLength + " > " + message.writerIndex());
+ }
+ if (message.hasArray()) {
+ // heap-based ByteBuf
+ crc.update(message.array(), message.arrayOffset() + startIndex,
dataLength);
+ } else {
+ // off-heap ByteBuf
+ crc.update(message.nioBuffer(startIndex, dataLength));
+ }
+ }
+
+ public long getValue() {
+ return crc.getValue();
+ }
+
+ public void reset() {
+ crc.reset();
+ }
+}
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/StreamingShuffleMessage.java
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/StreamingShuffleMessage.java
new file mode 100644
index 000000000000..f1ce41863bc2
--- /dev/null
+++
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/StreamingShuffleMessage.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.streaming;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+
+/**
+ * Base class for messages sent between the streaming shuffle writers (usually
mappers) and
+ * readers (usually reducers).
+ *
+ * To prevent memory leaks, streaming shuffle programmers should always abide
by the following
+ * principles:
+ *
+ * 1. If you create a buffer via ByteBufAllocator, you must explicitly
release it.
+ * 2. If you create a new StreamingShuffleMessage, you must call .release()
on it.
+ *
+ * To make these rules work out, implementations of StreamingShuffleMessage
should abide
+ * by the following rules:
+ *
+ * 1. StreamingShuffleMessages should *not* modify the refcount of ByteBufs
passed to them
+ * during encoding. For message implementations without ByteBufs, this
isn't a concern.
+ * But for messages that have ByteBufs (e.g. DataMessage), the encoding
method will likely
+ * call compositeByteBuf.addComponent(), which transfers ownership of the
ByteBuf to the
+ * CompositeByteBuf and decrements the refcount of the ByteBuf. So that
the caller can
+ * *always* follow rule 1 above, the ByteBuf should be retained before
being passed to the
+ * CompositeByteBuf; if this is not done, the refcount of the ByteBuf
after leaving
+ * encode() will be 0, and if the caller follows rule 1, they will try to
decrement an
+ * already 0 refcount. See DataMessage for an example of how to do this
properly.
+ * 2. If StreamingShuffleMessages keep a reference the ByteBufs passed to
them during
+ * decoding, they should increment the refcount of that ByteBuf, and
assign it to
+ * ownedBuf. This is so that resources get cleaned up when callers follow
rule 2 above,
+ * i.e. call .release() on the StreamingShuffleMessage. See DataMessage
for an example of
+ * how to do this properly.
+ */
+public abstract sealed class StreamingShuffleMessage
+ permits CreditControlMessage, DataMessage, TerminationAckMessage,
TerminationControlMessage {
+ protected ByteBuf ownedBuf = null;
+ private Runnable releaseCallback = null;
+
+ // To prevent any duplicate/out of order/missing messages, each writer will
track the current
+ // max sequence number that has been sent to each reader. Similarly, each
reader will track
+ // the latest sequence number it has received from each writer. Upon
receiving a new message
+ // from any writer, reader will check if the sequence number is expected.
When all finish, the
+ // reader will send TerminationAckMessage to the writer with the max
sequence number that has
+ // been received, and the writer will check if the latest sequence recorded
matches it.
+
+ // Thus the sequence number is valid for the following message types:
+ // 1. all message types from a writer to a reader. To make sure that the
reader
+ // receive all the messages sent by writer in order without missing or
duplicate any.
+ // 2. TerminationAckMessage from a reader to a writer. To make sure at the
end of the
+ // shuffle, the reader receives the same number of messages that the writer
has sent.
+ // Essentially, other message types from reader to writer won't have a valid
sequence number.
+ private long seqNum;
+ public void setSeqNum(long seqNum) {
+ this.seqNum = seqNum;
+ }
+ public long getSeqNum() { return seqNum; }
+
+ /** Returns the type of this message. */
+ public abstract StreamingShuffleMessageType messageType();
+
+ /** Encodes the current message into the provided ByteBuf. */
+ public void encode(CompositeByteBuf buf) {
+ buf.writeInt(messageType().id());
+ buf.writeLong(seqNum);
+ }
+
+ /**
+ * Returns the number of bytes the encoded message header occupies on the
wire.
+ *
+ * For control messages (CreditControl, TerminationControl, TerminationAck)
this is
+ * the full encoded length of the message. For {@link DataMessage} this is
the header
+ * size only; the variable-length payload of {@code dataSize} bytes follows
the header
+ * and is NOT included in this count. Callers use this value to pre-size the
+ * {@link CompositeByteBuf} that the encoded message is written into.
+ */
+ public int headerLength() {
+ // 4 bytes for message type, 8 bytes for the sequence number
+ return 12;
+ }
+
+ /**
+ * Registers a callback that will be invoked exactly once when {@link
#release()}
+ * runs to completion. The callback runs AFTER {@code ownedBuf.release()}
and is
+ * cleared after invocation, so a subsequent {@code release()} call on the
same
+ * thread will not re-run it.
+ *
+ * Replaces any previously-registered callback. Not thread-safe; see {@link
#release()}.
+ */
+ public void setReleaseCallback(Runnable releaseCallback) {
+ this.releaseCallback = releaseCallback;
+ }
+
+ /**
+ * Releases any resources associated with this message.
+ * In VERY RARE cases when the task fails unexpectedly, this method may be
called twice.
+ * This method is idempotent — a second call on the same thread is a no-op —
but it is
+ * NOT thread-safe.
+ */
+ public void release() {
+ if (ownedBuf != null) {
+ ownedBuf.release();
+ ownedBuf = null;
+ }
+ if (releaseCallback != null) {
+ releaseCallback.run();
+ releaseCallback = null;
+ }
+ }
+
+ public static StreamingShuffleMessage decode(ByteBuf message) {
+ StreamingShuffleMessageType messageType =
+ StreamingShuffleMessageType.decode(message.readInt());
+ long seqNum = message.readLong();
+
+ // Switch expression over the enum is exhaustive; the compiler enforces
that every
+ // case is handled, so any future StreamingShuffleMessageType added to the
enum will
+ // cause this method to fail compilation until the corresponding case is
added here.
+ StreamingShuffleMessage shuffleMessage = switch (messageType) {
+ case DATA_MESSAGE_UNSAFE_ROW -> DataMessage.decode(message);
+ case CREDIT_CONTROL_MESSAGE -> CreditControlMessage.decode(message);
+ case TERMINATION_CONTROL_MESSAGE ->
TerminationControlMessage.decode(message);
+ case TERMINATION_ACK_MESSAGE -> TerminationAckMessage.decode(message);
+ };
+ shuffleMessage.setSeqNum(seqNum);
+
+ return shuffleMessage;
+ }
+
+}
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/StreamingShuffleMessageType.java
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/StreamingShuffleMessageType.java
new file mode 100644
index 000000000000..d103009c62cc
--- /dev/null
+++
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/StreamingShuffleMessageType.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.streaming;
+
+public enum StreamingShuffleMessageType {
+ DATA_MESSAGE_UNSAFE_ROW(1),
+ CREDIT_CONTROL_MESSAGE(2),
+ TERMINATION_CONTROL_MESSAGE(3),
+ TERMINATION_ACK_MESSAGE(4);
+
+ private final int id;
+
+ StreamingShuffleMessageType(int id) {
+ this.id = id;
+ }
+
+ public int id() {
+ return id;
+ }
+
+ public static StreamingShuffleMessageType decode(int givenId) {
+ return switch (givenId) {
+ case 1 -> DATA_MESSAGE_UNSAFE_ROW;
+ case 2 -> CREDIT_CONTROL_MESSAGE;
+ case 3 -> TERMINATION_CONTROL_MESSAGE;
+ case 4 -> TERMINATION_ACK_MESSAGE;
+ default -> throw new IllegalArgumentException("Unknown message type: " +
givenId);
+ };
+ }
+}
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/TerminationAckMessage.java
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/TerminationAckMessage.java
new file mode 100644
index 000000000000..3a724fd32c16
--- /dev/null
+++
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/TerminationAckMessage.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.streaming;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+
+/**
+ * Reader → writer acknowledgement of {@link TerminationControlMessage}. The
+ * {@code seqNum} field carries the last sequence number the reader observed
from
+ * this writer; the writer compares that value against its own last-sent
sequence
+ * number to detect lost or reordered messages, failing the task with
+ * STREAMING_SHUFFLE_INCORRECT_SEQUENCE_NUMBER on mismatch.
+ */
+public final class TerminationAckMessage extends StreamingShuffleMessage {
+ public final int shuffleWriterId;
+ public final int shuffleReaderId;
+
+ public TerminationAckMessage(int shuffleWriterId, int shuffleReaderId) {
+ this.shuffleWriterId = shuffleWriterId;
+ this.shuffleReaderId = shuffleReaderId;
+ }
+
+ @Override
+ public StreamingShuffleMessageType messageType() {
+ return StreamingShuffleMessageType.TERMINATION_ACK_MESSAGE;
+ }
+
+ @Override
+ public int headerLength() {
+ // 4 bytes for the shuffle writer ID, 4 bytes for the shuffle reader ID
+ return super.headerLength() + 8;
+ }
+
+ @Override
+ public void encode(CompositeByteBuf buf) {
+ super.encode(buf);
+
+ // Write the shuffle writer ID
+ buf.writeInt(shuffleWriterId);
+ // Write the shuffle reader ID
+ buf.writeInt(shuffleReaderId);
+ }
+
+ public static TerminationAckMessage decode(ByteBuf buf) {
+ // Read the shuffle writer ID
+ int shuffleWriterId = buf.readInt();
+ // Read the shuffle reader ID
+ int shuffleReaderId = buf.readInt();
+
+ return new TerminationAckMessage(shuffleWriterId, shuffleReaderId);
+ }
+}
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/TerminationControlMessage.java
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/TerminationControlMessage.java
new file mode 100644
index 000000000000..db00832baac6
--- /dev/null
+++
b/common/network-common/src/main/java/org/apache/spark/network/shuffle/streaming/TerminationControlMessage.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.streaming;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+
+/**
+ * Writer → reader control message that signals end-of-stream. After a writer
has
+ * sent its last {@link DataMessage} to a given reader it sends one
+ * TerminationControlMessage on the same connection and waits for the matching
+ * {@link TerminationAckMessage} back. Receipt of this message tells the reader
+ * that no further DataMessages will arrive from this writer.
+ */
+public final class TerminationControlMessage extends StreamingShuffleMessage {
+ public final int shuffleWriterId;
+ public final int shuffleReaderId;
+
+ public TerminationControlMessage(int shuffleWriterId, int shuffleReaderId) {
+ this.shuffleWriterId = shuffleWriterId;
+ this.shuffleReaderId = shuffleReaderId;
+ }
+
+ @Override
+ public StreamingShuffleMessageType messageType() {
+ return StreamingShuffleMessageType.TERMINATION_CONTROL_MESSAGE;
+ }
+
+ @Override
+ public int headerLength() {
+ // 4 bytes for the shuffle writer ID, 4 bytes for the shuffle reader ID
+ return super.headerLength() + 8;
+ }
+
+ @Override
+ public void encode(CompositeByteBuf buf) {
+ super.encode(buf);
+
+ // Write the shuffle writer ID
+ buf.writeInt(shuffleWriterId);
+ // Write the shuffle reader ID
+ buf.writeInt(shuffleReaderId);
+ }
+
+ public static TerminationControlMessage decode(ByteBuf buf) {
+ // Read the shuffle writer ID
+ int shuffleWriterId = buf.readInt();
+ // Read the shuffle reader ID
+ int shuffleReaderId = buf.readInt();
+
+ return new TerminationControlMessage(shuffleWriterId, shuffleReaderId);
+ }
+}
diff --git
a/common/network-common/src/test/java/org/apache/spark/network/shuffle/streaming/StreamingShuffleMessageSuite.java
b/common/network-common/src/test/java/org/apache/spark/network/shuffle/streaming/StreamingShuffleMessageSuite.java
new file mode 100644
index 000000000000..2b5d92a4328d
--- /dev/null
+++
b/common/network-common/src/test/java/org/apache/spark/network/shuffle/streaming/StreamingShuffleMessageSuite.java
@@ -0,0 +1,399 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.streaming;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+/**
+ * Unit tests for {@link StreamingShuffleMessage} encode/decode round-trips and
+ * {@link ShuffleChecksum}.
+ */
+public class StreamingShuffleMessageSuite {
+
+ private static final long SEQ_NUM = 42L;
+
+ // ---- helpers
---------------------------------------------------------------
+
+ private ByteBuf encodeAndSlice(StreamingShuffleMessage msg) {
+ msg.setSeqNum(SEQ_NUM);
+ CompositeByteBuf buf = Unpooled.compositeBuffer();
+ buf.capacity(msg.headerLength());
+ msg.encode(buf);
+ msg.release();
+ // Return a copy so we can freely advance the reader index
+ ByteBuf copy = Unpooled.buffer(buf.readableBytes());
+ copy.writeBytes(buf);
+ buf.release();
+ return copy;
+ }
+
+ // ---- CreditControlMessage
--------------------------------------------------
+
+ @Test
+ public void testCreditControlRoundTrip() {
+ CreditControlMessage original = new CreditControlMessage(3, 7, 5);
+ ByteBuf encoded = encodeAndSlice(original);
+ try {
+ StreamingShuffleMessage decoded =
StreamingShuffleMessage.decode(encoded);
+ assertInstanceOf(CreditControlMessage.class, decoded);
+ CreditControlMessage credit = (CreditControlMessage) decoded;
+ assertEquals(SEQ_NUM, credit.getSeqNum());
+ assertEquals(3, credit.shuffleWriterId);
+ assertEquals(7, credit.shuffleReaderId);
+ assertEquals(5, credit.numMessages);
+ } finally {
+ encoded.release();
+ }
+ }
+
+ // ---- TerminationControlMessage
---------------------------------------------
+
+ @Test
+ public void testTerminationControlRoundTrip() {
+ TerminationControlMessage original = new TerminationControlMessage(1, 2);
+ ByteBuf encoded = encodeAndSlice(original);
+ try {
+ StreamingShuffleMessage decoded =
StreamingShuffleMessage.decode(encoded);
+ assertInstanceOf(TerminationControlMessage.class, decoded);
+ TerminationControlMessage term = (TerminationControlMessage) decoded;
+ assertEquals(SEQ_NUM, term.getSeqNum());
+ assertEquals(1, term.shuffleWriterId);
+ assertEquals(2, term.shuffleReaderId);
+ } finally {
+ encoded.release();
+ }
+ }
+
+ // ---- TerminationAckMessage
-------------------------------------------------
+
+ @Test
+ public void testTerminationAckRoundTrip() {
+ TerminationAckMessage original = new TerminationAckMessage(4, 8);
+ ByteBuf encoded = encodeAndSlice(original);
+ try {
+ StreamingShuffleMessage decoded =
StreamingShuffleMessage.decode(encoded);
+ assertInstanceOf(TerminationAckMessage.class, decoded);
+ TerminationAckMessage ack = (TerminationAckMessage) decoded;
+ assertEquals(SEQ_NUM, ack.getSeqNum());
+ assertEquals(4, ack.shuffleWriterId);
+ assertEquals(8, ack.shuffleReaderId);
+ } finally {
+ encoded.release();
+ }
+ }
+
+ // ---- DataMessage
-----------------------------------------------------------
+
+ @Test
+ public void testDataMessageRoundTrip() {
+ byte[] payload = "hello streaming shuffle".getBytes();
+ ByteBuf payloadBuf = Unpooled.wrappedBuffer(payload);
+ long checksum = 0xDEADBEEFL;
+
+ DataMessage original = new DataMessage(2, 5, payload.length, payloadBuf,
checksum);
+ ByteBuf encoded = encodeAndSlice(original);
+ payloadBuf.release();
+
+ try {
+ StreamingShuffleMessage decoded =
StreamingShuffleMessage.decode(encoded);
+ assertInstanceOf(DataMessage.class, decoded);
+ DataMessage dm = (DataMessage) decoded;
+ assertEquals(SEQ_NUM, dm.getSeqNum());
+ assertEquals(2, dm.shuffleWriterId);
+ assertEquals(5, dm.shuffleReaderId);
+ assertEquals(payload.length, dm.dataSize);
+ assertEquals(checksum, dm.checksum);
+
+ ByteBuf recordData = dm.getRecordData();
+ byte[] out = new byte[payload.length];
+ recordData.readBytes(out);
+ assertArrayEquals(payload, out);
+ dm.release();
+ } finally {
+ encoded.release();
+ }
+ }
+
+ @Test
+ public void testReleaseIsIdempotent() {
+ byte[] payload = "hi".getBytes();
+ ByteBuf payloadBuf = Unpooled.wrappedBuffer(payload);
+ // Constructor calls data.retain(), so refcount is now 2 (1 original + 1
retain).
+ DataMessage msg = new DataMessage(0, 0, payload.length, payloadBuf, 0L);
+ assertEquals(2, payloadBuf.refCnt());
+
+ msg.release();
+ assertEquals(1, payloadBuf.refCnt());
+
+ // Second release should be a no-op — refcount must not drop further or
throw.
+ msg.release();
+ assertEquals(1, payloadBuf.refCnt());
+
+ payloadBuf.release();
+ }
+
+ @Test
+ public void testReleaseCallbackRunsExactlyOnce() {
+ byte[] payload = "hi".getBytes();
+ ByteBuf payloadBuf = Unpooled.wrappedBuffer(payload);
+ DataMessage msg = new DataMessage(0, 0, payload.length, payloadBuf, 0L);
+
+ java.util.concurrent.atomic.AtomicInteger callbackInvocations =
+ new java.util.concurrent.atomic.AtomicInteger(0);
+ msg.setReleaseCallback(callbackInvocations::incrementAndGet);
+
+ msg.release();
+ assertEquals(1, callbackInvocations.get());
+
+ // Second release: idempotent — callback must NOT fire again.
+ msg.release();
+ assertEquals(1, callbackInvocations.get());
+
+ payloadBuf.release();
+ }
+
+ @Test
+ public void testDataMessageRejectsDataSizeLargerThanReadableBytes() {
+ byte[] payload = "hello".getBytes();
+ ByteBuf payloadBuf = Unpooled.wrappedBuffer(payload);
+ try {
+ // dataSize (10) > data.readableBytes() (5)
+ assertThrows(IllegalArgumentException.class,
+ () -> new DataMessage(0, 0, payload.length + 5, payloadBuf, 0L));
+ } finally {
+ payloadBuf.release();
+ }
+ }
+
+ @Test
+ public void testDataMessageRejectsDataSizeSmallerThanReadableBytes() {
+ byte[] payload = "hello".getBytes();
+ ByteBuf payloadBuf = Unpooled.wrappedBuffer(payload);
+ try {
+ // dataSize (3) < data.readableBytes() (5)
+ assertThrows(IllegalArgumentException.class,
+ () -> new DataMessage(0, 0, payload.length - 2, payloadBuf, 0L));
+ } finally {
+ payloadBuf.release();
+ }
+ }
+
+ @Test
+ public void testDataMessageRejectsNegativeDataSize() {
+ byte[] payload = "hello".getBytes();
+ ByteBuf payloadBuf = Unpooled.wrappedBuffer(payload);
+ try {
+ assertThrows(IllegalArgumentException.class,
+ () -> new DataMessage(0, 0, -1, payloadBuf, 0L));
+ } finally {
+ payloadBuf.release();
+ }
+ }
+
+ @Test
+ public void testDataMessageDecodeRejectsTrailingBytes() {
+ // Encode a valid DataMessage, then append junk bytes to the encoded
buffer.
+ // decode() should reject because dataSize != message.readableBytes()
after the header.
+ byte[] payload = "hello".getBytes();
+ ByteBuf payloadBuf = Unpooled.wrappedBuffer(payload);
+ DataMessage original = new DataMessage(1, 2, payload.length, payloadBuf,
0L);
+ original.setSeqNum(SEQ_NUM);
+ CompositeByteBuf buf = Unpooled.compositeBuffer();
+ buf.capacity(original.headerLength());
+ original.encode(buf);
+ original.release();
+ payloadBuf.release();
+
+ // Copy encoded bytes then append 3 extra junk bytes.
+ ByteBuf corrupted = Unpooled.buffer(buf.readableBytes() + 3);
+ corrupted.writeBytes(buf);
+ buf.release();
+ corrupted.writeBytes(new byte[]{0x7f, 0x7f, 0x7f});
+
+ try {
+ assertThrows(IllegalArgumentException.class,
+ () -> StreamingShuffleMessage.decode(corrupted));
+ } finally {
+ corrupted.release();
+ }
+ }
+
+ @Test
+ public void testDataMessageDecodeRejectsTruncatedFrame() {
+ // Encode a valid DataMessage, then strip a trailing byte from the encoded
buffer.
+ // decode() should reject because dataSize > message.readableBytes() after
the header.
+ byte[] payload = "hello".getBytes();
+ ByteBuf payloadBuf = Unpooled.wrappedBuffer(payload);
+ DataMessage original = new DataMessage(1, 2, payload.length, payloadBuf,
0L);
+ original.setSeqNum(SEQ_NUM);
+ CompositeByteBuf buf = Unpooled.compositeBuffer();
+ buf.capacity(original.headerLength());
+ original.encode(buf);
+ original.release();
+ payloadBuf.release();
+
+ int truncatedSize = buf.readableBytes() - 1;
+ ByteBuf truncated = Unpooled.buffer(truncatedSize);
+ truncated.writeBytes(buf, truncatedSize);
+ buf.release();
+
+ try {
+ assertThrows(IllegalArgumentException.class,
+ () -> StreamingShuffleMessage.decode(truncated));
+ } finally {
+ truncated.release();
+ }
+ }
+
+ // ---- ShuffleChecksum
-------------------------------------------------------
+
+ @Test
+ public void testShuffleChecksumHeapBuffer() {
+ byte[] data = {1, 2, 3, 4, 5};
+ ByteBuf buf = Unpooled.wrappedBuffer(data); // heap-backed
+
+ ShuffleChecksum cs = new ShuffleChecksum();
+ cs.updateChecksum(buf, 0, data.length);
+ long value1 = cs.getValue();
+ assertTrue(value1 != 0);
+
+ cs.reset();
+ cs.updateChecksum(buf, 0, data.length);
+ assertEquals(value1, cs.getValue(), "Checksum should be deterministic");
+ buf.release();
+ }
+
+ @Test
+ public void testShuffleChecksumDirectBuffer() {
+ byte[] data = {10, 20, 30};
+ ByteBuf heap = Unpooled.wrappedBuffer(data);
+ ByteBuf direct = Unpooled.directBuffer(data.length);
+ direct.writeBytes(data);
+
+ ShuffleChecksum csHeap = new ShuffleChecksum();
+ csHeap.updateChecksum(heap, 0, data.length);
+
+ ShuffleChecksum csDirect = new ShuffleChecksum();
+ csDirect.updateChecksum(direct, 0, data.length);
+
+ assertEquals(csHeap.getValue(), csDirect.getValue(),
+ "Heap and direct buffer checksums should match");
+
+ heap.release();
+ direct.release();
+ }
+
+ @Test
+ public void testShuffleChecksumNonZeroStartIndex() {
+ // Verify that startIndex correctly offsets where the checksum begins, and
that the
+ // result matches an independent reference computation. Cover heap and
direct buffers.
+ byte[] data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ int startIndex = 3;
+ int dataLength = 5; // bytes 3..7 inclusive
+
+ // Reference CRC32C computed over the same sub-range.
+ java.util.zip.CRC32C reference = new java.util.zip.CRC32C();
+ reference.update(data, startIndex, dataLength);
+ long expected = reference.getValue();
+
+ ByteBuf heap = Unpooled.wrappedBuffer(data);
+ ShuffleChecksum csHeap = new ShuffleChecksum();
+ csHeap.updateChecksum(heap, startIndex, dataLength);
+ assertEquals(expected, csHeap.getValue(),
+ "Heap buffer checksum should match reference for sub-range");
+ heap.release();
+
+ ByteBuf direct = Unpooled.directBuffer(data.length);
+ direct.writeBytes(data);
+ ShuffleChecksum csDirect = new ShuffleChecksum();
+ csDirect.updateChecksum(direct, startIndex, dataLength);
+ assertEquals(expected, csDirect.getValue(),
+ "Direct buffer checksum should match reference for sub-range");
+ direct.release();
+ }
+
+ @Test
+ public void testShuffleChecksumRejectsOutOfBoundsRange() {
+ // 5 bytes written, but the test asks for bytes [3..8). Should fail because
+ // startIndex (3) + dataLength (5) = 8 > writerIndex (5).
+ byte[] data = {1, 2, 3, 4, 5};
+ ByteBuf buf = Unpooled.wrappedBuffer(data);
+ try {
+ ShuffleChecksum cs = new ShuffleChecksum();
+ assertThrows(IllegalArgumentException.class,
+ () -> cs.updateChecksum(buf, 3, 5));
+ } finally {
+ buf.release();
+ }
+ }
+
+ @Test
+ public void
testShuffleChecksumRejectsRangeBeyondWriterIndexButWithinCapacity() {
+ // capacity = 10 but only 5 bytes written. Requesting [0..8) should fail
because
+ // 8 > writerIndex (5), even though 8 <= capacity (10). Guards against
checksumming
+ // uninitialized bytes.
+ ByteBuf buf = Unpooled.buffer(10);
+ buf.writeBytes(new byte[]{1, 2, 3, 4, 5});
+ try {
+ ShuffleChecksum cs = new ShuffleChecksum();
+ assertThrows(IllegalArgumentException.class,
+ () -> cs.updateChecksum(buf, 0, 8));
+ } finally {
+ buf.release();
+ }
+ }
+
+ @Test
+ public void testShuffleChecksumRejectsNegativeStartIndex() {
+ ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{1, 2, 3});
+ try {
+ ShuffleChecksum cs = new ShuffleChecksum();
+ assertThrows(IllegalArgumentException.class,
+ () -> cs.updateChecksum(buf, -1, 2));
+ } finally {
+ buf.release();
+ }
+ }
+
+ @Test
+ public void testShuffleChecksumRejectsNegativeDataLength() {
+ ByteBuf buf = Unpooled.wrappedBuffer(new byte[]{1, 2, 3});
+ try {
+ ShuffleChecksum cs = new ShuffleChecksum();
+ assertThrows(IllegalArgumentException.class,
+ () -> cs.updateChecksum(buf, 0, -1));
+ } finally {
+ buf.release();
+ }
+ }
+
+ @Test
+ public void testInvalidMessageTypeThrows() {
+ ByteBuf buf = Unpooled.buffer(12);
+ buf.writeInt(999); // unknown type
+ buf.writeLong(0L);
+ assertThrows(IllegalArgumentException.class, () ->
StreamingShuffleMessage.decode(buf));
+ buf.release();
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]