This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new d9117f423d9 [Dataflow Streaming] [Multi Key] Introduce
KeyGroupWorkQueue and integrate with BoundedQueueExecutor (#38767)
d9117f423d9 is described below
commit d9117f423d903b00ff5b10c3e8c915cf7caba279
Author: Arun Pandian <[email protected]>
AuthorDate: Mon Jun 8 03:46:18 2026 -0700
[Dataflow Streaming] [Multi Key] Introduce KeyGroupWorkQueue and integrate
with BoundedQueueExecutor (#38767)
* [Dataflow Streaming] Introduce KeyGroupWorkQueue and integrate with
BoundedQueueExecutor
- Add Uint128Proto and key_group to WorkItem in windmill.proto.
- Update Work model to support key groups.
- Implement KeyGroupWorkQueue for grouping tasks by key group.
KepGroupWorkQueue
is a FIFO queue that allows polling elements based on
global order and also by order within a key group.
- Integrate BoundedQueueExecutor with KepGroupWorkQueue and support
targeted polling.
---
.../dataflow/worker/StreamingDataflowWorker.java | 8 +-
.../dataflow/worker/streaming/ExecutableWork.java | 12 +-
.../runners/dataflow/worker/streaming/Work.java | 71 +++
.../dataflow/worker/util/BoundedQueueExecutor.java | 33 +-
.../dataflow/worker/util/KeyGroupWorkQueue.java | 462 +++++++++++++++++++
.../worker/StreamingDataflowWorkerTest.java | 12 +-
.../worker/util/BoundedQueueExecutorTest.java | 150 +++++-
.../worker/util/KeyGroupWorkQueueTest.java | 504 +++++++++++++++++++++
.../processing/StreamingCommitFinalizerTest.java | 3 +-
.../failures/WorkFailureProcessorTest.java | 3 +-
.../work/refresh/ActiveWorkRefresherTest.java | 3 +-
.../worker/windmill/src/main/proto/windmill.proto | 7 +
12 files changed, 1247 insertions(+), 21 deletions(-)
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 4c3e58978ec..9e82343474c 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -179,6 +179,9 @@ public final class StreamingDataflowWorker {
// Experiment make the monitor within BoundedQueueExecutor fair
public static final String
BOUNDED_QUEUE_EXECUTOR_USE_FAIR_MONITOR_EXPERIMENT =
"windmill_bounded_queue_executor_use_fair_monitor";
+ // Don't use. Experiment guarding multi key bundles. The feature is work in
progress and
+ // incomplete.
+ private static final String UNSTABLE_ENABLE_MULTI_KEY_BUNDLE =
"unstable_enable_multi_key_bundle";
private final WindmillStateCache stateCache;
private AtomicReference<StreamingWorkerStatusPages> statusPages = new
AtomicReference<>();
@@ -1018,6 +1021,8 @@ public final class StreamingDataflowWorker {
private static BoundedQueueExecutor
createWorkUnitExecutor(DataflowWorkerHarnessOptions options) {
boolean useFairMonitor =
DataflowRunner.hasExperiment(options,
BOUNDED_QUEUE_EXECUTOR_USE_FAIR_MONITOR_EXPERIMENT);
+ boolean useKeyGroupWorkQueue =
+ DataflowRunner.hasExperiment(options,
UNSTABLE_ENABLE_MULTI_KEY_BUNDLE);
return new BoundedQueueExecutor(
chooseMaxThreads(options),
THREAD_EXPIRATION_TIME_SEC,
@@ -1025,7 +1030,8 @@ public final class StreamingDataflowWorker {
chooseMaxBundlesOutstanding(options),
chooseMaxBytesOutstanding(options),
new
ThreadFactoryBuilder().setNameFormat("DataflowWorkUnits-%d").setDaemon(true).build(),
- useFairMonitor);
+ useFairMonitor,
+ useKeyGroupWorkQueue);
}
public static void main(String[] args) throws Exception {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java
index ecaa673f557..7748a554f0f 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java
@@ -62,11 +62,11 @@ public final class ExecutableWork {
}
}
- public final WorkId id() {
+ public WorkId id() {
return work().id();
}
- public final Windmill.WorkItem getWorkItem() {
+ public Windmill.WorkItem getWorkItem() {
return work().getWorkItem();
}
@@ -74,4 +74,12 @@ public final class ExecutableWork {
public String toString() {
return "ExecutableWork{" + id() + "}";
}
+
+ public String getComputationId() {
+ return work().getComputationId();
+ }
+
+ public Work.KeyGroup getKeyGroup() {
+ return work().getKeyGroup();
+ }
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
index cb01e1e508c..53ed30fdedb 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java
@@ -25,6 +25,7 @@ import java.util.EnumMap;
import java.util.IntSummaryStatistics;
import java.util.Map;
import java.util.Map.Entry;
+import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
@@ -52,6 +53,7 @@ import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSe
import org.apache.beam.sdk.annotations.Internal;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
import org.joda.time.Instant;
@@ -74,6 +76,7 @@ public final class Work implements RefreshableWork {
private final Instant startTime;
private final Map<LatencyAttribution.State, Duration> totalDurationPerState;
private final WorkId id;
+ private final KeyGroup keyGroup;
private final String latencyTrackingId;
private final long serializedWorkItemSize;
private volatile TimedState currentState;
@@ -101,6 +104,10 @@ public final class Work implements RefreshableWork {
// keyUniverse inside EnumMap every time.
this.totalDurationPerState = new EnumMap<>(EMPTY_ENUM_MAP);
this.id = WorkId.of(workItem);
+ this.keyGroup =
+ workItem.hasKeyGroup()
+ ? KeyGroup.create(workItem.getKeyGroup().getHigh(),
workItem.getKeyGroup().getLow())
+ : KeyGroup.DEFAULT;
this.latencyTrackingId =
Long.toHexString(workItem.getShardingKey())
+ '-'
@@ -383,6 +390,14 @@ public final class Work implements RefreshableWork {
abstract Instant startTime();
}
+ public String getComputationId() {
+ return processingContext.computationId();
+ }
+
+ public KeyGroup getKeyGroup() {
+ return keyGroup;
+ }
+
@AutoValue
public abstract static class ProcessingContext {
@@ -416,4 +431,60 @@ public final class Work implements RefreshableWork {
return Optional.ofNullable(getDataClient().getStateData(computationId(),
request));
}
}
+
+ /**
+ * WorkItems with same key group and computation are eligible to be executed
together in a
+ * multi-key bundle.
+ */
+ public static final class KeyGroup {
+
+ // Work items equaling to the default keyGroup will always be executed
+ // separately and not in a multi-key bundle
+ public static final KeyGroup DEFAULT = new KeyGroup(0, 0);
+
+ private final long high;
+ private final long low;
+
+ private KeyGroup(long high, long low) {
+ this.high = high;
+ this.low = low;
+ }
+
+ public static KeyGroup create(long high, long low) {
+ if (high == 0 && low == 0) {
+ return DEFAULT;
+ }
+ return new KeyGroup(high, low);
+ }
+
+ public long high() {
+ return high;
+ }
+
+ public long low() {
+ return low;
+ }
+
+ @Override
+ public boolean equals(@Nullable Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof KeyGroup)) {
+ return false;
+ }
+ KeyGroup other = (KeyGroup) o;
+ return high == other.high && low == other.low;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(high, low);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("%016x%016x", high, low);
+ }
+ }
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
index c6fd96e0a4c..8964246c116 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java
@@ -29,15 +29,15 @@ import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.concurrent.GuardedBy;
import
org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle;
import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.streaming.Work.KeyGroup;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor.Guard;
+import org.checkerframework.checker.nullness.qual.Nullable;
/** An executor for executing work on windmill items. */
-@SuppressWarnings({
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
public class BoundedQueueExecutor {
private final ThreadPoolExecutor executor;
@@ -78,6 +78,9 @@ public class BoundedQueueExecutor {
@GuardedBy("this")
private long totalTimeMaxActiveThreadsUsed;
+ // If set the keyGroupWorkQueue is used by the underlying executor.
+ private final @Nullable KeyGroupWorkQueue keyGroupWorkQueue;
+
public BoundedQueueExecutor(
int initialMaximumPoolSize,
long keepAliveTime,
@@ -85,7 +88,9 @@ public class BoundedQueueExecutor {
int maximumElementsOutstanding,
long maximumBytesOutstanding,
ThreadFactory threadFactory,
- boolean useFairMonitor) {
+ boolean useFairMonitor,
+ boolean useKeyGroupWorkQueue) {
+ this.keyGroupWorkQueue = useKeyGroupWorkQueue ? new
KeyGroupWorkQueue(useFairMonitor) : null;
this.maximumPoolSize = initialMaximumPoolSize;
monitor = new Monitor(useFairMonitor);
executor =
@@ -94,7 +99,7 @@ public class BoundedQueueExecutor {
initialMaximumPoolSize,
keepAliveTime,
unit,
- new LinkedBlockingQueue<>(),
+ keyGroupWorkQueue != null ? keyGroupWorkQueue : new
LinkedBlockingQueue<>(),
threadFactory) {
@Override
protected void beforeExecute(Thread t, Runnable r) {
@@ -313,7 +318,7 @@ public class BoundedQueueExecutor {
}
}
- private static final class QueuedWork implements Runnable {
+ static final class QueuedWork implements Runnable {
private final ExecutableWork work;
private final BoundedQueueExecutorWorkHandleImpl handle;
@@ -378,6 +383,22 @@ public class BoundedQueueExecutor {
return new BoundedQueueExecutorWorkHandleImpl(elements, bytes);
}
+ public @Nullable ExecutableWork pollWork(
+ String computationId, Work.KeyGroup keyGroup,
BoundedQueueExecutorWorkHandle handle) {
+ checkArgument(handle instanceof BoundedQueueExecutorWorkHandleImpl);
+ checkArgument(computationId != null && keyGroup != null &&
!keyGroup.equals(KeyGroup.DEFAULT));
+ BoundedQueueExecutorWorkHandleImpl internalHandle =
(BoundedQueueExecutorWorkHandleImpl) handle;
+ if (keyGroupWorkQueue == null) {
+ return null;
+ }
+ @Nullable QueuedWork queuedWork =
keyGroupWorkQueue.pollWork(computationId, keyGroup);
+ if (queuedWork == null) {
+ return null;
+ }
+ internalHandle.merge(queuedWork.getHandle());
+ return queuedWork.getWork();
+ }
+
private void decrementCounters(int elements, long bytes) {
// All threads queue decrements and one thread grabs the monitor and
updates
// counters. We do this to reduce contention on monitor which is locked by
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java
new file mode 100644
index 00000000000..d151157ec68
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java
@@ -0,0 +1,462 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.util;
+
+import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+
+import java.util.AbstractQueue;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.ReentrantLock;
+import javax.annotation.concurrent.GuardedBy;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import org.apache.beam.runners.dataflow.worker.streaming.Work.KeyGroup;
+import
org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.QueuedWork;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/**
+ * A custom, thread-safe doubly-linked BlockingQueue. In addition to global
FIFO ordering, the queue
+ * supports polling work by computation + key group in FIFO order
+ */
+class KeyGroupWorkQueue extends AbstractQueue<Runnable> implements
BlockingQueue<Runnable> {
+
+ public static final Runnable SENTINEL_RUNNABLE =
+ () -> {
+ throw new IllegalStateException("sentinel runnable called");
+ };
+
+ static class Node {
+ // If keyGroup is non-null, task is an instance of QueuedWork
+ final Runnable task;
+ final @Nullable String computationId;
+ final Work.@Nullable KeyGroup keyGroup;
+ // cached keyGroupList if the Node is part of one.
+ @Nullable KeyGroupWorkList keyGroupList;
+
+ // prevNode, nextNode are used for the global order across all queued
Runnables
+ @Nullable Node prevNode;
+ @Nullable Node nextNode;
+
+ // prevKeyGroupNode and nextKeyGroupNode are used for the keyGroup level
lists linking
+ // QueuedWork with same keyGroup
+ @Nullable Node prevKeyGroupNode;
+ @Nullable Node nextKeyGroupNode;
+
+ Node(Runnable task) {
+ this.task = task;
+ if (task instanceof QueuedWork) {
+ this.computationId = ((QueuedWork) task).getWork().getComputationId();
+ this.keyGroup = ((QueuedWork) task).getWork().getKeyGroup();
+ } else {
+ this.computationId = null;
+ this.keyGroup = null;
+ }
+ }
+ }
+
+ /** Double linked list implementing key group level queue */
+ private static class KeyGroupWorkList {
+ final Node head = new Node(SENTINEL_RUNNABLE);
+ final Node tail = new Node(SENTINEL_RUNNABLE);
+
+ KeyGroupWorkList() {
+ head.nextKeyGroupNode = tail;
+ tail.prevKeyGroupNode = head;
+ }
+
+ boolean isEmpty() {
+ return head.nextKeyGroupNode == tail;
+ }
+
+ void append(Node node) {
+ Node last = checkStateNotNull(tail.prevKeyGroupNode);
+ node.prevKeyGroupNode = last;
+ node.nextKeyGroupNode = tail;
+ last.nextKeyGroupNode = node;
+ tail.prevKeyGroupNode = node;
+ }
+
+ void remove(Node node) {
+ @Nullable Node prev = node.prevKeyGroupNode;
+ @Nullable Node next = node.nextKeyGroupNode;
+ if (prev != null && next != null) {
+ prev.nextKeyGroupNode = next;
+ next.prevKeyGroupNode = prev;
+ node.prevKeyGroupNode = null;
+ node.nextKeyGroupNode = null;
+ }
+ }
+ }
+
+ private final ReentrantLock lock;
+ private final Condition notEmpty;
+
+ // Sentinels for the global list
+ @GuardedBy("lock")
+ private final Node globalHead = new Node(SENTINEL_RUNNABLE);
+
+ @GuardedBy("lock")
+ private final Node globalTail = new Node(SENTINEL_RUNNABLE);
+
+ @GuardedBy("lock")
+ private final Map<QueueKey, KeyGroupWorkList> keyGroupQueueMap = new
HashMap<>();
+
+ @GuardedBy("lock")
+ private int size = 0;
+
+ public KeyGroupWorkQueue(boolean fair) {
+ this.lock = new ReentrantLock(fair);
+ this.notEmpty = lock.newCondition();
+ globalHead.nextNode = globalTail;
+ globalTail.prevNode = globalHead;
+ }
+
+ @GuardedBy("lock")
+ private void unlinkNode(Node node) {
+ // An existing node should always have previous and next since we have
sentinels
+ // 1. Unlink from global list
+ Node prevG = checkArgumentNotNull(node.prevNode);
+ Node nextG = checkArgumentNotNull(node.nextNode);
+ prevG.nextNode = nextG;
+ nextG.prevNode = prevG;
+ node.prevNode = null;
+ node.nextNode = null;
+
+ // 2. Unlink from key group list
+ KeyGroupWorkList list = node.keyGroupList;
+ if (list != null) {
+ list.remove(node);
+ if (list.isEmpty()) {
+ String compId = checkStateNotNull(node.computationId);
+ Work.KeyGroup keyGroup = checkStateNotNull(node.keyGroup);
+ QueueKey key = new QueueKey(compId, keyGroup);
+ keyGroupQueueMap.remove(key);
+ }
+ node.keyGroupList = null;
+ }
+ --size;
+ }
+
+ @GuardedBy("lock")
+ private @Nullable Node removeFirstGlobal() {
+ Node first = checkStateNotNull(globalHead.nextNode);
+ if (first == globalTail) {
+ return null;
+ }
+ unlinkNode(first);
+ return first;
+ }
+
+ /**
+ * Remove and Return QueuedWork for the computationId, keyGroup in the FIFO
order. Returns null,
+ * if there are no matches.
+ *
+ * @param keyGroup should not be equal to KeyGroup.DEFAULT
+ */
+ public @Nullable QueuedWork pollWork(String computationId, Work.KeyGroup
keyGroup) {
+ checkArgument(computationId != null && keyGroup != null &&
!keyGroup.equals(KeyGroup.DEFAULT));
+ QueueKey key = new QueueKey(computationId, keyGroup);
+ lock.lock();
+ try {
+ KeyGroupWorkList keyGroupWorkList = keyGroupQueueMap.get(key);
+ if (keyGroupWorkList == null || keyGroupWorkList.isEmpty()) {
+ return null;
+ }
+
+ // Retrieve the first pending task for this computation and keyGroup in
O(1)
+ Node firstNode =
checkStateNotNull(keyGroupWorkList.head.nextKeyGroupNode);
+ if (firstNode == keyGroupWorkList.tail) {
+ return null;
+ }
+ unlinkNode(firstNode);
+
+ return (QueuedWork) firstNode.task;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public boolean offer(@NonNull Runnable runnable) {
+ Node node = new Node(checkStateNotNull(runnable));
+ lock.lock();
+ try {
+ // Append to global list tail
+ Node lastG = checkStateNotNull(globalTail.prevNode);
+ node.prevNode = lastG;
+ node.nextNode = globalTail;
+ lastG.nextNode = node;
+ globalTail.prevNode = node;
+
+ // Append to key group list if applicable
+ String compId = node.computationId;
+ Work.KeyGroup keyGroup = node.keyGroup;
+ if (compId != null && keyGroup != null &&
!keyGroup.equals(KeyGroup.DEFAULT)) {
+ QueueKey key = new QueueKey(compId, keyGroup);
+ KeyGroupWorkList keyGroupWorkList =
+ keyGroupQueueMap.computeIfAbsent(key, k -> new KeyGroupWorkList());
+ keyGroupWorkList.append(node);
+ node.keyGroupList = keyGroupWorkList;
+ }
+
+ ++size;
+ notEmpty.signal();
+ return true;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public void put(Runnable e) throws InterruptedException {
+ offer(e); // Unbounded queue
+ }
+
+ @Override
+ public boolean offer(Runnable e, long timeout, TimeUnit unit) throws
InterruptedException {
+ return offer(e); // Unbounded queue
+ }
+
+ @Override
+ public @Nullable Runnable poll() {
+ lock.lock();
+ try {
+ @Nullable Node node = removeFirstGlobal();
+ return (node != null) ? node.task : null;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public Runnable take() throws InterruptedException {
+ lock.lockInterruptibly();
+ try {
+ while (size == 0) {
+ notEmpty.await();
+ }
+ @Nullable Node node = removeFirstGlobal();
+ checkStateNotNull(node, "Queue is empty but size was " + size);
+ return node.task;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public @Nullable Runnable poll(long timeout, TimeUnit unit) throws
InterruptedException {
+ long nanos = unit.toNanos(timeout);
+ lock.lockInterruptibly();
+ try {
+ while (size == 0) {
+ if (nanos <= 0) {
+ return null;
+ }
+ nanos = notEmpty.awaitNanos(nanos);
+ }
+ @Nullable Node node = removeFirstGlobal();
+ return (node != null) ? node.task : null;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public @Nullable Runnable peek() {
+ lock.lock();
+ try {
+ Node first = checkStateNotNull(globalHead.nextNode);
+ if (first == globalTail) {
+ return null;
+ }
+ return first.task;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public int size() {
+ lock.lock();
+ try {
+ return size;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public boolean isEmpty() {
+ lock.lock();
+ try {
+ return size == 0;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public boolean remove(Object o) {
+ if (o == null) return false;
+ lock.lock();
+ try {
+ // Walk the global queue in O(N) to find and unlink the node
+ Node curr = checkStateNotNull(globalHead.nextNode);
+ while (curr != globalTail) {
+ if (o.equals(curr.task)) {
+ unlinkNode(curr);
+ return true;
+ }
+ curr = checkStateNotNull(curr.nextNode);
+ }
+ return false;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public boolean contains(Object o) {
+ if (o == null) return false;
+ lock.lock();
+ try {
+ Node curr = checkStateNotNull(globalHead.nextNode);
+ while (curr != globalTail) {
+ if (o.equals(curr.task)) {
+ return true;
+ }
+ curr = checkStateNotNull(curr.nextNode);
+ }
+ return false;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public int remainingCapacity() {
+ return Integer.MAX_VALUE;
+ }
+
+ @Override
+ public int drainTo(Collection<? super Runnable> c) {
+ return drainTo(c, Integer.MAX_VALUE);
+ }
+
+ @Override
+ public int drainTo(Collection<? super Runnable> c, int maxElements) {
+ if (c == null) throw new NullPointerException();
+ if (c == this) throw new IllegalArgumentException();
+ if (maxElements <= 0) return 0;
+ lock.lock();
+ try {
+ int added = 0;
+ Node curr = checkStateNotNull(globalHead.nextNode);
+ while (curr != globalTail && added < maxElements) {
+ Node next = checkStateNotNull(curr.nextNode);
+ unlinkNode(curr);
+ Runnable task = curr.task;
+ c.add(task);
+ ++added;
+ curr = next;
+ }
+ return added;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public void clear() {
+ lock.lock();
+ try {
+ Node curr = checkStateNotNull(globalHead.nextNode);
+ while (curr != globalTail) {
+ Node next = checkStateNotNull(curr.nextNode);
+ unlinkNode(curr);
+ curr = next;
+ }
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public Iterator<Runnable> iterator() {
+ lock.lock();
+ try {
+ ImmutableList.Builder<Runnable> builder =
ImmutableList.builderWithExpectedSize(size);
+ Node curr = checkStateNotNull(globalHead.nextNode);
+ while (curr != globalTail) {
+ builder.add(curr.task);
+ curr = checkStateNotNull(curr.nextNode);
+ }
+ return builder.build().iterator();
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ static final class QueueKey {
+ private final String computationId;
+ private final Work.KeyGroup keyGroup;
+
+ QueueKey(String computationId, Work.KeyGroup keyGroup) {
+ this.computationId = computationId;
+ this.keyGroup = keyGroup;
+ }
+
+ @Override
+ public boolean equals(@Nullable Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof QueueKey)) {
+ return false;
+ }
+ QueueKey other = (QueueKey) o;
+ return computationId.equals(other.computationId) &&
keyGroup.equals(other.keyGroup);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(computationId, keyGroup);
+ }
+
+ @Override
+ public String toString() {
+ return "QueueKey{"
+ + "computationId='"
+ + computationId
+ + '\''
+ + ", keyGroup="
+ + keyGroup
+ + '}';
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index d58f2007699..5bcdffcc256 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -3036,7 +3036,8 @@ public class StreamingDataflowWorkerTest {
.setNameFormat("DataflowWorkUnits-%d")
.setDaemon(true)
.build(),
- /*useFairMonitor=*/ false);
+ /*useFairMonitor=*/ false,
+ /*useKeyGroupWorkQueue=*/ false);
ComputationState computationState =
new ComputationState(
@@ -3097,7 +3098,8 @@ public class StreamingDataflowWorkerTest {
.setNameFormat("DataflowWorkUnits-%d")
.setDaemon(true)
.build(),
- /*useFairMonitor=*/ false);
+ /*useFairMonitor=*/ false,
+ /*useKeyGroupWorkQueue=*/ false);
ComputationState computationState =
new ComputationState(
@@ -3167,7 +3169,8 @@ public class StreamingDataflowWorkerTest {
.setNameFormat("DataflowWorkUnits-%d")
.setDaemon(true)
.build(),
- /*useFairMonitor=*/ false);
+ /*useFairMonitor=*/ false,
+ /*useKeyGroupWorkQueue=*/ false);
ComputationState computationState =
new ComputationState(
@@ -3241,7 +3244,8 @@ public class StreamingDataflowWorkerTest {
.setNameFormat("DataflowWorkUnits-%d")
.setDaemon(true)
.build(),
- /*useFairMonitor=*/ false);
+ /*useFairMonitor=*/ false,
+ /*useKeyGroupWorkQueue=*/ false);
ComputationState computationState =
new ComputationState(
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
index 55fe82c7163..a98102751fb 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java
@@ -20,6 +20,8 @@ package org.apache.beam.runners.dataflow.worker.util;
import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
@@ -33,6 +35,7 @@ import
org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork;
import org.apache.beam.runners.dataflow.worker.streaming.Watermarks;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import
org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.BoundedQueueExecutorWorkHandleImpl;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient;
import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
@@ -66,13 +69,30 @@ public class BoundedQueueExecutorTest {
@Rule public transient Timeout globalTimeout = Timeout.seconds(300);
private BoundedQueueExecutor executor;
+ private static final Work.KeyGroup DEFAULT_KEY_GROUP =
Work.KeyGroup.create(1, 2);
+
private static ExecutableWork createWork(Consumer<Work> executeWorkFn) {
+ return createWorkWithCompId("computationId", executeWorkFn);
+ }
+
+ private static ExecutableWork createWorkWithCompId(
+ String computationId, Consumer<Work> executeWorkFn) {
+ return createWorkWithCompIdAndKeyGroup(computationId, DEFAULT_KEY_GROUP,
executeWorkFn);
+ }
+
+ private static ExecutableWork createWorkWithCompIdAndKeyGroup(
+ String computationId, Work.KeyGroup keyGroup, Consumer<Work>
executeWorkFn) {
WorkItem workItem =
WorkItem.newBuilder()
.setKey(ByteString.EMPTY)
.setShardingKey(1)
.setWorkToken(33)
.setCacheToken(1)
+ .setKeyGroup(
+ Windmill.Uint128Proto.newBuilder()
+ .setHigh(keyGroup.high())
+ .setLow(keyGroup.low())
+ .build())
.build();
return ExecutableWork.create(
Work.create(
@@ -80,10 +100,7 @@ public class BoundedQueueExecutorTest {
workItem.getSerializedSize(),
Watermarks.builder().setInputDataWatermark(Instant.now()).build(),
Work.createProcessingContext(
- "computationId",
- new FakeGetDataClient(),
- ignored -> {},
- mock(HeartbeatSender.class)),
+ computationId, new FakeGetDataClient(), ignored -> {},
mock(HeartbeatSender.class)),
false,
Instant::now),
(work, handle) -> {
@@ -116,7 +133,8 @@ public class BoundedQueueExecutorTest {
.setNameFormat("DataflowWorkUnits-%d")
.setDaemon(true)
.build(),
- useFairMonitor);
+ useFairMonitor,
+ /*useKeyGroupWorkQueue=*/ false);
}
@Test
@@ -413,4 +431,126 @@ public class BoundedQueueExecutorTest {
+ "Work Queue Bytes: 0/10000000<br>/n";
assertEquals(expectedSummaryHtml, executor.summaryHtml());
}
+
+ @Test
+ public void testPollWork() throws Exception {
+ // Create separate BoundedQueueExecutor with 1 thread so we can block it
easily
+ BoundedQueueExecutor testExecutor =
+ new BoundedQueueExecutor(
+ 1,
+ 60,
+ TimeUnit.SECONDS,
+ 100,
+ 10000000,
+ new
ThreadFactoryBuilder().setNameFormat("testStealing-%d").setDaemon(true).build(),
+ useFairMonitor,
+ /*useKeyGroupWorkQueue=*/ true);
+
+ // 1. Create blocker task to occupy the worker thread
+ CountDownLatch blockerStart = new CountDownLatch(1);
+ CountDownLatch blockerStop = new CountDownLatch(1);
+ ExecutableWork blockerWork =
+ createWorkWithCompIdAndKeyGroup(
+ "blockerComp",
+ DEFAULT_KEY_GROUP,
+ ignored -> {
+ blockerStart.countDown();
+ try {
+ blockerStop.await();
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ });
+
+ testExecutor.execute(blockerWork, 0);
+ blockerStart.await();
+
+ // 2. Create two distinct key groups
+ Work.KeyGroup keyGroup1 = Work.KeyGroup.create(1, 1);
+ Work.KeyGroup keyGroup2 = Work.KeyGroup.create(1, 2);
+
+ // Create executable tasks
+ CountDownLatch targetStart = new CountDownLatch(1);
+ ExecutableWork work1 = createWorkWithCompIdAndKeyGroup("compA", keyGroup1,
ignored -> {});
+ ExecutableWork work2 =
+ createWorkWithCompIdAndKeyGroup(
+ "compA",
+ keyGroup2,
+ ignored -> {
+ targetStart.countDown();
+ });
+
+ // Enqueue tasks (they will wait in the queue because the thread is
blocked)
+ testExecutor.execute(work1, 100);
+ testExecutor.execute(work2, 150);
+
+ // Total outstanding elements must be 3 (blocker + work1 + work2)
+ assertEquals(3, testExecutor.elementsOutstanding());
+
+ // Steal work2 using pollWork with compA and keyGroup2
+ try (BoundedQueueExecutorWorkHandleImpl stealHandle =
testExecutor.createBudgetHandle(0, 0L)) {
+ ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup2,
stealHandle);
+ assertNotNull(stolen);
+ assertEquals(work2, stolen);
+
+ // Run the stolen task
+ stolen.run(stealHandle);
+ targetStart.await();
+ }
+
+ // Steal work1 using pollWork with compA and keyGroup1
+ try (BoundedQueueExecutorWorkHandleImpl stealHandle =
testExecutor.createBudgetHandle(0, 0L)) {
+ ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup1,
stealHandle);
+ assertNotNull(stolen);
+ assertEquals(work1, stolen);
+ }
+
+ // Unblock the blocker and shut down
+ blockerStop.countDown();
+ testExecutor.shutdown();
+ }
+
+ @Test
+ public void testPollWorkWithLinkedBlockingQueue() throws Exception {
+ BoundedQueueExecutor testExecutor =
+ new BoundedQueueExecutor(
+ 1,
+ 60,
+ TimeUnit.SECONDS,
+ 100,
+ 10000000,
+ new
ThreadFactoryBuilder().setNameFormat("testLinkedQueue-%d").setDaemon(true).build(),
+ useFairMonitor,
+ /* useKeyGroupWorkQueue= */ false);
+
+ CountDownLatch blockerStart = new CountDownLatch(1);
+ CountDownLatch blockerStop = new CountDownLatch(1);
+ ExecutableWork blockerWork =
+ createWorkWithCompIdAndKeyGroup(
+ "blockerComp",
+ DEFAULT_KEY_GROUP,
+ ignored -> {
+ blockerStart.countDown();
+ try {
+ blockerStop.await();
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ });
+
+ testExecutor.execute(blockerWork, 0);
+ blockerStart.await();
+
+ Work.KeyGroup keyGroup = Work.KeyGroup.create(1, 1);
+ ExecutableWork work = createWorkWithCompIdAndKeyGroup("compA", keyGroup,
ignored -> {});
+ testExecutor.execute(work, 100);
+
+ try (BoundedQueueExecutorWorkHandleImpl stealHandle =
testExecutor.createBudgetHandle(0, 0L)) {
+ ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup,
stealHandle);
+ assertNull(stolen);
+ }
+
+ blockerStop.countDown();
+ testExecutor.shutdown();
+ }
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java
new file mode 100644
index 00000000000..994aa2030f3
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java
@@ -0,0 +1,504 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.util;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+
+import java.lang.Thread.State;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork;
+import org.apache.beam.runners.dataflow.worker.streaming.Watermarks;
+import org.apache.beam.runners.dataflow.worker.streaming.Work;
+import
org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.QueuedWork;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class KeyGroupWorkQueueTest {
+
+ @Parameters(name = "fairQueue={0}")
+ public static Iterable<Object[]> data() {
+ return Arrays.asList(new Object[][] {{true}, {false}});
+ }
+
+ @Parameterized.Parameter public boolean fairQueue;
+
+ private BoundedQueueExecutor executor;
+
+ @Before
+ public void setUp() {
+ executor =
+ new BoundedQueueExecutor(
+ 2,
+ 60,
+ TimeUnit.SECONDS,
+ 100,
+ 10000000,
+ new
ThreadFactoryBuilder().setNameFormat("Test-%d").setDaemon(true).build(),
+ fairQueue,
+ /*useKeyGroupWorkQueue=*/ true);
+ }
+
+ private static final Work.KeyGroup TEST_KEY_GROUP = Work.KeyGroup.create(1,
2);
+
+ private QueuedWork createQueuedWork(String computationId, long workBytes) {
+ return createQueuedWork(computationId, TEST_KEY_GROUP, workBytes);
+ }
+
+ private QueuedWork createQueuedWork(
+ String computationId, Work.@Nullable KeyGroup keyGroup, long workBytes) {
+ WorkItem.Builder workItemBuilder =
+ WorkItem.newBuilder()
+ .setKey(ByteString.EMPTY)
+ .setShardingKey(1)
+ .setWorkToken(33)
+ .setCacheToken(1);
+ if (keyGroup != null) {
+ workItemBuilder.setKeyGroup(
+
org.apache.beam.runners.dataflow.worker.windmill.Windmill.Uint128Proto.newBuilder()
+ .setHigh(keyGroup.high())
+ .setLow(keyGroup.low())
+ .build());
+ }
+ WorkItem workItem = workItemBuilder.build();
+ ExecutableWork work =
+ ExecutableWork.create(
+ Work.create(
+ workItem,
+ workItem.getSerializedSize(),
+
Watermarks.builder().setInputDataWatermark(Instant.now()).build(),
+ Work.createProcessingContext(
+ computationId,
+ new FakeGetDataClient(),
+ ignored -> {},
+ mock(HeartbeatSender.class)),
+ false,
+ Instant::now),
+ (w, h) -> {});
+ return new QueuedWork(work, executor.createBudgetHandle(1, workBytes));
+ }
+
+ private static class NoOpRunnable implements Runnable {
+ final String id;
+
+ NoOpRunnable(String id) {
+ this.id = id;
+ }
+
+ @Override
+ public void run() {}
+
+ @Override
+ public String toString() {
+ return "NoOpRunnable(" + id + ")";
+ }
+ }
+
+ @Test
+ public void testBasicOfferAndPoll() {
+ KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+ assertTrue(queue.isEmpty());
+ assertEquals(0, queue.size());
+
+ NoOpRunnable task1 = new NoOpRunnable("1");
+ NoOpRunnable task2 = new NoOpRunnable("2");
+
+ assertTrue(queue.offer(task1));
+ assertTrue(queue.offer(task2));
+ assertEquals(2, queue.size());
+
+ assertEquals(task1, queue.poll());
+ assertEquals(task2, queue.poll());
+ assertNull(queue.poll());
+ assertTrue(queue.isEmpty());
+ }
+
+ @Test
+ public void testRemove() {
+ KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+ NoOpRunnable task1 = new NoOpRunnable("1");
+ NoOpRunnable task2 = new NoOpRunnable("2");
+
+ queue.offer(task1);
+ queue.offer(task2);
+
+ assertTrue(queue.remove(task1));
+ assertEquals(1, queue.size());
+ assertEquals(task2, queue.poll());
+ assertFalse(queue.remove(task1)); // Already gone
+ }
+
+ @Test
+ public void testDrainTo() {
+ KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+ NoOpRunnable task1 = new NoOpRunnable("1");
+ NoOpRunnable task2 = new NoOpRunnable("2");
+ queue.offer(task1);
+ queue.offer(task2);
+
+ List<Runnable> drained = new ArrayList<>();
+ assertEquals(2, queue.drainTo(drained));
+ assertEquals(2, drained.size());
+ assertEquals(task1, drained.get(0));
+ assertEquals(task2, drained.get(1));
+ assertTrue(queue.isEmpty());
+ }
+
+ @Test
+ public void testIteratorSafeTraversalAndImmutable() {
+ KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+ NoOpRunnable task1 = new NoOpRunnable("1");
+ NoOpRunnable task2 = new NoOpRunnable("2");
+ queue.offer(task1);
+ queue.offer(task2);
+
+ Iterator<Runnable> it = queue.iterator();
+ assertTrue(it.hasNext());
+ assertEquals(task1, it.next());
+ assertTrue(it.hasNext());
+ assertEquals(task2, it.next());
+ assertFalse(it.hasNext());
+
+ // Assert that mutating the iterator throws UnsupportedOperationException
+ it = queue.iterator();
+ assertTrue(it.hasNext());
+ it.next();
+ try {
+ it.remove();
+ fail("Iterator must be immutable");
+ } catch (UnsupportedOperationException e) {
+ // Expected
+ }
+ }
+
+ @Test
+ public void testPollWorkTargeted() {
+ KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+
+ QueuedWork workA1 = createQueuedWork("compA", 100);
+ QueuedWork workB1 = createQueuedWork("compB", 200);
+ QueuedWork workA2 = createQueuedWork("compA", 150);
+
+ queue.offer(workA1);
+ queue.offer(workB1);
+ queue.offer(workA2);
+
+ assertEquals(3, queue.size());
+ assertFalse(queue.isEmpty());
+
+ // Targeted poll A
+ QueuedWork polledA1 = queue.pollWork("compA", TEST_KEY_GROUP);
+ assertNotNull(polledA1);
+ assertEquals("compA", polledA1.getWork().getComputationId());
+ assertEquals(100, polledA1.getHandle().bytes());
+ assertEquals(2, queue.size());
+ assertFalse(queue.isEmpty());
+
+ // Poll next should be B1 (since A1 was stolen, B1 is now first global)
+ assertEquals(workB1, queue.poll());
+ assertEquals(1, queue.size());
+ assertFalse(queue.isEmpty());
+
+ // Last should be A2
+ assertEquals(workA2, queue.poll());
+ assertEquals(0, queue.size());
+ assertTrue(queue.isEmpty());
+
+ assertEquals(null, queue.poll());
+ }
+
+ @Test
+ public void testConcurrentStress() throws InterruptedException,
ExecutionException {
+ KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+ int producerThreads = 4;
+ int consumerThreads = 4;
+ int tasksPerProducer = 1000;
+ int totalTasks = producerThreads * tasksPerProducer;
+ Runnable poisonPill =
+ new Runnable() {
+ @Override
+ public void run() {}
+
+ @Override
+ public String toString() {
+ return "POISON_PILL";
+ }
+ };
+
+ ExecutorService executorService =
+ Executors.newFixedThreadPool(producerThreads + consumerThreads);
+ CountDownLatch startLatch = new CountDownLatch(1);
+ CountDownLatch producersDoneLatch = new CountDownLatch(producerThreads);
+ CountDownLatch consumersDoneLatch = new CountDownLatch(consumerThreads);
+ CountDownLatch consumedLatch = new CountDownLatch(totalTasks);
+ List<Future<?>> futures = new ArrayList<>();
+
+ // Start consumers
+ for (int i = 0; i < consumerThreads; i++) {
+ int consumerId = i;
+ futures.add(
+ executorService.submit(
+ () -> {
+ try {
+ startLatch.await();
+ int iteration = consumerId % 4;
+ while (true) {
+ int strategy = iteration;
+ iteration = (iteration + 1) % 4;
+ Runnable task = null;
+ if (strategy == 0) {
+ String compId = "comp-" + (consumedLatch.getCount() % 5);
+ task = queue.pollWork(compId, TEST_KEY_GROUP);
+ } else if (strategy == 1) {
+ task = queue.poll();
+ } else if (strategy == 2) {
+ task = queue.poll(10, TimeUnit.MICROSECONDS);
+ } else if (strategy == 3) {
+ task = queue.take();
+ }
+
+ if (task == poisonPill) {
+ break;
+ }
+ if (task != null) {
+ consumedLatch.countDown();
+ }
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ } finally {
+ consumersDoneLatch.countDown();
+ }
+ }));
+ }
+
+ // Start producers
+ for (int i = 0; i < producerThreads; i++) {
+ futures.add(
+ executorService.submit(
+ () -> {
+ try {
+ startLatch.await();
+ for (int j = 0; j < tasksPerProducer; j++) {
+ String compId = "comp-" + (j % 5);
+ queue.offer(createQueuedWork(compId, 10));
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ } finally {
+ producersDoneLatch.countDown();
+ }
+ }));
+ }
+
+ // Release the start latch to start the test
+ startLatch.countDown();
+
+ // Wait for all tasks to be consumed
+ assertTrue(consumedLatch.await(30, TimeUnit.SECONDS));
+
+ // Send poison pills to stop all consumers
+ for (int i = 0; i < consumerThreads; i++) {
+ queue.offer(poisonPill);
+ }
+
+ // Wait for consumers to finish
+ assertTrue(consumersDoneLatch.await(30, TimeUnit.SECONDS));
+ // Wait for producers to finish
+ assertTrue(producersDoneLatch.await(30, TimeUnit.SECONDS));
+
+ // Check for exceptions in threads
+ for (Future<?> future : futures) {
+ future.get();
+ }
+
+ executorService.shutdown();
+ assertTrue(executorService.awaitTermination(30, TimeUnit.SECONDS));
+
+ assertEquals(0, queue.size());
+ assertTrue(queue.isEmpty());
+ }
+
+ @Test
+ public void testTakeBlocksAndWakesUp() throws InterruptedException {
+ final KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+ final NoOpRunnable task = new NoOpRunnable("take-task");
+ final AtomicReference<Runnable> result = new AtomicReference<>();
+ final CountDownLatch started = new CountDownLatch(1);
+ final CountDownLatch finished = new CountDownLatch(1);
+
+ Thread t =
+ new Thread(
+ () -> {
+ started.countDown();
+ try {
+ result.set(queue.take());
+ } catch (InterruptedException e) {
+ // Ignore
+ } finally {
+ finished.countDown();
+ }
+ });
+ t.setDaemon(true);
+ t.start();
+
+ assertTrue(started.await(30, TimeUnit.SECONDS));
+ waitForThreadState(t, State.WAITING);
+
+ queue.offer(task);
+
+ assertTrue(finished.await(30, TimeUnit.SECONDS));
+ assertEquals(task, result.get());
+ }
+
+ @Test
+ public void testPollWithTimeout() throws InterruptedException {
+ final KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+ final NoOpRunnable task = new NoOpRunnable("poll-task");
+ final AtomicReference<Runnable> result = new AtomicReference<>();
+ final CountDownLatch started = new CountDownLatch(1);
+ final CountDownLatch finished = new CountDownLatch(1);
+
+ // 1. Verify timeout returns null
+ Thread t1 =
+ new Thread(
+ () -> {
+ started.countDown();
+ try {
+ result.set(queue.poll(500, TimeUnit.MILLISECONDS));
+ } catch (InterruptedException e) {
+ // Ignore
+ } finally {
+ finished.countDown();
+ }
+ });
+ t1.setDaemon(true);
+ t1.start();
+
+ assertTrue(started.await(30, TimeUnit.SECONDS));
+ waitForThreadState(t1, State.TIMED_WAITING);
+
+ assertTrue(finished.await(30, TimeUnit.SECONDS));
+ assertNull(result.get());
+
+ // 2. Verify timed poll receives task offered concurrently
+ final CountDownLatch started2 = new CountDownLatch(1);
+ final CountDownLatch finished2 = new CountDownLatch(1);
+ final AtomicReference<Runnable> result2 = new AtomicReference<>();
+
+ Thread t2 =
+ new Thread(
+ () -> {
+ started2.countDown();
+ try {
+ result2.set(queue.poll(2, TimeUnit.SECONDS));
+ } catch (InterruptedException e) {
+ // Ignore
+ } finally {
+ finished2.countDown();
+ }
+ });
+ t2.setDaemon(true);
+ t2.start();
+
+ assertTrue(started2.await(30, TimeUnit.SECONDS));
+ waitForThreadState(t2, State.TIMED_WAITING);
+
+ queue.offer(task);
+
+ assertTrue(finished2.await(30, TimeUnit.SECONDS));
+ assertEquals(task, result2.get());
+ }
+
+ @Test
+ public void testPollWorkWithKeyGroup() {
+ KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue);
+
+ Work.KeyGroup keyGroup1 = Work.KeyGroup.create(1, 1);
+ Work.KeyGroup keyGroup2 = Work.KeyGroup.create(1, 2);
+ Work.KeyGroup keyGroupNotExist = Work.KeyGroup.create(3, 4);
+
+ QueuedWork workA1 = createQueuedWork("compA", keyGroup1, 100);
+ QueuedWork workA2 = createQueuedWork("compA", keyGroup2, 150);
+
+ queue.offer(workA1);
+ queue.offer(workA2);
+
+ assertEquals(2, queue.size());
+
+ QueuedWork polledNotExist = queue.pollWork("compA", keyGroupNotExist);
+ assertNull(polledNotExist);
+ assertEquals(2, queue.size());
+
+ // Poll with keyGroup2 first - should return workA2
+ QueuedWork polledA2 = queue.pollWork("compA", keyGroup2);
+ assertNotNull(polledA2);
+ assertEquals(workA2, polledA2);
+ assertEquals(1, queue.size());
+
+ // Poll with keyGroup2 again - should return null
+ assertNull(queue.pollWork("compA", keyGroup2));
+
+ // Poll with keyGroup1 - should return workA1
+ QueuedWork polledA1 = queue.pollWork("compA", keyGroup1);
+ assertNotNull(polledA1);
+ assertEquals(workA1, polledA1);
+ assertTrue(queue.isEmpty());
+
+ polledNotExist = queue.pollWork("compA", keyGroupNotExist);
+ assertNull(polledNotExist);
+ assertTrue(queue.isEmpty());
+ }
+
+ private void waitForThreadState(Thread t, State state) throws
InterruptedException {
+ long timeoutMs = 30000;
+ long start = System.currentTimeMillis();
+ while (t.getState() != state) {
+ if (System.currentTimeMillis() - start > timeoutMs) {
+ fail("Thread did not reach " + state + " state within " + timeoutMs +
"ms");
+ }
+ Thread.sleep(1);
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java
index 07b4b14fd11..ef0d8e43485 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java
@@ -62,7 +62,8 @@ public class StreamingCommitFinalizerTest {
.setNameFormat("FinalizationCallback-%d")
.setDaemon(true)
.build(),
- /*useFairMonitor=*/ false);
+ /*useFairMonitor=*/ false,
+ /*useKeyGroupWorkQueue=*/ false);
cleanupExecutor =
Executors.newScheduledThreadPool(
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java
index 51bd4816b03..0610ed44c27 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java
@@ -64,7 +64,8 @@ public class WorkFailureProcessorTest {
.setNameFormat("DataflowWorkUnits-%d")
.setDaemon(true)
.build(),
- /*useFairMonitor=*/ false);
+ /*useFairMonitor=*/ false,
+ /*useKeyGroupWorkQueue=*/ false);
return WorkFailureProcessor.forTesting(workExecutor, failureTracker,
Optional::empty, clock, 0);
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java
index 88a82c6f76b..f32282056e4 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java
@@ -80,7 +80,8 @@ public class ActiveWorkRefresherTest {
1,
10000000,
new
ThreadFactoryBuilder().setNameFormat("DataflowWorkUnits-%d").setDaemon(true).build(),
- /*useFairMonitor=*/ false);
+ /*useFairMonitor=*/ false,
+ /*useKeyGroupWorkQueue=*/ false);
}
private static ComputationState createComputationState(int
computationIdSuffix) {
diff --git
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
index 1da7ef9be8b..aaa09c105fc 100644
---
a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
+++
b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto
@@ -421,6 +421,11 @@ message WatermarkHold {
optional string state_family = 4;
}
+message Uint128Proto {
+ required fixed64 high = 1;
+ required fixed64 low = 2;
+}
+
// Proto describing a hot key detected on a given WorkItem.
message HotKeyInfo {
// The age of the hot key measured from when it was first detected.
@@ -448,6 +453,8 @@ message WorkItem {
// present, this field includes metadata associated with any hot key.
optional HotKeyInfo hot_key_info = 11;
+ optional Uint128Proto key_group = 18;
+
reserved 12, 13, 14, 15, 16;
}