This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 8e90c36ff1b [Java] Optimize StreamUtils to avoid copying bytes (#37912)
8e90c36ff1b is described below
commit 8e90c36ff1b6953cd792cdeb20d82853aee22ca0
Author: Sam Whittle <[email protected]>
AuthorDate: Wed Mar 25 13:17:50 2026 +0000
[Java] Optimize StreamUtils to avoid copying bytes (#37912)
* [Java] Optimize StreamUtils to avoid copying bytes in cases where
- the common case where available matches the actual size of the input
stream
- fix optimization for known input streams when it is wrapped by
UnownedOutputStream
---
.../java/org/apache/beam/sdk/util/StreamUtils.java | 47 +++++++++--
.../apache/beam/sdk/util/UnownedInputStream.java | 4 +
.../org/apache/beam/sdk/util/CoderUtilsTest.java | 10 +++
.../org/apache/beam/sdk/util/StreamUtilsTest.java | 95 ++++++++++++++++++++++
4 files changed, 150 insertions(+), 6 deletions(-)
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java
index 28c604361fd..16396551335 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java
@@ -22,6 +22,8 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.ref.SoftReference;
+import java.util.Arrays;
+import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Internal;
/** Utility functions for stream operations. */
@@ -35,34 +37,67 @@ public class StreamUtils {
private static final int BUF_SIZE = 8192;
- private static ThreadLocal<SoftReference<byte[]>> threadLocalBuffer = new
ThreadLocal<>();
+ private static final ThreadLocal<SoftReference<byte[]>> threadLocalBuffer =
new ThreadLocal<>();
/** Efficient converting stream to bytes. */
public static byte[] getBytesWithoutClosing(InputStream stream) throws
IOException {
+ // Unwrap the stream so the below optimizations based upon class type
function properly.
+ // We don't use mark or reset in this function.
+ while (stream instanceof UnownedInputStream) {
+ stream = ((UnownedInputStream) stream).getWrappedStream();
+ }
+
if (stream instanceof ExposedByteArrayInputStream) {
// Fast path for the exposed version.
return ((ExposedByteArrayInputStream) stream).readAll();
- } else if (stream instanceof ByteArrayInputStream) {
+ }
+ if (stream instanceof ByteArrayInputStream) {
// Fast path for ByteArrayInputStream.
byte[] ret = new byte[stream.available()];
stream.read(ret);
return ret;
}
- // Falls back to normal stream copying.
+
+ // Most inputs are fully available so we attempt to first read directly
+ // into a buffer of the right size, assuming available reflects all the
bytes.
+ int available = stream.available();
+ @Nullable ByteArrayOutputStream outputStream = null;
+ if (available > 0 && available < 1024 * 1024) {
+ byte[] initialBuffer = new byte[available];
+ int initialReadSize = stream.read(initialBuffer);
+ if (initialReadSize == -1) {
+ return new byte[0];
+ }
+ int nextByte = stream.read();
+ if (nextByte == -1) {
+ if (initialReadSize == available) {
+ // Available reflected the full buffer and we copied directly to the
+ // right size.
+ return initialBuffer;
+ }
+ return Arrays.copyOf(initialBuffer, initialReadSize);
+ }
+ outputStream = new ByteArrayOutputStream();
+ outputStream.write(initialBuffer, 0, initialReadSize);
+ outputStream.write(nextByte);
+ } else {
+ outputStream = new ByteArrayOutputStream();
+ }
+
+ // Normal stream copying using the thread-local buffer.
SoftReference<byte[]> refBuffer = threadLocalBuffer.get();
byte[] buffer = refBuffer == null ? null : refBuffer.get();
if (buffer == null) {
buffer = new byte[BUF_SIZE];
threadLocalBuffer.set(new SoftReference<>(buffer));
}
- ByteArrayOutputStream outStream = new ByteArrayOutputStream();
while (true) {
int r = stream.read(buffer);
if (r == -1) {
break;
}
- outStream.write(buffer, 0, r);
+ outputStream.write(buffer, 0, r);
}
- return outStream.toByteArray();
+ return outputStream.toByteArray();
}
}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java
index acf70ed6b00..345e6a8763b 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java
@@ -35,6 +35,10 @@ public class UnownedInputStream extends FilterInputStream {
super(delegate);
}
+ InputStream getWrappedStream() {
+ return in;
+ }
+
@Override
public void close() throws IOException {
throw new UnsupportedOperationException(
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java
index 13208160181..943ac5ddc8c 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java
@@ -18,6 +18,7 @@
package org.apache.beam.sdk.util;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
@@ -26,7 +27,9 @@ import static org.mockito.Mockito.mock;
import java.io.InputStream;
import java.io.OutputStream;
+import java.nio.charset.StandardCharsets;
import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.Coder.Context;
import org.apache.beam.sdk.coders.CoderException;
@@ -142,4 +145,11 @@ public class CoderUtilsTest {
CoderException.class,
() -> CoderUtils.decodeFromByteString(StringUtf8Coder.of(),
byteString, Context.NESTED));
}
+
+ @Test
+ public void testDecodeByteArrayWithoutCopy() throws Exception {
+ byte[] data = "test data".getBytes(StandardCharsets.UTF_8);
+ byte[] result = CoderUtils.decodeFromByteArray(ByteArrayCoder.of(), data);
+ assertSame(data, result);
+ }
}
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java
index 68f87d73763..c081c0a33e6 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java
@@ -23,9 +23,11 @@ import static org.junit.Assert.assertSame;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
+import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -68,4 +70,97 @@ public class StreamUtilsTest {
assertArrayEquals(testData, bytes);
assertEquals(0, stream.available());
}
+
+ @Test
+ public void testGetBytesFromUnownedInputStreamAroundExposed() throws
IOException {
+ InputStream stream = new UnownedInputStream(new
ExposedByteArrayInputStream(testData));
+ byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+ assertArrayEquals(testData, bytes);
+ assertSame(testData, bytes);
+ assertEquals(0, stream.available());
+ }
+
+ @Test
+ public void testGetBytesFromUnownedInputStreamAroundArray() throws
IOException {
+ InputStream stream = new UnownedInputStream(new
ByteArrayInputStream(testData));
+ byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+ assertArrayEquals(testData, bytes);
+ assertEquals(0, stream.available());
+ }
+
+ @Test
+ public void testGetBytesFromLimitedInputStream() throws IOException {
+ InputStream stream = ByteStreams.limit(new ByteArrayInputStream(testData),
Integer.MAX_VALUE);
+ byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+ assertArrayEquals(testData, bytes);
+ assertEquals(0, stream.available());
+ }
+
+ @Test
+ public void testGetBytesFromEmptyLimitedInputStream() throws IOException {
+ InputStream stream = ByteStreams.limit(new ByteArrayInputStream(testData),
0);
+ byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+ assertArrayEquals(new byte[0], bytes);
+ assertEquals(0, stream.available());
+ }
+
+ @Test
+ public void testGetBytesFromRepeatedInputStream() throws IOException {
+ byte[] largeBytes = new byte[2 * 1024 * 1024];
+ Arrays.fill(largeBytes, (byte) 1);
+ InputStream stream = ByteStreams.limit(new
ByteArrayInputStream(largeBytes), Integer.MAX_VALUE);
+ byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+ assertArrayEquals(largeBytes, bytes);
+ assertEquals(0, stream.available());
+ }
+
+ public static class LyingInputStream extends FilterInputStream {
+ private final int availableLie;
+
+ public LyingInputStream(InputStream in, int availableLie) {
+ super(in);
+ this.availableLie = availableLie;
+ }
+
+ @Override
+ public int available() throws IOException {
+ return availableLie;
+ }
+ }
+
+ @Test
+ public void testGetBytesFromHugeAvailable() throws IOException {
+ InputStream wrappedStream = new ByteArrayInputStream(testData);
+ InputStream stream = new LyingInputStream(wrappedStream, Integer.MAX_VALUE
- 1);
+ byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+ assertArrayEquals(testData, bytes);
+ assertEquals(0, wrappedStream.available());
+ }
+
+ @Test
+ public void testGetBytesFromZeroAvailable() throws IOException {
+ InputStream wrappedStream = new ByteArrayInputStream(testData);
+ InputStream stream = new LyingInputStream(wrappedStream, 0);
+ byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+ assertArrayEquals(testData, bytes);
+ assertEquals(0, wrappedStream.available());
+ }
+
+ @Test
+ public void testGetBytesFromOneExtraAvailable() throws IOException {
+ InputStream wrappedStream = new ByteArrayInputStream(testData);
+ InputStream stream = new LyingInputStream(wrappedStream,
wrappedStream.available() + 1);
+ byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+ assertArrayEquals(testData, bytes);
+ assertEquals(0, wrappedStream.available());
+ }
+
+ @Test
+ public void testGetBytesFromOneLessAvailable() throws IOException {
+ InputStream wrappedStream = new ByteArrayInputStream(testData);
+ InputStream stream = new LyingInputStream(wrappedStream,
wrappedStream.available() - 1);
+ byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+ assertArrayEquals(testData, bytes);
+ assertEquals(0, wrappedStream.available());
+ }
}