This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 24efe7619d9 Track windmill current active work budget. (#30048)
24efe7619d9 is described below

commit 24efe7619d9a0e5fae70a6d69beabb2d23f362b1
Author: martin trieu <marti...@google.com>
AuthorDate: Thu Feb 8 02:00:15 2024 -0800

    Track windmill current active work budget. (#30048)
---
 .../dataflow/worker/StreamingDataflowWorker.java   |  6 +-
 .../dataflow/worker/streaming/ActiveWorkState.java | 79 +++++++++++++++-----
 .../runners/dataflow/worker/streaming/Work.java    |  2 -
 .../worker/streaming/ActiveWorkStateTest.java      | 83 +++++++++++++++++++---
 4 files changed, 142 insertions(+), 28 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 463ab953fae..e8ca3a2834f 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -1970,8 +1970,10 @@ public class StreamingDataflowWorker {
           failedWork
               .computeIfAbsent(heartbeatResponse.getShardingKey(), key -> new 
ArrayList<>())
               .add(
-                  new FailedTokens(
-                      heartbeatResponse.getWorkToken(), 
heartbeatResponse.getCacheToken()));
+                  FailedTokens.newBuilder()
+                      .setWorkToken(heartbeatResponse.getWorkToken())
+                      .setCacheToken(heartbeatResponse.getCacheToken())
+                      .build());
         }
       }
       ComputationState state = 
computationMap.get(computationHeartbeatResponse.getComputationId());
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
index ff46356d956..b4b46932393 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java
@@ -19,6 +19,7 @@ package org.apache.beam.runners.dataflow.worker.streaming;
 
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
 
+import com.google.auto.value.AutoValue;
 import java.io.PrintWriter;
 import java.util.ArrayDeque;
 import java.util.Deque;
@@ -28,6 +29,7 @@ import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Optional;
 import java.util.Queue;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiConsumer;
 import java.util.stream.Stream;
 import javax.annotation.Nullable;
@@ -38,6 +40,7 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
 import 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
 import org.apache.beam.sdk.annotations.Internal;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
@@ -70,11 +73,19 @@ public final class ActiveWorkState {
   @GuardedBy("this")
   private final WindmillStateCache.ForComputation computationStateCache;
 
+  /**
+   * Current budget that is being processed or queued on the user worker. 
Incremented when work is
+   * activated in {@link #activateWorkForKey(ShardedKey, Work)}, and 
decremented when work is
+   * completed in {@link #completeWorkAndGetNextWorkForKey(ShardedKey, long)}.
+   */
+  private final AtomicReference<GetWorkBudget> activeGetWorkBudget;
+
   private ActiveWorkState(
       Map<ShardedKey, Deque<Work>> activeWork,
       WindmillStateCache.ForComputation computationStateCache) {
     this.activeWork = activeWork;
     this.computationStateCache = computationStateCache;
+    this.activeGetWorkBudget = new AtomicReference<>(GetWorkBudget.noBudget());
   }
 
   static ActiveWorkState create(WindmillStateCache.ForComputation 
computationStateCache) {
@@ -88,6 +99,12 @@ public final class ActiveWorkState {
     return new ActiveWorkState(activeWork, computationStateCache);
   }
 
+  private static String elapsedString(Instant start, Instant end) {
+    Duration activeFor = new Duration(start, end);
+    // Duration's toString always starts with "PT"; remove that here.
+    return activeFor.toString().substring(2);
+  }
+
   /**
    * Activates {@link Work} for the {@link ShardedKey}. Outcome can be 1 of 3 
{@link
    * ActivateWorkResult}
@@ -103,12 +120,12 @@ public final class ActiveWorkState {
    */
   synchronized ActivateWorkResult activateWorkForKey(ShardedKey shardedKey, 
Work work) {
     Deque<Work> workQueue = activeWork.getOrDefault(shardedKey, new 
ArrayDeque<>());
-
     // This key does not have any work queued up on it. Create one, insert 
Work, and mark the work
     // to be executed.
     if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) {
       workQueue.addLast(work);
       activeWork.put(shardedKey, workQueue);
+      incrementActiveWorkBudget(work);
       return ActivateWorkResult.EXECUTE;
     }
 
@@ -121,16 +138,27 @@ public final class ActiveWorkState {
 
     // Queue the work for later processing.
     workQueue.addLast(work);
+    incrementActiveWorkBudget(work);
     return ActivateWorkResult.QUEUED;
   }
 
-  public static final class FailedTokens {
-    public long workToken;
-    public long cacheToken;
+  @AutoValue
+  public abstract static class FailedTokens {
+    public static Builder newBuilder() {
+      return new AutoValue_ActiveWorkState_FailedTokens.Builder();
+    }
+
+    public abstract long workToken();
+
+    public abstract long cacheToken();
+
+    @AutoValue.Builder
+    public abstract static class Builder {
+      public abstract Builder setWorkToken(long value);
 
-    public FailedTokens(long workToken, long cacheToken) {
-      this.workToken = workToken;
-      this.cacheToken = cacheToken;
+      public abstract Builder setCacheToken(long value);
+
+      public abstract FailedTokens build();
     }
   }
 
@@ -148,17 +176,17 @@ public final class ActiveWorkState {
       for (FailedTokens failedToken : failedTokens) {
         for (Work queuedWork : entry.getValue()) {
           WorkItem workItem = queuedWork.getWorkItem();
-          if (workItem.getWorkToken() == failedToken.workToken
-              && workItem.getCacheToken() == failedToken.cacheToken) {
+          if (workItem.getWorkToken() == failedToken.workToken()
+              && workItem.getCacheToken() == failedToken.cacheToken()) {
             LOG.debug(
                 "Failing work "
                     + computationStateCache.getComputation()
                     + " "
                     + entry.getKey().shardingKey()
                     + " "
-                    + failedToken.workToken
+                    + failedToken.workToken()
                     + " "
-                    + failedToken.cacheToken
+                    + failedToken.cacheToken()
                     + ". The work will be retried and is not lost.");
             queuedWork.setFailed();
             break;
@@ -168,6 +196,16 @@ public final class ActiveWorkState {
     }
   }
 
+  private void incrementActiveWorkBudget(Work work) {
+    activeGetWorkBudget.updateAndGet(
+        getWorkBudget -> getWorkBudget.apply(1, 
work.getWorkItem().getSerializedSize()));
+  }
+
+  private void decrementActiveWorkBudget(Work work) {
+    activeGetWorkBudget.updateAndGet(
+        getWorkBudget -> getWorkBudget.subtract(1, 
work.getWorkItem().getSerializedSize()));
+  }
+
   /**
    * Removes the complete work from the {@link Queue<Work>}. The {@link Work} 
is marked as completed
    * if its workToken matches the one that is passed in. Returns the next 
{@link Work} in the {@link
@@ -208,6 +246,7 @@ public final class ActiveWorkState {
 
     // We consumed the matching work item.
     workQueue.remove();
+    decrementActiveWorkBudget(completedWork);
   }
 
   private synchronized Optional<Work> getNextWork(Queue<Work> workQueue, 
ShardedKey shardedKey) {
@@ -285,6 +324,15 @@ public final class ActiveWorkState {
                     .build());
   }
 
+  /**
+   * Returns the current aggregate {@link GetWorkBudget} that is active on the 
user worker. Active
+   * means that the work is received from Windmill, being processed or queued 
to be processed in
+   * {@link ActiveWorkState}, and not committed back to Windmill.
+   */
+  GetWorkBudget currentActiveWorkBudget() {
+    return activeGetWorkBudget.get();
+  }
+
   synchronized void printActiveWork(PrintWriter writer, Instant now) {
     writer.println(
         "<table border=\"1\" "
@@ -328,12 +376,11 @@ public final class ActiveWorkState {
       writer.println(commitsPendingCount - MAX_PRINTABLE_COMMIT_PENDING_KEYS);
       writer.println("<br>");
     }
-  }
 
-  private static String elapsedString(Instant start, Instant end) {
-    Duration activeFor = new Duration(start, end);
-    // Duration's toString always starts with "PT"; remove that here.
-    return activeFor.toString().substring(2);
+    writer.println("<br>");
+    writer.println("Current Active Work Budget: ");
+    writer.println(currentActiveWorkBudget());
+    writer.println("<br>");
   }
 
   enum ActivateWorkResult {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
index 8d4ba33a1ab..6c85c615af1 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
@@ -42,14 +42,12 @@ import org.joda.time.Instant;
 
 @NotThreadSafe
 public class Work implements Runnable {
-
   private final Windmill.WorkItem workItem;
   private final Supplier<Instant> clock;
   private final Instant startTime;
   private final Map<Windmill.LatencyAttribution.State, Duration> 
totalDurationPerState;
   private final Consumer<Work> processWorkFn;
   private TimedState currentState;
-
   private volatile boolean isFailed;
 
   private Work(Windmill.WorkItem workItem, Supplier<Instant> clock, 
Consumer<Work> processWorkFn) {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
index b384bb03185..82ff24c03bb 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java
@@ -32,12 +32,12 @@ import java.util.Deque;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Optional;
-import javax.annotation.Nullable;
 import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
 import 
org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
 import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import org.joda.time.Instant;
@@ -50,9 +50,9 @@ import org.junit.runners.JUnit4;
 
 @RunWith(JUnit4.class)
 public class ActiveWorkStateTest {
-  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
   private final WindmillStateCache.ForComputation computationStateCache =
       mock(WindmillStateCache.ForComputation.class);
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
   private Map<ShardedKey, Deque<Work>> readOnlyActiveWork;
 
   private ActiveWorkState activeWorkState;
@@ -61,11 +61,7 @@ public class ActiveWorkStateTest {
     return ShardedKey.create(ByteString.copyFromUtf8(str), shardKey);
   }
 
-  private static Work emptyWork() {
-    return createWork(null);
-  }
-
-  private static Work createWork(@Nullable Windmill.WorkItem workItem) {
+  private static Work createWork(Windmill.WorkItem workItem) {
     return Work.create(workItem, Instant::now, Collections.emptyList(), unused 
-> {});
   }
 
@@ -92,7 +88,8 @@ public class ActiveWorkStateTest {
   @Test
   public void testActivateWorkForKey_EXECUTE_unknownKey() {
     ActivateWorkResult activateWorkResult =
-        activeWorkState.activateWorkForKey(shardedKey("someKey", 1L), 
emptyWork());
+        activeWorkState.activateWorkForKey(
+            shardedKey("someKey", 1L), createWork(createWorkItem(1L)));
 
     assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult);
   }
@@ -214,6 +211,76 @@ public class ActiveWorkStateTest {
     assertFalse(readOnlyActiveWork.containsKey(shardedKey));
   }
 
+  @Test
+  public void 
testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_oneShardKey() {
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+    Work work1 = createWork(createWorkItem(1L));
+    Work work2 = createWork(createWorkItem(2L));
+
+    activeWorkState.activateWorkForKey(shardedKey, work1);
+    activeWorkState.activateWorkForKey(shardedKey, work2);
+
+    GetWorkBudget expectedActiveBudget1 =
+        GetWorkBudget.builder()
+            .setItems(2)
+            .setBytes(
+                work1.getWorkItem().getSerializedSize() + 
work2.getWorkItem().getSerializedSize())
+            .build();
+
+    
assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget1);
+
+    activeWorkState.completeWorkAndGetNextWorkForKey(
+        shardedKey, work1.getWorkItem().getWorkToken());
+
+    GetWorkBudget expectedActiveBudget2 =
+        GetWorkBudget.builder()
+            .setItems(1)
+            .setBytes(work1.getWorkItem().getSerializedSize())
+            .build();
+
+    
assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget2);
+  }
+
+  @Test
+  public void 
testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_whenWorkCompleted()
 {
+    ShardedKey shardedKey = shardedKey("someKey", 1L);
+    Work work1 = createWork(createWorkItem(1L));
+    Work work2 = createWork(createWorkItem(2L));
+
+    activeWorkState.activateWorkForKey(shardedKey, work1);
+    activeWorkState.activateWorkForKey(shardedKey, work2);
+    activeWorkState.completeWorkAndGetNextWorkForKey(
+        shardedKey, work1.getWorkItem().getWorkToken());
+
+    GetWorkBudget expectedActiveBudget =
+        GetWorkBudget.builder()
+            .setItems(1)
+            .setBytes(work1.getWorkItem().getSerializedSize())
+            .build();
+
+    
assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget);
+  }
+
+  @Test
+  public void 
testCurrentActiveWorkBudget_correctlyAggregatesActiveWorkBudget_multipleShardKeys()
 {
+    ShardedKey shardedKey1 = shardedKey("someKey", 1L);
+    ShardedKey shardedKey2 = shardedKey("someKey", 2L);
+    Work work1 = createWork(createWorkItem(1L));
+    Work work2 = createWork(createWorkItem(2L));
+
+    activeWorkState.activateWorkForKey(shardedKey1, work1);
+    activeWorkState.activateWorkForKey(shardedKey2, work2);
+
+    GetWorkBudget expectedActiveBudget =
+        GetWorkBudget.builder()
+            .setItems(2)
+            .setBytes(
+                work1.getWorkItem().getSerializedSize() + 
work2.getWorkItem().getSerializedSize())
+            .build();
+
+    
assertThat(activeWorkState.currentActiveWorkBudget()).isEqualTo(expectedActiveBudget);
+  }
+
   @Test
   public void testInvalidateStuckCommits() {
     Map<ShardedKey, Long> invalidatedCommits = new HashMap<>();

Reply via email to