This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c298da550f8 Refactor commit logic out of StreamingDataflowWorker 
(#30312)
c298da550f8 is described below

commit c298da550f88b1fd489bc25c08a40d484833e428
Author: martin trieu <marti...@google.com>
AuthorDate: Fri Mar 15 02:01:18 2024 -0700

    Refactor commit logic out of StreamingDataflowWorker (#30312)
---
 .../dataflow/worker/StreamingDataflowWorker.java   | 237 +++-------------
 .../dataflow/worker/WindmillComputationKey.java    |   5 +
 .../client/CloseableStream.java}                   |  30 +-
 .../worker/windmill/client/WindmillStreamPool.java |   7 +
 .../client/commits}/Commit.java                    |  10 +-
 .../windmill/client/commits/CompleteCommit.java    |  67 +++++
 .../commits/StreamingApplianceWorkCommitter.java   | 167 +++++++++++
 .../commits/StreamingEngineWorkCommitter.java      | 233 ++++++++++++++++
 .../windmill/client/commits/WorkCommitter.java     |  54 ++++
 .../worker/windmill/state/WindmillStateCache.java  |   5 +
 .../dataflow/worker/FakeWindmillServer.java        |  32 ++-
 .../worker/StreamingDataflowWorkerTest.java        |   2 +-
 .../StreamingApplianceWorkCommitterTest.java       | 140 ++++++++++
 .../commits/StreamingEngineWorkCommitterTest.java  | 308 +++++++++++++++++++++
 14 files changed, 1077 insertions(+), 220 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 6f1bb0847bc..4c3ffd08a0b 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -87,17 +87,14 @@ import 
org.apache.beam.runners.dataflow.worker.status.DebugCapture.Capturable;
 import 
org.apache.beam.runners.dataflow.worker.status.LastExceptionDataProvider;
 import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
 import org.apache.beam.runners.dataflow.worker.status.WorkerStatusPages;
-import org.apache.beam.runners.dataflow.worker.streaming.Commit;
 import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
 import org.apache.beam.runners.dataflow.worker.streaming.ExecutionState;
 import 
org.apache.beam.runners.dataflow.worker.streaming.KeyCommitTooLargeException;
 import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
 import org.apache.beam.runners.dataflow.worker.streaming.StageInfo;
-import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
 import org.apache.beam.runners.dataflow.worker.streaming.Work;
 import org.apache.beam.runners.dataflow.worker.streaming.Work.State;
 import 
org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor;
-import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
 import 
org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher;
 import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
 import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
@@ -110,9 +107,13 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribut
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
 import 
org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
-import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
+import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.ChannelzServlet;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillStreamFactory;
@@ -217,9 +218,6 @@ public class StreamingDataflowWorker {
   final WindmillStateCache stateCache;
   // Maps from computation ids to per-computation state.
   private final ConcurrentMap<String, ComputationState> computationMap;
-  private final WeightedBoundedQueue<Commit> commitQueue =
-      WeightedBoundedQueue.create(
-          MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, 
commit.getSize()));
   // Cache of tokens to commit callbacks.
   // Using Cache with time eviction policy helps us to prevent memory leak 
when callback ids are
   // discarded by Dataflow service and calling commitCallback is best-effort.
@@ -234,8 +232,6 @@ public class StreamingDataflowWorker {
   private final BoundedQueueExecutor workUnitExecutor;
   private final WindmillServerStub windmillServer;
   private final Thread dispatchThread;
-  @VisibleForTesting final ImmutableList<Thread> commitThreads;
-  private final AtomicLong activeCommitBytes = new AtomicLong();
   private final AtomicLong previousTimeAtMaxThreads = new AtomicLong();
   private final AtomicBoolean running = new AtomicBoolean();
   private final SideInputStateFetcher sideInputStateFetcher;
@@ -296,6 +292,7 @@ public class StreamingDataflowWorker {
 
   private final DataflowExecutionStateSampler sampler = 
DataflowExecutionStateSampler.instance();
   private final ActiveWorkRefresher activeWorkRefresher;
+  private final WorkCommitter workCommitter;
 
   private StreamingDataflowWorker(
       WindmillServerStub windmillServer,
@@ -403,29 +400,6 @@ public class StreamingDataflowWorker {
     dispatchThread.setPriority(Thread.MIN_PRIORITY);
     dispatchThread.setName("DispatchThread");
 
-    int numCommitThreads = 1;
-    if (windmillServiceEnabled && options.getWindmillServiceCommitThreads() > 
0) {
-      numCommitThreads = options.getWindmillServiceCommitThreads();
-    }
-
-    ImmutableList.Builder<Thread> commitThreadsBuilder = 
ImmutableList.builder();
-    for (int i = 0; i < numCommitThreads; ++i) {
-      Thread commitThread =
-          new Thread(
-              () -> {
-                if (windmillServiceEnabled) {
-                  streamingCommitLoop();
-                } else {
-                  commitLoop();
-                }
-              });
-      commitThread.setDaemon(true);
-      commitThread.setPriority(Thread.MAX_PRIORITY);
-      commitThread.setName("CommitThread " + i);
-      commitThreadsBuilder.add(commitThread);
-    }
-    commitThreads = commitThreadsBuilder.build();
-
     this.publishCounters = publishCounters;
     this.clientId = clientId;
     this.windmillServer = windmillServer;
@@ -438,6 +412,21 @@ public class StreamingDataflowWorker {
 
     this.sideInputStateFetcher =
         new 
SideInputStateFetcher(metricTrackingWindmillServer::getSideInputData, options);
+    int numCommitThreads = 1;
+    if (windmillServiceEnabled && options.getWindmillServiceCommitThreads() > 
0) {
+      numCommitThreads = options.getWindmillServiceCommitThreads();
+    }
+
+    this.workCommitter =
+        windmillServiceEnabled
+            ? StreamingEngineWorkCommitter.create(
+                WindmillStreamPool.create(
+                        NUM_COMMIT_STREAMS, COMMIT_STREAM_TIMEOUT, 
windmillServer::commitWorkStream)
+                    ::getCloseableStream,
+                numCommitThreads,
+                this::onCompleteCommit)
+            : StreamingApplianceWorkCommitter.create(
+                windmillServer::commitWork, this::onCompleteCommit);
 
     // Register standard file systems.
     FileSystems.setDefaultPipelineOptions(options);
@@ -705,6 +694,11 @@ public class StreamingDataflowWorker {
     return workUnitExecutor.executorQueueIsEmpty();
   }
 
+  @VisibleForTesting
+  int numCommitThreads() {
+    return workCommitter.parallelism();
+  }
+
   @SuppressWarnings("FutureReturnValueIgnored")
   public void start() {
     running.set(true);
@@ -716,7 +710,6 @@ public class StreamingDataflowWorker {
 
     memoryMonitorThread.start();
     dispatchThread.start();
-    commitThreads.forEach(Thread::start);
     sampler.start();
 
     // Periodically report workers counters and other updates.
@@ -778,7 +771,7 @@ public class StreamingDataflowWorker {
           TimeUnit.SECONDS);
       scheduledExecutors.add(statusPageTimer);
     }
-
+    workCommitter.start();
     reportHarnessStartup();
   }
 
@@ -834,12 +827,8 @@ public class StreamingDataflowWorker {
       running.set(false);
       dispatchThread.interrupt();
       dispatchThread.join();
-      // We need to interrupt the commitThreads in case they are blocking on 
pulling
-      // from the commitQueue.
-      commitThreads.forEach(Thread::interrupt);
-      for (Thread commitThread : commitThreads) {
-        commitThread.join();
-      }
+
+      workCommitter.stop();
       memoryMonitor.stop();
       memoryMonitorThread.join();
       workUnitExecutor.shutdown();
@@ -1086,7 +1075,7 @@ public class StreamingDataflowWorker {
     if (workItem.getSourceState().getOnlyFinalize()) {
       
outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true));
       work.setState(State.COMMIT_QUEUED);
-      commitQueue.put(Commit.create(outputBuilder.build(), computationState, 
work));
+      workCommitter.commit(Commit.create(outputBuilder.build(), 
computationState, work));
       return;
     }
 
@@ -1315,7 +1304,7 @@ public class StreamingDataflowWorker {
         commitRequest = buildWorkItemTruncationRequest(key, workItem, 
estimatedCommitSize);
       }
 
-      commitQueue.put(Commit.create(commitRequest, computationState, work));
+      workCommitter.commit(Commit.create(commitRequest, computationState, 
work));
 
       // Compute shuffle and state byte statistics these will be flushed 
asynchronously.
       long stateBytesWritten =
@@ -1444,163 +1433,21 @@ public class StreamingDataflowWorker {
     return outputBuilder.build();
   }
 
-  private void commitLoop() {
-    Map<ComputationState, Windmill.ComputationCommitWorkRequest.Builder> 
computationRequestMap =
-        new HashMap<>();
-    while (running.get()) {
-      computationRequestMap.clear();
-      Windmill.CommitWorkRequest.Builder commitRequestBuilder =
-          Windmill.CommitWorkRequest.newBuilder();
-      long commitBytes = 0;
-      // Block until we have a commit, then batch with additional commits.
-      Commit commit = null;
-      try {
-        commit = commitQueue.take();
-      } catch (InterruptedException e) {
-        Thread.currentThread().interrupt();
-        continue;
-      }
-      while (commit != null) {
-        ComputationState computationState = commit.computationState();
-        commit.work().setState(Work.State.COMMITTING);
-        Windmill.ComputationCommitWorkRequest.Builder 
computationRequestBuilder =
-            computationRequestMap.get(computationState);
-        if (computationRequestBuilder == null) {
-          computationRequestBuilder = 
commitRequestBuilder.addRequestsBuilder();
-          
computationRequestBuilder.setComputationId(computationState.getComputationId());
-          computationRequestMap.put(computationState, 
computationRequestBuilder);
-        }
-        computationRequestBuilder.addRequests(commit.request());
-        // Send the request if we've exceeded the bytes or there is no more
-        // pending work.  commitBytes is a long, so this cannot overflow.
-        commitBytes += commit.getSize();
-        if (commitBytes >= TARGET_COMMIT_BUNDLE_BYTES) {
-          break;
-        }
-        commit = commitQueue.poll();
-      }
-      Windmill.CommitWorkRequest commitRequest = commitRequestBuilder.build();
-      LOG.trace("Commit: {}", commitRequest);
-      activeCommitBytes.addAndGet(commitBytes);
-      windmillServer.commitWork(commitRequest);
-      activeCommitBytes.addAndGet(-commitBytes);
-      for (Map.Entry<ComputationState, 
Windmill.ComputationCommitWorkRequest.Builder> entry :
-          computationRequestMap.entrySet()) {
-        ComputationState computationState = entry.getKey();
-        for (Windmill.WorkItemCommitRequest workRequest : 
entry.getValue().getRequestsList()) {
-          computationState.completeWorkAndScheduleNextWorkForKey(
-              ShardedKey.create(workRequest.getKey(), 
workRequest.getShardingKey()),
-              WorkId.builder()
-                  .setCacheToken(workRequest.getCacheToken())
-                  .setWorkToken(workRequest.getWorkToken())
-                  .build());
-        }
-      }
-    }
-  }
-
-  // Adds the commit to the commitStream if it fits, returning true iff it is 
consumed.
-  private boolean addCommitToStream(Commit commit, CommitWorkStream 
commitStream) {
-    Preconditions.checkNotNull(commit);
-    final ComputationState state = commit.computationState();
-    final Windmill.WorkItemCommitRequest request = commit.request();
-    // Drop commits for failed work. Such commits will be dropped by Windmill 
anyway.
-    if (commit.work().isFailed()) {
+  private void onCompleteCommit(CompleteCommit completeCommit) {
+    if (completeCommit.status() != Windmill.CommitStatus.OK) {
       readerCache.invalidateReader(
           WindmillComputationKey.create(
-              state.getComputationId(), request.getKey(), 
request.getShardingKey()));
+              completeCommit.computationId(), completeCommit.shardedKey()));
       stateCache
-          .forComputation(state.getComputationId())
-          .invalidate(request.getKey(), request.getShardingKey());
-      state.completeWorkAndScheduleNextWorkForKey(
-          ShardedKey.create(request.getKey(), request.getShardingKey()),
-          WorkId.builder()
-              .setWorkToken(request.getWorkToken())
-              .setCacheToken(request.getCacheToken())
-              .build());
-      return true;
-    }
-
-    final int size = commit.getSize();
-    commit.work().setState(Work.State.COMMITTING);
-    activeCommitBytes.addAndGet(size);
-    if (commitStream.commitWorkItem(
-        state.getComputationId(),
-        request,
-        (Windmill.CommitStatus status) -> {
-          if (status != Windmill.CommitStatus.OK) {
-            readerCache.invalidateReader(
-                WindmillComputationKey.create(
-                    state.getComputationId(), request.getKey(), 
request.getShardingKey()));
-            stateCache
-                .forComputation(state.getComputationId())
-                .invalidate(request.getKey(), request.getShardingKey());
-          }
-          activeCommitBytes.addAndGet(-size);
-          state.completeWorkAndScheduleNextWorkForKey(
-              ShardedKey.create(request.getKey(), request.getShardingKey()),
-              WorkId.builder()
-                  .setCacheToken(request.getCacheToken())
-                  .setWorkToken(request.getWorkToken())
-                  .build());
-        })) {
-      return true;
-    } else {
-      // Back out the stats changes since the commit wasn't consumed.
-      commit.work().setState(Work.State.COMMIT_QUEUED);
-      activeCommitBytes.addAndGet(-size);
-      return false;
+          .forComputation(completeCommit.computationId())
+          .invalidate(completeCommit.shardedKey());
     }
-  }
 
-  // Helper to batch additional commits into the commit stream as long as they 
fit.
-  // Returns a commit that was removed from the queue but not consumed or null.
-  private Commit batchCommitsToStream(CommitWorkStream commitStream) {
-    int commits = 1;
-    while (running.get()) {
-      Commit commit;
-      try {
-        if (commits < 5) {
-          commit = commitQueue.poll(10 - 2L * commits, TimeUnit.MILLISECONDS);
-        } else {
-          commit = commitQueue.poll();
-        }
-      } catch (InterruptedException e) {
-        // Continue processing until !running.get()
-        continue;
-      }
-      if (commit == null || !addCommitToStream(commit, commitStream)) {
-        return commit;
-      }
-      commits++;
-    }
-    return null;
-  }
-
-  private void streamingCommitLoop() {
-    WindmillStreamPool<CommitWorkStream> streamPool =
-        WindmillStreamPool.create(
-            NUM_COMMIT_STREAMS, COMMIT_STREAM_TIMEOUT, 
windmillServer::commitWorkStream);
-    Commit initialCommit = null;
-    while (running.get()) {
-      if (initialCommit == null) {
-        try {
-          initialCommit = commitQueue.take();
-        } catch (InterruptedException e) {
-          continue;
-        }
-      }
-      // We initialize the commit stream only after we have a commit to make 
sure it is fresh.
-      CommitWorkStream commitStream = streamPool.getStream();
-      if (!addCommitToStream(initialCommit, commitStream)) {
-        throw new AssertionError("Initial commit on flushed stream should 
always be accepted.");
-      }
-      // Batch additional commits to the stream and possibly make an 
un-batched commit the next
-      // initial commit.
-      initialCommit = batchCommitsToStream(commitStream);
-      commitStream.flush();
-      streamPool.releaseStream(commitStream);
-    }
+    Optional.ofNullable(computationMap.get(completeCommit.computationId()))
+        .ifPresent(
+            state ->
+                state.completeWorkAndScheduleNextWorkForKey(
+                    completeCommit.shardedKey(), completeCommit.workId()));
   }
 
   private Windmill.GetWorkResponse getWork() {
@@ -2094,7 +1941,7 @@ public class StreamingDataflowWorker {
       writer.println(workUnitExecutor.summaryHtml());
 
       writer.print("Active commit: ");
-      appendHumanizedBytes(activeCommitBytes.get(), writer);
+      appendHumanizedBytes(workCommitter.currentActiveCommitBytes(), writer);
       writer.println("<br>");
 
       metricTrackingWindmillServer.printHtml(writer);
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
index a01b1d297c2..274fa3aff02 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
@@ -18,6 +18,7 @@
 package org.apache.beam.runners.dataflow.worker;
 
 import com.google.auto.value.AutoValue;
+import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
 import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.TextFormat;
 
@@ -29,6 +30,10 @@ public abstract class WindmillComputationKey {
     return new AutoValue_WindmillComputationKey(computationId, key, 
shardingKey);
   }
 
+  public static WindmillComputationKey create(String computationId, ShardedKey 
shardedKey) {
+    return create(computationId, shardedKey.key(), shardedKey.shardingKey());
+  }
+
   public abstract String computationId();
 
   public abstract ByteString key();
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/CloseableStream.java
similarity index 55%
copy from 
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
copy to 
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/CloseableStream.java
index a01b1d297c2..e76cc365965 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillComputationKey.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/CloseableStream.java
@@ -15,29 +15,29 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.beam.runners.dataflow.worker;
+package org.apache.beam.runners.dataflow.worker.windmill.client;
 
 import com.google.auto.value.AutoValue;
-import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
-import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.TextFormat;
+import org.apache.beam.sdk.annotations.Internal;
 
+/**
+ * Wrapper for a {@link WindmillStream} that allows callers to tie an action 
after the stream is
+ * finished being used. Has an option for closing code to be a no-op.
+ */
+@Internal
 @AutoValue
-public abstract class WindmillComputationKey {
-
-  public static WindmillComputationKey create(
-      String computationId, ByteString key, long shardingKey) {
-    return new AutoValue_WindmillComputationKey(computationId, key, 
shardingKey);
+public abstract class CloseableStream<StreamT extends WindmillStream> 
implements AutoCloseable {
+  public static <StreamT extends WindmillStream> CloseableStream<StreamT> 
create(
+      StreamT stream, Runnable onClose) {
+    return new AutoValue_CloseableStream<>(stream, onClose);
   }
 
-  public abstract String computationId();
-
-  public abstract ByteString key();
+  public abstract StreamT stream();
 
-  public abstract long shardingKey();
+  abstract Runnable onClose();
 
   @Override
-  public final String toString() {
-    return String.format(
-        "%s: %s-%d", computationId(), TextFormat.escapeBytes(key()), 
shardingKey());
+  public void close() throws Exception {
+    onClose().run();
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java
index 9f1b67edc1e..0e4e085c066 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPool.java
@@ -25,6 +25,7 @@ import java.util.concurrent.ThreadLocalRandom;
 import java.util.function.Supplier;
 import javax.annotation.concurrent.GuardedBy;
 import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.sdk.annotations.Internal;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
@@ -36,6 +37,7 @@ import org.joda.time.Instant;
  * <p>The pool holds a fixed total number of streams, and keeps each stream 
open for a specified
  * time to allow for better load-balancing.
  */
+@Internal
 @ThreadSafe
 public class WindmillStreamPool<StreamT extends WindmillStream> {
 
@@ -131,6 +133,11 @@ public class WindmillStreamPool<StreamT extends 
WindmillStream> {
     }
   }
 
+  public CloseableStream<StreamT> getCloseableStream() {
+    StreamT stream = getStream();
+    return CloseableStream.create(stream, () -> releaseStream(stream));
+  }
+
   private synchronized WindmillStreamPool.StreamData<StreamT> 
createAndCacheStream(int cacheKey) {
     WindmillStreamPool.StreamData<StreamT> newStreamData =
         new WindmillStreamPool.StreamData<>(streamSupplier.get());
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java
similarity index 81%
rename from 
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java
rename to 
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java
index 94689796756..b840d22a343 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Commit.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commit.java
@@ -15,13 +15,17 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.beam.runners.dataflow.worker.streaming;
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
 
 import com.google.auto.value.AutoValue;
+import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
+import org.apache.beam.sdk.annotations.Internal;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
 
 /** Value class for a queued commit. */
+@Internal
 @AutoValue
 public abstract class Commit {
 
@@ -31,6 +35,10 @@ public abstract class Commit {
     return new AutoValue_Commit(request, computationState, work);
   }
 
+  public final String computationId() {
+    return computationState().getComputationId();
+  }
+
   public abstract WorkItemCommitRequest request();
 
   public abstract ComputationState computationState();
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java
new file mode 100644
index 00000000000..64fec71b000
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/CompleteCommit.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import com.google.auto.value.AutoValue;
+import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
+import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+
+/**
+ * A {@link Commit} is marked as complete when it has been attempted to be 
committed back to
+ * Streaming Engine/Appliance via {@link
+ * 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub#commitWorkStream(StreamObserver)}
+ * for Streaming Engine or {@link
+ * 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub#commitWork(Windmill.CommitWorkRequest,
+ * StreamObserver)} for Streaming Appliance.
+ */
+@Internal
+@AutoValue
+public abstract class CompleteCommit {
+
+  public static CompleteCommit create(Commit commit, CommitStatus 
commitStatus) {
+    return new AutoValue_CompleteCommit(
+        commit.computationId(),
+        ShardedKey.create(commit.request().getKey(), 
commit.request().getShardingKey()),
+        WorkId.builder()
+            .setWorkToken(commit.request().getWorkToken())
+            .setCacheToken(commit.request().getCacheToken())
+            .build(),
+        commitStatus);
+  }
+
+  public static CompleteCommit create(
+      String computationId, ShardedKey shardedKey, WorkId workId, CommitStatus 
status) {
+    return new AutoValue_CompleteCommit(computationId, shardedKey, workId, 
status);
+  }
+
+  public static CompleteCommit forFailedWork(Commit commit) {
+    return create(commit, CommitStatus.ABORTED);
+  }
+
+  public abstract String computationId();
+
+  public abstract ShardedKey shardedKey();
+
+  public abstract WorkId workId();
+
+  public abstract CommitStatus status();
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
new file mode 100644
index 00000000000..344f04cfd00
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
+import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
+import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkRequest;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Streaming appliance implementation of {@link WorkCommitter}. */
+@Internal
+@ThreadSafe
+public final class StreamingApplianceWorkCommitter implements WorkCommitter {
+  private static final Logger LOG = 
LoggerFactory.getLogger(StreamingApplianceWorkCommitter.class);
+  private static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20;
+  private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB
+
+  private final Consumer<CommitWorkRequest> commitWorkFn;
+  private final WeightedBoundedQueue<Commit> commitQueue;
+  private final ExecutorService commitWorkers;
+  private final AtomicLong activeCommitBytes;
+  private final Consumer<CompleteCommit> onCommitComplete;
+
+  private StreamingApplianceWorkCommitter(
+      Consumer<CommitWorkRequest> commitWorkFn, Consumer<CompleteCommit> 
onCommitComplete) {
+    this.commitWorkFn = commitWorkFn;
+    this.commitQueue =
+        WeightedBoundedQueue.create(
+            MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, 
commit.getSize()));
+    this.commitWorkers =
+        Executors.newSingleThreadScheduledExecutor(
+            new ThreadFactoryBuilder()
+                .setDaemon(true)
+                .setPriority(Thread.MAX_PRIORITY)
+                .setNameFormat("CommitThread-%d")
+                .build());
+    this.activeCommitBytes = new AtomicLong();
+    this.onCommitComplete = onCommitComplete;
+  }
+
+  public static StreamingApplianceWorkCommitter create(
+      Consumer<CommitWorkRequest> commitWork, Consumer<CompleteCommit> 
onCommitComplete) {
+    return new StreamingApplianceWorkCommitter(commitWork, onCommitComplete);
+  }
+
+  @Override
+  @SuppressWarnings("FutureReturnValueIgnored")
+  public void start() {
+    if (!commitWorkers.isShutdown()) {
+      commitWorkers.submit(this::commitLoop);
+    }
+  }
+
+  @Override
+  public void commit(Commit commit) {
+    commitQueue.put(commit);
+  }
+
+  @Override
+  public long currentActiveCommitBytes() {
+    return activeCommitBytes.get();
+  }
+
+  @Override
+  public void stop() {
+    commitWorkers.shutdownNow();
+  }
+
+  @Override
+  public int parallelism() {
+    return 1;
+  }
+
+  private void commitLoop() {
+    Map<ComputationState, Windmill.ComputationCommitWorkRequest.Builder> 
computationRequestMap =
+        new HashMap<>();
+    while (true) {
+      computationRequestMap.clear();
+      CommitWorkRequest.Builder commitRequestBuilder = 
CommitWorkRequest.newBuilder();
+      long commitBytes = 0;
+      // Block until we have a commit, then batch with additional commits.
+      Commit commit;
+      try {
+        commit = commitQueue.take();
+      } catch (InterruptedException e) {
+        Thread.currentThread().interrupt();
+        continue;
+      }
+      while (commit != null) {
+        ComputationState computationState = commit.computationState();
+        commit.work().setState(Work.State.COMMITTING);
+        Windmill.ComputationCommitWorkRequest.Builder 
computationRequestBuilder =
+            computationRequestMap.get(computationState);
+        if (computationRequestBuilder == null) {
+          computationRequestBuilder = 
commitRequestBuilder.addRequestsBuilder();
+          
computationRequestBuilder.setComputationId(computationState.getComputationId());
+          computationRequestMap.put(computationState, 
computationRequestBuilder);
+        }
+        computationRequestBuilder.addRequests(commit.request());
+        // Send the request if we've exceeded the bytes or there is no more
+        // pending work.  commitBytes is a long, so this cannot overflow.
+        commitBytes += commit.getSize();
+        if (commitBytes >= TARGET_COMMIT_BUNDLE_BYTES) {
+          break;
+        }
+        commit = commitQueue.poll();
+      }
+      commitWork(commitRequestBuilder.build(), commitBytes);
+      completeWork(computationRequestMap);
+    }
+  }
+
+  private void commitWork(CommitWorkRequest commitRequest, long commitBytes) {
+    LOG.trace("Commit: {}", commitRequest);
+    activeCommitBytes.addAndGet(commitBytes);
+    commitWorkFn.accept(commitRequest);
+    activeCommitBytes.addAndGet(-commitBytes);
+  }
+
+  private void completeWork(
+      Map<ComputationState, Windmill.ComputationCommitWorkRequest.Builder> 
committedWork) {
+    for (Map.Entry<ComputationState, 
Windmill.ComputationCommitWorkRequest.Builder> entry :
+        committedWork.entrySet()) {
+      for (Windmill.WorkItemCommitRequest workRequest : 
entry.getValue().getRequestsList()) {
+        // Appliance errors are propagated by exception on entire batch.
+        onCommitComplete.accept(
+            CompleteCommit.create(
+                entry.getKey().getComputationId(),
+                ShardedKey.create(workRequest.getKey(), 
workRequest.getShardingKey()),
+                WorkId.builder()
+                    .setCacheToken(workRequest.getCacheToken())
+                    .setWorkToken(workRequest.getWorkToken())
+                    .build(),
+                Windmill.CommitStatus.OK));
+      }
+    }
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
new file mode 100644
index 00000000000..f6088acf011
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Streaming engine implementation of {@link WorkCommitter}. Commits work back 
to Streaming Engine
+ * backend.
+ */
+@Internal
+@ThreadSafe
+public final class StreamingEngineWorkCommitter implements WorkCommitter {
+  private static final Logger LOG = 
LoggerFactory.getLogger(StreamingEngineWorkCommitter.class);
+  private static final int TARGET_COMMIT_BATCH_KEYS = 5;
+  private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB
+
+  private final Supplier<CloseableStream<CommitWorkStream>> 
commitWorkStreamFactory;
+  private final WeightedBoundedQueue<Commit> commitQueue;
+  private final ExecutorService commitSenders;
+  private final AtomicLong activeCommitBytes;
+  private final Consumer<CompleteCommit> onCommitComplete;
+  private final int numCommitSenders;
+
+  private StreamingEngineWorkCommitter(
+      Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory,
+      int numCommitSenders,
+      Consumer<CompleteCommit> onCommitComplete) {
+    this.commitWorkStreamFactory = commitWorkStreamFactory;
+    this.commitQueue =
+        WeightedBoundedQueue.create(
+            MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, 
commit.getSize()));
+    this.commitSenders =
+        Executors.newFixedThreadPool(
+            numCommitSenders,
+            new ThreadFactoryBuilder()
+                .setDaemon(true)
+                .setPriority(Thread.MAX_PRIORITY)
+                .setNameFormat("CommitThread-%d")
+                .build());
+    this.activeCommitBytes = new AtomicLong();
+    this.onCommitComplete = onCommitComplete;
+    this.numCommitSenders = numCommitSenders;
+  }
+
+  public static StreamingEngineWorkCommitter create(
+      Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory,
+      int numCommitSenders,
+      Consumer<CompleteCommit> onCommitComplete) {
+    return new StreamingEngineWorkCommitter(
+        commitWorkStreamFactory, numCommitSenders, onCommitComplete);
+  }
+
+  @Override
+  @SuppressWarnings("FutureReturnValueIgnored")
+  public void start() {
+    if (!commitSenders.isShutdown()) {
+      for (int i = 0; i < numCommitSenders; i++) {
+        commitSenders.submit(this::streamingCommitLoop);
+      }
+    }
+  }
+
+  @Override
+  public void commit(Commit commit) {
+    commitQueue.put(commit);
+  }
+
+  @Override
+  public long currentActiveCommitBytes() {
+    return activeCommitBytes.get();
+  }
+
+  @Override
+  public void stop() {
+    if (!commitSenders.isTerminated() || !commitSenders.isShutdown()) {
+      commitSenders.shutdown();
+      try {
+        commitSenders.awaitTermination(10, TimeUnit.SECONDS);
+      } catch (InterruptedException e) {
+        LOG.warn("Could not shut down commitSenders gracefully, forcing 
shutdown.", e);
+      }
+      commitSenders.shutdownNow();
+    }
+    drainCommitQueue();
+  }
+
+  private void drainCommitQueue() {
+    Commit queuedCommit = commitQueue.poll();
+    while (queuedCommit != null) {
+      failCommit(queuedCommit);
+      queuedCommit = commitQueue.poll();
+    }
+  }
+
+  private void failCommit(Commit commit) {
+    commit.work().setFailed();
+    onCommitComplete.accept(CompleteCommit.forFailedWork(commit));
+  }
+
+  @Override
+  public int parallelism() {
+    return numCommitSenders;
+  }
+
+  private void streamingCommitLoop() {
+    @Nullable Commit initialCommit = null;
+    try {
+      while (true) {
+        if (initialCommit == null) {
+          try {
+            // Block until we have a commit or are shutting down.
+            initialCommit = commitQueue.take();
+          } catch (InterruptedException e) {
+            continue;
+          }
+        }
+
+        if (initialCommit.work().isFailed()) {
+          onCommitComplete.accept(CompleteCommit.forFailedWork(initialCommit));
+          initialCommit = null;
+          continue;
+        }
+
+        try (CloseableStream<CommitWorkStream> closeableCommitStream =
+            commitWorkStreamFactory.get()) {
+          CommitWorkStream commitStream = closeableCommitStream.stream();
+          if (!tryAddToCommitStream(initialCommit, commitStream)) {
+            throw new AssertionError("Initial commit on flushed stream should 
always be accepted.");
+          }
+          // Batch additional commits to the stream and possibly make an 
un-batched commit the next
+          // initial commit.
+          initialCommit = batchCommitsToStream(commitStream);
+          commitStream.flush();
+        } catch (Exception e) {
+          LOG.error("Error occurred fetching a CommitWorkStream.", e);
+        }
+      }
+    } finally {
+      if (initialCommit != null) {
+        failCommit(initialCommit);
+      }
+    }
+  }
+
+  /** Adds the commit to the commitStream if it fits, returning true if it is 
consumed. */
+  private boolean tryAddToCommitStream(Commit commit, CommitWorkStream 
commitStream) {
+    Preconditions.checkNotNull(commit);
+    commit.work().setState(Work.State.COMMITTING);
+    activeCommitBytes.addAndGet(commit.getSize());
+    boolean isCommitAccepted =
+        commitStream.commitWorkItem(
+            commit.computationId(),
+            commit.request(),
+            (commitStatus) -> {
+              onCommitComplete.accept(CompleteCommit.create(commit, 
commitStatus));
+              activeCommitBytes.addAndGet(-commit.getSize());
+            });
+
+    // Since the commit was not accepted, revert the changes made above.
+    if (!isCommitAccepted) {
+      commit.work().setState(Work.State.COMMIT_QUEUED);
+      activeCommitBytes.addAndGet(-commit.getSize());
+    }
+
+    return isCommitAccepted;
+  }
+
+  // Helper to batch additional commits into the commit stream as long as they 
fit.
+  // Returns a commit that was removed from the queue but not consumed or null.
+  private Commit batchCommitsToStream(CommitWorkStream commitStream) {
+    int commits = 1;
+    while (true) {
+      Commit commit;
+      try {
+        if (commits < TARGET_COMMIT_BATCH_KEYS) {
+          commit = commitQueue.poll(10 - 2L * commits, TimeUnit.MILLISECONDS);
+        } else {
+          commit = commitQueue.poll();
+        }
+      } catch (InterruptedException e) {
+        // Continue processing until !running.get()
+        continue;
+      }
+
+      if (commit == null) {
+        return null;
+      }
+
+      // Drop commits for failed work. Such commits will be dropped by 
Windmill anyway.
+      if (commit.work().isFailed()) {
+        onCommitComplete.accept(CompleteCommit.forFailedWork(commit));
+        continue;
+      }
+
+      if (!tryAddToCommitStream(commit, commitStream)) {
+        return commit;
+      }
+      commits++;
+    }
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/WorkCommitter.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/WorkCommitter.java
new file mode 100644
index 00000000000..11a4c00db9d
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/WorkCommitter.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.sdk.annotations.Internal;
+
+/**
+ * Commits {@link org.apache.beam.runners.dataflow.worker.streaming.Work} that 
has completed user
+ * processing back to persistence layer.
+ */
+@Internal
+@ThreadSafe
+public interface WorkCommitter {
+
+  /** Starts internal processing of commits. */
+  void start();
+
+  /**
+   * Add a commit to {@link WorkCommitter}. This may be block the calling 
thread depending on
+   * underlying implementations, and persisting to the persistence layer may 
be done asynchronously.
+   */
+  void commit(Commit commit);
+
+  /** Number of bytes currently trying to be committed to the backing 
persistence layer. */
+  long currentActiveCommitBytes();
+
+  /**
+   * Stops internal processing of commits. In progress and subsequent commits 
may be canceled or
+   * dropped.
+   */
+  void stop();
+
+  /**
+   * Number of internal workers {@link WorkCommitter} uses to commit work to 
the backing persistence
+   * layer.
+   */
+  int parallelism();
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
index 0d4e7c6b645..85c74fe8591 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
@@ -34,6 +34,7 @@ import org.apache.beam.runners.dataflow.worker.Weighers;
 import org.apache.beam.runners.dataflow.worker.WindmillComputationKey;
 import org.apache.beam.runners.dataflow.worker.status.BaseStatusServlet;
 import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
+import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
 import org.apache.beam.sdk.state.State;
 import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
@@ -318,6 +319,10 @@ public class WindmillStateCache implements 
StatusDataProvider {
       keyIndex.remove(key);
     }
 
+    public final void invalidate(ShardedKey shardedKey) {
+      invalidate(shardedKey.key(), shardedKey.shardingKey());
+    }
+
     /**
      * Returns a per-computation, per-key view of the state cache. Access to 
the cached data for
      * this key is not thread-safe. Callers should ensure that there is only a 
single ForKey object
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
index e4985193d1c..89939d5d341 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
@@ -29,6 +29,7 @@ import static org.junit.Assert.assertFalse;
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -45,6 +46,7 @@ import java.util.function.Function;
 import javax.annotation.concurrent.GuardedBy;
 import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
 import 
org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor;
+import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationCommitWorkRequest;
@@ -74,11 +76,12 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /** An in-memory Windmill server that offers provided work and data. */
-public class FakeWindmillServer extends WindmillServerStub {
+public final class FakeWindmillServer extends WindmillServerStub {
   private static final Logger LOG = 
LoggerFactory.getLogger(FakeWindmillServer.class);
   private final ResponseQueue<Windmill.GetWorkRequest, 
Windmill.GetWorkResponse> workToOffer;
   private final ResponseQueue<GetDataRequest, GetDataResponse> dataToOffer;
   private final ResponseQueue<Windmill.CommitWorkRequest, CommitWorkResponse> 
commitsToOffer;
+  private final Map<WorkId, Windmill.CommitStatus> streamingCommitsToOffer;
   // Keys are work tokens.
   private final Map<Long, WorkItemCommitRequest> commitsReceived;
   private final ArrayList<Windmill.ReportStatsRequest> statsReceived;
@@ -109,6 +112,7 @@ public class FakeWindmillServer extends WindmillServerStub {
     commitsToOffer =
         new ResponseQueue<Windmill.CommitWorkRequest, CommitWorkResponse>()
             .returnByDefault(CommitWorkResponse.getDefaultInstance());
+    streamingCommitsToOffer = new HashMap<>();
     commitsReceived = new ConcurrentHashMap<>();
     exceptions = new LinkedBlockingQueue<>();
     expectedExceptionCount = new AtomicInteger();
@@ -139,6 +143,10 @@ public class FakeWindmillServer extends WindmillServerStub 
{
     return commitsToOffer;
   }
 
+  public Map<WorkId, Windmill.CommitStatus> whenCommitWorkStreamCalled() {
+    return streamingCommitsToOffer;
+  }
+
   @Override
   public Windmill.GetWorkResponse getWork(Windmill.GetWorkRequest request) {
     LOG.debug("getWorkRequest: {}", request.toString());
@@ -376,7 +384,15 @@ public class FakeWindmillServer extends WindmillServerStub 
{
           droppedStreamingCommits.put(request.getWorkToken(), onDone);
         } else {
           commitsReceived.put(request.getWorkToken(), request);
-          onDone.accept(Windmill.CommitStatus.OK);
+          onDone.accept(
+              Optional.ofNullable(
+                      streamingCommitsToOffer.remove(
+                          WorkId.builder()
+                              .setWorkToken(request.getWorkToken())
+                              .setCacheToken(request.getCacheToken())
+                              .build()))
+                  // Default to CommitStatus.OK
+                  .orElse(Windmill.CommitStatus.OK));
         }
         // Return true to indicate the request was accepted even if we are 
dropping the commit
         // to simulate a dropped commit.
@@ -502,32 +518,32 @@ public class FakeWindmillServer extends 
WindmillServerStub {
     this.isReady = ready;
   }
 
-  static class ResponseQueue<T, U> {
+  public static class ResponseQueue<T, U> {
     private final Queue<Function<T, U>> responses = new 
ConcurrentLinkedQueue<>();
     Duration sleep = Duration.ZERO;
     private Function<T, U> defaultResponse;
 
     // (Fluent) interface for response producers, accessible from tests.
 
-    ResponseQueue<T, U> thenAnswer(Function<T, U> mapFun) {
+    public ResponseQueue<T, U> thenAnswer(Function<T, U> mapFun) {
       responses.add(mapFun);
       return this;
     }
 
-    ResponseQueue<T, U> thenReturn(U response) {
+    public ResponseQueue<T, U> thenReturn(U response) {
       return thenAnswer((request) -> response);
     }
 
-    ResponseQueue<T, U> answerByDefault(Function<T, U> mapFun) {
+    public ResponseQueue<T, U> answerByDefault(Function<T, U> mapFun) {
       defaultResponse = mapFun;
       return this;
     }
 
-    ResponseQueue<T, U> returnByDefault(U response) {
+    public ResponseQueue<T, U> returnByDefault(U response) {
       return answerByDefault((request) -> response);
     }
 
-    ResponseQueue<T, U> delayEachResponseBy(Duration sleep) {
+    public ResponseQueue<T, U> delayEachResponseBy(Duration sleep) {
       this.sleep = sleep;
       return this;
     }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index d00ea64d7d4..d8ead447e8e 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -3894,7 +3894,7 @@ public class StreamingDataflowWorkerTest {
     options.setWindmillServiceCommitThreads(configNumCommitThreads);
     StreamingDataflowWorker worker = makeWorker(instructions, options, true /* 
publishCounters */);
     worker.start();
-    assertEquals(expectedNumCommitThreads, worker.commitThreads.size());
+    assertEquals(expectedNumCommitThreads, worker.numCommitThreads());
     worker.stop();
   }
 
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java
new file mode 100644
index 00000000000..cfad6138547
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertNotNull;
+
+import com.google.api.services.dataflow.model.MapTask;
+import com.google.common.truth.Correspondence;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Consumer;
+import org.apache.beam.runners.dataflow.worker.FakeWindmillServer;
+import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
+import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.joda.time.Instant;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ErrorCollector;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+@RunWith(JUnit4.class)
+public class StreamingApplianceWorkCommitterTest {
+  @Rule public ErrorCollector errorCollector = new ErrorCollector();
+  private FakeWindmillServer fakeWindmillServer;
+  private StreamingApplianceWorkCommitter workCommitter;
+
+  private static Work createMockWork(long workToken, Consumer<Work> 
processWorkFn) {
+    return Work.create(
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.EMPTY)
+            .setWorkToken(workToken)
+            .setShardingKey(workToken)
+            .setCacheToken(workToken)
+            .build(),
+        Instant::now,
+        Collections.emptyList(),
+        processWorkFn);
+  }
+
+  private static ComputationState createComputationState(String computationId) 
{
+    return new ComputationState(
+        computationId,
+        new MapTask().setSystemName("system").setStageName("stage"),
+        Mockito.mock(BoundedQueueExecutor.class),
+        ImmutableMap.of(),
+        null);
+  }
+
+  private StreamingApplianceWorkCommitter createWorkCommitter(
+      Consumer<CompleteCommit> onCommitComplete) {
+    return 
StreamingApplianceWorkCommitter.create(fakeWindmillServer::commitWork, 
onCommitComplete);
+  }
+
+  @Before
+  public void setUp() {
+    fakeWindmillServer =
+        new FakeWindmillServer(
+            errorCollector, ignored -> 
Optional.of(Mockito.mock(ComputationState.class)));
+  }
+
+  @After
+  public void cleanUp() {
+    workCommitter.stop();
+  }
+
+  @Test
+  public void testCommit() {
+    List<CompleteCommit> completeCommits = new ArrayList<>();
+    workCommitter = createWorkCommitter(completeCommits::add);
+    List<Commit> commits = new ArrayList<>();
+    for (int i = 1; i <= 5; i++) {
+      Work work = createMockWork(i, ignored -> {});
+      Windmill.WorkItemCommitRequest commitRequest =
+          Windmill.WorkItemCommitRequest.newBuilder()
+              .setKey(work.getWorkItem().getKey())
+              .setShardingKey(work.getWorkItem().getShardingKey())
+              .setWorkToken(work.getWorkItem().getWorkToken())
+              .setCacheToken(work.getWorkItem().getCacheToken())
+              .build();
+      commits.add(Commit.create(commitRequest, 
createComputationState("computationId-" + i), work));
+    }
+
+    workCommitter.start();
+    commits.forEach(workCommitter::commit);
+
+    Map<Long, Windmill.WorkItemCommitRequest> committed =
+        fakeWindmillServer.waitForAndGetCommits(commits.size());
+
+    for (Commit commit : commits) {
+      Windmill.WorkItemCommitRequest request =
+          committed.get(commit.work().getWorkItem().getWorkToken());
+      assertNotNull(request);
+      assertThat(request).isEqualTo(commit.request());
+    }
+
+    assertThat(completeCommits).hasSize(commits.size());
+    assertThat(completeCommits)
+        .comparingElementsUsing(
+            Correspondence.from(
+                (CompleteCommit completeCommit, Commit commit) ->
+                    
completeCommit.computationId().equals(commit.computationId())
+                        && completeCommit.status() == Windmill.CommitStatus.OK
+                        && completeCommit.workId().equals(commit.work().id())
+                        && completeCommit
+                            .shardedKey()
+                            .equals(
+                                ShardedKey.create(
+                                    commit.request().getKey(), 
commit.request().getShardingKey())),
+                "expected to equal"))
+        .containsExactlyElementsIn(commits);
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
new file mode 100644
index 00000000000..1bf2e44f9f0
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.commits;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus.OK;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import com.google.api.services.dataflow.model.MapTask;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import org.apache.beam.runners.dataflow.worker.FakeWindmillServer;
+import org.apache.beam.runners.dataflow.worker.streaming.ComputationState;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.streaming.WorkId;
+import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ErrorCollector;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+@RunWith(JUnit4.class)
+public class StreamingEngineWorkCommitterTest {
+
+  @Rule public ErrorCollector errorCollector = new ErrorCollector();
+  private StreamingEngineWorkCommitter workCommitter;
+  private FakeWindmillServer fakeWindmillServer;
+  private Supplier<CloseableStream<CommitWorkStream>> commitWorkStreamFactory;
+
+  private static Work createMockWork(long workToken, Consumer<Work> 
processWorkFn) {
+    return Work.create(
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.EMPTY)
+            .setWorkToken(workToken)
+            .setShardingKey(workToken)
+            .setCacheToken(workToken)
+            .build(),
+        Instant::now,
+        Collections.emptyList(),
+        processWorkFn);
+  }
+
+  private static ComputationState createComputationState(String computationId) 
{
+    return new ComputationState(
+        computationId,
+        new MapTask().setSystemName("system").setStageName("stage"),
+        Mockito.mock(BoundedQueueExecutor.class),
+        ImmutableMap.of(),
+        null);
+  }
+
+  private static CompleteCommit asCompleteCommit(Commit commit, 
Windmill.CommitStatus status) {
+    if (commit.work().isFailed()) {
+      return CompleteCommit.forFailedWork(commit);
+    }
+
+    return CompleteCommit.create(commit, status);
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    fakeWindmillServer =
+        new FakeWindmillServer(
+            errorCollector, ignored -> 
Optional.of(Mockito.mock(ComputationState.class)));
+    commitWorkStreamFactory =
+        WindmillStreamPool.create(
+                1, Duration.standardMinutes(1), 
fakeWindmillServer::commitWorkStream)
+            ::getCloseableStream;
+  }
+
+  @After
+  public void cleanUp() {
+    workCommitter.stop();
+  }
+
+  private StreamingEngineWorkCommitter createWorkCommitter(
+      Consumer<CompleteCommit> onCommitComplete) {
+    return StreamingEngineWorkCommitter.create(commitWorkStreamFactory, 1, 
onCommitComplete);
+  }
+
+  @Test
+  public void testCommit_sendsCommitsToStreamingEngine() {
+    Set<CompleteCommit> completeCommits = new HashSet<>();
+    workCommitter = createWorkCommitter(completeCommits::add);
+    List<Commit> commits = new ArrayList<>();
+    for (int i = 1; i <= 5; i++) {
+      Work work = createMockWork(i, ignored -> {});
+      WorkItemCommitRequest commitRequest =
+          WorkItemCommitRequest.newBuilder()
+              .setKey(work.getWorkItem().getKey())
+              .setShardingKey(work.getWorkItem().getShardingKey())
+              .setWorkToken(work.getWorkItem().getWorkToken())
+              .setCacheToken(work.getWorkItem().getCacheToken())
+              .build();
+      commits.add(Commit.create(commitRequest, 
createComputationState("computationId-" + i), work));
+    }
+
+    workCommitter.start();
+    commits.parallelStream().forEach(workCommitter::commit);
+
+    Map<Long, WorkItemCommitRequest> committed =
+        fakeWindmillServer.waitForAndGetCommits(commits.size());
+
+    for (Commit commit : commits) {
+      WorkItemCommitRequest request = 
committed.get(commit.work().getWorkItem().getWorkToken());
+      assertNotNull(request);
+      assertThat(request).isEqualTo(commit.request());
+      assertThat(completeCommits).contains(asCompleteCommit(commit, 
Windmill.CommitStatus.OK));
+    }
+  }
+
+  @Test
+  public void testCommit_handlesFailedCommits() {
+    Set<CompleteCommit> completeCommits = new HashSet<>();
+    workCommitter = createWorkCommitter(completeCommits::add);
+    List<Commit> commits = new ArrayList<>();
+    for (int i = 1; i <= 10; i++) {
+      Work work = createMockWork(i, ignored -> {});
+      // Fail half of the work.
+      if (i % 2 == 0) {
+        work.setFailed();
+      }
+      WorkItemCommitRequest commitRequest =
+          WorkItemCommitRequest.newBuilder()
+              .setKey(work.getWorkItem().getKey())
+              .setShardingKey(work.getWorkItem().getShardingKey())
+              .setWorkToken(work.getWorkItem().getWorkToken())
+              .setCacheToken(work.getWorkItem().getCacheToken())
+              .build();
+      commits.add(Commit.create(commitRequest, 
createComputationState("computationId-" + i), work));
+    }
+
+    workCommitter.start();
+    commits.parallelStream().forEach(workCommitter::commit);
+
+    Map<Long, WorkItemCommitRequest> committed =
+        fakeWindmillServer.waitForAndGetCommits(commits.size() / 2);
+
+    for (Commit commit : commits) {
+      if (commit.work().isFailed()) {
+        assertThat(completeCommits)
+            .contains(asCompleteCommit(commit, Windmill.CommitStatus.ABORTED));
+        
assertThat(committed).doesNotContainKey(commit.work().getWorkItem().getWorkToken());
+      } else {
+        assertThat(completeCommits).contains(asCompleteCommit(commit, 
Windmill.CommitStatus.OK));
+        assertThat(committed)
+            .containsEntry(commit.work().getWorkItem().getWorkToken(), 
commit.request());
+      }
+    }
+  }
+
+  @Test
+  public void testCommit_handlesCompleteCommits_commitStatusNotOK() {
+    Set<CompleteCommit> completeCommits = new HashSet<>();
+    workCommitter = createWorkCommitter(completeCommits::add);
+    Map<WorkId, Windmill.CommitStatus> expectedCommitStatus = new HashMap<>();
+    Random commitStatusSelector = new Random();
+    int commitStatusSelectorBound = Windmill.CommitStatus.values().length - 1;
+    // Compute the CommitStatus randomly, to test plumbing of different 
commitStatuses to
+    // StreamingEngine.
+    Function<Work, Windmill.CommitStatus> computeCommitStatusForTest =
+        work -> {
+          Windmill.CommitStatus commitStatus =
+              work.getWorkItem().getWorkToken() % 2 == 0
+                  ? Windmill.CommitStatus.values()[
+                      commitStatusSelector.nextInt(commitStatusSelectorBound)]
+                  : OK;
+          expectedCommitStatus.put(work.id(), commitStatus);
+          return commitStatus;
+        };
+
+    List<Commit> commits = new ArrayList<>();
+    for (int i = 1; i <= 10; i++) {
+      Work work = createMockWork(i, ignored -> {});
+      WorkItemCommitRequest commitRequest =
+          WorkItemCommitRequest.newBuilder()
+              .setKey(work.getWorkItem().getKey())
+              .setShardingKey(work.getWorkItem().getShardingKey())
+              .setWorkToken(work.getWorkItem().getWorkToken())
+              .setCacheToken(work.getWorkItem().getCacheToken())
+              .build();
+      commits.add(Commit.create(commitRequest, 
createComputationState("computationId-" + i), work));
+      fakeWindmillServer
+          .whenCommitWorkStreamCalled()
+          .put(work.id(), computeCommitStatusForTest.apply(work));
+    }
+
+    workCommitter.start();
+    commits.parallelStream().forEach(workCommitter::commit);
+
+    Map<Long, WorkItemCommitRequest> committed =
+        fakeWindmillServer.waitForAndGetCommits(commits.size());
+
+    for (Commit commit : commits) {
+      WorkItemCommitRequest request = 
committed.get(commit.work().getWorkItem().getWorkToken());
+      assertNotNull(request);
+      assertThat(request).isEqualTo(commit.request());
+      assertThat(completeCommits)
+          .contains(asCompleteCommit(commit, 
expectedCommitStatus.get(commit.work().id())));
+    }
+    assertThat(completeCommits.size()).isEqualTo(commits.size());
+  }
+
+  @Test
+  public void testStop_drainsCommitQueue() {
+    // Use this fake to queue up commits on the committer.
+    Supplier<CommitWorkStream> fakeCommitWorkStream =
+        () ->
+            new CommitWorkStream() {
+              @Override
+              public boolean commitWorkItem(
+                  String computation,
+                  WorkItemCommitRequest request,
+                  Consumer<Windmill.CommitStatus> onDone) {
+                return false;
+              }
+
+              @Override
+              public void flush() {}
+
+              @Override
+              public void close() {}
+
+              @Override
+              public boolean awaitTermination(int time, TimeUnit unit) {
+                return false;
+              }
+
+              @Override
+              public Instant startTime() {
+                return Instant.now();
+              }
+            };
+    commitWorkStreamFactory =
+        WindmillStreamPool.create(1, Duration.standardMinutes(1), 
fakeCommitWorkStream)
+            ::getCloseableStream;
+
+    Set<CompleteCommit> completeCommits = new HashSet<>();
+    workCommitter = createWorkCommitter(completeCommits::add);
+
+    List<Commit> commits = new ArrayList<>();
+    for (int i = 1; i <= 10; i++) {
+      Work work = createMockWork(i, ignored -> {});
+      WorkItemCommitRequest commitRequest =
+          WorkItemCommitRequest.newBuilder()
+              .setKey(work.getWorkItem().getKey())
+              .setShardingKey(work.getWorkItem().getShardingKey())
+              .setWorkToken(work.getWorkItem().getWorkToken())
+              .setCacheToken(work.getWorkItem().getCacheToken())
+              .build();
+      commits.add(Commit.create(commitRequest, 
createComputationState("computationId-" + i), work));
+    }
+
+    workCommitter.start();
+    commits.parallelStream().forEach(workCommitter::commit);
+    workCommitter.stop();
+
+    assertThat(commits.size()).isEqualTo(completeCommits.size());
+    for (CompleteCommit completeCommit : completeCommits) {
+      
assertThat(completeCommit.status()).isEqualTo(Windmill.CommitStatus.ABORTED);
+    }
+
+    for (Commit commit : commits) {
+      assertTrue(commit.work().isFailed());
+    }
+  }
+}

Reply via email to