gemini-code-assist[bot] commented on code in PR #38609: URL: https://github.com/apache/beam/pull/38609#discussion_r3319474440
########## sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java: ########## @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Class that wraps a dofn and converts it from one which process elements synchronously to one + * which processes them asynchronously. + * + * <p>For synchronous dofns the default settings mean that many (100s) of elements will be processed + * in parallel and that processing an element will block all other work on that key. In addition + * runners are optimized for latencies less than a few seconds and longer operations can result in + * high retry rates. Async should be considered when the default parallelism is not correct and/or + * items are expected to take longer than a few seconds to process. + */ +public class AsyncDoFn<K, InputT, OutputT> extends DoFn<KV<K, InputT>, OutputT> { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); + + private static final int DEFAULT_MIN_BUFFER_CAPACITY = 10; + private static final int DEFAULT_TIMEOUT_SEC = 1; + private static final int DEFAULT_MAX_WAIT_TIME_MS = 500; + private static final int TEARDOWN_AWAIT_SEC = 5; + private static final int INITIAL_BACKOFF_SLEEP_MS = 10; + private static final int BACKPRESSURE_LOG_THRESHOLD_MS = 10000; + + @StateId("to_process") + private final StateSpec<BagState<KV<K, InputT>>> toProcessSpec; + + @TimerId("timer") + private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final DoFn<InputT, OutputT> syncFn; Review Comment:  ### Critical Concurrency Issue: Shared `syncFn` Instance Across Background Threads\n\nThe `syncFn` instance is shared globally across all background threads in the executor pool. If `parallelism > 1`, multiple threads will concurrently invoke `processElement`, `startBundle`, and `finishBundle` on the exact same `syncFn` instance.\n\nSince standard `DoFn`s in Apache Beam are **not** thread-safe and often maintain state (especially during bundle lifecycle methods), this will lead to severe race conditions, state corruption, and unpredictable runtime failures.\n\n#### Recommended Solution:\nUse a `ThreadLocal` to clone `syncFn` per thread using `SerializableUtils.clone(syncFn)`. This ensures each thread has its own isolated, thread-safe instance of the `DoFn`.\n\n```java\nprivate transient ThreadLocal<DoFn<InputT, OutputT>> localSyncFn;\nprivate transient ConcurrentHashMap.KeySetView<DoFnInvoker<InputT, OutputT>, Boolean> localInvokers;\n```\n\nIn `@Setup`, initialize the `ThreadLoca l` and track the cloned invokers:\n```java\nthis.localInvokers = ConcurrentHashMap.newKeySet();\nthis.localSyncFn = ThreadLocal.withInitial(() -> {\n DoFn<InputT, OutputT> clone = SerializableUtils.clone(syncFn);\n DoFnInvoker<InputT, OutputT> invoker = DoFnInvokers.invokerFor(clone);\n localInvokers.add(invoker);\n invoker.invokeSetup(new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() {\n @Override\n public PipelineOptions pipelineOptions() {\n return options;\n }\n @Override\n public String getErrorContext() {\n return \"AsyncDoFn/Setup\";\n }\n });\n return clone;\n});\n```\n\nIn `@Teardown`, invoke teardown on all cloned invokers:\n```java\nif (localInvokers != null) {\n for (DoFnInvoker<InputT, OutputT> invoker : localInvokers) {\n try {\n invoker.invokeTeardown();\n } catch (Exception e) {\n LOG.warn(\"Error tearing down cloned DoFn\", e);\n }\n }\n}\n``` ########## sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java: ########## @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Class that wraps a dofn and converts it from one which process elements synchronously to one + * which processes them asynchronously. + * + * <p>For synchronous dofns the default settings mean that many (100s) of elements will be processed + * in parallel and that processing an element will block all other work on that key. In addition + * runners are optimized for latencies less than a few seconds and longer operations can result in + * high retry rates. Async should be considered when the default parallelism is not correct and/or + * items are expected to take longer than a few seconds to process. + */ +public class AsyncDoFn<K, InputT, OutputT> extends DoFn<KV<K, InputT>, OutputT> { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); + + private static final int DEFAULT_MIN_BUFFER_CAPACITY = 10; + private static final int DEFAULT_TIMEOUT_SEC = 1; + private static final int DEFAULT_MAX_WAIT_TIME_MS = 500; + private static final int TEARDOWN_AWAIT_SEC = 5; + private static final int INITIAL_BACKOFF_SLEEP_MS = 10; + private static final int BACKPRESSURE_LOG_THRESHOLD_MS = 10000; + + @StateId("to_process") + private final StateSpec<BagState<KV<K, InputT>>> toProcessSpec; + + @TimerId("timer") + private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final DoFn<InputT, OutputT> syncFn; + private final int parallelism; + private final Duration timerFrequency; + private final int maxItemsToBuffer; + private final Duration timeout; + private final Duration maxWaitTime; + private final SerializableFunction<InputT, Object> idFn; + private final boolean useThreadPool; + private final String uuid; + + private transient volatile @Nullable PipelineOptions pipelineOptions; + + // Shared JVM-Wide States (Static Registries) + // Map-backed registry holding shared resources across serialized worker instances. Since runners + // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse. + private static final ConcurrentHashMap<String, ExecutorService> pool = new ConcurrentHashMap<>(); + // activeElements (processingElements) is global JVM memory (all keys) + private static final ConcurrentHashMap<String, ConcurrentHashMap<Object, InFlightElement<?>>> + processingElements = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap<String, AtomicInteger> itemsInBuffer = + new ConcurrentHashMap<>(); Review Comment:  Declare a static reference counter map to track active instances of `AsyncDoFn` per `uuid` for safe teardown. ```java private static final ConcurrentHashMap<String, AtomicInteger> itemsInBuffer =\n new ConcurrentHashMap<>();\n private static final ConcurrentHashMap<String, AtomicInteger> refCounts =\n new ConcurrentHashMap<>(); ``` ########## sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java: ########## @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Class that wraps a dofn and converts it from one which process elements synchronously to one + * which processes them asynchronously. + * + * <p>For synchronous dofns the default settings mean that many (100s) of elements will be processed + * in parallel and that processing an element will block all other work on that key. In addition + * runners are optimized for latencies less than a few seconds and longer operations can result in + * high retry rates. Async should be considered when the default parallelism is not correct and/or + * items are expected to take longer than a few seconds to process. + */ +public class AsyncDoFn<K, InputT, OutputT> extends DoFn<KV<K, InputT>, OutputT> { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); + + private static final int DEFAULT_MIN_BUFFER_CAPACITY = 10; + private static final int DEFAULT_TIMEOUT_SEC = 1; + private static final int DEFAULT_MAX_WAIT_TIME_MS = 500; + private static final int TEARDOWN_AWAIT_SEC = 5; + private static final int INITIAL_BACKOFF_SLEEP_MS = 10; + private static final int BACKPRESSURE_LOG_THRESHOLD_MS = 10000; + + @StateId("to_process") + private final StateSpec<BagState<KV<K, InputT>>> toProcessSpec; + + @TimerId("timer") + private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final DoFn<InputT, OutputT> syncFn; + private final int parallelism; + private final Duration timerFrequency; + private final int maxItemsToBuffer; + private final Duration timeout; + private final Duration maxWaitTime; + private final SerializableFunction<InputT, Object> idFn; + private final boolean useThreadPool; + private final String uuid; + + private transient volatile @Nullable PipelineOptions pipelineOptions; + + // Shared JVM-Wide States (Static Registries) + // Map-backed registry holding shared resources across serialized worker instances. Since runners + // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse. + private static final ConcurrentHashMap<String, ExecutorService> pool = new ConcurrentHashMap<>(); + // activeElements (processingElements) is global JVM memory (all keys) + private static final ConcurrentHashMap<String, ConcurrentHashMap<Object, InFlightElement<?>>> + processingElements = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap<String, AtomicInteger> itemsInBuffer = + new ConcurrentHashMap<>(); + + private static final ReentrantLock lock = new ReentrantLock(); + private static final boolean verboseLogging = false; + + private static class TimestampedOutput<T> { + final T value; + final @Nullable Instant timestamp; + + TimestampedOutput(T value, @Nullable Instant timestamp) { + this.value = value; + this.timestamp = timestamp; + } + } + + private static class InFlightElement<OutputT> { + final CompletableFuture<List<TimestampedOutput<OutputT>>> future; + + InFlightElement(CompletableFuture<List<TimestampedOutput<OutputT>>> future) { + this.future = future; + } + } + + // The In-Memory Accumulating Receiver + // Accumulates elements in-memory during asynchronous background worker execution. + // Buffered elements are only committed downstream once the parent task completes successfully + // and the timer fires. + private static class AccumulatingOutputReceiver<T> implements OutputReceiver<T> { + private final List<TimestampedOutput<T>> outputs = + Collections.synchronizedList(new ArrayList<>()); + + @Override + public org.apache.beam.sdk.values.OutputBuilder<T> builder(T value) { + return org.apache.beam.sdk.values.WindowedValues.<T>builder() + .setValue(value) + .setTimestamp(Instant.now()) + .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE)) + .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING) + .setReceiver( + windowedValue -> + outputs.add( + new TimestampedOutput<>( + windowedValue.getValue(), windowedValue.getTimestamp()))); + } + + // Bypasses the nested anonymous OutputBuilder instantiation for standard outputs. + // JVM optimization to prevent garbage collection pressure under high pipeline throughput. + @Override + public void output(T output) { + outputs.add(new TimestampedOutput<>(output, null)); + } + + @Override + public void outputWithTimestamp(T output, Instant timestamp) { + outputs.add(new TimestampedOutput<>(output, timestamp)); + } + + public List<T> getOutputs() { + List<T> rawOutputs = new ArrayList<>(); + for (TimestampedOutput<T> out : outputs) { + rawOutputs.add(out.value); + } + return rawOutputs; + } + + public List<TimestampedOutput<T>> getTimestampedOutputs() { + return outputs; + } + } + + public AsyncDoFn( + DoFn<InputT, OutputT> syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction<InputT, Object> idFn, + boolean useThreadPool) { + this( + syncFn, + parallelism, + timerFrequency, + maxItemsToBuffer, + timeout, + maxWaitTime, + idFn, + useThreadPool, + null); + } + + public AsyncDoFn( + DoFn<InputT, OutputT> syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction<InputT, Object> idFn, + boolean useThreadPool, + @Nullable Coder<KV<K, InputT>> coder) { + this.syncFn = syncFn; + this.parallelism = parallelism; + this.timerFrequency = timerFrequency; + this.maxItemsToBuffer = + (maxItemsToBuffer != null) + ? maxItemsToBuffer + : Math.max(parallelism * 2, DEFAULT_MIN_BUFFER_CAPACITY); + this.timeout = (timeout != null) ? timeout : Duration.standardSeconds(DEFAULT_TIMEOUT_SEC); + this.maxWaitTime = + (maxWaitTime != null) ? maxWaitTime : Duration.millis(DEFAULT_MAX_WAIT_TIME_MS); + this.idFn = + (idFn != null) + ? idFn + : (SerializableFunction<InputT, Object>) + input -> java.util.Objects.requireNonNull(input); + this.useThreadPool = useThreadPool; + this.uuid = UUID.randomUUID().toString(); + this.toProcessSpec = (coder != null) ? StateSpecs.bag(coder) : StateSpecs.bag(); + } + + private ExecutorService getThreadPool() { + ExecutorService threadPool = pool.get(uuid); + if (threadPool == null) { + throw new IllegalStateException("Thread pool not initialized for UUID: " + uuid); + } + return threadPool; + } + + @SuppressWarnings("unchecked") + private ConcurrentHashMap<Object, InFlightElement<OutputT>> getProcessingElements() { + ConcurrentHashMap<Object, InFlightElement<?>> elements = processingElements.get(uuid); + if (elements == null) { + throw new IllegalStateException("Processing elements map not initialized for UUID: " + uuid); + } + return (ConcurrentHashMap<Object, InFlightElement<OutputT>>) (ConcurrentHashMap<?, ?>) elements; + } + + private AtomicInteger getItemsInBuffer() { + AtomicInteger buffer = itemsInBuffer.get(uuid); + if (buffer == null) { + throw new IllegalStateException("Buffer counter not initialized for UUID: " + uuid); + } + return buffer; + } + + @Setup + public void setup(PipelineOptions options) { + this.pipelineOptions = options; + + // Setup the wrapped DoFn + DoFnInvokers.invokerFor(syncFn) + .invokeSetup( + new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() { + @Override + public PipelineOptions pipelineOptions() { + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Setup"; + } + }); + + if (useThreadPool) { + LOG.info("Using thread pool for asynchronous execution with parallelism {}", parallelism); + } + + lock.lock(); + try { + pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism)); + processingElements.computeIfAbsent(uuid, k -> new ConcurrentHashMap<>()); + itemsInBuffer.computeIfAbsent(uuid, k -> new AtomicInteger(0)); + } finally { + lock.unlock(); + } Review Comment:  Increment the reference count for the current `uuid` during `@Setup` to ensure safe teardown. ```java lock.lock();\n try {\n pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism));\n processingElements.computeIfAbsent(uuid, k -> new ConcurrentHashMap<>());\n itemsInBuffer.computeIfAbsent(uuid, k -> new AtomicInteger(0));\n refCounts.computeIfAbsent(uuid, k -> new AtomicInteger(0)).incrementAndGet();\n } finally {\n lock.unlock();\n } ``` ########## sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java: ########## @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Class that wraps a dofn and converts it from one which process elements synchronously to one + * which processes them asynchronously. + * + * <p>For synchronous dofns the default settings mean that many (100s) of elements will be processed + * in parallel and that processing an element will block all other work on that key. In addition + * runners are optimized for latencies less than a few seconds and longer operations can result in + * high retry rates. Async should be considered when the default parallelism is not correct and/or + * items are expected to take longer than a few seconds to process. + */ +public class AsyncDoFn<K, InputT, OutputT> extends DoFn<KV<K, InputT>, OutputT> { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); + + private static final int DEFAULT_MIN_BUFFER_CAPACITY = 10; + private static final int DEFAULT_TIMEOUT_SEC = 1; + private static final int DEFAULT_MAX_WAIT_TIME_MS = 500; + private static final int TEARDOWN_AWAIT_SEC = 5; + private static final int INITIAL_BACKOFF_SLEEP_MS = 10; + private static final int BACKPRESSURE_LOG_THRESHOLD_MS = 10000; + + @StateId("to_process") + private final StateSpec<BagState<KV<K, InputT>>> toProcessSpec; + + @TimerId("timer") + private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final DoFn<InputT, OutputT> syncFn; + private final int parallelism; + private final Duration timerFrequency; + private final int maxItemsToBuffer; + private final Duration timeout; + private final Duration maxWaitTime; + private final SerializableFunction<InputT, Object> idFn; + private final boolean useThreadPool; + private final String uuid; + + private transient volatile @Nullable PipelineOptions pipelineOptions; + + // Shared JVM-Wide States (Static Registries) + // Map-backed registry holding shared resources across serialized worker instances. Since runners + // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse. + private static final ConcurrentHashMap<String, ExecutorService> pool = new ConcurrentHashMap<>(); + // activeElements (processingElements) is global JVM memory (all keys) + private static final ConcurrentHashMap<String, ConcurrentHashMap<Object, InFlightElement<?>>> + processingElements = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap<String, AtomicInteger> itemsInBuffer = + new ConcurrentHashMap<>(); + + private static final ReentrantLock lock = new ReentrantLock(); + private static final boolean verboseLogging = false; + + private static class TimestampedOutput<T> { + final T value; + final @Nullable Instant timestamp; + + TimestampedOutput(T value, @Nullable Instant timestamp) { + this.value = value; + this.timestamp = timestamp; + } + } + + private static class InFlightElement<OutputT> { + final CompletableFuture<List<TimestampedOutput<OutputT>>> future; + + InFlightElement(CompletableFuture<List<TimestampedOutput<OutputT>>> future) { + this.future = future; + } + } + + // The In-Memory Accumulating Receiver + // Accumulates elements in-memory during asynchronous background worker execution. + // Buffered elements are only committed downstream once the parent task completes successfully + // and the timer fires. + private static class AccumulatingOutputReceiver<T> implements OutputReceiver<T> { + private final List<TimestampedOutput<T>> outputs = + Collections.synchronizedList(new ArrayList<>()); + + @Override + public org.apache.beam.sdk.values.OutputBuilder<T> builder(T value) { + return org.apache.beam.sdk.values.WindowedValues.<T>builder() + .setValue(value) + .setTimestamp(Instant.now()) + .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE)) + .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING) + .setReceiver( + windowedValue -> + outputs.add( + new TimestampedOutput<>( + windowedValue.getValue(), windowedValue.getTimestamp()))); + } + + // Bypasses the nested anonymous OutputBuilder instantiation for standard outputs. + // JVM optimization to prevent garbage collection pressure under high pipeline throughput. + @Override + public void output(T output) { + outputs.add(new TimestampedOutput<>(output, null)); + } + + @Override + public void outputWithTimestamp(T output, Instant timestamp) { + outputs.add(new TimestampedOutput<>(output, timestamp)); + } + + public List<T> getOutputs() { + List<T> rawOutputs = new ArrayList<>(); + for (TimestampedOutput<T> out : outputs) { + rawOutputs.add(out.value); + } + return rawOutputs; + } + + public List<TimestampedOutput<T>> getTimestampedOutputs() { + return outputs; + } + } + + public AsyncDoFn( + DoFn<InputT, OutputT> syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction<InputT, Object> idFn, + boolean useThreadPool) { + this( + syncFn, + parallelism, + timerFrequency, + maxItemsToBuffer, + timeout, + maxWaitTime, + idFn, + useThreadPool, + null); + } + + public AsyncDoFn( + DoFn<InputT, OutputT> syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction<InputT, Object> idFn, + boolean useThreadPool, + @Nullable Coder<KV<K, InputT>> coder) { + this.syncFn = syncFn; + this.parallelism = parallelism; + this.timerFrequency = timerFrequency; + this.maxItemsToBuffer = + (maxItemsToBuffer != null) + ? maxItemsToBuffer + : Math.max(parallelism * 2, DEFAULT_MIN_BUFFER_CAPACITY); + this.timeout = (timeout != null) ? timeout : Duration.standardSeconds(DEFAULT_TIMEOUT_SEC); + this.maxWaitTime = + (maxWaitTime != null) ? maxWaitTime : Duration.millis(DEFAULT_MAX_WAIT_TIME_MS); + this.idFn = + (idFn != null) + ? idFn + : (SerializableFunction<InputT, Object>) + input -> java.util.Objects.requireNonNull(input); + this.useThreadPool = useThreadPool; + this.uuid = UUID.randomUUID().toString(); + this.toProcessSpec = (coder != null) ? StateSpecs.bag(coder) : StateSpecs.bag(); + } + + private ExecutorService getThreadPool() { + ExecutorService threadPool = pool.get(uuid); + if (threadPool == null) { + throw new IllegalStateException("Thread pool not initialized for UUID: " + uuid); + } + return threadPool; + } + + @SuppressWarnings("unchecked") + private ConcurrentHashMap<Object, InFlightElement<OutputT>> getProcessingElements() { + ConcurrentHashMap<Object, InFlightElement<?>> elements = processingElements.get(uuid); + if (elements == null) { + throw new IllegalStateException("Processing elements map not initialized for UUID: " + uuid); + } + return (ConcurrentHashMap<Object, InFlightElement<OutputT>>) (ConcurrentHashMap<?, ?>) elements; + } + + private AtomicInteger getItemsInBuffer() { + AtomicInteger buffer = itemsInBuffer.get(uuid); + if (buffer == null) { + throw new IllegalStateException("Buffer counter not initialized for UUID: " + uuid); + } + return buffer; + } + + @Setup + public void setup(PipelineOptions options) { + this.pipelineOptions = options; + + // Setup the wrapped DoFn + DoFnInvokers.invokerFor(syncFn) + .invokeSetup( + new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() { + @Override + public PipelineOptions pipelineOptions() { + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Setup"; + } + }); + + if (useThreadPool) { + LOG.info("Using thread pool for asynchronous execution with parallelism {}", parallelism); + } + + lock.lock(); + try { + pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism)); + processingElements.computeIfAbsent(uuid, k -> new ConcurrentHashMap<>()); + itemsInBuffer.computeIfAbsent(uuid, k -> new AtomicInteger(0)); + } finally { + lock.unlock(); + } + } + + // Clean up JVM-wide shared resources to prevent thread leaks on the worker + @Teardown + public void teardown() { + DoFnInvokers.invokerFor(syncFn).invokeTeardown(); + + ExecutorService threadPool; + lock.lock(); + try { + threadPool = pool.remove(uuid); + processingElements.remove(uuid); + itemsInBuffer.remove(uuid); + } finally { + lock.unlock(); + } + + if (threadPool != null) { + threadPool.shutdown(); + try { + if (!threadPool.awaitTermination(TEARDOWN_AWAIT_SEC, TimeUnit.SECONDS)) { + threadPool.shutdownNow(); + } + } catch (InterruptedException e) { + threadPool.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + } + + // Asynchronous Scheduling & Deduplication + // Submits tasks to the background thread pool. If an element with the same ID is already + // in-flight, + // the submission is silently ignored to enforce exactly-once semantics. + private boolean scheduleIfRoom( + KV<K, InputT> element, BoundedWindow window, Instant timestamp, boolean ignoreBuffer) { + lock.lock(); + try { + ConcurrentHashMap<Object, InFlightElement<OutputT>> activeElements = getProcessingElements(); + Object elementId = idFn.apply(element.getValue()); + + if (activeElements.containsKey(elementId)) { + LOG.info("Item {} already in processing elements", element); + return true; + } + + int currentBuffer = getItemsInBuffer().get(); + if (currentBuffer < maxItemsToBuffer || ignoreBuffer) { + java.util.concurrent.Executor executor = + useThreadPool ? getThreadPool() : java.util.concurrent.ForkJoinPool.commonPool(); + + // Pending asynchronous task that will produce a list of outputs + CompletableFuture<List<TimestampedOutput<OutputT>>> future = + CompletableFuture.supplyAsync( + () -> { + try { + AccumulatingOutputReceiver<OutputT> receiver = + new AccumulatingOutputReceiver<>(); + DoFnInvoker<InputT, OutputT> invoker = DoFnInvokers.invokerFor(syncFn); + + DoFnInvoker.ArgumentProvider<InputT, OutputT> bundleArgProvider = + new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() { + @Override + public PipelineOptions pipelineOptions() { + PipelineOptions options = pipelineOptions; + if (options == null) { + throw new IllegalStateException("PipelineOptions not set"); + } + return options; + } + + @Override + public DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext( + DoFn<InputT, OutputT> doFn) { + return doFn.new FinishBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions(); + } + + @Override + public void output( + OutputT output, Instant timestamp, BoundedWindow window) { + receiver.outputWithTimestamp(output, timestamp); + } + + @Override + public <T> void output( + TupleTag<T> tag, + T output, + Instant timestamp, + BoundedWindow window) { + throw new UnsupportedOperationException( + "Tagged output not supported in " + + "FinishBundleContext for AsyncDoFn"); + } + }; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Bundle"; + } + }; + + invoker.invokeStartBundle(bundleArgProvider); + + DoFnInvoker.ArgumentProvider<InputT, OutputT> processArgProvider = + new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() { + @Override + public InputT element(DoFn<InputT, OutputT> doFn) { + return element.getValue(); + } + + @Override + public OutputReceiver<OutputT> outputReceiver( + DoFn<InputT, OutputT> doFn) { + return receiver; + } + + @Override + public BoundedWindow window() { + return window; + } + + @Override + public Instant timestamp(DoFn<InputT, OutputT> doFn) { + return timestamp; + } + + @Override + public PipelineOptions pipelineOptions() { + PipelineOptions options = pipelineOptions; + if (options == null) { + throw new IllegalStateException("PipelineOptions not set"); + } + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Process"; + } + }; + + invoker.invokeProcessElement(processArgProvider); + invoker.invokeFinishBundle(bundleArgProvider); + + return receiver.getTimestampedOutputs(); + } catch (Exception e) { + throw new CompletionException(e); + } + }, + executor); + + // Assigned to 'unused' to satisfy ErrorProne while preserving parent future for + // cancellation + CompletableFuture<List<TimestampedOutput<OutputT>>> unused = + future.whenComplete( + (res, ex) -> { + lock.lock(); + try { + getItemsInBuffer().decrementAndGet(); + } finally { + lock.unlock(); + } + }); + + activeElements.put(elementId, new InFlightElement<>(future)); + getItemsInBuffer().incrementAndGet(); + return true; + } + + return false; + } finally { + lock.unlock(); + } + } + + private void scheduleItem(KV<K, InputT> element, BoundedWindow window, Instant timestamp) { + boolean done = false; + long sleepTime = INITIAL_BACKOFF_SLEEP_MS; + long totalSleep = 0; + long timeoutMs = timeout.getMillis(); + + while (!done && totalSleep < timeoutMs) { + done = scheduleIfRoom(element, window, timestamp, false); + if (!done) { + long sleep = Math.min(maxWaitTime.getMillis(), sleepTime); + if (verboseLogging || totalSleep > BACKPRESSURE_LOG_THRESHOLD_MS) { + LOG.info( + "buffer is full for item {}, {} waiting {} ms. Have waited for {} ms.", + element, + getItemsInBuffer().get(), + sleep, + totalSleep); + } + try { + Thread.sleep(sleep); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for space in buffer", e); + } + + // Prevents long overflow possibility + if (sleepTime < maxWaitTime.getMillis()) { + sleepTime *= 2; + } + + totalSleep += sleep; + } + } + // Timeout: element skips JVM pool but stays in BagState for timer to reschedule later. + } + + private Instant nextTimeToFire(@Nullable K key) { + long seed = (key == null) ? 0 : key.hashCode(); + double fractionalOffset = Math.abs(seed % 1000000) / 1000000.0; + double timerFrequencySec = timerFrequency.getMillis() / 1000.0; + double nowSec = System.currentTimeMillis() / 1000.0; + + double base = Math.floor((nowSec + timerFrequencySec) / timerFrequencySec) * timerFrequencySec; + double offset = fractionalOffset * timerFrequencySec; + + return Instant.ofEpochMilli((long) ((base + offset) * 1000)); + } + + @ProcessElement + public void processElement( + ProcessContext c, + BoundedWindow window, + @StateId("to_process") BagState<KV<K, InputT>> toProcessState, + @TimerId("timer") Timer timer) { + + KV<K, InputT> element = c.element(); + scheduleItem(element, window, c.timestamp()); + toProcessState.add(element); + + Instant timeToFire = nextTimeToFire(element.getKey()); + timer.set(timeToFire); + } + + @OnTimer("timer") + public void onTimer( + OnTimerContext c, + @StateId("to_process") BagState<KV<K, InputT>> toProcessState, + @TimerId("timer") Timer timer, + OutputReceiver<OutputT> receiver) { + + commitFinishedItems(c.fireTimestamp(), toProcessState, timer, receiver); + } + + // Synchronizes local task results with the runner's persistent state container. + // Emits successfully completed elements, cancels rolled-back tasks, and reschedules lost work. + void commitFinishedItems( + Instant fireTimestamp, + BagState<KV<K, InputT>> toProcessState, + Timer timer, + OutputReceiver<OutputT> receiver) { + + Iterable<KV<K, InputT>> toProcessLocal = toProcessState.read(); + if (toProcessLocal == null || !toProcessLocal.iterator().hasNext()) { + // Early Exit: if BagState is empty, we skip checking activeElements for this key. + return; + } + + // Since fireTimestamp is key-scoped, we determine the current key from the first element in + // state + List<KV<K, InputT>> stateList = new ArrayList<>(); + K key = null; + for (KV<K, InputT> element : toProcessLocal) { + stateList.add(element); + if (key == null) { + key = element.getKey(); + } + } + + if (verboseLogging) { + LOG.info("processing timer for key: {}", key); + } + + ConcurrentHashMap<Object, InFlightElement<OutputT>> activeElements = getProcessingElements(); + + List<List<TimestampedOutput<OutputT>>> toReturn = new ArrayList<>(); + Set<KV<K, InputT>> finishedItems = new HashSet<>(); + List<KV<K, InputT>> toReschedule = new ArrayList<>(); + + int itemsFinished = 0; + int itemsNotYetFinished = 0; + int itemsRescheduled = 0; + + Set<Object> finishedElementIds = new HashSet<>(); + Set<Object> inFlightElementIds = new HashSet<>(); + Set<Object> rescheduledElementIds = new HashSet<>(); + + lock.lock(); + try { + for (KV<K, InputT> element : stateList) { + Object elementId = idFn.apply(element.getValue()); + + // Skip processing if we already completed, rescheduled, or found this elementId active in + // this cycle + if (finishedElementIds.contains(elementId) + || rescheduledElementIds.contains(elementId) + || inFlightElementIds.contains(elementId)) { + continue; + } + + if (activeElements.containsKey(elementId)) { + InFlightElement<OutputT> inFlight = activeElements.get(elementId); + if (inFlight.future.isDone()) { + try { + if (!inFlight.future.isCancelled()) { + toReturn.add(inFlight.future.get()); + } + finishedItems.add(element); + finishedElementIds.add(elementId); + activeElements.remove(elementId); + itemsFinished++; + } catch (Exception e) { + LOG.error("Error executing async task for element {}", element, e); + finishedItems.add(element); + finishedElementIds.add(elementId); + activeElements.remove(elementId); + } + } else { + inFlightElementIds.add(elementId); + itemsNotYetFinished++; + } + } else { + LOG.info( + "Item {} found in state but not in local active elements, scheduling now", element); + toReschedule.add(element); + rescheduledElementIds.add(elementId); + itemsRescheduled++; + } + } + } finally { + lock.unlock(); + } + + // Reschedule missing elements + for (KV<K, InputT> element : toReschedule) { + scheduleItem(element, GlobalWindow.INSTANCE, fireTimestamp); + } + + // Update State: keep only unfinished items + toProcessState.clear(); + int itemsInProcessingState = 0; + for (KV<K, InputT> element : stateList) { + if (!finishedItems.contains(element)) { + toProcessState.add(element); + itemsInProcessingState++; + } + } + + // Emit completed outputs + // (Emit completed tasks immediately; do not wait for all active tasks to finish). + for (List<TimestampedOutput<OutputT>> outputs : toReturn) { + for (TimestampedOutput<OutputT> out : outputs) { + if (out.timestamp != null) { + receiver.outputWithTimestamp(out.value, out.timestamp); + } else { + receiver.output(out.value); + } + } + } + + LOG.info( + "Items finished: {}, not yet finished: {}, " + "rescheduled: {}, in processing state: {}", + itemsFinished, + itemsNotYetFinished, + itemsRescheduled, + itemsInProcessingState); + + if (itemsInProcessingState > 0) { + Instant timeToFire = nextTimeToFire(key); + timer.set(timeToFire); + } + } + + // Package-private helper methods for testing direct execution without Pipeline / ProcessContext + // boilerplate + void processDirect( + KV<K, InputT> element, + BoundedWindow window, + Instant timestamp, + BagState<KV<K, InputT>> toProcessState, + Timer timer) { + scheduleItem(element, window, timestamp); + toProcessState.add(element); + Instant timeToFire = nextTimeToFire(element.getKey()); + timer.set(timeToFire); + } + + List<OutputT> commitFinishedItemsDirect( + Instant fireTimestamp, BagState<KV<K, InputT>> toProcessState, Timer timer) { + AccumulatingOutputReceiver<OutputT> receiver = new AccumulatingOutputReceiver<>(); + commitFinishedItems(fireTimestamp, toProcessState, timer, receiver); + return receiver.getOutputs(); + } + + boolean isEmpty() { + return getItemsInBuffer().get() == 0; + } + + int getItemsInBufferCount() { + return getItemsInBuffer().get(); + } + + static void resetState() { + lock.lock(); + try { + for (Map.Entry<String, ExecutorService> entry : pool.entrySet()) { + entry.getValue().shutdownNow(); + } + pool.clear(); + processingElements.clear(); + itemsInBuffer.clear(); + } finally { + lock.unlock(); + } + } Review Comment:  Clear the static `refCounts` map when resetting state. ```java static void resetState() {\n lock.lock();\n try {\n for (Map.Entry<String, ExecutorService> entry : pool.entrySet()) {\n entry.getValue().shutdownNow();\n }\n pool.clear();\n processingElements.clear();\n itemsInBuffer.clear();\n refCounts.clear();\n } finally {\n lock.unlock();\n }\n } ``` ########## sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java: ########## @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Class that wraps a dofn and converts it from one which process elements synchronously to one + * which processes them asynchronously. + * + * <p>For synchronous dofns the default settings mean that many (100s) of elements will be processed + * in parallel and that processing an element will block all other work on that key. In addition + * runners are optimized for latencies less than a few seconds and longer operations can result in + * high retry rates. Async should be considered when the default parallelism is not correct and/or + * items are expected to take longer than a few seconds to process. + */ +public class AsyncDoFn<K, InputT, OutputT> extends DoFn<KV<K, InputT>, OutputT> { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); + + private static final int DEFAULT_MIN_BUFFER_CAPACITY = 10; + private static final int DEFAULT_TIMEOUT_SEC = 1; + private static final int DEFAULT_MAX_WAIT_TIME_MS = 500; + private static final int TEARDOWN_AWAIT_SEC = 5; + private static final int INITIAL_BACKOFF_SLEEP_MS = 10; + private static final int BACKPRESSURE_LOG_THRESHOLD_MS = 10000; + + @StateId("to_process") + private final StateSpec<BagState<KV<K, InputT>>> toProcessSpec; + + @TimerId("timer") + private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final DoFn<InputT, OutputT> syncFn; + private final int parallelism; + private final Duration timerFrequency; + private final int maxItemsToBuffer; + private final Duration timeout; + private final Duration maxWaitTime; + private final SerializableFunction<InputT, Object> idFn; + private final boolean useThreadPool; + private final String uuid; + + private transient volatile @Nullable PipelineOptions pipelineOptions; + + // Shared JVM-Wide States (Static Registries) + // Map-backed registry holding shared resources across serialized worker instances. Since runners + // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse. + private static final ConcurrentHashMap<String, ExecutorService> pool = new ConcurrentHashMap<>(); + // activeElements (processingElements) is global JVM memory (all keys) + private static final ConcurrentHashMap<String, ConcurrentHashMap<Object, InFlightElement<?>>> + processingElements = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap<String, AtomicInteger> itemsInBuffer = + new ConcurrentHashMap<>(); + + private static final ReentrantLock lock = new ReentrantLock(); + private static final boolean verboseLogging = false; + + private static class TimestampedOutput<T> { + final T value; + final @Nullable Instant timestamp; + + TimestampedOutput(T value, @Nullable Instant timestamp) { + this.value = value; + this.timestamp = timestamp; + } + } + + private static class InFlightElement<OutputT> { + final CompletableFuture<List<TimestampedOutput<OutputT>>> future; + + InFlightElement(CompletableFuture<List<TimestampedOutput<OutputT>>> future) { + this.future = future; + } + } + + // The In-Memory Accumulating Receiver + // Accumulates elements in-memory during asynchronous background worker execution. + // Buffered elements are only committed downstream once the parent task completes successfully + // and the timer fires. + private static class AccumulatingOutputReceiver<T> implements OutputReceiver<T> { + private final List<TimestampedOutput<T>> outputs = + Collections.synchronizedList(new ArrayList<>()); + + @Override + public org.apache.beam.sdk.values.OutputBuilder<T> builder(T value) { + return org.apache.beam.sdk.values.WindowedValues.<T>builder() + .setValue(value) + .setTimestamp(Instant.now()) + .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE)) + .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING) + .setReceiver( + windowedValue -> + outputs.add( + new TimestampedOutput<>( + windowedValue.getValue(), windowedValue.getTimestamp()))); + } + + // Bypasses the nested anonymous OutputBuilder instantiation for standard outputs. + // JVM optimization to prevent garbage collection pressure under high pipeline throughput. + @Override + public void output(T output) { + outputs.add(new TimestampedOutput<>(output, null)); + } + + @Override + public void outputWithTimestamp(T output, Instant timestamp) { + outputs.add(new TimestampedOutput<>(output, timestamp)); + } + + public List<T> getOutputs() { + List<T> rawOutputs = new ArrayList<>(); + for (TimestampedOutput<T> out : outputs) { + rawOutputs.add(out.value); + } + return rawOutputs; + } + + public List<TimestampedOutput<T>> getTimestampedOutputs() { + return outputs; + } + } + + public AsyncDoFn( + DoFn<InputT, OutputT> syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction<InputT, Object> idFn, + boolean useThreadPool) { + this( + syncFn, + parallelism, + timerFrequency, + maxItemsToBuffer, + timeout, + maxWaitTime, + idFn, + useThreadPool, + null); + } + + public AsyncDoFn( + DoFn<InputT, OutputT> syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction<InputT, Object> idFn, + boolean useThreadPool, + @Nullable Coder<KV<K, InputT>> coder) { + this.syncFn = syncFn; + this.parallelism = parallelism; + this.timerFrequency = timerFrequency; + this.maxItemsToBuffer = + (maxItemsToBuffer != null) + ? maxItemsToBuffer + : Math.max(parallelism * 2, DEFAULT_MIN_BUFFER_CAPACITY); + this.timeout = (timeout != null) ? timeout : Duration.standardSeconds(DEFAULT_TIMEOUT_SEC); + this.maxWaitTime = + (maxWaitTime != null) ? maxWaitTime : Duration.millis(DEFAULT_MAX_WAIT_TIME_MS); + this.idFn = + (idFn != null) + ? idFn + : (SerializableFunction<InputT, Object>) + input -> java.util.Objects.requireNonNull(input); + this.useThreadPool = useThreadPool; + this.uuid = UUID.randomUUID().toString(); + this.toProcessSpec = (coder != null) ? StateSpecs.bag(coder) : StateSpecs.bag(); + } + + private ExecutorService getThreadPool() { + ExecutorService threadPool = pool.get(uuid); + if (threadPool == null) { + throw new IllegalStateException("Thread pool not initialized for UUID: " + uuid); + } + return threadPool; + } + + @SuppressWarnings("unchecked") + private ConcurrentHashMap<Object, InFlightElement<OutputT>> getProcessingElements() { + ConcurrentHashMap<Object, InFlightElement<?>> elements = processingElements.get(uuid); + if (elements == null) { + throw new IllegalStateException("Processing elements map not initialized for UUID: " + uuid); + } + return (ConcurrentHashMap<Object, InFlightElement<OutputT>>) (ConcurrentHashMap<?, ?>) elements; + } + + private AtomicInteger getItemsInBuffer() { + AtomicInteger buffer = itemsInBuffer.get(uuid); + if (buffer == null) { + throw new IllegalStateException("Buffer counter not initialized for UUID: " + uuid); + } + return buffer; + } + + @Setup + public void setup(PipelineOptions options) { + this.pipelineOptions = options; + + // Setup the wrapped DoFn + DoFnInvokers.invokerFor(syncFn) + .invokeSetup( + new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() { + @Override + public PipelineOptions pipelineOptions() { + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Setup"; + } + }); + + if (useThreadPool) { + LOG.info("Using thread pool for asynchronous execution with parallelism {}", parallelism); + } + + lock.lock(); + try { + pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism)); + processingElements.computeIfAbsent(uuid, k -> new ConcurrentHashMap<>()); + itemsInBuffer.computeIfAbsent(uuid, k -> new AtomicInteger(0)); + } finally { + lock.unlock(); + } + } + + // Clean up JVM-wide shared resources to prevent thread leaks on the worker + @Teardown + public void teardown() { + DoFnInvokers.invokerFor(syncFn).invokeTeardown(); + + ExecutorService threadPool; + lock.lock(); + try { + threadPool = pool.remove(uuid); + processingElements.remove(uuid); + itemsInBuffer.remove(uuid); + } finally { + lock.unlock(); + } + Review Comment:  ### High Severity Bug: Shared Resource Teardown Race Condition\n\nSince runners clone `DoFn` instances on the same worker node, multiple active instances of `AsyncDoFn` with the same `uuid` will share the same `ExecutorService` and static registries.\n\nWhen any of these cloned instances is torn down, its `@Teardown` method is called, which immediately removes the shared `threadPool` from the static map and shuts it down. This will cause all other active instances sharing the same `uuid` to fail with `IllegalStateException` when they attempt to access the thread pool or processing elements map.\n\n#### Solution:\nImplement reference counting to ensure that the shared resources are only removed and shut down when the last active instance of `AsyncDoFn` for a given `uuid` is torn down. ```java ExecutorService threadPool = null;\n lock.lock();\n try {\n AtomicInteger refCount = refCounts.get(uuid);\n if (refCount != null && refCount.decrementAndGet() == 0) {\n refCounts.remove(uuid);\n threadPool = pool.remove(uuid);\n processingElements.remove(uuid);\n itemsInBuffer.remove(uuid);\n }\n } finally {\n lock.unlock();\n } ``` ########## sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java: ########## @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Class that wraps a dofn and converts it from one which process elements synchronously to one + * which processes them asynchronously. + * + * <p>For synchronous dofns the default settings mean that many (100s) of elements will be processed + * in parallel and that processing an element will block all other work on that key. In addition + * runners are optimized for latencies less than a few seconds and longer operations can result in + * high retry rates. Async should be considered when the default parallelism is not correct and/or + * items are expected to take longer than a few seconds to process. + */ +public class AsyncDoFn<K, InputT, OutputT> extends DoFn<KV<K, InputT>, OutputT> { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); + + private static final int DEFAULT_MIN_BUFFER_CAPACITY = 10; + private static final int DEFAULT_TIMEOUT_SEC = 1; + private static final int DEFAULT_MAX_WAIT_TIME_MS = 500; + private static final int TEARDOWN_AWAIT_SEC = 5; + private static final int INITIAL_BACKOFF_SLEEP_MS = 10; + private static final int BACKPRESSURE_LOG_THRESHOLD_MS = 10000; + + @StateId("to_process") + private final StateSpec<BagState<KV<K, InputT>>> toProcessSpec; + + @TimerId("timer") + private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final DoFn<InputT, OutputT> syncFn; + private final int parallelism; + private final Duration timerFrequency; + private final int maxItemsToBuffer; + private final Duration timeout; + private final Duration maxWaitTime; + private final SerializableFunction<InputT, Object> idFn; + private final boolean useThreadPool; + private final String uuid; + + private transient volatile @Nullable PipelineOptions pipelineOptions; + + // Shared JVM-Wide States (Static Registries) + // Map-backed registry holding shared resources across serialized worker instances. Since runners + // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse. + private static final ConcurrentHashMap<String, ExecutorService> pool = new ConcurrentHashMap<>(); + // activeElements (processingElements) is global JVM memory (all keys) + private static final ConcurrentHashMap<String, ConcurrentHashMap<Object, InFlightElement<?>>> + processingElements = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap<String, AtomicInteger> itemsInBuffer = + new ConcurrentHashMap<>(); + + private static final ReentrantLock lock = new ReentrantLock(); + private static final boolean verboseLogging = false; + + private static class TimestampedOutput<T> { + final T value; + final @Nullable Instant timestamp; + + TimestampedOutput(T value, @Nullable Instant timestamp) { + this.value = value; + this.timestamp = timestamp; + } + } + + private static class InFlightElement<OutputT> { + final CompletableFuture<List<TimestampedOutput<OutputT>>> future; + + InFlightElement(CompletableFuture<List<TimestampedOutput<OutputT>>> future) { + this.future = future; + } + } + + // The In-Memory Accumulating Receiver + // Accumulates elements in-memory during asynchronous background worker execution. + // Buffered elements are only committed downstream once the parent task completes successfully + // and the timer fires. + private static class AccumulatingOutputReceiver<T> implements OutputReceiver<T> { + private final List<TimestampedOutput<T>> outputs = + Collections.synchronizedList(new ArrayList<>()); + + @Override + public org.apache.beam.sdk.values.OutputBuilder<T> builder(T value) { + return org.apache.beam.sdk.values.WindowedValues.<T>builder() + .setValue(value) + .setTimestamp(Instant.now()) + .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE)) + .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING) + .setReceiver( + windowedValue -> + outputs.add( + new TimestampedOutput<>( + windowedValue.getValue(), windowedValue.getTimestamp()))); + } + + // Bypasses the nested anonymous OutputBuilder instantiation for standard outputs. + // JVM optimization to prevent garbage collection pressure under high pipeline throughput. + @Override + public void output(T output) { + outputs.add(new TimestampedOutput<>(output, null)); + } + + @Override + public void outputWithTimestamp(T output, Instant timestamp) { + outputs.add(new TimestampedOutput<>(output, timestamp)); + } + + public List<T> getOutputs() { + List<T> rawOutputs = new ArrayList<>(); + for (TimestampedOutput<T> out : outputs) { + rawOutputs.add(out.value); + } + return rawOutputs; + } + + public List<TimestampedOutput<T>> getTimestampedOutputs() { + return outputs; + } + } + + public AsyncDoFn( + DoFn<InputT, OutputT> syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction<InputT, Object> idFn, + boolean useThreadPool) { + this( + syncFn, + parallelism, + timerFrequency, + maxItemsToBuffer, + timeout, + maxWaitTime, + idFn, + useThreadPool, + null); + } + + public AsyncDoFn( + DoFn<InputT, OutputT> syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction<InputT, Object> idFn, + boolean useThreadPool, + @Nullable Coder<KV<K, InputT>> coder) { + this.syncFn = syncFn; + this.parallelism = parallelism; + this.timerFrequency = timerFrequency; + this.maxItemsToBuffer = + (maxItemsToBuffer != null) + ? maxItemsToBuffer + : Math.max(parallelism * 2, DEFAULT_MIN_BUFFER_CAPACITY); + this.timeout = (timeout != null) ? timeout : Duration.standardSeconds(DEFAULT_TIMEOUT_SEC); + this.maxWaitTime = + (maxWaitTime != null) ? maxWaitTime : Duration.millis(DEFAULT_MAX_WAIT_TIME_MS); + this.idFn = + (idFn != null) + ? idFn + : (SerializableFunction<InputT, Object>) + input -> java.util.Objects.requireNonNull(input); + this.useThreadPool = useThreadPool; + this.uuid = UUID.randomUUID().toString(); + this.toProcessSpec = (coder != null) ? StateSpecs.bag(coder) : StateSpecs.bag(); + } + + private ExecutorService getThreadPool() { + ExecutorService threadPool = pool.get(uuid); + if (threadPool == null) { + throw new IllegalStateException("Thread pool not initialized for UUID: " + uuid); + } + return threadPool; + } + + @SuppressWarnings("unchecked") + private ConcurrentHashMap<Object, InFlightElement<OutputT>> getProcessingElements() { + ConcurrentHashMap<Object, InFlightElement<?>> elements = processingElements.get(uuid); + if (elements == null) { + throw new IllegalStateException("Processing elements map not initialized for UUID: " + uuid); + } + return (ConcurrentHashMap<Object, InFlightElement<OutputT>>) (ConcurrentHashMap<?, ?>) elements; + } + + private AtomicInteger getItemsInBuffer() { + AtomicInteger buffer = itemsInBuffer.get(uuid); + if (buffer == null) { + throw new IllegalStateException("Buffer counter not initialized for UUID: " + uuid); + } + return buffer; + } + + @Setup + public void setup(PipelineOptions options) { + this.pipelineOptions = options; + + // Setup the wrapped DoFn + DoFnInvokers.invokerFor(syncFn) + .invokeSetup( + new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() { + @Override + public PipelineOptions pipelineOptions() { + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Setup"; + } + }); + + if (useThreadPool) { + LOG.info("Using thread pool for asynchronous execution with parallelism {}", parallelism); + } + + lock.lock(); + try { + pool.computeIfAbsent(uuid, k -> Executors.newFixedThreadPool(parallelism)); + processingElements.computeIfAbsent(uuid, k -> new ConcurrentHashMap<>()); + itemsInBuffer.computeIfAbsent(uuid, k -> new AtomicInteger(0)); + } finally { + lock.unlock(); + } + } + + // Clean up JVM-wide shared resources to prevent thread leaks on the worker + @Teardown + public void teardown() { + DoFnInvokers.invokerFor(syncFn).invokeTeardown(); + + ExecutorService threadPool; + lock.lock(); + try { + threadPool = pool.remove(uuid); + processingElements.remove(uuid); + itemsInBuffer.remove(uuid); + } finally { + lock.unlock(); + } + + if (threadPool != null) { + threadPool.shutdown(); + try { + if (!threadPool.awaitTermination(TEARDOWN_AWAIT_SEC, TimeUnit.SECONDS)) { + threadPool.shutdownNow(); + } + } catch (InterruptedException e) { + threadPool.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + } + + // Asynchronous Scheduling & Deduplication + // Submits tasks to the background thread pool. If an element with the same ID is already + // in-flight, + // the submission is silently ignored to enforce exactly-once semantics. + private boolean scheduleIfRoom( + KV<K, InputT> element, BoundedWindow window, Instant timestamp, boolean ignoreBuffer) { + lock.lock(); + try { + ConcurrentHashMap<Object, InFlightElement<OutputT>> activeElements = getProcessingElements(); + Object elementId = idFn.apply(element.getValue()); + + if (activeElements.containsKey(elementId)) { + LOG.info("Item {} already in processing elements", element); + return true; + } + + int currentBuffer = getItemsInBuffer().get(); + if (currentBuffer < maxItemsToBuffer || ignoreBuffer) { + java.util.concurrent.Executor executor = + useThreadPool ? getThreadPool() : java.util.concurrent.ForkJoinPool.commonPool(); + + // Pending asynchronous task that will produce a list of outputs + CompletableFuture<List<TimestampedOutput<OutputT>>> future = + CompletableFuture.supplyAsync( + () -> { + try { + AccumulatingOutputReceiver<OutputT> receiver = + new AccumulatingOutputReceiver<>(); + DoFnInvoker<InputT, OutputT> invoker = DoFnInvokers.invokerFor(syncFn); + + DoFnInvoker.ArgumentProvider<InputT, OutputT> bundleArgProvider = + new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() { + @Override + public PipelineOptions pipelineOptions() { + PipelineOptions options = pipelineOptions; + if (options == null) { + throw new IllegalStateException("PipelineOptions not set"); + } + return options; + } + + @Override + public DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext( + DoFn<InputT, OutputT> doFn) { + return doFn.new FinishBundleContext() { + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions(); + } + + @Override + public void output( + OutputT output, Instant timestamp, BoundedWindow window) { + receiver.outputWithTimestamp(output, timestamp); + } + + @Override + public <T> void output( + TupleTag<T> tag, + T output, + Instant timestamp, + BoundedWindow window) { + throw new UnsupportedOperationException( + "Tagged output not supported in " + + "FinishBundleContext for AsyncDoFn"); + } + }; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Bundle"; + } + }; + + invoker.invokeStartBundle(bundleArgProvider); + + DoFnInvoker.ArgumentProvider<InputT, OutputT> processArgProvider = + new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() { + @Override + public InputT element(DoFn<InputT, OutputT> doFn) { + return element.getValue(); + } + + @Override + public OutputReceiver<OutputT> outputReceiver( + DoFn<InputT, OutputT> doFn) { + return receiver; + } + + @Override + public BoundedWindow window() { + return window; + } + + @Override + public Instant timestamp(DoFn<InputT, OutputT> doFn) { + return timestamp; + } + + @Override + public PipelineOptions pipelineOptions() { + PipelineOptions options = pipelineOptions; + if (options == null) { + throw new IllegalStateException("PipelineOptions not set"); + } + return options; + } + + @Override + public String getErrorContext() { + return "AsyncDoFn/Process"; + } + }; + + invoker.invokeProcessElement(processArgProvider); + invoker.invokeFinishBundle(bundleArgProvider); + + return receiver.getTimestampedOutputs(); + } catch (Exception e) { + throw new CompletionException(e); + } + }, + executor); + + // Assigned to 'unused' to satisfy ErrorProne while preserving parent future for + // cancellation + CompletableFuture<List<TimestampedOutput<OutputT>>> unused = + future.whenComplete( + (res, ex) -> { + lock.lock(); + try { + getItemsInBuffer().decrementAndGet(); + } finally { + lock.unlock(); + } + }); + + activeElements.put(elementId, new InFlightElement<>(future)); + getItemsInBuffer().incrementAndGet(); + return true; + } + + return false; + } finally { + lock.unlock(); + } + } + + private void scheduleItem(KV<K, InputT> element, BoundedWindow window, Instant timestamp) { + boolean done = false; + long sleepTime = INITIAL_BACKOFF_SLEEP_MS; + long totalSleep = 0; + long timeoutMs = timeout.getMillis(); + + while (!done && totalSleep < timeoutMs) { + done = scheduleIfRoom(element, window, timestamp, false); + if (!done) { + long sleep = Math.min(maxWaitTime.getMillis(), sleepTime); + if (verboseLogging || totalSleep > BACKPRESSURE_LOG_THRESHOLD_MS) { + LOG.info( + "buffer is full for item {}, {} waiting {} ms. Have waited for {} ms.", + element, + getItemsInBuffer().get(), + sleep, + totalSleep); + } + try { + Thread.sleep(sleep); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while waiting for space in buffer", e); + } + + // Prevents long overflow possibility + if (sleepTime < maxWaitTime.getMillis()) { + sleepTime *= 2; + } + + totalSleep += sleep; + } + } + // Timeout: element skips JVM pool but stays in BagState for timer to reschedule later. + } + + private Instant nextTimeToFire(@Nullable K key) { + long seed = (key == null) ? 0 : key.hashCode(); + double fractionalOffset = Math.abs(seed % 1000000) / 1000000.0; + double timerFrequencySec = timerFrequency.getMillis() / 1000.0; + double nowSec = System.currentTimeMillis() / 1000.0; + + double base = Math.floor((nowSec + timerFrequencySec) / timerFrequencySec) * timerFrequencySec; + double offset = fractionalOffset * timerFrequencySec; + + return Instant.ofEpochMilli((long) ((base + offset) * 1000)); + } + + @ProcessElement + public void processElement( + ProcessContext c, + BoundedWindow window, + @StateId("to_process") BagState<KV<K, InputT>> toProcessState, + @TimerId("timer") Timer timer) { + + KV<K, InputT> element = c.element(); + scheduleItem(element, window, c.timestamp()); + toProcessState.add(element); + + Instant timeToFire = nextTimeToFire(element.getKey()); + timer.set(timeToFire); + } + + @OnTimer("timer") + public void onTimer( + OnTimerContext c, + @StateId("to_process") BagState<KV<K, InputT>> toProcessState, + @TimerId("timer") Timer timer, + OutputReceiver<OutputT> receiver) { + + commitFinishedItems(c.fireTimestamp(), toProcessState, timer, receiver); + } + + // Synchronizes local task results with the runner's persistent state container. + // Emits successfully completed elements, cancels rolled-back tasks, and reschedules lost work. + void commitFinishedItems( + Instant fireTimestamp, + BagState<KV<K, InputT>> toProcessState, + Timer timer, + OutputReceiver<OutputT> receiver) { + + Iterable<KV<K, InputT>> toProcessLocal = toProcessState.read(); + if (toProcessLocal == null || !toProcessLocal.iterator().hasNext()) { + // Early Exit: if BagState is empty, we skip checking activeElements for this key. + return; + } + + // Since fireTimestamp is key-scoped, we determine the current key from the first element in + // state + List<KV<K, InputT>> stateList = new ArrayList<>(); + K key = null; + for (KV<K, InputT> element : toProcessLocal) { + stateList.add(element); + if (key == null) { + key = element.getKey(); + } + } + + if (verboseLogging) { + LOG.info("processing timer for key: {}", key); + } + + ConcurrentHashMap<Object, InFlightElement<OutputT>> activeElements = getProcessingElements(); + + List<List<TimestampedOutput<OutputT>>> toReturn = new ArrayList<>(); + Set<KV<K, InputT>> finishedItems = new HashSet<>(); + List<KV<K, InputT>> toReschedule = new ArrayList<>(); + + int itemsFinished = 0; + int itemsNotYetFinished = 0; + int itemsRescheduled = 0; + + Set<Object> finishedElementIds = new HashSet<>(); + Set<Object> inFlightElementIds = new HashSet<>(); + Set<Object> rescheduledElementIds = new HashSet<>(); + + lock.lock(); + try { + for (KV<K, InputT> element : stateList) { + Object elementId = idFn.apply(element.getValue()); + + // Skip processing if we already completed, rescheduled, or found this elementId active in + // this cycle + if (finishedElementIds.contains(elementId) + || rescheduledElementIds.contains(elementId) + || inFlightElementIds.contains(elementId)) { + continue; + } + + if (activeElements.containsKey(elementId)) { + InFlightElement<OutputT> inFlight = activeElements.get(elementId); + if (inFlight.future.isDone()) { + try { + if (!inFlight.future.isCancelled()) { + toReturn.add(inFlight.future.get()); + } + finishedItems.add(element); + finishedElementIds.add(elementId); + activeElements.remove(elementId); + itemsFinished++; + } catch (Exception e) { + LOG.error("Error executing async task for element {}", element, e); + finishedItems.add(element); + finishedElementIds.add(elementId); + activeElements.remove(elementId); + } + } else { Review Comment:  ### Critical Correctness Issue: Silent Data Loss on Exception\n\nWhen an exception is thrown during the execution of the async task, it is caught, logged, and then the element is added to `finishedItems` and removed from the persistent state (`toProcessState`).\n\nThis causes **silent data loss** because the failed element is permanently discarded and the pipeline continues as if it succeeded. In Apache Beam, exceptions should be propagated to fail the bundle/timer so that the runner can retry the execution. ```java } catch (Exception e) {\n LOG.error(\"Error executing async task for element {}\", element, e);\n throw new RuntimeException(\"Error executing async task for element \" + element, e);\n } ``` ########## sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java: ########## @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Class that wraps a dofn and converts it from one which process elements synchronously to one + * which processes them asynchronously. + * + * <p>For synchronous dofns the default settings mean that many (100s) of elements will be processed + * in parallel and that processing an element will block all other work on that key. In addition + * runners are optimized for latencies less than a few seconds and longer operations can result in + * high retry rates. Async should be considered when the default parallelism is not correct and/or + * items are expected to take longer than a few seconds to process. + */ +public class AsyncDoFn<K, InputT, OutputT> extends DoFn<KV<K, InputT>, OutputT> { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncDoFn.class); + + private static final int DEFAULT_MIN_BUFFER_CAPACITY = 10; + private static final int DEFAULT_TIMEOUT_SEC = 1; + private static final int DEFAULT_MAX_WAIT_TIME_MS = 500; + private static final int TEARDOWN_AWAIT_SEC = 5; + private static final int INITIAL_BACKOFF_SLEEP_MS = 10; + private static final int BACKPRESSURE_LOG_THRESHOLD_MS = 10000; + + @StateId("to_process") + private final StateSpec<BagState<KV<K, InputT>>> toProcessSpec; + + @TimerId("timer") + private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final DoFn<InputT, OutputT> syncFn; + private final int parallelism; + private final Duration timerFrequency; + private final int maxItemsToBuffer; + private final Duration timeout; + private final Duration maxWaitTime; + private final SerializableFunction<InputT, Object> idFn; + private final boolean useThreadPool; + private final String uuid; + + private transient volatile @Nullable PipelineOptions pipelineOptions; + + // Shared JVM-Wide States (Static Registries) + // Map-backed registry holding shared resources across serialized worker instances. Since runners + // clone DoFn instances on the same worker node, static maps ensure safe JVM-wide resource reuse. + private static final ConcurrentHashMap<String, ExecutorService> pool = new ConcurrentHashMap<>(); + // activeElements (processingElements) is global JVM memory (all keys) + private static final ConcurrentHashMap<String, ConcurrentHashMap<Object, InFlightElement<?>>> + processingElements = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap<String, AtomicInteger> itemsInBuffer = + new ConcurrentHashMap<>(); + + private static final ReentrantLock lock = new ReentrantLock(); + private static final boolean verboseLogging = false; + + private static class TimestampedOutput<T> { + final T value; + final @Nullable Instant timestamp; + + TimestampedOutput(T value, @Nullable Instant timestamp) { + this.value = value; + this.timestamp = timestamp; + } + } + + private static class InFlightElement<OutputT> { + final CompletableFuture<List<TimestampedOutput<OutputT>>> future; + + InFlightElement(CompletableFuture<List<TimestampedOutput<OutputT>>> future) { + this.future = future; + } + } + + // The In-Memory Accumulating Receiver + // Accumulates elements in-memory during asynchronous background worker execution. + // Buffered elements are only committed downstream once the parent task completes successfully + // and the timer fires. + private static class AccumulatingOutputReceiver<T> implements OutputReceiver<T> { + private final List<TimestampedOutput<T>> outputs = + Collections.synchronizedList(new ArrayList<>()); + + @Override + public org.apache.beam.sdk.values.OutputBuilder<T> builder(T value) { + return org.apache.beam.sdk.values.WindowedValues.<T>builder() + .setValue(value) + .setTimestamp(Instant.now()) + .setWindows(java.util.Collections.singletonList(GlobalWindow.INSTANCE)) + .setPaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING) + .setReceiver( + windowedValue -> + outputs.add( + new TimestampedOutput<>( + windowedValue.getValue(), windowedValue.getTimestamp()))); + } + + // Bypasses the nested anonymous OutputBuilder instantiation for standard outputs. + // JVM optimization to prevent garbage collection pressure under high pipeline throughput. + @Override + public void output(T output) { + outputs.add(new TimestampedOutput<>(output, null)); + } + + @Override + public void outputWithTimestamp(T output, Instant timestamp) { + outputs.add(new TimestampedOutput<>(output, timestamp)); + } + + public List<T> getOutputs() { + List<T> rawOutputs = new ArrayList<>(); + for (TimestampedOutput<T> out : outputs) { + rawOutputs.add(out.value); + } + return rawOutputs; + } + + public List<TimestampedOutput<T>> getTimestampedOutputs() { + return outputs; + } + } + + public AsyncDoFn( + DoFn<InputT, OutputT> syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction<InputT, Object> idFn, + boolean useThreadPool) { + this( + syncFn, + parallelism, + timerFrequency, + maxItemsToBuffer, + timeout, + maxWaitTime, + idFn, + useThreadPool, + null); + } + + public AsyncDoFn( + DoFn<InputT, OutputT> syncFn, + int parallelism, + Duration timerFrequency, + @Nullable Integer maxItemsToBuffer, + @Nullable Duration timeout, + @Nullable Duration maxWaitTime, + @Nullable SerializableFunction<InputT, Object> idFn, + boolean useThreadPool, + @Nullable Coder<KV<K, InputT>> coder) { + this.syncFn = syncFn; + this.parallelism = parallelism; + this.timerFrequency = timerFrequency; Review Comment:  ### High Severity Bug: Division by Zero / Infinite Loop\n\nIf `timerFrequency` is set to `Duration.ZERO` or a negative duration, `timerFrequencySec` will be `0.0` or negative. This will cause a division by zero in `nextTimeToFire`, resulting in `Double.NaN` or `Double.POSITIVE_INFINITY`, which ultimately sets the timer to `Instant.ofEpochMilli(0)`. This causes the timer to fire immediately and repeatedly in an infinite loop, hanging the pipeline.\n\nWe should validate that `timerFrequency` is strictly positive in the constructor. ```java this.parallelism = parallelism;\n if (timerFrequency.getMillis() <= 0) {\n throw new IllegalArgumentException(\"timerFrequency must be greater than zero\");\n }\n this.timerFrequency = timerFrequency; ``` ########## sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AsyncDoFn.java: ########## @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Class that wraps a dofn and converts it from one which process elements synchronously to one + * which processes them asynchronously. + * + * <p>For synchronous dofns the default settings mean that many (100s) of elements will be processed + * in parallel and that processing an element will block all other work on that key. In addition + * runners are optimized for latencies less than a few seconds and longer operations can result in + * high retry rates. Async should be considered when the default parallelism is not correct and/or + * items are expected to take longer than a few seconds to process. + */ Review Comment:  ### Medium Severity: Document Limitations and Requirements in Javadoc\n\nThe class Javadoc should clearly document the requirements and limitations of `AsyncDoFn` to prevent users from encountering hard-to-debug runtime issues:\n1. **Thread-Safety**: The wrapped `DoFn` must be thread-safe because its methods are invoked concurrently by multiple background threads.\n2. **No Multi-Output Support**: Tagged outputs and `MultiOutputReceiver` are not supported.\n3. **Bundle Lifecycle**: `startBundle` and `finishBundle` are invoked per element, so any batching or aggregation logic in them will not behave as expected. ```java /**\n * Class that wraps a {@link DoFn} and converts it from one which processes elements synchronously to one\n * which processes them asynchronously.\n *\n * <p>For synchronous DoFns, the default settings mean that many (100s) of elements will be processed\n * in parallel and that processing an element will block all other work on that key. In addition,\n * runners are optimized for latencies less than a few seconds, and longer operations can result in\n * high retry rates. Async should be considered when the default parallelism is not correct and/or\n * items are expected to take longer than a few seconds to process.\n *\n * <h3>Limitations & Requirements:</h3>\n * <ul>\n * <li><b>Thread-Safety:</b> The wrapped {@code DoFn} must be thread-safe because its methods\n * (including {@code processElement}) will be invoked concurrently by multiple background threads.</li>\n * <li><b>No Multi-Output Support:</b> Tagged outputs and {@code MultiOutputReceiver} are not supporte d.\n * Attempting to use them will result in an {@link UnsupportedOperationException}.</li>\n * <li><b>Bundle Lifecycle:</b> {@code startBundle} and {@code finishBundle} are invoked per element\n * within the background tasks, meaning any batching or aggregation logic implemented in them\n * will not function as expected.</li>\n * </ul>\n */ ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
