This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 8e90c36ff1b [Java] Optimize StreamUtils to avoid copying bytes (#37912)
8e90c36ff1b is described below

commit 8e90c36ff1b6953cd792cdeb20d82853aee22ca0
Author: Sam Whittle <[email protected]>
AuthorDate: Wed Mar 25 13:17:50 2026 +0000

    [Java] Optimize StreamUtils to avoid copying bytes (#37912)
    
    * [Java] Optimize StreamUtils to avoid copying bytes in cases where
    - the common case where available matches the actual size of the input 
stream
    - fix optimization for known input streams when it is wrapped by 
UnownedOutputStream
---
 .../java/org/apache/beam/sdk/util/StreamUtils.java | 47 +++++++++--
 .../apache/beam/sdk/util/UnownedInputStream.java   |  4 +
 .../org/apache/beam/sdk/util/CoderUtilsTest.java   | 10 +++
 .../org/apache/beam/sdk/util/StreamUtilsTest.java  | 95 ++++++++++++++++++++++
 4 files changed, 150 insertions(+), 6 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java
index 28c604361fd..16396551335 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java
@@ -22,6 +22,8 @@ import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.lang.ref.SoftReference;
+import java.util.Arrays;
+import javax.annotation.Nullable;
 import org.apache.beam.sdk.annotations.Internal;
 
 /** Utility functions for stream operations. */
@@ -35,34 +37,67 @@ public class StreamUtils {
 
   private static final int BUF_SIZE = 8192;
 
-  private static ThreadLocal<SoftReference<byte[]>> threadLocalBuffer = new 
ThreadLocal<>();
+  private static final ThreadLocal<SoftReference<byte[]>> threadLocalBuffer = 
new ThreadLocal<>();
 
   /** Efficient converting stream to bytes. */
   public static byte[] getBytesWithoutClosing(InputStream stream) throws 
IOException {
+    // Unwrap the stream so the below optimizations based upon class type 
function properly.
+    // We don't use mark or reset in this function.
+    while (stream instanceof UnownedInputStream) {
+      stream = ((UnownedInputStream) stream).getWrappedStream();
+    }
+
     if (stream instanceof ExposedByteArrayInputStream) {
       // Fast path for the exposed version.
       return ((ExposedByteArrayInputStream) stream).readAll();
-    } else if (stream instanceof ByteArrayInputStream) {
+    }
+    if (stream instanceof ByteArrayInputStream) {
       // Fast path for ByteArrayInputStream.
       byte[] ret = new byte[stream.available()];
       stream.read(ret);
       return ret;
     }
-    // Falls back to normal stream copying.
+
+    // Most inputs are fully available so we attempt to first read directly
+    // into a buffer of the right size, assuming available reflects all the 
bytes.
+    int available = stream.available();
+    @Nullable ByteArrayOutputStream outputStream = null;
+    if (available > 0 && available < 1024 * 1024) {
+      byte[] initialBuffer = new byte[available];
+      int initialReadSize = stream.read(initialBuffer);
+      if (initialReadSize == -1) {
+        return new byte[0];
+      }
+      int nextByte = stream.read();
+      if (nextByte == -1) {
+        if (initialReadSize == available) {
+          // Available reflected the full buffer and we copied directly to the
+          // right size.
+          return initialBuffer;
+        }
+        return Arrays.copyOf(initialBuffer, initialReadSize);
+      }
+      outputStream = new ByteArrayOutputStream();
+      outputStream.write(initialBuffer, 0, initialReadSize);
+      outputStream.write(nextByte);
+    } else {
+      outputStream = new ByteArrayOutputStream();
+    }
+
+    // Normal stream copying using the thread-local buffer.
     SoftReference<byte[]> refBuffer = threadLocalBuffer.get();
     byte[] buffer = refBuffer == null ? null : refBuffer.get();
     if (buffer == null) {
       buffer = new byte[BUF_SIZE];
       threadLocalBuffer.set(new SoftReference<>(buffer));
     }
-    ByteArrayOutputStream outStream = new ByteArrayOutputStream();
     while (true) {
       int r = stream.read(buffer);
       if (r == -1) {
         break;
       }
-      outStream.write(buffer, 0, r);
+      outputStream.write(buffer, 0, r);
     }
-    return outStream.toByteArray();
+    return outputStream.toByteArray();
   }
 }
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java
index acf70ed6b00..345e6a8763b 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java
@@ -35,6 +35,10 @@ public class UnownedInputStream extends FilterInputStream {
     super(delegate);
   }
 
+  InputStream getWrappedStream() {
+    return in;
+  }
+
   @Override
   public void close() throws IOException {
     throw new UnsupportedOperationException(
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java
index 13208160181..943ac5ddc8c 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java
@@ -18,6 +18,7 @@
 package org.apache.beam.sdk.util;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertThrows;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
@@ -26,7 +27,9 @@ import static org.mockito.Mockito.mock;
 
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.nio.charset.StandardCharsets;
 import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.Coder.Context;
 import org.apache.beam.sdk.coders.CoderException;
@@ -142,4 +145,11 @@ public class CoderUtilsTest {
         CoderException.class,
         () -> CoderUtils.decodeFromByteString(StringUtf8Coder.of(), 
byteString, Context.NESTED));
   }
+
+  @Test
+  public void testDecodeByteArrayWithoutCopy() throws Exception {
+    byte[] data = "test data".getBytes(StandardCharsets.UTF_8);
+    byte[] result = CoderUtils.decodeFromByteArray(ByteArrayCoder.of(), data);
+    assertSame(data, result);
+  }
 }
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java
index 68f87d73763..c081c0a33e6 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java
@@ -23,9 +23,11 @@ import static org.junit.Assert.assertSame;
 
 import java.io.BufferedInputStream;
 import java.io.ByteArrayInputStream;
+import java.io.FilterInputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.util.Arrays;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -68,4 +70,97 @@ public class StreamUtilsTest {
     assertArrayEquals(testData, bytes);
     assertEquals(0, stream.available());
   }
+
+  @Test
+  public void testGetBytesFromUnownedInputStreamAroundExposed() throws 
IOException {
+    InputStream stream = new UnownedInputStream(new 
ExposedByteArrayInputStream(testData));
+    byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+    assertArrayEquals(testData, bytes);
+    assertSame(testData, bytes);
+    assertEquals(0, stream.available());
+  }
+
+  @Test
+  public void testGetBytesFromUnownedInputStreamAroundArray() throws 
IOException {
+    InputStream stream = new UnownedInputStream(new 
ByteArrayInputStream(testData));
+    byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+    assertArrayEquals(testData, bytes);
+    assertEquals(0, stream.available());
+  }
+
+  @Test
+  public void testGetBytesFromLimitedInputStream() throws IOException {
+    InputStream stream = ByteStreams.limit(new ByteArrayInputStream(testData), 
Integer.MAX_VALUE);
+    byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+    assertArrayEquals(testData, bytes);
+    assertEquals(0, stream.available());
+  }
+
+  @Test
+  public void testGetBytesFromEmptyLimitedInputStream() throws IOException {
+    InputStream stream = ByteStreams.limit(new ByteArrayInputStream(testData), 
0);
+    byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+    assertArrayEquals(new byte[0], bytes);
+    assertEquals(0, stream.available());
+  }
+
+  @Test
+  public void testGetBytesFromRepeatedInputStream() throws IOException {
+    byte[] largeBytes = new byte[2 * 1024 * 1024];
+    Arrays.fill(largeBytes, (byte) 1);
+    InputStream stream = ByteStreams.limit(new 
ByteArrayInputStream(largeBytes), Integer.MAX_VALUE);
+    byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+    assertArrayEquals(largeBytes, bytes);
+    assertEquals(0, stream.available());
+  }
+
+  public static class LyingInputStream extends FilterInputStream {
+    private final int availableLie;
+
+    public LyingInputStream(InputStream in, int availableLie) {
+      super(in);
+      this.availableLie = availableLie;
+    }
+
+    @Override
+    public int available() throws IOException {
+      return availableLie;
+    }
+  }
+
+  @Test
+  public void testGetBytesFromHugeAvailable() throws IOException {
+    InputStream wrappedStream = new ByteArrayInputStream(testData);
+    InputStream stream = new LyingInputStream(wrappedStream, Integer.MAX_VALUE 
- 1);
+    byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+    assertArrayEquals(testData, bytes);
+    assertEquals(0, wrappedStream.available());
+  }
+
+  @Test
+  public void testGetBytesFromZeroAvailable() throws IOException {
+    InputStream wrappedStream = new ByteArrayInputStream(testData);
+    InputStream stream = new LyingInputStream(wrappedStream, 0);
+    byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+    assertArrayEquals(testData, bytes);
+    assertEquals(0, wrappedStream.available());
+  }
+
+  @Test
+  public void testGetBytesFromOneExtraAvailable() throws IOException {
+    InputStream wrappedStream = new ByteArrayInputStream(testData);
+    InputStream stream = new LyingInputStream(wrappedStream, 
wrappedStream.available() + 1);
+    byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+    assertArrayEquals(testData, bytes);
+    assertEquals(0, wrappedStream.available());
+  }
+
+  @Test
+  public void testGetBytesFromOneLessAvailable() throws IOException {
+    InputStream wrappedStream = new ByteArrayInputStream(testData);
+    InputStream stream = new LyingInputStream(wrappedStream, 
wrappedStream.available() - 1);
+    byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
+    assertArrayEquals(testData, bytes);
+    assertEquals(0, wrappedStream.available());
+  }
 }

Reply via email to