This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 7356785e9ba Calculate byte size via sampling in StateBackedIterable if 
size is not cheap to calculate (#33780)
7356785e9ba is described below

commit 7356785e9bae219b24ab023252bfc6f3d54dce77
Author: RadosÅ‚aw Stankiewicz <[email protected]>
AuthorDate: Tue Feb 4 10:01:10 2025 +0100

    Calculate byte size via sampling in StateBackedIterable if size is not 
cheap to calculate (#33780)
---
 .../beam/fn/harness/state/StateBackedIterable.java | 42 +++++++++++++++++-----
 .../fn/harness/state/StateBackedIterableTest.java  |  4 ++-
 2 files changed, 36 insertions(+), 10 deletions(-)

diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
index 8030ca334ef..7b6a6195f32 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
@@ -30,6 +30,8 @@ import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+
+import java.util.concurrent.ThreadLocalRandom;
 import java.util.function.Supplier;
 import org.apache.beam.fn.harness.Cache;
 import org.apache.beam.fn.harness.Caches;
@@ -106,6 +108,11 @@ public class StateBackedIterable<T>
     private boolean observerNeedsAdvance = false;
     private boolean exceptionLogged = false;
 
+    // Lowest sampling probability: 0.001%.
+    private static final int SAMPLING_TOKEN_UPPER_BOUND = 1000000;
+    private static final int SAMPLING_CUTOFF = 10;
+    private int samplingToken = 0;
+
     static <T> WrappedObservingIterator<T> create(
         Iterator<T> iterator, org.apache.beam.sdk.coders.Coder<T> 
elementCoder) {
       WrappedObservingIterator<T> result = new 
WrappedObservingIterator<>(iterator, elementCoder);
@@ -125,6 +132,18 @@ public class StateBackedIterable<T>
       this.elementCoder = elementCoder;
     }
 
+    private boolean sampleElement() {
+      // Sampling probability decreases as the element count is increasing.
+      // We unconditionally sample the first samplingCutoff elements. For the
+      // next samplingCutoff elements, the sampling probability drops from 100%
+      // to 50%. The probability of sampling the Nth element is:
+      // min(1, samplingCutoff / N), with an additional lower bound of
+      // samplingCutoff / samplingTokenUpperBound. This algorithm may be 
refined
+      // later.
+      samplingToken = Math.min(samplingToken + 1, SAMPLING_TOKEN_UPPER_BOUND);
+      return ThreadLocalRandom.current().nextInt(samplingToken) < 
SAMPLING_CUTOFF;
+    }
+
     @Override
     public boolean hasNext() {
       if (observerNeedsAdvance) {
@@ -138,15 +157,20 @@ public class StateBackedIterable<T>
     public T next() {
       T value = wrappedIterator.next();
       try {
-        elementCoder.registerByteSizeObserver(value, observerProxy);
-        if (observerProxy.getIsLazy()) {
-          // The observer will only be notified of bytes as the result
-          // is used. We defer advancing the observer until hasNext in an
-          // attempt to capture those bytes.
-          observerNeedsAdvance = true;
-        } else {
-          observerNeedsAdvance = false;
-          observerProxy.advance();
+        boolean cheap = elementCoder.isRegisterByteSizeObserverCheap(value);
+        if (cheap || sampleElement()) {
+          observerProxy.setScalingFactor(
+              cheap ? 1.0 : Math.max(samplingToken, SAMPLING_CUTOFF) / 
(double) SAMPLING_CUTOFF);
+          elementCoder.registerByteSizeObserver(value, observerProxy);
+          if (observerProxy.getIsLazy()) {
+            // The observer will only be notified of bytes as the result
+            // is used. We defer advancing the observer until hasNext in an
+            // attempt to capture those bytes.
+            observerNeedsAdvance = true;
+          } else {
+            observerNeedsAdvance = false;
+            observerProxy.advance();
+          }
         }
       } catch (Exception e) {
         if (!exceptionLogged) {
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
index fdb373c269b..0e2598cf078 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
@@ -269,7 +269,9 @@ public class StateBackedIterableTest {
               .sum();
       observer.advance();
       // 5 comes from size and hasNext (see IterableLikeCoder)
-      assertEquals(iterateBytes + 5, observer.total);
+      // observer receives scaled, StringUtf8Coder is not cheap so sampling 
may produce value that
+      // is off
+      assertEquals((float) iterateBytes + 5, (float) observer.total, 3);
     }
   }
 

Reply via email to