This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 6ce104769c4 Change StateBackedIterable to implement 
ElementByteSizeObservableIterable avoiding iteration to estimate observe bytes. 
(#29517)
6ce104769c4 is described below

commit 6ce104769c45b56a01760f8e6574e2290cd7c4e8
Author: Sam Whittle <scwhit...@users.noreply.github.com>
AuthorDate: Thu Nov 23 12:13:05 2023 +0100

    Change StateBackedIterable to implement ElementByteSizeObservableIterable 
avoiding iteration to estimate observe bytes. (#29517)
    
    * Change StateBackedIterable to implement ElementByteSizeObservableIterable 
reducing byte estimation costs.
---
 .../beam/fn/harness/state/StateBackedIterable.java | 87 +++++++++++++++++++++-
 .../fn/harness/state/StateBackedIterableTest.java  | 58 +++++++++++++++
 2 files changed, 142 insertions(+), 3 deletions(-)

diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
index 9c95e9ad90e..22e0822b619 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
@@ -43,12 +43,17 @@ import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
 import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
 import org.apache.beam.sdk.util.BufferedElementCountingOutputStream;
 import org.apache.beam.sdk.util.VarInt;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator;
+import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * A {@link BeamFnStateClient state} backed iterable which allows for fetching 
elements over the
@@ -62,12 +67,17 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams
 @SuppressWarnings({
   "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
 })
-public class StateBackedIterable<T> implements Iterable<T>, Serializable {
+public class StateBackedIterable<T>
+    extends ElementByteSizeObservableIterable<T, 
ElementByteSizeObservableIterator<T>>
+    implements Serializable {
+  private static final Logger LOG = 
LoggerFactory.getLogger(StateBackedIterable.class);
 
   @VisibleForTesting final StateRequest request;
   @VisibleForTesting final List<T> prefix;
   private final transient PrefetchableIterable<T> suffix;
 
+  private final org.apache.beam.sdk.coders.Coder<T> elemCoder;
+
   public StateBackedIterable(
       Cache<?, ?> cache,
       BeamFnStateClient beamFnStateClient,
@@ -81,11 +91,82 @@ public class StateBackedIterable<T> implements Iterable<T>, 
Serializable {
     this.suffix =
         StateFetchingIterators.readAllAndDecodeStartingFrom(
             Caches.subCache(cache, stateKey), beamFnStateClient, request, 
elemCoder);
+    this.elemCoder = elemCoder;
+  }
+
+  @SuppressWarnings("nullness")
+  private static class WrappedObservingIterator<T> extends 
ElementByteSizeObservableIterator<T> {
+    private final Iterator<T> wrappedIterator;
+    private final org.apache.beam.sdk.coders.Coder<T> elementCoder;
+
+    // Logically final and non-null but initialized after construction by 
factory method for
+    // initialization ordering.
+    private ElementByteSizeObserver observerProxy = null;
+
+    private boolean observerNeedsAdvance = false;
+    private boolean exceptionLogged = false;
+
+    static <T> WrappedObservingIterator<T> create(
+        Iterator<T> iterator, org.apache.beam.sdk.coders.Coder<T> 
elementCoder) {
+      WrappedObservingIterator<T> result = new 
WrappedObservingIterator<>(iterator, elementCoder);
+      result.observerProxy =
+          new ElementByteSizeObserver() {
+            @Override
+            protected void reportElementSize(long elementByteSize) {
+              result.notifyValueReturned(elementByteSize);
+            }
+          };
+      return result;
+    }
+
+    private WrappedObservingIterator(
+        Iterator<T> iterator, org.apache.beam.sdk.coders.Coder<T> 
elementCoder) {
+      this.wrappedIterator = iterator;
+      this.elementCoder = elementCoder;
+    }
+
+    @Override
+    public boolean hasNext() {
+      if (observerNeedsAdvance) {
+        observerProxy.advance();
+        observerNeedsAdvance = false;
+      }
+      return wrappedIterator.hasNext();
+    }
+
+    @Override
+    public T next() {
+      T value = wrappedIterator.next();
+      try {
+        elementCoder.registerByteSizeObserver(value, observerProxy);
+        if (observerProxy.getIsLazy()) {
+          // The observer will only be notified of bytes as the result
+          // is used. We defer advancing the observer until hasNext in an
+          // attempt to capture those bytes.
+          observerNeedsAdvance = true;
+        } else {
+          observerNeedsAdvance = false;
+          observerProxy.advance();
+        }
+      } catch (Exception e) {
+        if (!exceptionLogged) {
+          LOG.warn("Lazily observed byte size will be under reported due to 
exception", e);
+          exceptionLogged = true;
+        }
+      }
+      return value;
+    }
+
+    @Override
+    public void remove() {
+      super.remove();
+    }
   }
 
   @Override
-  public Iterator<T> iterator() {
-    return PrefetchableIterators.concat(prefix.iterator(), suffix.iterator());
+  protected ElementByteSizeObservableIterator<T> createIterator() {
+    return WrappedObservingIterator.create(
+        PrefetchableIterators.concat(prefix.iterator(), suffix.iterator()), 
elemCoder);
   }
 
   protected Object writeReplace() throws ObjectStreamException {
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
index 4d53bcaef11..f758c367f73 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
@@ -19,6 +19,7 @@ package org.apache.beam.fn.harness.state;
 
 import static java.util.Arrays.asList;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
@@ -36,11 +37,13 @@ import org.apache.beam.fn.harness.Caches;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.util.ByteStringOutputStream;
+import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams;
 import org.junit.Test;
 import org.junit.experimental.runners.Enclosed;
 import org.junit.runner.RunWith;
@@ -213,6 +216,61 @@ public class StateBackedIterableTest {
         }
       }
     }
+
+    private static class TestByteObserver extends ElementByteSizeObserver {
+      public long total = 0;
+
+      @Override
+      protected void reportElementSize(long elementByteSize) {
+        total += elementByteSize;
+      }
+    };
+
+    @Test
+    public void testByteObservingStateBackedIterable() throws Exception {
+      FakeBeamFnStateClient fakeBeamFnStateClient =
+          new FakeBeamFnStateClient(
+              StringUtf8Coder.of(),
+              ImmutableMap.of(
+                  key("nonEmptySuffix"), asList("C", "D", "E", "F", "G", "H", 
"I", "J", "K"),
+                  key("emptySuffix"), asList()));
+
+      StateBackedIterable<String> iterable =
+          new StateBackedIterable<>(
+              Caches.noop(),
+              fakeBeamFnStateClient,
+              "instruction",
+              key(suffixKey),
+              StringUtf8Coder.of(),
+              prefix);
+      StateBackedIterable.Coder<String> coder =
+          new StateBackedIterable.Coder<>(
+              () -> Caches.noop(),
+              fakeBeamFnStateClient,
+              () -> "instructionId",
+              StringUtf8Coder.of());
+
+      assertTrue(coder.isRegisterByteSizeObserverCheap(iterable));
+      TestByteObserver observer = new TestByteObserver();
+      coder.registerByteSizeObserver(iterable, observer);
+      assertTrue(observer.getIsLazy());
+
+      long iterateBytes =
+          Streams.stream(iterable)
+              .mapToLong(
+                  s -> {
+                    try {
+                      // 1 comes from hasNext = true flag (see 
IterableLikeCoder)
+                      return 1 + 
StringUtf8Coder.of().getEncodedElementByteSize(s);
+                    } catch (Exception e) {
+                      throw new RuntimeException(e);
+                    }
+                  })
+              .sum();
+      observer.advance();
+      // 5 comes from size and hasNext (see IterableLikeCoder)
+      assertEquals(iterateBytes + 5, observer.total);
+    }
   }
 
   @RunWith(JUnit4.class)

Reply via email to