This is an automated email from the ASF dual-hosted git repository. scwhittle pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 6ce104769c4 Change StateBackedIterable to implement ElementByteSizeObservableIterable avoiding iteration to estimate observe bytes. (#29517) 6ce104769c4 is described below commit 6ce104769c45b56a01760f8e6574e2290cd7c4e8 Author: Sam Whittle <scwhit...@users.noreply.github.com> AuthorDate: Thu Nov 23 12:13:05 2023 +0100 Change StateBackedIterable to implement ElementByteSizeObservableIterable avoiding iteration to estimate observe bytes. (#29517) * Change StateBackedIterable to implement ElementByteSizeObservableIterable reducing byte estimation costs. --- .../beam/fn/harness/state/StateBackedIterable.java | 87 +++++++++++++++++++++- .../fn/harness/state/StateBackedIterableTest.java | 58 +++++++++++++++ 2 files changed, 142 insertions(+), 3 deletions(-) diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java index 9c95e9ad90e..22e0822b619 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java @@ -43,12 +43,17 @@ import org.apache.beam.sdk.fn.stream.PrefetchableIterable; import org.apache.beam.sdk.fn.stream.PrefetchableIterators; import org.apache.beam.sdk.util.BufferedElementCountingOutputStream; import org.apache.beam.sdk.util.VarInt; +import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable; +import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator; +import org.apache.beam.sdk.util.common.ElementByteSizeObserver; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A {@link BeamFnStateClient state} backed iterable which allows for fetching elements over the @@ -62,12 +67,17 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams @SuppressWarnings({ "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) }) -public class StateBackedIterable<T> implements Iterable<T>, Serializable { +public class StateBackedIterable<T> + extends ElementByteSizeObservableIterable<T, ElementByteSizeObservableIterator<T>> + implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(StateBackedIterable.class); @VisibleForTesting final StateRequest request; @VisibleForTesting final List<T> prefix; private final transient PrefetchableIterable<T> suffix; + private final org.apache.beam.sdk.coders.Coder<T> elemCoder; + public StateBackedIterable( Cache<?, ?> cache, BeamFnStateClient beamFnStateClient, @@ -81,11 +91,82 @@ public class StateBackedIterable<T> implements Iterable<T>, Serializable { this.suffix = StateFetchingIterators.readAllAndDecodeStartingFrom( Caches.subCache(cache, stateKey), beamFnStateClient, request, elemCoder); + this.elemCoder = elemCoder; + } + + @SuppressWarnings("nullness") + private static class WrappedObservingIterator<T> extends ElementByteSizeObservableIterator<T> { + private final Iterator<T> wrappedIterator; + private final org.apache.beam.sdk.coders.Coder<T> elementCoder; + + // Logically final and non-null but initialized after construction by factory method for + // initialization ordering. + private ElementByteSizeObserver observerProxy = null; + + private boolean observerNeedsAdvance = false; + private boolean exceptionLogged = false; + + static <T> WrappedObservingIterator<T> create( + Iterator<T> iterator, org.apache.beam.sdk.coders.Coder<T> elementCoder) { + WrappedObservingIterator<T> result = new WrappedObservingIterator<>(iterator, elementCoder); + result.observerProxy = + new ElementByteSizeObserver() { + @Override + protected void reportElementSize(long elementByteSize) { + result.notifyValueReturned(elementByteSize); + } + }; + return result; + } + + private WrappedObservingIterator( + Iterator<T> iterator, org.apache.beam.sdk.coders.Coder<T> elementCoder) { + this.wrappedIterator = iterator; + this.elementCoder = elementCoder; + } + + @Override + public boolean hasNext() { + if (observerNeedsAdvance) { + observerProxy.advance(); + observerNeedsAdvance = false; + } + return wrappedIterator.hasNext(); + } + + @Override + public T next() { + T value = wrappedIterator.next(); + try { + elementCoder.registerByteSizeObserver(value, observerProxy); + if (observerProxy.getIsLazy()) { + // The observer will only be notified of bytes as the result + // is used. We defer advancing the observer until hasNext in an + // attempt to capture those bytes. + observerNeedsAdvance = true; + } else { + observerNeedsAdvance = false; + observerProxy.advance(); + } + } catch (Exception e) { + if (!exceptionLogged) { + LOG.warn("Lazily observed byte size will be under reported due to exception", e); + exceptionLogged = true; + } + } + return value; + } + + @Override + public void remove() { + super.remove(); + } } @Override - public Iterator<T> iterator() { - return PrefetchableIterators.concat(prefix.iterator(), suffix.iterator()); + protected ElementByteSizeObservableIterator<T> createIterator() { + return WrappedObservingIterator.create( + PrefetchableIterators.concat(prefix.iterator(), suffix.iterator()), elemCoder); } protected Object writeReplace() throws ObjectStreamException { diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java index 4d53bcaef11..f758c367f73 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java @@ -19,6 +19,7 @@ package org.apache.beam.fn.harness.state; import static java.util.Arrays.asList; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -36,11 +37,13 @@ import org.apache.beam.fn.harness.Caches; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.util.common.ElementByteSizeObserver; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; import org.junit.Test; import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; @@ -213,6 +216,61 @@ public class StateBackedIterableTest { } } } + + private static class TestByteObserver extends ElementByteSizeObserver { + public long total = 0; + + @Override + protected void reportElementSize(long elementByteSize) { + total += elementByteSize; + } + }; + + @Test + public void testByteObservingStateBackedIterable() throws Exception { + FakeBeamFnStateClient fakeBeamFnStateClient = + new FakeBeamFnStateClient( + StringUtf8Coder.of(), + ImmutableMap.of( + key("nonEmptySuffix"), asList("C", "D", "E", "F", "G", "H", "I", "J", "K"), + key("emptySuffix"), asList())); + + StateBackedIterable<String> iterable = + new StateBackedIterable<>( + Caches.noop(), + fakeBeamFnStateClient, + "instruction", + key(suffixKey), + StringUtf8Coder.of(), + prefix); + StateBackedIterable.Coder<String> coder = + new StateBackedIterable.Coder<>( + () -> Caches.noop(), + fakeBeamFnStateClient, + () -> "instructionId", + StringUtf8Coder.of()); + + assertTrue(coder.isRegisterByteSizeObserverCheap(iterable)); + TestByteObserver observer = new TestByteObserver(); + coder.registerByteSizeObserver(iterable, observer); + assertTrue(observer.getIsLazy()); + + long iterateBytes = + Streams.stream(iterable) + .mapToLong( + s -> { + try { + // 1 comes from hasNext = true flag (see IterableLikeCoder) + return 1 + StringUtf8Coder.of().getEncodedElementByteSize(s); + } catch (Exception e) { + throw new RuntimeException(e); + } + }) + .sum(); + observer.advance(); + // 5 comes from size and hasNext (see IterableLikeCoder) + assertEquals(iterateBytes + 5, observer.total); + } } @RunWith(JUnit4.class)