This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 57320861aa2 Ensure that Operations are aborted when MapTaskExecutor is
closed. Add tests around setup/teardown of DoFns (#36631)
57320861aa2 is described below
commit 57320861aa24516ec75d4b47d7d28d95aacb2010
Author: Sam Whittle <[email protected]>
AuthorDate: Mon Nov 17 09:54:23 2025 +0100
Ensure that Operations are aborted when MapTaskExecutor is closed. Add
tests around setup/teardown of DoFns (#36631)
---
...it_Java_ValidatesRunner_Dataflow_Streaming.json | 3 +-
...ostCommit_Java_ValidatesRunner_Dataflow_V2.json | 3 +-
runners/google-cloud-dataflow-java/build.gradle | 30 ++++-
.../worker/IntrinsicMapTaskExecutorFactory.java | 53 +++++---
.../worker/util/common/worker/MapTaskExecutor.java | 71 ++++++++---
.../work/processing/StreamingWorkScheduler.java | 13 +-
.../IntrinsicMapTaskExecutorFactoryTest.java | 137 +++++++++++++++++++--
.../runners/dataflow/worker/SimpleParDoFnTest.java | 12 +-
.../worker/StreamingDataflowWorkerTest.java | 114 ++++++++++++++++-
.../worker/testing/TestCountingSource.java | 5 +
.../util/common/worker/MapTaskExecutorTest.java | 39 ++++++
11 files changed, 415 insertions(+), 65 deletions(-)
diff --git
a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json
b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json
index 24fc17d4c74..743ee4b948f 100644
---
a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json
+++
b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json
@@ -4,5 +4,6 @@
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should
run this test",
"https://github.com/apache/beam/pull/31268": "noting that PR #31268 should
run this test",
"https://github.com/apache/beam/pull/31490": "noting that PR #31490 should
run this test",
- "https://github.com/apache/beam/pull/35159": "moving WindowedValue and
making an interface"
+ "https://github.com/apache/beam/pull/35159": "moving WindowedValue and
making an interface",
+ "https://github.com/apache/beam/pull/36631": "dofn lifecycle",
}
diff --git
a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json
b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json
index 24fc17d4c74..47d924953c5 100644
---
a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json
+++
b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json
@@ -4,5 +4,6 @@
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should
run this test",
"https://github.com/apache/beam/pull/31268": "noting that PR #31268 should
run this test",
"https://github.com/apache/beam/pull/31490": "noting that PR #31490 should
run this test",
- "https://github.com/apache/beam/pull/35159": "moving WindowedValue and
making an interface"
+ "https://github.com/apache/beam/pull/35159": "moving WindowedValue and
making an interface",
+ "https://github.com/apache/beam/pull/36631": "dofn lifecycle validation",
}
diff --git a/runners/google-cloud-dataflow-java/build.gradle
b/runners/google-cloud-dataflow-java/build.gradle
index b4ba32c1cc9..415132fa7d2 100644
--- a/runners/google-cloud-dataflow-java/build.gradle
+++ b/runners/google-cloud-dataflow-java/build.gradle
@@ -205,7 +205,6 @@ def commonLegacyExcludeCategories = [
'org.apache.beam.sdk.testing.UsesGaugeMetrics',
'org.apache.beam.sdk.testing.UsesMultimapState',
'org.apache.beam.sdk.testing.UsesTestStream',
- 'org.apache.beam.sdk.testing.UsesParDoLifecycle', // doesn't support remote
runner
'org.apache.beam.sdk.testing.UsesMetricsPusher',
'org.apache.beam.sdk.testing.UsesBundleFinalizer',
'org.apache.beam.sdk.testing.UsesBoundedTrieMetrics', // Dataflow QM as of
now does not support returning back BoundedTrie in metric result.
@@ -452,7 +451,17 @@ task validatesRunner {
excludedTests: [
// TODO(https://github.com/apache/beam/issues/21472)
'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerUsingState',
- ]
+
+ // These tests use static state and don't work with remote execution.
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElementStateful',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetup',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetupStateful',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundle',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundleStateful',
+ ]
))
}
@@ -474,7 +483,17 @@ task validatesRunnerStreaming {
// GroupIntoBatches.withShardedKey not supported on streaming runner v1
// https://github.com/apache/beam/issues/22592
'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testWithShardedKeyInGlobalWindow',
- ]
+
+ // These tests use static state and don't work with remote execution.
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElementStateful',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetup',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetupStateful',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundle',
+
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundleStateful',
+]
))
}
@@ -543,8 +562,7 @@ task validatesRunnerV2 {
excludedTests: [
'org.apache.beam.sdk.transforms.ReshuffleTest.testReshuffleWithTimestampsStreaming',
- // TODO(https://github.com/apache/beam/issues/18592): respect ParDo
lifecycle.
-
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testFnCallSequenceStateful',
+ // These tests use static state and don't work with remote execution.
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle',
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful',
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement',
@@ -586,7 +604,7 @@ task validatesRunnerV2Streaming {
'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerUsingState',
'org.apache.beam.sdk.transforms.GroupByKeyTest.testCombiningAccumulatingProcessingTime',
- // TODO(https://github.com/apache/beam/issues/18592): respect ParDo
lifecycle.
+ // These tests use static state and don't work with remote execution.
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle',
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful',
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement',
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java
index 91fb640a175..d3f2aacc74d 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java
@@ -105,11 +105,32 @@ public class IntrinsicMapTaskExecutorFactory implements
DataflowMapTaskExecutorF
Networks.replaceDirectedNetworkNodes(
network, createOutputReceiversTransform(stageName, counterSet));
- // Swap out all the ParallelInstruction nodes with Operation nodes
- Networks.replaceDirectedNetworkNodes(
- network,
- createOperationTransformForParallelInstructionNodes(
- stageName, network, options, readerFactory, sinkFactory,
executionContext));
+ // Swap out all the ParallelInstruction nodes with Operation nodes. While
updating the network,
+ // we keep track of
+ // the created Operations so that if an exception is encountered we can
properly abort started
+ // operations.
+ ArrayList<Operation> createdOperations = new ArrayList<>();
+ try {
+ Networks.replaceDirectedNetworkNodes(
+ network,
+ createOperationTransformForParallelInstructionNodes(
+ stageName,
+ network,
+ options,
+ readerFactory,
+ sinkFactory,
+ executionContext,
+ createdOperations));
+ } catch (RuntimeException exn) {
+ for (Operation o : createdOperations) {
+ try {
+ o.abort();
+ } catch (Exception exn2) {
+ exn.addSuppressed(exn2);
+ }
+ }
+ throw exn;
+ }
// Collect all the operations within the network and attach all the
operations as receivers
// to preceding output receivers.
@@ -144,7 +165,8 @@ public class IntrinsicMapTaskExecutorFactory implements
DataflowMapTaskExecutorF
final PipelineOptions options,
final ReaderFactory readerFactory,
final SinkFactory sinkFactory,
- final DataflowExecutionContext<?> executionContext) {
+ final DataflowExecutionContext<?> executionContext,
+ final List<Operation> createdOperations) {
return new
TypeSafeNodeFunction<ParallelInstructionNode>(ParallelInstructionNode.class) {
@Override
@@ -156,20 +178,22 @@ public class IntrinsicMapTaskExecutorFactory implements
DataflowMapTaskExecutorF
instruction.getOriginalName(),
instruction.getSystemName(),
instruction.getName());
+ OperationNode result;
try {
DataflowOperationContext context =
executionContext.createOperationContext(nameContext);
if (instruction.getRead() != null) {
- return createReadOperation(
- network, node, options, readerFactory, executionContext,
context);
+ result =
+ createReadOperation(
+ network, node, options, readerFactory, executionContext,
context);
} else if (instruction.getWrite() != null) {
- return createWriteOperation(node, options, sinkFactory,
executionContext, context);
+ result = createWriteOperation(node, options, sinkFactory,
executionContext, context);
} else if (instruction.getParDo() != null) {
- return createParDoOperation(network, node, options,
executionContext, context);
+ result = createParDoOperation(network, node, options,
executionContext, context);
} else if (instruction.getPartialGroupByKey() != null) {
- return createPartialGroupByKeyOperation(
- network, node, options, executionContext, context);
+ result =
+ createPartialGroupByKeyOperation(network, node, options,
executionContext, context);
} else if (instruction.getFlatten() != null) {
- return createFlattenOperation(network, node, context);
+ result = createFlattenOperation(network, node, context);
} else {
throw new IllegalArgumentException(
String.format("Unexpected instruction: %s", instruction));
@@ -177,6 +201,8 @@ public class IntrinsicMapTaskExecutorFactory implements
DataflowMapTaskExecutorF
} catch (Exception e) {
throw new RuntimeException(e);
}
+ createdOperations.add(result.getOperation());
+ return result;
}
};
}
@@ -328,7 +354,6 @@ public class IntrinsicMapTaskExecutorFactory implements
DataflowMapTaskExecutorF
Coder<?> coder =
CloudObjects.coderFromCloudObject(CloudObject.fromSpec(cloudOutput.getCodec()));
- @SuppressWarnings("unchecked")
ElementCounter outputCounter =
new DataflowOutputCounter(
cloudOutput.getName(),
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java
index 877e3198e91..58b95f286d5 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java
@@ -18,8 +18,8 @@
package org.apache.beam.runners.dataflow.worker.util.common.worker;
import java.io.Closeable;
+import java.util.ArrayList;
import java.util.List;
-import java.util.ListIterator;
import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
import org.apache.beam.runners.dataflow.worker.counters.CounterSet;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
@@ -36,7 +36,9 @@ public class MapTaskExecutor implements WorkExecutor {
private static final Logger LOG =
LoggerFactory.getLogger(MapTaskExecutor.class);
/** The operations in the map task, in execution order. */
- public final List<Operation> operations;
+ public final ArrayList<Operation> operations;
+
+ private boolean closed = false;
private final ExecutionStateTracker executionStateTracker;
@@ -54,7 +56,7 @@ public class MapTaskExecutor implements WorkExecutor {
CounterSet counters,
ExecutionStateTracker executionStateTracker) {
this.counters = counters;
- this.operations = operations;
+ this.operations = new ArrayList<>(operations);
this.executionStateTracker = executionStateTracker;
}
@@ -63,6 +65,7 @@ public class MapTaskExecutor implements WorkExecutor {
return counters;
}
+ /** May be reused if execute() returns without an exception being thrown. */
@Override
public void execute() throws Exception {
LOG.debug("Executing map task");
@@ -74,13 +77,11 @@ public class MapTaskExecutor implements WorkExecutor {
// Starting a root operation such as a ReadOperation does the work
// of processing the input dataset.
LOG.debug("Starting operations");
- ListIterator<Operation> iterator =
operations.listIterator(operations.size());
- while (iterator.hasPrevious()) {
+ for (int i = operations.size() - 1; i >= 0; --i) {
if (Thread.currentThread().isInterrupted()) {
throw new InterruptedException("Worker aborted");
}
- Operation op = iterator.previous();
- op.start();
+ operations.get(i).start();
}
// Finish operations, in forward-execution-order, so that a
@@ -94,16 +95,13 @@ public class MapTaskExecutor implements WorkExecutor {
op.finish();
}
} catch (Exception | Error exn) {
- LOG.debug("Aborting operations", exn);
- for (Operation op : operations) {
- try {
- op.abort();
- } catch (Exception | Error exn2) {
- exn.addSuppressed(exn2);
- if (exn2 instanceof InterruptedException) {
- Thread.currentThread().interrupt();
- }
- }
+ try {
+ closeInternal();
+ } catch (Exception closeExn) {
+ exn.addSuppressed(closeExn);
+ }
+ if (exn instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
}
throw exn;
}
@@ -164,6 +162,45 @@ public class MapTaskExecutor implements WorkExecutor {
}
}
+ private void closeInternal() throws Exception {
+ if (closed) {
+ return;
+ }
+ LOG.debug("Aborting operations");
+ @Nullable Exception exn = null;
+ for (Operation op : operations) {
+ try {
+ op.abort();
+ } catch (Exception | Error exn2) {
+ if (exn2 instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ }
+ if (exn == null) {
+ if (exn2 instanceof Exception) {
+ exn = (Exception) exn2;
+ } else {
+ exn = new RuntimeException(exn2);
+ }
+ } else {
+ exn.addSuppressed(exn2);
+ }
+ }
+ }
+ closed = true;
+ if (exn != null) {
+ throw exn;
+ }
+ }
+
+ @Override
+ public void close() {
+ try {
+ closeInternal();
+ } catch (Exception e) {
+ LOG.error("Exception while closing MapTaskExecutor, ignoring", e);
+ }
+ }
+
@Override
public List<Integer> reportProducedEmptyOutput() {
List<Integer> emptyOutputSinkIndexes = Lists.newArrayList();
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java
index a4cd5d6d8a6..e61c2d1f4a0 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java
@@ -415,6 +415,7 @@ public class StreamingWorkScheduler {
// Release the execution state for another thread to use.
computationState.releaseComputationWorkExecutor(computationWorkExecutor);
+ computationWorkExecutor = null;
work.setState(Work.State.COMMIT_QUEUED);
outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler));
@@ -422,11 +423,13 @@ public class StreamingWorkScheduler {
return ExecuteWorkResult.create(
outputBuilder, stateReader.getBytesRead() +
localSideInputStateFetcher.getBytesRead());
} catch (Throwable t) {
- // If processing failed due to a thrown exception, close the
executionState. Do not
- // return/release the executionState back to computationState as that
will lead to this
- // executionState instance being reused.
- LOG.debug("Invalidating executor after work item {} failed",
workItem.getWorkToken(), t);
- computationWorkExecutor.invalidate();
+ if (computationWorkExecutor != null) {
+ // If processing failed due to a thrown exception, close the
executionState. Do not
+ // return/release the executionState back to computationState as that
will lead to this
+ // executionState instance being reused.
+ LOG.debug("Invalidating executor after work item {} failed",
workItem.getWorkToken(), t);
+ computationWorkExecutor.invalidate();
+ }
// Re-throw the exception, it will be caught and handled by
workFailureProcessor downstream.
throw t;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java
index e77ae309d35..3443ae0022b 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java
@@ -24,11 +24,16 @@ import static
org.apache.beam.runners.dataflow.worker.counters.CounterName.named
import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray;
import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString;
import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.hasItems;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
@@ -52,6 +57,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import org.apache.beam.runners.dataflow.util.CloudObject;
import org.apache.beam.runners.dataflow.util.CloudObjects;
@@ -254,8 +260,9 @@ public class IntrinsicMapTaskExecutorFactoryTest {
List<ParallelInstruction> instructions =
Arrays.asList(
createReadInstruction("Read",
ReaderFactoryTest.SingletonTestReaderFactory.class),
- createParDoInstruction(0, 0, "DoFn1", "DoFnUserName"),
- createParDoInstruction(1, 0, "DoFnWithContext",
"DoFnWithContextUserName"));
+ createParDoInstruction(0, 0, "DoFn1", "DoFnUserName", new
TestDoFn()),
+ createParDoInstruction(
+ 1, 0, "DoFnWithContext", "DoFnWithContextUserName", new
TestDoFn()));
MapTask mapTask = new MapTask();
mapTask.setStageName(STAGE);
@@ -330,6 +337,7 @@ public class IntrinsicMapTaskExecutorFactoryTest {
PCOLLECTION_ID))));
when(network.outDegree(instructionNode)).thenReturn(1);
+ ArrayList<Operation> createdOperations = new ArrayList<>();
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
@@ -338,11 +346,13 @@ public class IntrinsicMapTaskExecutorFactoryTest {
PipelineOptionsFactory.create(),
readerRegistry,
sinkRegistry,
- BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"))
+ BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"),
+ createdOperations)
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(),
instanceOf(ReadOperation.class));
ReadOperation readOperation = (ReadOperation) ((OperationNode)
operationNode).getOperation();
+ assertThat(createdOperations, contains(readOperation));
assertEquals(1, readOperation.receivers.length);
assertEquals(0, readOperation.receivers[0].getReceiverCount());
@@ -391,6 +401,7 @@ public class IntrinsicMapTaskExecutorFactoryTest {
ParallelInstructionNode.create(
createWriteInstruction(producerIndex, producerOutputNum,
"WriteOperation"),
ExecutionLocation.UNKNOWN);
+ ArrayList<Operation> createdOperations = new ArrayList<>();
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
@@ -399,11 +410,13 @@ public class IntrinsicMapTaskExecutorFactoryTest {
options,
readerRegistry,
sinkRegistry,
- BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"))
+ BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"),
+ createdOperations)
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(),
instanceOf(WriteOperation.class));
WriteOperation writeOperation = (WriteOperation) ((OperationNode)
operationNode).getOperation();
+ assertThat(createdOperations, contains(writeOperation));
assertEquals(0, writeOperation.receivers.length);
assertEquals(Operation.InitializationState.UNSTARTED,
writeOperation.initializationState);
@@ -461,17 +474,15 @@ public class IntrinsicMapTaskExecutorFactoryTest {
static ParallelInstruction createParDoInstruction(
int producerIndex, int producerOutputNum, String systemName) {
- return createParDoInstruction(producerIndex, producerOutputNum,
systemName, "");
+ return createParDoInstruction(producerIndex, producerOutputNum,
systemName, "", new TestDoFn());
}
static ParallelInstruction createParDoInstruction(
- int producerIndex, int producerOutputNum, String systemName, String
userName) {
+ int producerIndex, int producerOutputNum, String systemName, String
userName, DoFn<?, ?> fn) {
InstructionInput cloudInput = new InstructionInput();
cloudInput.setProducerInstructionIndex(producerIndex);
cloudInput.setOutputNum(producerOutputNum);
- TestDoFn fn = new TestDoFn();
-
String serializedFn =
StringUtils.byteArrayToJsonString(
SerializableUtils.serializeToByteArray(
@@ -541,14 +552,16 @@ public class IntrinsicMapTaskExecutorFactoryTest {
.getMultiOutputInfos()
.get(0))));
+ ArrayList<Operation> createdOperations = new ArrayList<>();
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
- STAGE, network, options, readerRegistry, sinkRegistry, context)
+ STAGE, network, options, readerRegistry, sinkRegistry,
context, createdOperations)
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(),
instanceOf(ParDoOperation.class));
ParDoOperation parDoOperation = (ParDoOperation) ((OperationNode)
operationNode).getOperation();
+ assertThat(createdOperations, contains(parDoOperation));
assertEquals(1, parDoOperation.receivers.length);
assertEquals(0, parDoOperation.receivers[0].getReceiverCount());
@@ -608,6 +621,7 @@ public class IntrinsicMapTaskExecutorFactoryTest {
PCOLLECTION_ID))));
when(network.outDegree(instructionNode)).thenReturn(1);
+ ArrayList<Operation> createdOperations = new ArrayList<>();
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
@@ -616,11 +630,13 @@ public class IntrinsicMapTaskExecutorFactoryTest {
PipelineOptionsFactory.create(),
readerRegistry,
sinkRegistry,
- BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"))
+ BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"),
+ createdOperations)
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(),
instanceOf(ParDoOperation.class));
ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode)
operationNode).getOperation();
+ assertThat(createdOperations, contains(pgbkOperation));
assertEquals(1, pgbkOperation.receivers.length);
assertEquals(0, pgbkOperation.receivers[0].getReceiverCount());
@@ -660,6 +676,7 @@ public class IntrinsicMapTaskExecutorFactoryTest {
PCOLLECTION_ID))));
when(network.outDegree(instructionNode)).thenReturn(1);
+ ArrayList<Operation> createdOperations = new ArrayList<>();
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
@@ -668,11 +685,13 @@ public class IntrinsicMapTaskExecutorFactoryTest {
options,
readerRegistry,
sinkRegistry,
- BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"))
+ BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"),
+ createdOperations)
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(),
instanceOf(ParDoOperation.class));
ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode)
operationNode).getOperation();
+ assertThat(createdOperations, contains(pgbkOperation));
assertEquals(1, pgbkOperation.receivers.length);
assertEquals(0, pgbkOperation.receivers[0].getReceiverCount());
@@ -738,6 +757,7 @@ public class IntrinsicMapTaskExecutorFactoryTest {
PCOLLECTION_ID))));
when(network.outDegree(instructionNode)).thenReturn(1);
+ ArrayList<Operation> createdOperations = new ArrayList<>();
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
@@ -746,15 +766,108 @@ public class IntrinsicMapTaskExecutorFactoryTest {
options,
readerRegistry,
sinkRegistry,
- BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"))
+ BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"),
+ createdOperations)
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(),
instanceOf(FlattenOperation.class));
FlattenOperation flattenOperation =
(FlattenOperation) ((OperationNode) operationNode).getOperation();
+ assertThat(createdOperations, contains(flattenOperation));
assertEquals(1, flattenOperation.receivers.length);
assertEquals(0, flattenOperation.receivers[0].getReceiverCount());
assertEquals(Operation.InitializationState.UNSTARTED,
flattenOperation.initializationState);
}
+
+ static class TestTeardownDoFn extends DoFn<String, String> {
+ static AtomicInteger setupCalls = new AtomicInteger();
+ static AtomicInteger teardownCalls = new AtomicInteger();
+
+ private final boolean throwExceptionOnSetup;
+ private boolean setupCalled = false;
+
+ TestTeardownDoFn(boolean throwExceptionOnSetup) {
+ this.throwExceptionOnSetup = throwExceptionOnSetup;
+ }
+
+ @Setup
+ public void setup() {
+ assertFalse(setupCalled);
+ setupCalled = true;
+ setupCalls.addAndGet(1);
+ if (throwExceptionOnSetup) {
+ throw new RuntimeException("Test setup exception");
+ }
+ }
+
+ @ProcessElement
+ public void process(ProcessContext c) {
+ fail("no elements should be processed");
+ }
+
+ @Teardown
+ public void teardown() {
+ assertTrue(setupCalled);
+ setupCalled = false;
+ teardownCalls.addAndGet(1);
+ }
+ }
+
+ @Test
+ public void testCreateMapTaskExecutorException() throws Exception {
+ List<ParallelInstruction> instructions =
+ Arrays.asList(
+ createReadInstruction("Read"),
+ createParDoInstruction(0, 0, "DoFn1", "DoFn1", new
TestTeardownDoFn(false)),
+ createParDoInstruction(0, 0, "DoFn2", "DoFn2", new
TestTeardownDoFn(false)),
+ createParDoInstruction(0, 0, "ErrorFn", "", new
TestTeardownDoFn(true)),
+ createParDoInstruction(0, 0, "DoFn3", "DoFn3", new
TestTeardownDoFn(false)),
+ createFlattenInstruction(1, 0, 2, 0, "Flatten"),
+ createWriteInstruction(3, 0, "Write"));
+
+ MapTask mapTask = new MapTask();
+ mapTask.setStageName(STAGE);
+ mapTask.setSystemName("systemName");
+ mapTask.setInstructions(instructions);
+ mapTask.setFactory(Transport.getJsonFactory());
+
+ assertThrows(
+ "Test setup exception",
+ RuntimeException.class,
+ () ->
+ mapTaskExecutorFactory.create(
+ mapTaskToNetwork.apply(mapTask),
+ options,
+ STAGE,
+ readerRegistry,
+ sinkRegistry,
+ BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"),
+ counterSet,
+ idGenerator));
+ assertEquals(3, TestTeardownDoFn.setupCalls.getAndSet(0));
+ // We only tear-down the instruction we were unable to create. The other
+ // infos are cached within UserParDoFnFactory and not torn-down.
+ assertEquals(1, TestTeardownDoFn.teardownCalls.getAndSet(0));
+
+ assertThrows(
+ "Test setup exception",
+ RuntimeException.class,
+ () ->
+ mapTaskExecutorFactory.create(
+ mapTaskToNetwork.apply(mapTask),
+ options,
+ STAGE,
+ readerRegistry,
+ sinkRegistry,
+ BatchModeExecutionContext.forTesting(options, counterSet,
"testStage"),
+ counterSet,
+ idGenerator));
+ // The non-erroring functions are cached, and a new setup call is called on
+ // erroring dofn.
+ assertEquals(1, TestTeardownDoFn.setupCalls.get());
+ // We only tear-down the instruction we were unable to create. The other
+ // infos are cached within UserParDoFnFactory and not torn-down.
+ assertEquals(1, TestTeardownDoFn.teardownCalls.get());
+ }
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java
index bb92fca3d8b..9e45425562a 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java
@@ -198,7 +198,7 @@ public class SimpleParDoFnTest {
new TestDoFn(
ImmutableList.of(
new TupleTag<>("tag1"), new TupleTag<>("tag2"), new
TupleTag<>("tag3")));
- DoFnInfo<?, ?> fnInfo =
+ DoFnInfo<Integer, String> fnInfo =
DoFnInfo.forFn(
fn,
WindowingStrategy.globalDefault(),
@@ -279,7 +279,7 @@ public class SimpleParDoFnTest {
@SuppressWarnings("AssertionFailureIgnored")
public void testUnexpectedNumberOfReceivers() throws Exception {
TestDoFn fn = new TestDoFn(Collections.emptyList());
- DoFnInfo<?, ?> fnInfo =
+ DoFnInfo<Integer, String> fnInfo =
DoFnInfo.forFn(
fn,
WindowingStrategy.globalDefault(),
@@ -330,7 +330,7 @@ public class SimpleParDoFnTest {
@Test
public void testErrorPropagation() throws Exception {
TestErrorDoFn fn = new TestErrorDoFn();
- DoFnInfo<?, ?> fnInfo =
+ DoFnInfo<Integer, String> fnInfo =
DoFnInfo.forFn(
fn,
WindowingStrategy.globalDefault(),
@@ -423,7 +423,7 @@ public class SimpleParDoFnTest {
new TupleTag<>("undecl1"),
new TupleTag<>("undecl2"),
new TupleTag<>("undecl3")));
- DoFnInfo<?, ?> fnInfo =
+ DoFnInfo<Integer, String> fnInfo =
DoFnInfo.forFn(
fn,
WindowingStrategy.globalDefault(),
@@ -485,7 +485,7 @@ public class SimpleParDoFnTest {
}
StateTestingDoFn fn = new StateTestingDoFn();
- DoFnInfo<?, ?> fnInfo =
+ DoFnInfo<Integer, String> fnInfo =
DoFnInfo.forFn(
fn,
WindowingStrategy.globalDefault(),
@@ -578,7 +578,7 @@ public class SimpleParDoFnTest {
}
DoFn<Integer, String> fn = new RepeaterDoFn();
- DoFnInfo<?, ?> fnInfo =
+ DoFnInfo<Integer, String> fnInfo =
DoFnInfo.forFn(
fn,
WindowingStrategy.globalDefault(),
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index e16a8b9f88c..df90bb96139 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -3276,6 +3276,9 @@ public class StreamingDataflowWorkerTest {
TestCountingSource counter = new
TestCountingSource(3).withThrowOnFirstSnapshot(true);
+ // Reset static state that may leak across tests.
+ TestExceptionInvalidatesCacheFn.resetStaticState();
+ TestCountingSource.resetStaticState();
List<ParallelInstruction> instructions =
Arrays.asList(
new ParallelInstruction()
@@ -3310,7 +3313,10 @@ public class StreamingDataflowWorkerTest {
.build());
worker.start();
- // Three GetData requests
+ // Three GetData requests:
+ // - first processing has no state
+ // - recovering from checkpoint exception has no persisted state
+ // - recovering from processing exception recovers last committed state
for (int i = 0; i < 3; i++) {
ByteString state;
if (i == 0 || i == 1) {
@@ -3437,6 +3443,11 @@ public class StreamingDataflowWorkerTest {
parseCommitRequest(sb.toString()))
.build()));
}
+
+ // Ensure that the invalidated dofn had tearDown called on them.
+ assertEquals(1, TestExceptionInvalidatesCacheFn.tearDownCallCount.get());
+ assertEquals(2, TestExceptionInvalidatesCacheFn.setupCallCount.get());
+
worker.stop();
}
@@ -3484,7 +3495,7 @@ public class StreamingDataflowWorkerTest {
}
@Test
- public void testActiveWorkFailure() throws Exception {
+ public void testQueuedWorkFailure() throws Exception {
List<ParallelInstruction> instructions =
Arrays.asList(
makeSourceInstruction(StringUtf8Coder.of()),
@@ -3515,6 +3526,9 @@ public class StreamingDataflowWorkerTest {
server.whenGetWorkCalled().thenReturn(workItem).thenReturn(workItemToFail);
server.waitForEmptyWorkQueue();
+ // Wait for key to schedule, it will be blocked.
+ BlockingFn.counter().acquire(1);
+
// Mock Windmill sending a heartbeat response failing the second work item
while the first
// is still processing.
ComputationHeartbeatResponse.Builder failedHeartbeat =
@@ -3534,6 +3548,64 @@ public class StreamingDataflowWorkerTest {
server.waitForAndGetCommitsWithTimeout(1,
Duration.standardSeconds((5)));
assertEquals(1, commits.size());
+ assertEquals(0, BlockingFn.teardownCounter.get());
+ assertEquals(1, BlockingFn.setupCounter.get());
+
+ worker.stop();
+ }
+
+ @Test
+ public void testActiveWorkFailure() throws Exception {
+ List<ParallelInstruction> instructions =
+ Arrays.asList(
+ makeSourceInstruction(StringUtf8Coder.of()),
+ makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()),
+ makeSinkInstruction(StringUtf8Coder.of(), 0));
+
+ StreamingDataflowWorker worker =
+ makeWorker(
+ defaultWorkerParams("--activeWorkRefreshPeriodMillis=100")
+ .setInstructions(instructions)
+ .publishCounters()
+ .build());
+ worker.start();
+
+ GetWorkResponse workItemToFail =
+ makeInput(0, TimeUnit.MILLISECONDS.toMicros(0), "key",
DEFAULT_SHARDING_KEY);
+ long failedWorkToken = workItemToFail.getWork(0).getWork(0).getWorkToken();
+ long failedCacheToken =
workItemToFail.getWork(0).getWork(0).getCacheToken();
+ GetWorkResponse workItem =
+ makeInput(1, TimeUnit.MILLISECONDS.toMicros(0), "key",
DEFAULT_SHARDING_KEY);
+
+ // Queue up the work item for the key.
+ server.whenGetWorkCalled().thenReturn(workItemToFail).thenReturn(workItem);
+ server.waitForEmptyWorkQueue();
+
+ // Wait for key to schedule, it will be blocked.
+ BlockingFn.counter().acquire(1);
+
+ // Mock Windmill sending a heartbeat response failing the first work item
while it is
+ // is processing.
+ ComputationHeartbeatResponse.Builder failedHeartbeat =
+ ComputationHeartbeatResponse.newBuilder();
+ failedHeartbeat
+ .setComputationId(DEFAULT_COMPUTATION_ID)
+ .addHeartbeatResponsesBuilder()
+ .setCacheToken(failedCacheToken)
+ .setWorkToken(failedWorkToken)
+ .setShardingKey(DEFAULT_SHARDING_KEY)
+ .setFailed(true);
+
server.sendFailedHeartbeats(Collections.singletonList(failedHeartbeat.build()));
+
+ // Release the blocked call, there should not be a commit and the dofn
should be invalidated.
+ BlockingFn.blocker().countDown();
+ Map<Long, Windmill.WorkItemCommitRequest> commits =
+ server.waitForAndGetCommitsWithTimeout(1,
Duration.standardSeconds((5)));
+ assertEquals(1, commits.size());
+
+ assertEquals(0, BlockingFn.teardownCounter.get());
+ assertEquals(1, BlockingFn.setupCounter.get());
+
worker.stop();
}
@@ -4246,6 +4318,18 @@ public class StreamingDataflowWorkerTest {
new AtomicReference<>(new CountDownLatch(1));
public static AtomicReference<Semaphore> counter = new
AtomicReference<>(new Semaphore(0));
public static AtomicInteger callCounter = new AtomicInteger(0);
+ public static AtomicInteger setupCounter = new AtomicInteger(0);
+ public static AtomicInteger teardownCounter = new AtomicInteger(0);
+
+ @Setup
+ public void setup() {
+ setupCounter.incrementAndGet();
+ }
+
+ @Teardown
+ public void tearDown() {
+ teardownCounter.incrementAndGet();
+ }
@ProcessElement
public void processElement(ProcessContext c) throws InterruptedException {
@@ -4278,6 +4362,8 @@ public class StreamingDataflowWorkerTest {
blocker.set(new CountDownLatch(1));
counter.set(new Semaphore(0));
callCounter.set(0);
+ setupCounter.set(0);
+ teardownCounter.set(0);
}
}
};
@@ -4397,11 +4483,33 @@ public class StreamingDataflowWorkerTest {
static class TestExceptionInvalidatesCacheFn
extends DoFn<ValueWithRecordId<KV<Integer, Integer>>, String> {
- static boolean thrown = false;
+ public static AtomicInteger setupCallCount = new AtomicInteger();
+ public static AtomicInteger tearDownCallCount = new AtomicInteger();
+ private static boolean thrown = false;
+ private boolean setupCalled = false;
+
+ static void resetStaticState() {
+ setupCallCount.set(0);
+ tearDownCallCount.set(0);
+ thrown = false;
+ }
@StateId("int")
private final StateSpec<ValueState<Integer>> counter =
StateSpecs.value(VarIntCoder.of());
+ @Setup
+ public void setUp() {
+ assertFalse(setupCalled);
+ setupCalled = true;
+ setupCallCount.addAndGet(1);
+ }
+
+ @Teardown
+ public void tearDown() {
+ assertTrue(setupCalled);
+ tearDownCallCount.addAndGet(1);
+ }
+
@ProcessElement
public void processElement(ProcessContext c, @StateId("int")
ValueState<Integer> state)
throws Exception {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java
index 6771e9dbb71..21e4d8c55e7 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java
@@ -65,6 +65,11 @@ public class TestCountingSource
TestCountingSource.finalizeTracker = finalizeTracker;
}
+ public static void resetStaticState() {
+ finalizeTracker = null;
+ thrown = false;
+ }
+
public TestCountingSource(int numMessagesPerShard) {
this(numMessagesPerShard, 0, false, false, true);
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java
index 2eeaa06eb5e..188466a5057 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java
@@ -519,4 +519,43 @@ public class MapTaskExecutorTest {
Mockito.verify(o2, atLeastOnce()).abortReadLoop();
Mockito.verify(stateTracker).deactivate();
}
+
+ @Test
+ public void testCloseAbortsOperations() throws Exception {
+ Operation o1 = Mockito.mock(Operation.class);
+ Operation o2 = Mockito.mock(Operation.class);
+ List<Operation> operations = Arrays.asList(o1, o2);
+ ExecutionStateTracker stateTracker =
Mockito.spy(ExecutionStateTracker.newForTest());
+ Mockito.verifyNoMoreInteractions(stateTracker);
+ try (MapTaskExecutor executor = new MapTaskExecutor(operations,
counterSet, stateTracker)) {}
+
+ Mockito.verify(o1).abort();
+ Mockito.verify(o2).abort();
+ }
+
+ @Test
+ public void testExceptionAndThenCloseAbortsJustOnce() throws Exception {
+ Operation o1 = Mockito.mock(Operation.class);
+ Operation o2 = Mockito.mock(Operation.class);
+ Mockito.doThrow(new Exception("in start")).when(o2).start();
+
+ ExecutionStateTracker stateTracker =
Mockito.spy(ExecutionStateTracker.newForTest());
+ MapTaskExecutor executor = new MapTaskExecutor(Arrays.asList(o1, o2),
counterSet, stateTracker);
+ try {
+ executor.execute();
+ fail("Should have thrown");
+ } catch (Exception e) {
+ }
+ InOrder inOrder = Mockito.inOrder(o2, stateTracker);
+ inOrder.verify(stateTracker).activate();
+ inOrder.verify(o2).start();
+ inOrder.verify(o2).abort();
+ inOrder.verify(stateTracker).deactivate();
+
+ // Order of o1 abort doesn't matter
+ Mockito.verify(o1).abort();
+ Mockito.verifyNoMoreInteractions(o1);
+ // Closing after already closed should not call abort again.
+ executor.close();
+ }
}