This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new d9117f423d9 [Dataflow Streaming] [Multi Key] Introduce 
KeyGroupWorkQueue and integrate with BoundedQueueExecutor (#38767)
d9117f423d9 is described below

commit d9117f423d903b00ff5b10c3e8c915cf7caba279
Author: Arun Pandian <[email protected]>
AuthorDate: Mon Jun 8 03:46:18 2026 -0700

    [Dataflow Streaming] [Multi Key] Introduce KeyGroupWorkQueue and integrate 
with BoundedQueueExecutor (#38767)
    
    * [Dataflow Streaming] Introduce KeyGroupWorkQueue and integrate with 
BoundedQueueExecutor
        - Add Uint128Proto and key_group to WorkItem in windmill.proto.
        - Update Work model to support key groups.
        - Implement KeyGroupWorkQueue for grouping tasks by key group. 
KepGroupWorkQueue
          is a FIFO queue that allows polling elements based on
          global order and also by order within a key group.
        - Integrate BoundedQueueExecutor with KepGroupWorkQueue and support 
targeted polling.
---
 .../dataflow/worker/StreamingDataflowWorker.java   |   8 +-
 .../dataflow/worker/streaming/ExecutableWork.java  |  12 +-
 .../runners/dataflow/worker/streaming/Work.java    |  71 +++
 .../dataflow/worker/util/BoundedQueueExecutor.java |  33 +-
 .../dataflow/worker/util/KeyGroupWorkQueue.java    | 462 +++++++++++++++++++
 .../worker/StreamingDataflowWorkerTest.java        |  12 +-
 .../worker/util/BoundedQueueExecutorTest.java      | 150 +++++-
 .../worker/util/KeyGroupWorkQueueTest.java         | 504 +++++++++++++++++++++
 .../processing/StreamingCommitFinalizerTest.java   |   3 +-
 .../failures/WorkFailureProcessorTest.java         |   3 +-
 .../work/refresh/ActiveWorkRefresherTest.java      |   3 +-
 .../worker/windmill/src/main/proto/windmill.proto  |   7 +
 12 files changed, 1247 insertions(+), 21 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 4c3e58978ec..9e82343474c 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -179,6 +179,9 @@ public final class StreamingDataflowWorker {
   // Experiment make the monitor within BoundedQueueExecutor fair
   public static final String 
BOUNDED_QUEUE_EXECUTOR_USE_FAIR_MONITOR_EXPERIMENT =
       "windmill_bounded_queue_executor_use_fair_monitor";
+  // Don't use. Experiment guarding multi key bundles. The feature is work in 
progress and
+  // incomplete.
+  private static final String UNSTABLE_ENABLE_MULTI_KEY_BUNDLE = 
"unstable_enable_multi_key_bundle";
 
   private final WindmillStateCache stateCache;
   private AtomicReference<StreamingWorkerStatusPages> statusPages = new 
AtomicReference<>();
@@ -1018,6 +1021,8 @@ public final class StreamingDataflowWorker {
   private static BoundedQueueExecutor 
createWorkUnitExecutor(DataflowWorkerHarnessOptions options) {
     boolean useFairMonitor =
         DataflowRunner.hasExperiment(options, 
BOUNDED_QUEUE_EXECUTOR_USE_FAIR_MONITOR_EXPERIMENT);
+    boolean useKeyGroupWorkQueue =
+        DataflowRunner.hasExperiment(options, 
UNSTABLE_ENABLE_MULTI_KEY_BUNDLE);
     return new BoundedQueueExecutor(
         chooseMaxThreads(options),
         THREAD_EXPIRATION_TIME_SEC,
@@ -1025,7 +1030,8 @@ public final class StreamingDataflowWorker {
         chooseMaxBundlesOutstanding(options),
         chooseMaxBytesOutstanding(options),
         new 
ThreadFactoryBuilder().setNameFormat("DataflowWorkUnits-%d").setDaemon(true).build(),
-        useFairMonitor);
+        useFairMonitor,
+        useKeyGroupWorkQueue);
   }
 
   public static void main(String[] args) throws Exception {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java
index ecaa673f557..7748a554f0f 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java
@@ -62,11 +62,11 @@ public final class ExecutableWork {
     }
   }
 
-  public final WorkId id() {
+  public WorkId id() {
     return work().id();
   }
 
-  public final Windmill.WorkItem getWorkItem() {
+  public Windmill.WorkItem getWorkItem() {
     return work().getWorkItem();
   }
 
@@ -74,4 +74,12 @@ public final class ExecutableWork {
   public String toString() {
     return "ExecutableWork{" + id() + "}";
   }
+
+  public String getComputationId() {
+    return work().getComputationId();
+  }
+
+  public Work.KeyGroup getKeyGroup() {
+    return work().getKeyGroup();
+  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
index cb01e1e508c..53ed30fdedb 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
@@ -25,6 +25,7 @@ import java.util.EnumMap;
 import java.util.IntSummaryStatistics;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Consumer;
@@ -52,6 +53,7 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSe
 import org.apache.beam.sdk.annotations.Internal;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 
@@ -74,6 +76,7 @@ public final class Work implements RefreshableWork {
   private final Instant startTime;
   private final Map<LatencyAttribution.State, Duration> totalDurationPerState;
   private final WorkId id;
+  private final KeyGroup keyGroup;
   private final String latencyTrackingId;
   private final long serializedWorkItemSize;
   private volatile TimedState currentState;
@@ -101,6 +104,10 @@ public final class Work implements RefreshableWork {
     // keyUniverse inside EnumMap every time.
     this.totalDurationPerState = new EnumMap<>(EMPTY_ENUM_MAP);
     this.id = WorkId.of(workItem);
+    this.keyGroup =
+        workItem.hasKeyGroup()
+            ? KeyGroup.create(workItem.getKeyGroup().getHigh(), 
workItem.getKeyGroup().getLow())
+            : KeyGroup.DEFAULT;
     this.latencyTrackingId =
         Long.toHexString(workItem.getShardingKey())
             + '-'
@@ -383,6 +390,14 @@ public final class Work implements RefreshableWork {
     abstract Instant startTime();
   }
 
+  public String getComputationId() {
+    return processingContext.computationId();
+  }
+
+  public KeyGroup getKeyGroup() {
+    return keyGroup;
+  }
+
   @AutoValue
   public abstract static class ProcessingContext {
 
@@ -416,4 +431,60 @@ public final class Work implements RefreshableWork {
       return Optional.ofNullable(getDataClient().getStateData(computationId(), 
request));
     }
   }
+
+  /**
+   * WorkItems with same key group and computation are eligible to be executed 
together in a
+   * multi-key bundle.
+   */
+  public static final class KeyGroup {
+
+    // Work items equaling to the default keyGroup will always be executed
+    // separately and not in a multi-key bundle
+    public static final KeyGroup DEFAULT = new KeyGroup(0, 0);
+
+    private final long high;
+    private final long low;
+
+    private KeyGroup(long high, long low) {
+      this.high = high;
+      this.low = low;
+    }
+
+    public static KeyGroup create(long high, long low) {
+      if (high == 0 && low == 0) {
+        return DEFAULT;
+      }
+      return new KeyGroup(high, low);
+    }
+
+    public long high() {
+      return high;
+    }
+
+    public long low() {
+      return low;
+    }
+
+    @Override
+    public boolean equals(@Nullable Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (!(o instanceof KeyGroup)) {
+        return false;
+      }
+      KeyGroup other = (KeyGroup) o;
+      return high == other.high && low == other.low;
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(high, low);
+    }
+
+    @Override
+    public String toString() {
+      return String.format("%016x%016x", high, low);
+    }
+  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
index c6fd96e0a4c..8964246c116 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
@@ -29,15 +29,15 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import javax.annotation.concurrent.GuardedBy;
 import 
org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle;
 import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.streaming.Work.KeyGroup;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor.Guard;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 /** An executor for executing work on windmill items. */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
 public class BoundedQueueExecutor {
 
   private final ThreadPoolExecutor executor;
@@ -78,6 +78,9 @@ public class BoundedQueueExecutor {
   @GuardedBy("this")
   private long totalTimeMaxActiveThreadsUsed;
 
+  // If set the keyGroupWorkQueue is used by the underlying executor.
+  private final @Nullable KeyGroupWorkQueue keyGroupWorkQueue;
+
   public BoundedQueueExecutor(
       int initialMaximumPoolSize,
       long keepAliveTime,
@@ -85,7 +88,9 @@ public class BoundedQueueExecutor {
       int maximumElementsOutstanding,
       long maximumBytesOutstanding,
       ThreadFactory threadFactory,
-      boolean useFairMonitor) {
+      boolean useFairMonitor,
+      boolean useKeyGroupWorkQueue) {
+    this.keyGroupWorkQueue = useKeyGroupWorkQueue ? new 
KeyGroupWorkQueue(useFairMonitor) : null;
     this.maximumPoolSize = initialMaximumPoolSize;
     monitor = new Monitor(useFairMonitor);
     executor =
@@ -94,7 +99,7 @@ public class BoundedQueueExecutor {
             initialMaximumPoolSize,
             keepAliveTime,
             unit,
-            new LinkedBlockingQueue<>(),
+            keyGroupWorkQueue != null ? keyGroupWorkQueue : new 
LinkedBlockingQueue<>(),
             threadFactory) {
           @Override
           protected void beforeExecute(Thread t, Runnable r) {
@@ -313,7 +318,7 @@ public class BoundedQueueExecutor {
     }
   }
 
-  private static final class QueuedWork implements Runnable {
+  static final class QueuedWork implements Runnable {
 
     private final ExecutableWork work;
     private final BoundedQueueExecutorWorkHandleImpl handle;
@@ -378,6 +383,22 @@ public class BoundedQueueExecutor {
     return new BoundedQueueExecutorWorkHandleImpl(elements, bytes);
   }
 
+  public @Nullable ExecutableWork pollWork(
+      String computationId, Work.KeyGroup keyGroup, 
BoundedQueueExecutorWorkHandle handle) {
+    checkArgument(handle instanceof BoundedQueueExecutorWorkHandleImpl);
+    checkArgument(computationId != null && keyGroup != null && 
!keyGroup.equals(KeyGroup.DEFAULT));
+    BoundedQueueExecutorWorkHandleImpl internalHandle = 
(BoundedQueueExecutorWorkHandleImpl) handle;
+    if (keyGroupWorkQueue == null) {
+      return null;
+    }
+    @Nullable QueuedWork queuedWork = 
keyGroupWorkQueue.pollWork(computationId, keyGroup);
+    if (queuedWork == null) {
+      return null;
+    }
+    internalHandle.merge(queuedWork.getHandle());
+    return queuedWork.getWork();
+  }
+
   private void decrementCounters(int elements, long bytes) {
     // All threads queue decrements and one thread grabs the monitor and 
updates
     // counters. We do this to reduce contention on monitor which is locked by
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java
new file mode 100644
index 00000000000..d151157ec68
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java
@@ -0,0 +1,462 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.util;
+
+import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+
+import java.util.AbstractQueue;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.ReentrantLock;
+import javax.annotation.concurrent.GuardedBy;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.streaming.Work.KeyGroup;
+import 
org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.QueuedWork;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/**
+ * A custom, thread-safe doubly-linked BlockingQueue. In addition to global 
FIFO ordering, the queue
+ * supports polling work by computation + key group in FIFO order
+ */
+class KeyGroupWorkQueue extends AbstractQueue<Runnable> implements 
BlockingQueue<Runnable> {
+
+  public static final Runnable SENTINEL_RUNNABLE =
+      () -> {
+        throw new IllegalStateException("sentinel runnable called");
+      };
+
+  static class Node {
+    // If keyGroup is non-null, task is an instance of QueuedWork
+    final Runnable task;
+    final @Nullable String computationId;
+    final Work.@Nullable KeyGroup keyGroup;
+    // cached keyGroupList if the Node is part of one.
+    @Nullable KeyGroupWorkList keyGroupList;
+
+    // prevNode, nextNode are used for the global order across all queued 
Runnables
+    @Nullable Node prevNode;
+    @Nullable Node nextNode;
+
+    // prevKeyGroupNode and nextKeyGroupNode are used for the keyGroup level 
lists linking
+    // QueuedWork with same keyGroup
+    @Nullable Node prevKeyGroupNode;
+    @Nullable Node nextKeyGroupNode;
+
+    Node(Runnable task) {
+      this.task = task;
+      if (task instanceof QueuedWork) {
+        this.computationId = ((QueuedWork) task).getWork().getComputationId();
+        this.keyGroup = ((QueuedWork) task).getWork().getKeyGroup();
+      } else {
+        this.computationId = null;
+        this.keyGroup = null;
+      }
+    }
+  }
+
+  /** Double linked list implementing key group level queue */
+  private static class KeyGroupWorkList {
+    final Node head = new Node(SENTINEL_RUNNABLE);
+    final Node tail = new Node(SENTINEL_RUNNABLE);
+
+    KeyGroupWorkList() {
+      head.nextKeyGroupNode = tail;
+      tail.prevKeyGroupNode = head;
+    }
+
+    boolean isEmpty() {
+      return head.nextKeyGroupNode == tail;
+    }
+
+    void append(Node node) {
+      Node last = checkStateNotNull(tail.prevKeyGroupNode);
+      node.prevKeyGroupNode = last;
+      node.nextKeyGroupNode = tail;
+      last.nextKeyGroupNode = node;
+      tail.prevKeyGroupNode = node;
+    }
+
+    void remove(Node node) {
+      @Nullable Node prev = node.prevKeyGroupNode;
+      @Nullable Node next = node.nextKeyGroupNode;
+      if (prev != null && next != null) {
+        prev.nextKeyGroupNode = next;
+        next.prevKeyGroupNode = prev;
+        node.prevKeyGroupNode = null;
+        node.nextKeyGroupNode = null;
+      }
+    }
+  }
+
+  private final ReentrantLock lock;
+  private final Condition notEmpty;
+
+  // Sentinels for the global list
+  @GuardedBy("lock")
+  private final Node globalHead = new Node(SENTINEL_RUNNABLE);
+
+  @GuardedBy("lock")
+  private final Node globalTail = new Node(SENTINEL_RUNNABLE);
+
+  @GuardedBy("lock")
+  private final Map<QueueKey, KeyGroupWorkList> keyGroupQueueMap = new 
HashMap<>();
+
+  @GuardedBy("lock")
+  private int size = 0;
+
+  public KeyGroupWorkQueue(boolean fair) {
+    this.lock = new ReentrantLock(fair);
+    this.notEmpty = lock.newCondition();
+    globalHead.nextNode = globalTail;
+    globalTail.prevNode = globalHead;
+  }
+
+  @GuardedBy("lock")
+  private void unlinkNode(Node node) {
+    // An existing node should always have previous and next since we have 
sentinels
+    // 1. Unlink from global list
+    Node prevG = checkArgumentNotNull(node.prevNode);
+    Node nextG = checkArgumentNotNull(node.nextNode);
+    prevG.nextNode = nextG;
+    nextG.prevNode = prevG;
+    node.prevNode = null;
+    node.nextNode = null;
+
+    // 2. Unlink from key group list
+    KeyGroupWorkList list = node.keyGroupList;
+    if (list != null) {
+      list.remove(node);
+      if (list.isEmpty()) {
+        String compId = checkStateNotNull(node.computationId);
+        Work.KeyGroup keyGroup = checkStateNotNull(node.keyGroup);
+        QueueKey key = new QueueKey(compId, keyGroup);
+        keyGroupQueueMap.remove(key);
+      }
+      node.keyGroupList = null;
+    }
+    --size;
+  }
+
+  @GuardedBy("lock")
+  private @Nullable Node removeFirstGlobal() {
+    Node first = checkStateNotNull(globalHead.nextNode);
+    if (first == globalTail) {
+      return null;
+    }
+    unlinkNode(first);
+    return first;
+  }
+
+  /**
+   * Remove and Return QueuedWork for the computationId, keyGroup in the FIFO 
order. Returns null,
+   * if there are no matches.
+   *
+   * @param keyGroup should not be equal to KeyGroup.DEFAULT
+   */
+  public @Nullable QueuedWork pollWork(String computationId, Work.KeyGroup 
keyGroup) {
+    checkArgument(computationId != null && keyGroup != null && 
!keyGroup.equals(KeyGroup.DEFAULT));
+    QueueKey key = new QueueKey(computationId, keyGroup);
+    lock.lock();
+    try {
+      KeyGroupWorkList keyGroupWorkList = keyGroupQueueMap.get(key);
+      if (keyGroupWorkList == null || keyGroupWorkList.isEmpty()) {
+        return null;
+      }
+
+      // Retrieve the first pending task for this computation and keyGroup in 
O(1)
+      Node firstNode = 
checkStateNotNull(keyGroupWorkList.head.nextKeyGroupNode);
+      if (firstNode == keyGroupWorkList.tail) {
+        return null;
+      }
+      unlinkNode(firstNode);
+
+      return (QueuedWork) firstNode.task;
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public boolean offer(@NonNull Runnable runnable) {
+    Node node = new Node(checkStateNotNull(runnable));
+    lock.lock();
+    try {
+      // Append to global list tail
+      Node lastG = checkStateNotNull(globalTail.prevNode);
+      node.prevNode = lastG;
+      node.nextNode = globalTail;
+      lastG.nextNode = node;
+      globalTail.prevNode = node;
+
+      // Append to key group list if applicable
+      String compId = node.computationId;
+      Work.KeyGroup keyGroup = node.keyGroup;
+      if (compId != null && keyGroup != null && 
!keyGroup.equals(KeyGroup.DEFAULT)) {
+        QueueKey key = new QueueKey(compId, keyGroup);
+        KeyGroupWorkList keyGroupWorkList =
+            keyGroupQueueMap.computeIfAbsent(key, k -> new KeyGroupWorkList());
+        keyGroupWorkList.append(node);
+        node.keyGroupList = keyGroupWorkList;
+      }
+
+      ++size;
+      notEmpty.signal();
+      return true;
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public void put(Runnable e) throws InterruptedException {
+    offer(e); // Unbounded queue
+  }
+
+  @Override
+  public boolean offer(Runnable e, long timeout, TimeUnit unit) throws 
InterruptedException {
+    return offer(e); // Unbounded queue
+  }
+
+  @Override
+  public @Nullable Runnable poll() {
+    lock.lock();
+    try {
+      @Nullable Node node = removeFirstGlobal();
+      return (node != null) ? node.task : null;
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public Runnable take() throws InterruptedException {
+    lock.lockInterruptibly();
+    try {
+      while (size == 0) {
+        notEmpty.await();
+      }
+      @Nullable Node node = removeFirstGlobal();
+      checkStateNotNull(node, "Queue is empty but size was " + size);
+      return node.task;
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public @Nullable Runnable poll(long timeout, TimeUnit unit) throws 
InterruptedException {
+    long nanos = unit.toNanos(timeout);
+    lock.lockInterruptibly();
+    try {
+      while (size == 0) {
+        if (nanos <= 0) {
+          return null;
+        }
+        nanos = notEmpty.awaitNanos(nanos);
+      }
+      @Nullable Node node = removeFirstGlobal();
+      return (node != null) ? node.task : null;
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public @Nullable Runnable peek() {
+    lock.lock();
+    try {
+      Node first = checkStateNotNull(globalHead.nextNode);
+      if (first == globalTail) {
+        return null;
+      }
+      return first.task;
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public int size() {
+    lock.lock();
+    try {
+      return size;
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public boolean isEmpty() {
+    lock.lock();
+    try {
+      return size == 0;
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public boolean remove(Object o) {
+    if (o == null) return false;
+    lock.lock();
+    try {
+      // Walk the global queue in O(N) to find and unlink the node
+      Node curr = checkStateNotNull(globalHead.nextNode);
+      while (curr != globalTail) {
+        if (o.equals(curr.task)) {
+          unlinkNode(curr);
+          return true;
+        }
+        curr = checkStateNotNull(curr.nextNode);
+      }
+      return false;
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public boolean contains(Object o) {
+    if (o == null) return false;
+    lock.lock();
+    try {
+      Node curr = checkStateNotNull(globalHead.nextNode);
+      while (curr != globalTail) {
+        if (o.equals(curr.task)) {
+          return true;
+        }
+        curr = checkStateNotNull(curr.nextNode);
+      }
+      return false;
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public int remainingCapacity() {
+    return Integer.MAX_VALUE;
+  }
+
+  @Override
+  public int drainTo(Collection<? super Runnable> c) {
+    return drainTo(c, Integer.MAX_VALUE);
+  }
+
+  @Override
+  public int drainTo(Collection<? super Runnable> c, int maxElements) {
+    if (c == null) throw new NullPointerException();
+    if (c == this) throw new IllegalArgumentException();
+    if (maxElements <= 0) return 0;
+    lock.lock();
+    try {
+      int added = 0;
+      Node curr = checkStateNotNull(globalHead.nextNode);
+      while (curr != globalTail && added < maxElements) {
+        Node next = checkStateNotNull(curr.nextNode);
+        unlinkNode(curr);
+        Runnable task = curr.task;
+        c.add(task);
+        ++added;
+        curr = next;
+      }
+      return added;
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public void clear() {
+    lock.lock();
+    try {
+      Node curr = checkStateNotNull(globalHead.nextNode);
+      while (curr != globalTail) {
+        Node next = checkStateNotNull(curr.nextNode);
+        unlinkNode(curr);
+        curr = next;
+      }
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public Iterator<Runnable> iterator() {
+    lock.lock();
+    try {
+      ImmutableList.Builder<Runnable> builder = 
ImmutableList.builderWithExpectedSize(size);
+      Node curr = checkStateNotNull(globalHead.nextNode);
+      while (curr != globalTail) {
+        builder.add(curr.task);
+        curr = checkStateNotNull(curr.nextNode);
+      }
+      return builder.build().iterator();
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  static final class QueueKey {
+    private final String computationId;
+    private final Work.KeyGroup keyGroup;
+
+    QueueKey(String computationId, Work.KeyGroup keyGroup) {
+      this.computationId = computationId;
+      this.keyGroup = keyGroup;
+    }
+
+    @Override
+    public boolean equals(@Nullable Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (!(o instanceof QueueKey)) {
+        return false;
+      }
+      QueueKey other = (QueueKey) o;
+      return computationId.equals(other.computationId) && 
keyGroup.equals(other.keyGroup);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(computationId, keyGroup);
+    }
+
+    @Override
+    public String toString() {
+      return "QueueKey{"
+          + "computationId='"
+          + computationId
+          + '\''
+          + ", keyGroup="
+          + keyGroup
+          + '}';
+    }
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index d58f2007699..5bcdffcc256 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -3036,7 +3036,8 @@ public class StreamingDataflowWorkerTest {
                 .setNameFormat("DataflowWorkUnits-%d")
                 .setDaemon(true)
                 .build(),
-            /*useFairMonitor=*/ false);
+            /*useFairMonitor=*/ false,
+            /*useKeyGroupWorkQueue=*/ false);
 
     ComputationState computationState =
         new ComputationState(
@@ -3097,7 +3098,8 @@ public class StreamingDataflowWorkerTest {
                 .setNameFormat("DataflowWorkUnits-%d")
                 .setDaemon(true)
                 .build(),
-            /*useFairMonitor=*/ false);
+            /*useFairMonitor=*/ false,
+            /*useKeyGroupWorkQueue=*/ false);
 
     ComputationState computationState =
         new ComputationState(
@@ -3167,7 +3169,8 @@ public class StreamingDataflowWorkerTest {
                 .setNameFormat("DataflowWorkUnits-%d")
                 .setDaemon(true)
                 .build(),
-            /*useFairMonitor=*/ false);
+            /*useFairMonitor=*/ false,
+            /*useKeyGroupWorkQueue=*/ false);
 
     ComputationState computationState =
         new ComputationState(
@@ -3241,7 +3244,8 @@ public class StreamingDataflowWorkerTest {
                 .setNameFormat("DataflowWorkUnits-%d")
                 .setDaemon(true)
                 .build(),
-            /*useFairMonitor=*/ false);
+            /*useFairMonitor=*/ false,
+            /*useKeyGroupWorkQueue=*/ false);
 
     ComputationState computationState =
         new ComputationState(
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
index 55fe82c7163..a98102751fb 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
@@ -20,6 +20,8 @@ package org.apache.beam.runners.dataflow.worker.util;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
@@ -33,6 +35,7 @@ import 
org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork;
 import org.apache.beam.runners.dataflow.worker.streaming.Watermarks;
 import org.apache.beam.runners.dataflow.worker.streaming.Work;
 import 
org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.BoundedQueueExecutorWorkHandleImpl;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient;
 import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
@@ -66,13 +69,30 @@ public class BoundedQueueExecutorTest {
   @Rule public transient Timeout globalTimeout = Timeout.seconds(300);
   private BoundedQueueExecutor executor;
 
+  private static final Work.KeyGroup DEFAULT_KEY_GROUP = 
Work.KeyGroup.create(1, 2);
+
   private static ExecutableWork createWork(Consumer<Work> executeWorkFn) {
+    return createWorkWithCompId("computationId", executeWorkFn);
+  }
+
+  private static ExecutableWork createWorkWithCompId(
+      String computationId, Consumer<Work> executeWorkFn) {
+    return createWorkWithCompIdAndKeyGroup(computationId, DEFAULT_KEY_GROUP, 
executeWorkFn);
+  }
+
+  private static ExecutableWork createWorkWithCompIdAndKeyGroup(
+      String computationId, Work.KeyGroup keyGroup, Consumer<Work> 
executeWorkFn) {
     WorkItem workItem =
         WorkItem.newBuilder()
             .setKey(ByteString.EMPTY)
             .setShardingKey(1)
             .setWorkToken(33)
             .setCacheToken(1)
+            .setKeyGroup(
+                Windmill.Uint128Proto.newBuilder()
+                    .setHigh(keyGroup.high())
+                    .setLow(keyGroup.low())
+                    .build())
             .build();
     return ExecutableWork.create(
         Work.create(
@@ -80,10 +100,7 @@ public class BoundedQueueExecutorTest {
             workItem.getSerializedSize(),
             Watermarks.builder().setInputDataWatermark(Instant.now()).build(),
             Work.createProcessingContext(
-                "computationId",
-                new FakeGetDataClient(),
-                ignored -> {},
-                mock(HeartbeatSender.class)),
+                computationId, new FakeGetDataClient(), ignored -> {}, 
mock(HeartbeatSender.class)),
             false,
             Instant::now),
         (work, handle) -> {
@@ -116,7 +133,8 @@ public class BoundedQueueExecutorTest {
                 .setNameFormat("DataflowWorkUnits-%d")
                 .setDaemon(true)
                 .build(),
-            useFairMonitor);
+            useFairMonitor,
+            /*useKeyGroupWorkQueue=*/ false);
   }
 
   @Test
@@ -413,4 +431,126 @@ public class BoundedQueueExecutorTest {
             + "Work Queue Bytes: 0/10000000<br>/n";
     assertEquals(expectedSummaryHtml, executor.summaryHtml());
   }
+
+  @Test
+  public void testPollWork() throws Exception {
+    // Create separate BoundedQueueExecutor with 1 thread so we can block it 
easily
+    BoundedQueueExecutor testExecutor =
+        new BoundedQueueExecutor(
+            1,
+            60,
+            TimeUnit.SECONDS,
+            100,
+            10000000,
+            new 
ThreadFactoryBuilder().setNameFormat("testStealing-%d").setDaemon(true).build(),
+            useFairMonitor,
+            /*useKeyGroupWorkQueue=*/ true);
+
+    // 1. Create blocker task to occupy the worker thread
+    CountDownLatch blockerStart = new CountDownLatch(1);
+    CountDownLatch blockerStop = new CountDownLatch(1);
+    ExecutableWork blockerWork =
+        createWorkWithCompIdAndKeyGroup(
+            "blockerComp",
+            DEFAULT_KEY_GROUP,
+            ignored -> {
+              blockerStart.countDown();
+              try {
+                blockerStop.await();
+              } catch (InterruptedException e) {
+                throw new RuntimeException(e);
+              }
+            });
+
+    testExecutor.execute(blockerWork, 0);
+    blockerStart.await();
+
+    // 2. Create two distinct key groups
+    Work.KeyGroup keyGroup1 = Work.KeyGroup.create(1, 1);
+    Work.KeyGroup keyGroup2 = Work.KeyGroup.create(1, 2);
+
+    // Create executable tasks
+    CountDownLatch targetStart = new CountDownLatch(1);
+    ExecutableWork work1 = createWorkWithCompIdAndKeyGroup("compA", keyGroup1, 
ignored -> {});
+    ExecutableWork work2 =
+        createWorkWithCompIdAndKeyGroup(
+            "compA",
+            keyGroup2,
+            ignored -> {
+              targetStart.countDown();
+            });
+
+    // Enqueue tasks (they will wait in the queue because the thread is 
blocked)
+    testExecutor.execute(work1, 100);
+    testExecutor.execute(work2, 150);
+
+    // Total outstanding elements must be 3 (blocker + work1 + work2)
+    assertEquals(3, testExecutor.elementsOutstanding());
+
+    // Steal work2 using pollWork with compA and keyGroup2
+    try (BoundedQueueExecutorWorkHandleImpl stealHandle = 
testExecutor.createBudgetHandle(0, 0L)) {
+      ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup2, 
stealHandle);
+      assertNotNull(stolen);
+      assertEquals(work2, stolen);
+
+      // Run the stolen task
+      stolen.run(stealHandle);
+      targetStart.await();
+    }
+
+    // Steal work1 using pollWork with compA and keyGroup1
+    try (BoundedQueueExecutorWorkHandleImpl stealHandle = 
testExecutor.createBudgetHandle(0, 0L)) {
+      ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup1, 
stealHandle);
+      assertNotNull(stolen);
+      assertEquals(work1, stolen);
+    }
+
+    // Unblock the blocker and shut down
+    blockerStop.countDown();
+    testExecutor.shutdown();
+  }
+
+  @Test
+  public void testPollWorkWithLinkedBlockingQueue() throws Exception {
+    BoundedQueueExecutor testExecutor =
+        new BoundedQueueExecutor(
+            1,
+            60,
+            TimeUnit.SECONDS,
+            100,
+            10000000,
+            new 
ThreadFactoryBuilder().setNameFormat("testLinkedQueue-%d").setDaemon(true).build(),
+            useFairMonitor,
+            /* useKeyGroupWorkQueue= */ false);
+
+    CountDownLatch blockerStart = new CountDownLatch(1);
+    CountDownLatch blockerStop = new CountDownLatch(1);
+    ExecutableWork blockerWork =
+        createWorkWithCompIdAndKeyGroup(
+            "blockerComp",
+            DEFAULT_KEY_GROUP,
+            ignored -> {
+              blockerStart.countDown();
+              try {
+                blockerStop.await();
+              } catch (InterruptedException e) {
+                throw new RuntimeException(e);
+              }
+            });
+
+    testExecutor.execute(blockerWork, 0);
+    blockerStart.await();
+
+    Work.KeyGroup keyGroup = Work.KeyGroup.create(1, 1);
+    ExecutableWork work = createWorkWithCompIdAndKeyGroup("compA", keyGroup, 
ignored -> {});
+    testExecutor.execute(work, 100);
+
+    try (BoundedQueueExecutorWorkHandleImpl stealHandle = 
testExecutor.createBudgetHandle(0, 0L)) {
+      ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup, 
stealHandle);
+      assertNull(stolen);
+    }
+
+    blockerStop.countDown();
+    testExecutor.shutdown();
+  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java
new file mode 100644
index 00000000000..994aa2030f3
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java
@@ -0,0 +1,504 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.util;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+
+import java.lang.Thread.State;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork;
+import org.apache.beam.runners.dataflow.worker.streaming.Watermarks;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import 
org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.QueuedWork;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class KeyGroupWorkQueueTest {
+
+  @Parameters(name = "fairQueue={0}")
+  public static Iterable<Object[]> data() {
+    return Arrays.asList(new Object[][] {{true}, {false}});
+  }
+
+  @Parameterized.Parameter public boolean fairQueue;
+
+  private BoundedQueueExecutor executor;
+
+  @Before
+  public void setUp() {
+    executor =
+        new BoundedQueueExecutor(
+            2,
+            60,
+            TimeUnit.SECONDS,
+            100,
+            10000000,
+            new 
ThreadFactoryBuilder().setNameFormat("Test-%d").setDaemon(true).build(),
+            fairQueue,
+            /*useKeyGroupWorkQueue=*/ true);
+  }
+
+  private static final Work.KeyGroup TEST_KEY_GROUP = Work.KeyGroup.create(1, 
2);
+
+  private QueuedWork createQueuedWork(String computationId, long workBytes) {
+    return createQueuedWork(computationId, TEST_KEY_GROUP, workBytes);
+  }
+
+  private QueuedWork createQueuedWork(
+      String computationId, Work.@Nullable KeyGroup keyGroup, long workBytes) {
+    WorkItem.Builder workItemBuilder =
+        WorkItem.newBuilder()
+            .setKey(ByteString.EMPTY)
+            .setShardingKey(1)
+            .setWorkToken(33)
+            .setCacheToken(1);
+    if (keyGroup != null) {
+      workItemBuilder.setKeyGroup(
+          
org.apache.beam.runners.dataflow.worker.windmill.Windmill.Uint128Proto.newBuilder()
+              .setHigh(keyGroup.high())
+              .setLow(keyGroup.low())
+              .build());
+    }
+    WorkItem workItem = workItemBuilder.build();
+    ExecutableWork work =
+        ExecutableWork.create(
+            Work.create(
+                workItem,
+                workItem.getSerializedSize(),
+                
Watermarks.builder().setInputDataWatermark(Instant.now()).build(),
+                Work.createProcessingContext(
+                    computationId,
+                    new FakeGetDataClient(),
+                    ignored -> {},
+                    mock(HeartbeatSender.class)),
+                false,
+                Instant::now),
+            (w, h) -> {});
+    return new QueuedWork(work, executor.createBudgetHandle(1, workBytes));
+  }
+
+  private static class NoOpRunnable implements Runnable {
+    final String id;
+
+    NoOpRunnable(String id) {
+      this.id = id;
+    }
+
+    @Override
+    public void run() {}
+
+    @Override
+    public String toString() {
+      return "NoOpRunnable(" + id + ")";
+    }
+  }
+
+  @Test
+  public void testBasicOfferAndPoll() {
+    KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+    assertTrue(queue.isEmpty());
+    assertEquals(0, queue.size());
+
+    NoOpRunnable task1 = new NoOpRunnable("1");
+    NoOpRunnable task2 = new NoOpRunnable("2");
+
+    assertTrue(queue.offer(task1));
+    assertTrue(queue.offer(task2));
+    assertEquals(2, queue.size());
+
+    assertEquals(task1, queue.poll());
+    assertEquals(task2, queue.poll());
+    assertNull(queue.poll());
+    assertTrue(queue.isEmpty());
+  }
+
+  @Test
+  public void testRemove() {
+    KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+    NoOpRunnable task1 = new NoOpRunnable("1");
+    NoOpRunnable task2 = new NoOpRunnable("2");
+
+    queue.offer(task1);
+    queue.offer(task2);
+
+    assertTrue(queue.remove(task1));
+    assertEquals(1, queue.size());
+    assertEquals(task2, queue.poll());
+    assertFalse(queue.remove(task1)); // Already gone
+  }
+
+  @Test
+  public void testDrainTo() {
+    KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+    NoOpRunnable task1 = new NoOpRunnable("1");
+    NoOpRunnable task2 = new NoOpRunnable("2");
+    queue.offer(task1);
+    queue.offer(task2);
+
+    List<Runnable> drained = new ArrayList<>();
+    assertEquals(2, queue.drainTo(drained));
+    assertEquals(2, drained.size());
+    assertEquals(task1, drained.get(0));
+    assertEquals(task2, drained.get(1));
+    assertTrue(queue.isEmpty());
+  }
+
+  @Test
+  public void testIteratorSafeTraversalAndImmutable() {
+    KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+    NoOpRunnable task1 = new NoOpRunnable("1");
+    NoOpRunnable task2 = new NoOpRunnable("2");
+    queue.offer(task1);
+    queue.offer(task2);
+
+    Iterator<Runnable> it = queue.iterator();
+    assertTrue(it.hasNext());
+    assertEquals(task1, it.next());
+    assertTrue(it.hasNext());
+    assertEquals(task2, it.next());
+    assertFalse(it.hasNext());
+
+    // Assert that mutating the iterator throws UnsupportedOperationException
+    it = queue.iterator();
+    assertTrue(it.hasNext());
+    it.next();
+    try {
+      it.remove();
+      fail("Iterator must be immutable");
+    } catch (UnsupportedOperationException e) {
+      // Expected
+    }
+  }
+
+  @Test
+  public void testPollWorkTargeted() {
+    KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+
+    QueuedWork workA1 = createQueuedWork("compA", 100);
+    QueuedWork workB1 = createQueuedWork("compB", 200);
+    QueuedWork workA2 = createQueuedWork("compA", 150);
+
+    queue.offer(workA1);
+    queue.offer(workB1);
+    queue.offer(workA2);
+
+    assertEquals(3, queue.size());
+    assertFalse(queue.isEmpty());
+
+    // Targeted poll A
+    QueuedWork polledA1 = queue.pollWork("compA", TEST_KEY_GROUP);
+    assertNotNull(polledA1);
+    assertEquals("compA", polledA1.getWork().getComputationId());
+    assertEquals(100, polledA1.getHandle().bytes());
+    assertEquals(2, queue.size());
+    assertFalse(queue.isEmpty());
+
+    // Poll next should be B1 (since A1 was stolen, B1 is now first global)
+    assertEquals(workB1, queue.poll());
+    assertEquals(1, queue.size());
+    assertFalse(queue.isEmpty());
+
+    // Last should be A2
+    assertEquals(workA2, queue.poll());
+    assertEquals(0, queue.size());
+    assertTrue(queue.isEmpty());
+
+    assertEquals(null, queue.poll());
+  }
+
+  @Test
+  public void testConcurrentStress() throws InterruptedException, 
ExecutionException {
+    KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+    int producerThreads = 4;
+    int consumerThreads = 4;
+    int tasksPerProducer = 1000;
+    int totalTasks = producerThreads * tasksPerProducer;
+    Runnable poisonPill =
+        new Runnable() {
+          @Override
+          public void run() {}
+
+          @Override
+          public String toString() {
+            return "POISON_PILL";
+          }
+        };
+
+    ExecutorService executorService =
+        Executors.newFixedThreadPool(producerThreads + consumerThreads);
+    CountDownLatch startLatch = new CountDownLatch(1);
+    CountDownLatch producersDoneLatch = new CountDownLatch(producerThreads);
+    CountDownLatch consumersDoneLatch = new CountDownLatch(consumerThreads);
+    CountDownLatch consumedLatch = new CountDownLatch(totalTasks);
+    List<Future<?>> futures = new ArrayList<>();
+
+    // Start consumers
+    for (int i = 0; i < consumerThreads; i++) {
+      int consumerId = i;
+      futures.add(
+          executorService.submit(
+              () -> {
+                try {
+                  startLatch.await();
+                  int iteration = consumerId % 4;
+                  while (true) {
+                    int strategy = iteration;
+                    iteration = (iteration + 1) % 4;
+                    Runnable task = null;
+                    if (strategy == 0) {
+                      String compId = "comp-" + (consumedLatch.getCount() % 5);
+                      task = queue.pollWork(compId, TEST_KEY_GROUP);
+                    } else if (strategy == 1) {
+                      task = queue.poll();
+                    } else if (strategy == 2) {
+                      task = queue.poll(10, TimeUnit.MICROSECONDS);
+                    } else if (strategy == 3) {
+                      task = queue.take();
+                    }
+
+                    if (task == poisonPill) {
+                      break;
+                    }
+                    if (task != null) {
+                      consumedLatch.countDown();
+                    }
+                  }
+                } catch (Exception e) {
+                  throw new RuntimeException(e);
+                } finally {
+                  consumersDoneLatch.countDown();
+                }
+              }));
+    }
+
+    // Start producers
+    for (int i = 0; i < producerThreads; i++) {
+      futures.add(
+          executorService.submit(
+              () -> {
+                try {
+                  startLatch.await();
+                  for (int j = 0; j < tasksPerProducer; j++) {
+                    String compId = "comp-" + (j % 5);
+                    queue.offer(createQueuedWork(compId, 10));
+                  }
+                } catch (Exception e) {
+                  throw new RuntimeException(e);
+                } finally {
+                  producersDoneLatch.countDown();
+                }
+              }));
+    }
+
+    // Release the start latch to start the test
+    startLatch.countDown();
+
+    // Wait for all tasks to be consumed
+    assertTrue(consumedLatch.await(30, TimeUnit.SECONDS));
+
+    // Send poison pills to stop all consumers
+    for (int i = 0; i < consumerThreads; i++) {
+      queue.offer(poisonPill);
+    }
+
+    // Wait for consumers to finish
+    assertTrue(consumersDoneLatch.await(30, TimeUnit.SECONDS));
+    // Wait for producers to finish
+    assertTrue(producersDoneLatch.await(30, TimeUnit.SECONDS));
+
+    // Check for exceptions in threads
+    for (Future<?> future : futures) {
+      future.get();
+    }
+
+    executorService.shutdown();
+    assertTrue(executorService.awaitTermination(30, TimeUnit.SECONDS));
+
+    assertEquals(0, queue.size());
+    assertTrue(queue.isEmpty());
+  }
+
+  @Test
+  public void testTakeBlocksAndWakesUp() throws InterruptedException {
+    final KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+    final NoOpRunnable task = new NoOpRunnable("take-task");
+    final AtomicReference<Runnable> result = new AtomicReference<>();
+    final CountDownLatch started = new CountDownLatch(1);
+    final CountDownLatch finished = new CountDownLatch(1);
+
+    Thread t =
+        new Thread(
+            () -> {
+              started.countDown();
+              try {
+                result.set(queue.take());
+              } catch (InterruptedException e) {
+                // Ignore
+              } finally {
+                finished.countDown();
+              }
+            });
+    t.setDaemon(true);
+    t.start();
+
+    assertTrue(started.await(30, TimeUnit.SECONDS));
+    waitForThreadState(t, State.WAITING);
+
+    queue.offer(task);
+
+    assertTrue(finished.await(30, TimeUnit.SECONDS));
+    assertEquals(task, result.get());
+  }
+
+  @Test
+  public void testPollWithTimeout() throws InterruptedException {
+    final KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+    final NoOpRunnable task = new NoOpRunnable("poll-task");
+    final AtomicReference<Runnable> result = new AtomicReference<>();
+    final CountDownLatch started = new CountDownLatch(1);
+    final CountDownLatch finished = new CountDownLatch(1);
+
+    // 1. Verify timeout returns null
+    Thread t1 =
+        new Thread(
+            () -> {
+              started.countDown();
+              try {
+                result.set(queue.poll(500, TimeUnit.MILLISECONDS));
+              } catch (InterruptedException e) {
+                // Ignore
+              } finally {
+                finished.countDown();
+              }
+            });
+    t1.setDaemon(true);
+    t1.start();
+
+    assertTrue(started.await(30, TimeUnit.SECONDS));
+    waitForThreadState(t1, State.TIMED_WAITING);
+
+    assertTrue(finished.await(30, TimeUnit.SECONDS));
+    assertNull(result.get());
+
+    // 2. Verify timed poll receives task offered concurrently
+    final CountDownLatch started2 = new CountDownLatch(1);
+    final CountDownLatch finished2 = new CountDownLatch(1);
+    final AtomicReference<Runnable> result2 = new AtomicReference<>();
+
+    Thread t2 =
+        new Thread(
+            () -> {
+              started2.countDown();
+              try {
+                result2.set(queue.poll(2, TimeUnit.SECONDS));
+              } catch (InterruptedException e) {
+                // Ignore
+              } finally {
+                finished2.countDown();
+              }
+            });
+    t2.setDaemon(true);
+    t2.start();
+
+    assertTrue(started2.await(30, TimeUnit.SECONDS));
+    waitForThreadState(t2, State.TIMED_WAITING);
+
+    queue.offer(task);
+
+    assertTrue(finished2.await(30, TimeUnit.SECONDS));
+    assertEquals(task, result2.get());
+  }
+
+  @Test
+  public void testPollWorkWithKeyGroup() {
+    KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+
+    Work.KeyGroup keyGroup1 = Work.KeyGroup.create(1, 1);
+    Work.KeyGroup keyGroup2 = Work.KeyGroup.create(1, 2);
+    Work.KeyGroup keyGroupNotExist = Work.KeyGroup.create(3, 4);
+
+    QueuedWork workA1 = createQueuedWork("compA", keyGroup1, 100);
+    QueuedWork workA2 = createQueuedWork("compA", keyGroup2, 150);
+
+    queue.offer(workA1);
+    queue.offer(workA2);
+
+    assertEquals(2, queue.size());
+
+    QueuedWork polledNotExist = queue.pollWork("compA", keyGroupNotExist);
+    assertNull(polledNotExist);
+    assertEquals(2, queue.size());
+
+    // Poll with keyGroup2 first - should return workA2
+    QueuedWork polledA2 = queue.pollWork("compA", keyGroup2);
+    assertNotNull(polledA2);
+    assertEquals(workA2, polledA2);
+    assertEquals(1, queue.size());
+
+    // Poll with keyGroup2 again - should return null
+    assertNull(queue.pollWork("compA", keyGroup2));
+
+    // Poll with keyGroup1 - should return workA1
+    QueuedWork polledA1 = queue.pollWork("compA", keyGroup1);
+    assertNotNull(polledA1);
+    assertEquals(workA1, polledA1);
+    assertTrue(queue.isEmpty());
+
+    polledNotExist = queue.pollWork("compA", keyGroupNotExist);
+    assertNull(polledNotExist);
+    assertTrue(queue.isEmpty());
+  }
+
+  private void waitForThreadState(Thread t, State state) throws 
InterruptedException {
+    long timeoutMs = 30000;
+    long start = System.currentTimeMillis();
+    while (t.getState() != state) {
+      if (System.currentTimeMillis() - start > timeoutMs) {
+        fail("Thread did not reach " + state + " state within " + timeoutMs + 
"ms");
+      }
+      Thread.sleep(1);
+    }
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java
index 07b4b14fd11..ef0d8e43485 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java
@@ -62,7 +62,8 @@ public class StreamingCommitFinalizerTest {
                 .setNameFormat("FinalizationCallback-%d")
                 .setDaemon(true)
                 .build(),
-            /*useFairMonitor=*/ false);
+            /*useFairMonitor=*/ false,
+            /*useKeyGroupWorkQueue=*/ false);
 
     cleanupExecutor =
         Executors.newScheduledThreadPool(
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java
index 51bd4816b03..0610ed44c27 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java
@@ -64,7 +64,8 @@ public class WorkFailureProcessorTest {
                 .setNameFormat("DataflowWorkUnits-%d")
                 .setDaemon(true)
                 .build(),
-            /*useFairMonitor=*/ false);
+            /*useFairMonitor=*/ false,
+            /*useKeyGroupWorkQueue=*/ false);
 
     return WorkFailureProcessor.forTesting(workExecutor, failureTracker, 
Optional::empty, clock, 0);
   }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java
index 88a82c6f76b..f32282056e4 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java
@@ -80,7 +80,8 @@ public class ActiveWorkRefresherTest {
         1,
         10000000,
         new 
ThreadFactoryBuilder().setNameFormat("DataflowWorkUnits-%d").setDaemon(true).build(),
-        /*useFairMonitor=*/ false);
+        /*useFairMonitor=*/ false,
+        /*useKeyGroupWorkQueue=*/ false);
   }
 
   private static ComputationState createComputationState(int 
computationIdSuffix) {
diff --git 
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
 
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
index 1da7ef9be8b..aaa09c105fc 100644
--- 
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
+++ 
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
@@ -421,6 +421,11 @@ message WatermarkHold {
   optional string state_family = 4;
 }
 
+message Uint128Proto {
+  required fixed64 high = 1;
+  required fixed64 low = 2;
+}
+
 // Proto describing a hot key detected on a given WorkItem.
 message HotKeyInfo {
   // The age of the hot key measured from when it was first detected.
@@ -448,6 +453,8 @@ message WorkItem {
   // present, this field includes metadata associated with any hot key.
   optional HotKeyInfo hot_key_info = 11;
 
+  optional Uint128Proto key_group = 18;
+
   reserved 12, 13, 14, 15, 16;
 }
 

Reply via email to