This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 57320861aa2 Ensure that Operations are aborted when MapTaskExecutor is 
closed. Add tests around setup/teardown of DoFns (#36631)
57320861aa2 is described below

commit 57320861aa24516ec75d4b47d7d28d95aacb2010
Author: Sam Whittle <[email protected]>
AuthorDate: Mon Nov 17 09:54:23 2025 +0100

    Ensure that Operations are aborted when MapTaskExecutor is closed. Add 
tests around setup/teardown of DoFns (#36631)
---
 ...it_Java_ValidatesRunner_Dataflow_Streaming.json |   3 +-
 ...ostCommit_Java_ValidatesRunner_Dataflow_V2.json |   3 +-
 runners/google-cloud-dataflow-java/build.gradle    |  30 ++++-
 .../worker/IntrinsicMapTaskExecutorFactory.java    |  53 +++++---
 .../worker/util/common/worker/MapTaskExecutor.java |  71 ++++++++---
 .../work/processing/StreamingWorkScheduler.java    |  13 +-
 .../IntrinsicMapTaskExecutorFactoryTest.java       | 137 +++++++++++++++++++--
 .../runners/dataflow/worker/SimpleParDoFnTest.java |  12 +-
 .../worker/StreamingDataflowWorkerTest.java        | 114 ++++++++++++++++-
 .../worker/testing/TestCountingSource.java         |   5 +
 .../util/common/worker/MapTaskExecutorTest.java    |  39 ++++++
 11 files changed, 415 insertions(+), 65 deletions(-)

diff --git 
a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json
 
b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json
index 24fc17d4c74..743ee4b948f 100644
--- 
a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json
+++ 
b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json
@@ -4,5 +4,6 @@
   "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should 
run this test",
   "https://github.com/apache/beam/pull/31268": "noting that PR #31268 should 
run this test",
   "https://github.com/apache/beam/pull/31490": "noting that PR #31490 should 
run this test",
-  "https://github.com/apache/beam/pull/35159": "moving WindowedValue and 
making an interface"
+  "https://github.com/apache/beam/pull/35159": "moving WindowedValue and 
making an interface",
+  "https://github.com/apache/beam/pull/36631": "dofn lifecycle",
 }
diff --git 
a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json 
b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json
index 24fc17d4c74..47d924953c5 100644
--- 
a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json
+++ 
b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json
@@ -4,5 +4,6 @@
   "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should 
run this test",
   "https://github.com/apache/beam/pull/31268": "noting that PR #31268 should 
run this test",
   "https://github.com/apache/beam/pull/31490": "noting that PR #31490 should 
run this test",
-  "https://github.com/apache/beam/pull/35159": "moving WindowedValue and 
making an interface"
+  "https://github.com/apache/beam/pull/35159": "moving WindowedValue and 
making an interface",
+  "https://github.com/apache/beam/pull/36631": "dofn lifecycle validation",
 }
diff --git a/runners/google-cloud-dataflow-java/build.gradle 
b/runners/google-cloud-dataflow-java/build.gradle
index b4ba32c1cc9..415132fa7d2 100644
--- a/runners/google-cloud-dataflow-java/build.gradle
+++ b/runners/google-cloud-dataflow-java/build.gradle
@@ -205,7 +205,6 @@ def commonLegacyExcludeCategories = [
   'org.apache.beam.sdk.testing.UsesGaugeMetrics',
   'org.apache.beam.sdk.testing.UsesMultimapState',
   'org.apache.beam.sdk.testing.UsesTestStream',
-  'org.apache.beam.sdk.testing.UsesParDoLifecycle', // doesn't support remote 
runner
   'org.apache.beam.sdk.testing.UsesMetricsPusher',
   'org.apache.beam.sdk.testing.UsesBundleFinalizer',
   'org.apache.beam.sdk.testing.UsesBoundedTrieMetrics', // Dataflow QM as of 
now does not support returning back BoundedTrie in metric result.
@@ -452,7 +451,17 @@ task validatesRunner {
     excludedTests: [
       // TODO(https://github.com/apache/beam/issues/21472)
       
'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerUsingState',
-    ]
+
+      // These tests use static state and don't work with remote execution.
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElementStateful',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetup',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetupStateful',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundle',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundleStateful',
+      ]
   ))
 }
 
@@ -474,7 +483,17 @@ task validatesRunnerStreaming {
       // GroupIntoBatches.withShardedKey not supported on streaming runner v1
       // https://github.com/apache/beam/issues/22592
       
'org.apache.beam.sdk.transforms.GroupIntoBatchesTest.testWithShardedKeyInGlobalWindow',
-    ]
+
+      // These tests use static state and don't work with remote execution.
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElementStateful',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetup',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetupStateful',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundle',
+      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundleStateful',
+]
   ))
 }
 
@@ -543,8 +562,7 @@ task validatesRunnerV2 {
     excludedTests: [
       
'org.apache.beam.sdk.transforms.ReshuffleTest.testReshuffleWithTimestampsStreaming',
 
-      // TODO(https://github.com/apache/beam/issues/18592): respect ParDo 
lifecycle.
-      
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testFnCallSequenceStateful',
+      // These tests use static state and don't work with remote execution.
       
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle',
       
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful',
       
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement',
@@ -586,7 +604,7 @@ task validatesRunnerV2Streaming {
       
'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerUsingState',
       
'org.apache.beam.sdk.transforms.GroupByKeyTest.testCombiningAccumulatingProcessingTime',
 
-      // TODO(https://github.com/apache/beam/issues/18592): respect ParDo 
lifecycle.
+      // These tests use static state and don't work with remote execution.
       
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle',
       
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful',
       
'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement',
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java
index 91fb640a175..d3f2aacc74d 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactory.java
@@ -105,11 +105,32 @@ public class IntrinsicMapTaskExecutorFactory implements 
DataflowMapTaskExecutorF
     Networks.replaceDirectedNetworkNodes(
         network, createOutputReceiversTransform(stageName, counterSet));
 
-    // Swap out all the ParallelInstruction nodes with Operation nodes
-    Networks.replaceDirectedNetworkNodes(
-        network,
-        createOperationTransformForParallelInstructionNodes(
-            stageName, network, options, readerFactory, sinkFactory, 
executionContext));
+    // Swap out all the ParallelInstruction nodes with Operation nodes. While 
updating the network,
+    // we keep track of
+    // the created Operations so that if an exception is encountered we can 
properly abort started
+    // operations.
+    ArrayList<Operation> createdOperations = new ArrayList<>();
+    try {
+      Networks.replaceDirectedNetworkNodes(
+          network,
+          createOperationTransformForParallelInstructionNodes(
+              stageName,
+              network,
+              options,
+              readerFactory,
+              sinkFactory,
+              executionContext,
+              createdOperations));
+    } catch (RuntimeException exn) {
+      for (Operation o : createdOperations) {
+        try {
+          o.abort();
+        } catch (Exception exn2) {
+          exn.addSuppressed(exn2);
+        }
+      }
+      throw exn;
+    }
 
     // Collect all the operations within the network and attach all the 
operations as receivers
     // to preceding output receivers.
@@ -144,7 +165,8 @@ public class IntrinsicMapTaskExecutorFactory implements 
DataflowMapTaskExecutorF
       final PipelineOptions options,
       final ReaderFactory readerFactory,
       final SinkFactory sinkFactory,
-      final DataflowExecutionContext<?> executionContext) {
+      final DataflowExecutionContext<?> executionContext,
+      final List<Operation> createdOperations) {
 
     return new 
TypeSafeNodeFunction<ParallelInstructionNode>(ParallelInstructionNode.class) {
       @Override
@@ -156,20 +178,22 @@ public class IntrinsicMapTaskExecutorFactory implements 
DataflowMapTaskExecutorF
                 instruction.getOriginalName(),
                 instruction.getSystemName(),
                 instruction.getName());
+        OperationNode result;
         try {
           DataflowOperationContext context = 
executionContext.createOperationContext(nameContext);
           if (instruction.getRead() != null) {
-            return createReadOperation(
-                network, node, options, readerFactory, executionContext, 
context);
+            result =
+                createReadOperation(
+                    network, node, options, readerFactory, executionContext, 
context);
           } else if (instruction.getWrite() != null) {
-            return createWriteOperation(node, options, sinkFactory, 
executionContext, context);
+            result = createWriteOperation(node, options, sinkFactory, 
executionContext, context);
           } else if (instruction.getParDo() != null) {
-            return createParDoOperation(network, node, options, 
executionContext, context);
+            result = createParDoOperation(network, node, options, 
executionContext, context);
           } else if (instruction.getPartialGroupByKey() != null) {
-            return createPartialGroupByKeyOperation(
-                network, node, options, executionContext, context);
+            result =
+                createPartialGroupByKeyOperation(network, node, options, 
executionContext, context);
           } else if (instruction.getFlatten() != null) {
-            return createFlattenOperation(network, node, context);
+            result = createFlattenOperation(network, node, context);
           } else {
             throw new IllegalArgumentException(
                 String.format("Unexpected instruction: %s", instruction));
@@ -177,6 +201,8 @@ public class IntrinsicMapTaskExecutorFactory implements 
DataflowMapTaskExecutorF
         } catch (Exception e) {
           throw new RuntimeException(e);
         }
+        createdOperations.add(result.getOperation());
+        return result;
       }
     };
   }
@@ -328,7 +354,6 @@ public class IntrinsicMapTaskExecutorFactory implements 
DataflowMapTaskExecutorF
         Coder<?> coder =
             
CloudObjects.coderFromCloudObject(CloudObject.fromSpec(cloudOutput.getCodec()));
 
-        @SuppressWarnings("unchecked")
         ElementCounter outputCounter =
             new DataflowOutputCounter(
                 cloudOutput.getName(),
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java
index 877e3198e91..58b95f286d5 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java
@@ -18,8 +18,8 @@
 package org.apache.beam.runners.dataflow.worker.util.common.worker;
 
 import java.io.Closeable;
+import java.util.ArrayList;
 import java.util.List;
-import java.util.ListIterator;
 import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
 import org.apache.beam.runners.dataflow.worker.counters.CounterSet;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
@@ -36,7 +36,9 @@ public class MapTaskExecutor implements WorkExecutor {
   private static final Logger LOG = 
LoggerFactory.getLogger(MapTaskExecutor.class);
 
   /** The operations in the map task, in execution order. */
-  public final List<Operation> operations;
+  public final ArrayList<Operation> operations;
+
+  private boolean closed = false;
 
   private final ExecutionStateTracker executionStateTracker;
 
@@ -54,7 +56,7 @@ public class MapTaskExecutor implements WorkExecutor {
       CounterSet counters,
       ExecutionStateTracker executionStateTracker) {
     this.counters = counters;
-    this.operations = operations;
+    this.operations = new ArrayList<>(operations);
     this.executionStateTracker = executionStateTracker;
   }
 
@@ -63,6 +65,7 @@ public class MapTaskExecutor implements WorkExecutor {
     return counters;
   }
 
+  /** May be reused if execute() returns without an exception being thrown. */
   @Override
   public void execute() throws Exception {
     LOG.debug("Executing map task");
@@ -74,13 +77,11 @@ public class MapTaskExecutor implements WorkExecutor {
         // Starting a root operation such as a ReadOperation does the work
         // of processing the input dataset.
         LOG.debug("Starting operations");
-        ListIterator<Operation> iterator = 
operations.listIterator(operations.size());
-        while (iterator.hasPrevious()) {
+        for (int i = operations.size() - 1; i >= 0; --i) {
           if (Thread.currentThread().isInterrupted()) {
             throw new InterruptedException("Worker aborted");
           }
-          Operation op = iterator.previous();
-          op.start();
+          operations.get(i).start();
         }
 
         // Finish operations, in forward-execution-order, so that a
@@ -94,16 +95,13 @@ public class MapTaskExecutor implements WorkExecutor {
           op.finish();
         }
       } catch (Exception | Error exn) {
-        LOG.debug("Aborting operations", exn);
-        for (Operation op : operations) {
-          try {
-            op.abort();
-          } catch (Exception | Error exn2) {
-            exn.addSuppressed(exn2);
-            if (exn2 instanceof InterruptedException) {
-              Thread.currentThread().interrupt();
-            }
-          }
+        try {
+          closeInternal();
+        } catch (Exception closeExn) {
+          exn.addSuppressed(closeExn);
+        }
+        if (exn instanceof InterruptedException) {
+          Thread.currentThread().interrupt();
         }
         throw exn;
       }
@@ -164,6 +162,45 @@ public class MapTaskExecutor implements WorkExecutor {
     }
   }
 
+  private void closeInternal() throws Exception {
+    if (closed) {
+      return;
+    }
+    LOG.debug("Aborting operations");
+    @Nullable Exception exn = null;
+    for (Operation op : operations) {
+      try {
+        op.abort();
+      } catch (Exception | Error exn2) {
+        if (exn2 instanceof InterruptedException) {
+          Thread.currentThread().interrupt();
+        }
+        if (exn == null) {
+          if (exn2 instanceof Exception) {
+            exn = (Exception) exn2;
+          } else {
+            exn = new RuntimeException(exn2);
+          }
+        } else {
+          exn.addSuppressed(exn2);
+        }
+      }
+    }
+    closed = true;
+    if (exn != null) {
+      throw exn;
+    }
+  }
+
+  @Override
+  public void close() {
+    try {
+      closeInternal();
+    } catch (Exception e) {
+      LOG.error("Exception while closing MapTaskExecutor, ignoring", e);
+    }
+  }
+
   @Override
   public List<Integer> reportProducedEmptyOutput() {
     List<Integer> emptyOutputSinkIndexes = Lists.newArrayList();
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java
index a4cd5d6d8a6..e61c2d1f4a0 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java
@@ -415,6 +415,7 @@ public class StreamingWorkScheduler {
 
       // Release the execution state for another thread to use.
       computationState.releaseComputationWorkExecutor(computationWorkExecutor);
+      computationWorkExecutor = null;
 
       work.setState(Work.State.COMMIT_QUEUED);
       
outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions(sampler));
@@ -422,11 +423,13 @@ public class StreamingWorkScheduler {
       return ExecuteWorkResult.create(
           outputBuilder, stateReader.getBytesRead() + 
localSideInputStateFetcher.getBytesRead());
     } catch (Throwable t) {
-      // If processing failed due to a thrown exception, close the 
executionState. Do not
-      // return/release the executionState back to computationState as that 
will lead to this
-      // executionState instance being reused.
-      LOG.debug("Invalidating executor after work item {} failed", 
workItem.getWorkToken(), t);
-      computationWorkExecutor.invalidate();
+      if (computationWorkExecutor != null) {
+        // If processing failed due to a thrown exception, close the 
executionState. Do not
+        // return/release the executionState back to computationState as that 
will lead to this
+        // executionState instance being reused.
+        LOG.debug("Invalidating executor after work item {} failed", 
workItem.getWorkToken(), t);
+        computationWorkExecutor.invalidate();
+      }
 
       // Re-throw the exception, it will be caught and handled by 
workFailureProcessor downstream.
       throw t;
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java
index e77ae309d35..3443ae0022b 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java
@@ -24,11 +24,16 @@ import static 
org.apache.beam.runners.dataflow.worker.counters.CounterName.named
 import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray;
 import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.hasItems;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.ArgumentMatchers.eq;
@@ -52,6 +57,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Function;
 import org.apache.beam.runners.dataflow.util.CloudObject;
 import org.apache.beam.runners.dataflow.util.CloudObjects;
@@ -254,8 +260,9 @@ public class IntrinsicMapTaskExecutorFactoryTest {
     List<ParallelInstruction> instructions =
         Arrays.asList(
             createReadInstruction("Read", 
ReaderFactoryTest.SingletonTestReaderFactory.class),
-            createParDoInstruction(0, 0, "DoFn1", "DoFnUserName"),
-            createParDoInstruction(1, 0, "DoFnWithContext", 
"DoFnWithContextUserName"));
+            createParDoInstruction(0, 0, "DoFn1", "DoFnUserName", new 
TestDoFn()),
+            createParDoInstruction(
+                1, 0, "DoFnWithContext", "DoFnWithContextUserName", new 
TestDoFn()));
 
     MapTask mapTask = new MapTask();
     mapTask.setStageName(STAGE);
@@ -330,6 +337,7 @@ public class IntrinsicMapTaskExecutorFactoryTest {
                             PCOLLECTION_ID))));
     when(network.outDegree(instructionNode)).thenReturn(1);
 
+    ArrayList<Operation> createdOperations = new ArrayList<>();
     Node operationNode =
         mapTaskExecutorFactory
             .createOperationTransformForParallelInstructionNodes(
@@ -338,11 +346,13 @@ public class IntrinsicMapTaskExecutorFactoryTest {
                 PipelineOptionsFactory.create(),
                 readerRegistry,
                 sinkRegistry,
-                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"))
+                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"),
+                createdOperations)
             .apply(instructionNode);
     assertThat(operationNode, instanceOf(OperationNode.class));
     assertThat(((OperationNode) operationNode).getOperation(), 
instanceOf(ReadOperation.class));
     ReadOperation readOperation = (ReadOperation) ((OperationNode) 
operationNode).getOperation();
+    assertThat(createdOperations, contains(readOperation));
 
     assertEquals(1, readOperation.receivers.length);
     assertEquals(0, readOperation.receivers[0].getReceiverCount());
@@ -391,6 +401,7 @@ public class IntrinsicMapTaskExecutorFactoryTest {
         ParallelInstructionNode.create(
             createWriteInstruction(producerIndex, producerOutputNum, 
"WriteOperation"),
             ExecutionLocation.UNKNOWN);
+    ArrayList<Operation> createdOperations = new ArrayList<>();
     Node operationNode =
         mapTaskExecutorFactory
             .createOperationTransformForParallelInstructionNodes(
@@ -399,11 +410,13 @@ public class IntrinsicMapTaskExecutorFactoryTest {
                 options,
                 readerRegistry,
                 sinkRegistry,
-                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"))
+                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"),
+                createdOperations)
             .apply(instructionNode);
     assertThat(operationNode, instanceOf(OperationNode.class));
     assertThat(((OperationNode) operationNode).getOperation(), 
instanceOf(WriteOperation.class));
     WriteOperation writeOperation = (WriteOperation) ((OperationNode) 
operationNode).getOperation();
+    assertThat(createdOperations, contains(writeOperation));
 
     assertEquals(0, writeOperation.receivers.length);
     assertEquals(Operation.InitializationState.UNSTARTED, 
writeOperation.initializationState);
@@ -461,17 +474,15 @@ public class IntrinsicMapTaskExecutorFactoryTest {
 
   static ParallelInstruction createParDoInstruction(
       int producerIndex, int producerOutputNum, String systemName) {
-    return createParDoInstruction(producerIndex, producerOutputNum, 
systemName, "");
+    return createParDoInstruction(producerIndex, producerOutputNum, 
systemName, "", new TestDoFn());
   }
 
   static ParallelInstruction createParDoInstruction(
-      int producerIndex, int producerOutputNum, String systemName, String 
userName) {
+      int producerIndex, int producerOutputNum, String systemName, String 
userName, DoFn<?, ?> fn) {
     InstructionInput cloudInput = new InstructionInput();
     cloudInput.setProducerInstructionIndex(producerIndex);
     cloudInput.setOutputNum(producerOutputNum);
 
-    TestDoFn fn = new TestDoFn();
-
     String serializedFn =
         StringUtils.byteArrayToJsonString(
             SerializableUtils.serializeToByteArray(
@@ -541,14 +552,16 @@ public class IntrinsicMapTaskExecutorFactoryTest {
                         .getMultiOutputInfos()
                         .get(0))));
 
+    ArrayList<Operation> createdOperations = new ArrayList<>();
     Node operationNode =
         mapTaskExecutorFactory
             .createOperationTransformForParallelInstructionNodes(
-                STAGE, network, options, readerRegistry, sinkRegistry, context)
+                STAGE, network, options, readerRegistry, sinkRegistry, 
context, createdOperations)
             .apply(instructionNode);
     assertThat(operationNode, instanceOf(OperationNode.class));
     assertThat(((OperationNode) operationNode).getOperation(), 
instanceOf(ParDoOperation.class));
     ParDoOperation parDoOperation = (ParDoOperation) ((OperationNode) 
operationNode).getOperation();
+    assertThat(createdOperations, contains(parDoOperation));
 
     assertEquals(1, parDoOperation.receivers.length);
     assertEquals(0, parDoOperation.receivers[0].getReceiverCount());
@@ -608,6 +621,7 @@ public class IntrinsicMapTaskExecutorFactoryTest {
                             PCOLLECTION_ID))));
     when(network.outDegree(instructionNode)).thenReturn(1);
 
+    ArrayList<Operation> createdOperations = new ArrayList<>();
     Node operationNode =
         mapTaskExecutorFactory
             .createOperationTransformForParallelInstructionNodes(
@@ -616,11 +630,13 @@ public class IntrinsicMapTaskExecutorFactoryTest {
                 PipelineOptionsFactory.create(),
                 readerRegistry,
                 sinkRegistry,
-                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"))
+                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"),
+                createdOperations)
             .apply(instructionNode);
     assertThat(operationNode, instanceOf(OperationNode.class));
     assertThat(((OperationNode) operationNode).getOperation(), 
instanceOf(ParDoOperation.class));
     ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode) 
operationNode).getOperation();
+    assertThat(createdOperations, contains(pgbkOperation));
 
     assertEquals(1, pgbkOperation.receivers.length);
     assertEquals(0, pgbkOperation.receivers[0].getReceiverCount());
@@ -660,6 +676,7 @@ public class IntrinsicMapTaskExecutorFactoryTest {
                             PCOLLECTION_ID))));
     when(network.outDegree(instructionNode)).thenReturn(1);
 
+    ArrayList<Operation> createdOperations = new ArrayList<>();
     Node operationNode =
         mapTaskExecutorFactory
             .createOperationTransformForParallelInstructionNodes(
@@ -668,11 +685,13 @@ public class IntrinsicMapTaskExecutorFactoryTest {
                 options,
                 readerRegistry,
                 sinkRegistry,
-                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"))
+                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"),
+                createdOperations)
             .apply(instructionNode);
     assertThat(operationNode, instanceOf(OperationNode.class));
     assertThat(((OperationNode) operationNode).getOperation(), 
instanceOf(ParDoOperation.class));
     ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode) 
operationNode).getOperation();
+    assertThat(createdOperations, contains(pgbkOperation));
 
     assertEquals(1, pgbkOperation.receivers.length);
     assertEquals(0, pgbkOperation.receivers[0].getReceiverCount());
@@ -738,6 +757,7 @@ public class IntrinsicMapTaskExecutorFactoryTest {
                             PCOLLECTION_ID))));
     when(network.outDegree(instructionNode)).thenReturn(1);
 
+    ArrayList<Operation> createdOperations = new ArrayList<>();
     Node operationNode =
         mapTaskExecutorFactory
             .createOperationTransformForParallelInstructionNodes(
@@ -746,15 +766,108 @@ public class IntrinsicMapTaskExecutorFactoryTest {
                 options,
                 readerRegistry,
                 sinkRegistry,
-                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"))
+                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"),
+                createdOperations)
             .apply(instructionNode);
     assertThat(operationNode, instanceOf(OperationNode.class));
     assertThat(((OperationNode) operationNode).getOperation(), 
instanceOf(FlattenOperation.class));
     FlattenOperation flattenOperation =
         (FlattenOperation) ((OperationNode) operationNode).getOperation();
+    assertThat(createdOperations, contains(flattenOperation));
 
     assertEquals(1, flattenOperation.receivers.length);
     assertEquals(0, flattenOperation.receivers[0].getReceiverCount());
     assertEquals(Operation.InitializationState.UNSTARTED, 
flattenOperation.initializationState);
   }
+
+  static class TestTeardownDoFn extends DoFn<String, String> {
+    static AtomicInteger setupCalls = new AtomicInteger();
+    static AtomicInteger teardownCalls = new AtomicInteger();
+
+    private final boolean throwExceptionOnSetup;
+    private boolean setupCalled = false;
+
+    TestTeardownDoFn(boolean throwExceptionOnSetup) {
+      this.throwExceptionOnSetup = throwExceptionOnSetup;
+    }
+
+    @Setup
+    public void setup() {
+      assertFalse(setupCalled);
+      setupCalled = true;
+      setupCalls.addAndGet(1);
+      if (throwExceptionOnSetup) {
+        throw new RuntimeException("Test setup exception");
+      }
+    }
+
+    @ProcessElement
+    public void process(ProcessContext c) {
+      fail("no elements should be processed");
+    }
+
+    @Teardown
+    public void teardown() {
+      assertTrue(setupCalled);
+      setupCalled = false;
+      teardownCalls.addAndGet(1);
+    }
+  }
+
+  @Test
+  public void testCreateMapTaskExecutorException() throws Exception {
+    List<ParallelInstruction> instructions =
+        Arrays.asList(
+            createReadInstruction("Read"),
+            createParDoInstruction(0, 0, "DoFn1", "DoFn1", new 
TestTeardownDoFn(false)),
+            createParDoInstruction(0, 0, "DoFn2", "DoFn2", new 
TestTeardownDoFn(false)),
+            createParDoInstruction(0, 0, "ErrorFn", "", new 
TestTeardownDoFn(true)),
+            createParDoInstruction(0, 0, "DoFn3", "DoFn3", new 
TestTeardownDoFn(false)),
+            createFlattenInstruction(1, 0, 2, 0, "Flatten"),
+            createWriteInstruction(3, 0, "Write"));
+
+    MapTask mapTask = new MapTask();
+    mapTask.setStageName(STAGE);
+    mapTask.setSystemName("systemName");
+    mapTask.setInstructions(instructions);
+    mapTask.setFactory(Transport.getJsonFactory());
+
+    assertThrows(
+        "Test setup exception",
+        RuntimeException.class,
+        () ->
+            mapTaskExecutorFactory.create(
+                mapTaskToNetwork.apply(mapTask),
+                options,
+                STAGE,
+                readerRegistry,
+                sinkRegistry,
+                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"),
+                counterSet,
+                idGenerator));
+    assertEquals(3, TestTeardownDoFn.setupCalls.getAndSet(0));
+    // We only tear-down the instruction we were unable to create.  The other
+    // infos are cached within UserParDoFnFactory and not torn-down.
+    assertEquals(1, TestTeardownDoFn.teardownCalls.getAndSet(0));
+
+    assertThrows(
+        "Test setup exception",
+        RuntimeException.class,
+        () ->
+            mapTaskExecutorFactory.create(
+                mapTaskToNetwork.apply(mapTask),
+                options,
+                STAGE,
+                readerRegistry,
+                sinkRegistry,
+                BatchModeExecutionContext.forTesting(options, counterSet, 
"testStage"),
+                counterSet,
+                idGenerator));
+    // The non-erroring functions are cached, and a new setup call is called on
+    // erroring dofn.
+    assertEquals(1, TestTeardownDoFn.setupCalls.get());
+    // We only tear-down the instruction we were unable to create.  The other
+    // infos are cached within UserParDoFnFactory and not torn-down.
+    assertEquals(1, TestTeardownDoFn.teardownCalls.get());
+  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java
index bb92fca3d8b..9e45425562a 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java
@@ -198,7 +198,7 @@ public class SimpleParDoFnTest {
         new TestDoFn(
             ImmutableList.of(
                 new TupleTag<>("tag1"), new TupleTag<>("tag2"), new 
TupleTag<>("tag3")));
-    DoFnInfo<?, ?> fnInfo =
+    DoFnInfo<Integer, String> fnInfo =
         DoFnInfo.forFn(
             fn,
             WindowingStrategy.globalDefault(),
@@ -279,7 +279,7 @@ public class SimpleParDoFnTest {
   @SuppressWarnings("AssertionFailureIgnored")
   public void testUnexpectedNumberOfReceivers() throws Exception {
     TestDoFn fn = new TestDoFn(Collections.emptyList());
-    DoFnInfo<?, ?> fnInfo =
+    DoFnInfo<Integer, String> fnInfo =
         DoFnInfo.forFn(
             fn,
             WindowingStrategy.globalDefault(),
@@ -330,7 +330,7 @@ public class SimpleParDoFnTest {
   @Test
   public void testErrorPropagation() throws Exception {
     TestErrorDoFn fn = new TestErrorDoFn();
-    DoFnInfo<?, ?> fnInfo =
+    DoFnInfo<Integer, String> fnInfo =
         DoFnInfo.forFn(
             fn,
             WindowingStrategy.globalDefault(),
@@ -423,7 +423,7 @@ public class SimpleParDoFnTest {
                 new TupleTag<>("undecl1"),
                 new TupleTag<>("undecl2"),
                 new TupleTag<>("undecl3")));
-    DoFnInfo<?, ?> fnInfo =
+    DoFnInfo<Integer, String> fnInfo =
         DoFnInfo.forFn(
             fn,
             WindowingStrategy.globalDefault(),
@@ -485,7 +485,7 @@ public class SimpleParDoFnTest {
     }
 
     StateTestingDoFn fn = new StateTestingDoFn();
-    DoFnInfo<?, ?> fnInfo =
+    DoFnInfo<Integer, String> fnInfo =
         DoFnInfo.forFn(
             fn,
             WindowingStrategy.globalDefault(),
@@ -578,7 +578,7 @@ public class SimpleParDoFnTest {
     }
 
     DoFn<Integer, String> fn = new RepeaterDoFn();
-    DoFnInfo<?, ?> fnInfo =
+    DoFnInfo<Integer, String> fnInfo =
         DoFnInfo.forFn(
             fn,
             WindowingStrategy.globalDefault(),
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
index e16a8b9f88c..df90bb96139 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java
@@ -3276,6 +3276,9 @@ public class StreamingDataflowWorkerTest {
 
     TestCountingSource counter = new 
TestCountingSource(3).withThrowOnFirstSnapshot(true);
 
+    // Reset static state that may leak across tests.
+    TestExceptionInvalidatesCacheFn.resetStaticState();
+    TestCountingSource.resetStaticState();
     List<ParallelInstruction> instructions =
         Arrays.asList(
             new ParallelInstruction()
@@ -3310,7 +3313,10 @@ public class StreamingDataflowWorkerTest {
                 .build());
     worker.start();
 
-    // Three GetData requests
+    // Three GetData requests:
+    // - first processing has no state
+    // - recovering from checkpoint exception has no persisted state
+    // - recovering from processing exception recovers last committed state
     for (int i = 0; i < 3; i++) {
       ByteString state;
       if (i == 0 || i == 1) {
@@ -3437,6 +3443,11 @@ public class StreamingDataflowWorkerTest {
                       parseCommitRequest(sb.toString()))
                   .build()));
     }
+
+    // Ensure that the invalidated dofn had tearDown called on them.
+    assertEquals(1, TestExceptionInvalidatesCacheFn.tearDownCallCount.get());
+    assertEquals(2, TestExceptionInvalidatesCacheFn.setupCallCount.get());
+
     worker.stop();
   }
 
@@ -3484,7 +3495,7 @@ public class StreamingDataflowWorkerTest {
   }
 
   @Test
-  public void testActiveWorkFailure() throws Exception {
+  public void testQueuedWorkFailure() throws Exception {
     List<ParallelInstruction> instructions =
         Arrays.asList(
             makeSourceInstruction(StringUtf8Coder.of()),
@@ -3515,6 +3526,9 @@ public class StreamingDataflowWorkerTest {
     server.whenGetWorkCalled().thenReturn(workItem).thenReturn(workItemToFail);
     server.waitForEmptyWorkQueue();
 
+    // Wait for key to schedule, it will be blocked.
+    BlockingFn.counter().acquire(1);
+
     // Mock Windmill sending a heartbeat response failing the second work item 
while the first
     // is still processing.
     ComputationHeartbeatResponse.Builder failedHeartbeat =
@@ -3534,6 +3548,64 @@ public class StreamingDataflowWorkerTest {
         server.waitForAndGetCommitsWithTimeout(1, 
Duration.standardSeconds((5)));
     assertEquals(1, commits.size());
 
+    assertEquals(0, BlockingFn.teardownCounter.get());
+    assertEquals(1, BlockingFn.setupCounter.get());
+
+    worker.stop();
+  }
+
+  @Test
+  public void testActiveWorkFailure() throws Exception {
+    List<ParallelInstruction> instructions =
+        Arrays.asList(
+            makeSourceInstruction(StringUtf8Coder.of()),
+            makeDoFnInstruction(blockingFn, 0, StringUtf8Coder.of()),
+            makeSinkInstruction(StringUtf8Coder.of(), 0));
+
+    StreamingDataflowWorker worker =
+        makeWorker(
+            defaultWorkerParams("--activeWorkRefreshPeriodMillis=100")
+                .setInstructions(instructions)
+                .publishCounters()
+                .build());
+    worker.start();
+
+    GetWorkResponse workItemToFail =
+        makeInput(0, TimeUnit.MILLISECONDS.toMicros(0), "key", 
DEFAULT_SHARDING_KEY);
+    long failedWorkToken = workItemToFail.getWork(0).getWork(0).getWorkToken();
+    long failedCacheToken = 
workItemToFail.getWork(0).getWork(0).getCacheToken();
+    GetWorkResponse workItem =
+        makeInput(1, TimeUnit.MILLISECONDS.toMicros(0), "key", 
DEFAULT_SHARDING_KEY);
+
+    // Queue up the work item for the key.
+    server.whenGetWorkCalled().thenReturn(workItemToFail).thenReturn(workItem);
+    server.waitForEmptyWorkQueue();
+
+    // Wait for key to schedule, it will be blocked.
+    BlockingFn.counter().acquire(1);
+
+    // Mock Windmill sending a heartbeat response failing the first work item 
while it is
+    // is processing.
+    ComputationHeartbeatResponse.Builder failedHeartbeat =
+        ComputationHeartbeatResponse.newBuilder();
+    failedHeartbeat
+        .setComputationId(DEFAULT_COMPUTATION_ID)
+        .addHeartbeatResponsesBuilder()
+        .setCacheToken(failedCacheToken)
+        .setWorkToken(failedWorkToken)
+        .setShardingKey(DEFAULT_SHARDING_KEY)
+        .setFailed(true);
+    
server.sendFailedHeartbeats(Collections.singletonList(failedHeartbeat.build()));
+
+    // Release the blocked call, there should not be a commit and the dofn 
should be invalidated.
+    BlockingFn.blocker().countDown();
+    Map<Long, Windmill.WorkItemCommitRequest> commits =
+        server.waitForAndGetCommitsWithTimeout(1, 
Duration.standardSeconds((5)));
+    assertEquals(1, commits.size());
+
+    assertEquals(0, BlockingFn.teardownCounter.get());
+    assertEquals(1, BlockingFn.setupCounter.get());
+
     worker.stop();
   }
 
@@ -4246,6 +4318,18 @@ public class StreamingDataflowWorkerTest {
         new AtomicReference<>(new CountDownLatch(1));
     public static AtomicReference<Semaphore> counter = new 
AtomicReference<>(new Semaphore(0));
     public static AtomicInteger callCounter = new AtomicInteger(0);
+    public static AtomicInteger setupCounter = new AtomicInteger(0);
+    public static AtomicInteger teardownCounter = new AtomicInteger(0);
+
+    @Setup
+    public void setup() {
+      setupCounter.incrementAndGet();
+    }
+
+    @Teardown
+    public void tearDown() {
+      teardownCounter.incrementAndGet();
+    }
 
     @ProcessElement
     public void processElement(ProcessContext c) throws InterruptedException {
@@ -4278,6 +4362,8 @@ public class StreamingDataflowWorkerTest {
             blocker.set(new CountDownLatch(1));
             counter.set(new Semaphore(0));
             callCounter.set(0);
+            setupCounter.set(0);
+            teardownCounter.set(0);
           }
         }
       };
@@ -4397,11 +4483,33 @@ public class StreamingDataflowWorkerTest {
   static class TestExceptionInvalidatesCacheFn
       extends DoFn<ValueWithRecordId<KV<Integer, Integer>>, String> {
 
-    static boolean thrown = false;
+    public static AtomicInteger setupCallCount = new AtomicInteger();
+    public static AtomicInteger tearDownCallCount = new AtomicInteger();
+    private static boolean thrown = false;
+    private boolean setupCalled = false;
+
+    static void resetStaticState() {
+      setupCallCount.set(0);
+      tearDownCallCount.set(0);
+      thrown = false;
+    }
 
     @StateId("int")
     private final StateSpec<ValueState<Integer>> counter = 
StateSpecs.value(VarIntCoder.of());
 
+    @Setup
+    public void setUp() {
+      assertFalse(setupCalled);
+      setupCalled = true;
+      setupCallCount.addAndGet(1);
+    }
+
+    @Teardown
+    public void tearDown() {
+      assertTrue(setupCalled);
+      tearDownCallCount.addAndGet(1);
+    }
+
     @ProcessElement
     public void processElement(ProcessContext c, @StateId("int") 
ValueState<Integer> state)
         throws Exception {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java
index 6771e9dbb71..21e4d8c55e7 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/testing/TestCountingSource.java
@@ -65,6 +65,11 @@ public class TestCountingSource
     TestCountingSource.finalizeTracker = finalizeTracker;
   }
 
+  public static void resetStaticState() {
+    finalizeTracker = null;
+    thrown = false;
+  }
+
   public TestCountingSource(int numMessagesPerShard) {
     this(numMessagesPerShard, 0, false, false, true);
   }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java
index 2eeaa06eb5e..188466a5057 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java
@@ -519,4 +519,43 @@ public class MapTaskExecutorTest {
     Mockito.verify(o2, atLeastOnce()).abortReadLoop();
     Mockito.verify(stateTracker).deactivate();
   }
+
+  @Test
+  public void testCloseAbortsOperations() throws Exception {
+    Operation o1 = Mockito.mock(Operation.class);
+    Operation o2 = Mockito.mock(Operation.class);
+    List<Operation> operations = Arrays.asList(o1, o2);
+    ExecutionStateTracker stateTracker = 
Mockito.spy(ExecutionStateTracker.newForTest());
+    Mockito.verifyNoMoreInteractions(stateTracker);
+    try (MapTaskExecutor executor = new MapTaskExecutor(operations, 
counterSet, stateTracker)) {}
+
+    Mockito.verify(o1).abort();
+    Mockito.verify(o2).abort();
+  }
+
+  @Test
+  public void testExceptionAndThenCloseAbortsJustOnce() throws Exception {
+    Operation o1 = Mockito.mock(Operation.class);
+    Operation o2 = Mockito.mock(Operation.class);
+    Mockito.doThrow(new Exception("in start")).when(o2).start();
+
+    ExecutionStateTracker stateTracker = 
Mockito.spy(ExecutionStateTracker.newForTest());
+    MapTaskExecutor executor = new MapTaskExecutor(Arrays.asList(o1, o2), 
counterSet, stateTracker);
+    try {
+      executor.execute();
+      fail("Should have thrown");
+    } catch (Exception e) {
+    }
+    InOrder inOrder = Mockito.inOrder(o2, stateTracker);
+    inOrder.verify(stateTracker).activate();
+    inOrder.verify(o2).start();
+    inOrder.verify(o2).abort();
+    inOrder.verify(stateTracker).deactivate();
+
+    // Order of o1 abort doesn't matter
+    Mockito.verify(o1).abort();
+    Mockito.verifyNoMoreInteractions(o1);
+    // Closing after already closed should not call abort again.
+    executor.close();
+  }
 }


Reply via email to