Repository: flink
Updated Branches:
  refs/heads/master c7ec74e45 -> cbde2c2a3


[FLINK-2339] Prevent asynchronous checkpoint calls from overtaking each other


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/cbde2c2a
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/cbde2c2a
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/cbde2c2a

Branch: refs/heads/master
Commit: cbde2c2a3d71e17990d76d603e1bb6d275c888be
Parents: c7ec74e
Author: Stephan Ewen <[email protected]>
Authored: Thu Jul 9 16:34:58 2015 +0200
Committer: Stephan Ewen <[email protected]>
Committed: Thu Jul 9 16:34:58 2015 +0200

----------------------------------------------------------------------
 .../io/network/api/TaskEventHandler.java        |   4 +-
 .../taskmanager/DispatherThreadFactory.java     |  50 ++++
 .../flink/runtime/taskmanager/MemoryLogger.java |  14 +-
 .../apache/flink/runtime/taskmanager/Task.java  |  69 +++++-
 .../runtime/taskmanager/TaskAsyncCallTest.java  | 247 +++++++++++++++++++
 5 files changed, 370 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/cbde2c2a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java
index 95fce96..ccd0feb 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/TaskEventHandler.java
@@ -29,7 +29,7 @@ import org.apache.flink.runtime.util.event.EventListener;
  */
 public class TaskEventHandler {
 
-       // Listeners for each event type
+       /** Listeners for each event type */
        private final Multimap<Class<? extends TaskEvent>, 
EventListener<TaskEvent>> listeners = HashMultimap.create();
 
        public void subscribe(EventListener<TaskEvent> listener, Class<? 
extends TaskEvent> eventType) {
@@ -45,7 +45,7 @@ public class TaskEventHandler {
        }
 
        /**
-        * Publishes the task event to all subscribed event listeners..
+        * Publishes the task event to all subscribed event listeners.
         *
         * @param event The event to publish.
         */

http://git-wip-us.apache.org/repos/asf/flink/blob/cbde2c2a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatherThreadFactory.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatherThreadFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatherThreadFactory.java
new file mode 100644
index 0000000..f5f1565
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/DispatherThreadFactory.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.taskmanager;
+
+import java.util.concurrent.ThreadFactory;
+
+/**
+ * Thread factory that creates threads with a given name, associates them with 
a given
+ * thread group, and set them to daemon mode.
+ */
+public class DispatherThreadFactory implements ThreadFactory {
+       
+       private final ThreadGroup group;
+       
+       private final String threadName;
+       
+       /**
+        * Creates a new thread factory.
+        * 
+        * @param group The group that the threads will be associated with.
+        * @param threadName The name for the threads.
+        */
+       public DispatherThreadFactory(ThreadGroup group, String threadName) {
+               this.group = group;
+               this.threadName = threadName;
+       }
+
+       @Override
+       public Thread newThread(Runnable r) {
+               Thread t = new Thread(group, r, threadName);
+               t.setDaemon(true);
+               return t;
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cbde2c2a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/MemoryLogger.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/MemoryLogger.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/MemoryLogger.java
index 5c821e9..9258482 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/MemoryLogger.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/MemoryLogger.java
@@ -53,12 +53,16 @@ public class MemoryLogger extends Thread {
        private final ActorSystem monitored;
        
        private volatile boolean running = true;
-
        
-       public MemoryLogger(Logger logger, long interval) {
-               this(logger, interval, null);
-       }
-               
+       /**
+        * Creates a new memory logger that logs in the given interval and 
lives as long as the
+        * given actor system.
+        * 
+        * @param logger The logger to use for outputting the memory statistics.
+        * @param interval The interval in which the thread logs.
+        * @param monitored The actor system to whose life the thread is bound. 
The thread terminates
+        *                  once the actor system terminates.   
+        */
        public MemoryLogger(Logger logger, long interval, ActorSystem 
monitored) {
                super("Memory Logger");
                setDaemon(true);

http://git-wip-us.apache.org/repos/asf/flink/blob/cbde2c2a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java 
b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
index 25ad28d..d9168e3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.taskmanager;
 
 import akka.actor.ActorRef;
 import akka.util.Timeout;
+
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.cache.DistributedCache;
 import org.apache.flink.configuration.Configuration;
@@ -64,7 +65,10 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import java.util.concurrent.RejectedExecutionException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
 
@@ -188,6 +192,7 @@ public class Task implements Runnable {
        //  proper happens-before semantics on parallel modification
        // 
------------------------------------------------------------------------
 
+       /** atomic flag that makes sure the invokable is canceled exactly once 
upon error */
        private final AtomicBoolean invokableHasBeenCanceled;
        
        /** The invokable of this task, if initialized */
@@ -199,6 +204,9 @@ public class Task implements Runnable {
        /** The observed exception, in case the task execution failed */
        private volatile Throwable failureCause;
 
+       /** Serial executor for asynchronous calls (checkpoints, etc), lazily 
initialized */
+       private volatile ExecutorService asyncCallDispatcher;
+       
        /** The handle to the state that the operator was initialized with. 
Will be set to null after the
         * initialization, to be memory friendly */
        private volatile SerializedValue<StateHandle<?>> operatorState;
@@ -290,11 +298,11 @@ public class Task implements Runnable {
                        this.inputGates[i] = gate;
                        inputGatesById.put(gate.getConsumedResultId(), gate);
                }
+
+               invokableHasBeenCanceled = new AtomicBoolean(false);
                
                // finally, create the executing thread, but do not start it
                executingThread = new Thread(TASK_THREADS_GROUP, this, 
taskNameWithSubtask);
-               
-               invokableHasBeenCanceled = new AtomicBoolean(false);
        }
 
        // 
------------------------------------------------------------------------
@@ -646,9 +654,17 @@ public class Task implements Runnable {
                        try {
                                LOG.info("Freeing task resources for " + 
taskNameWithSubtask);
                                
+                               // stop the async dispatcher.
+                               // copy dispatcher reference to stack, against 
concurrent release
+                               ExecutorService dispatcher = 
this.asyncCallDispatcher;
+                               if (dispatcher != null && 
!dispatcher.isShutdown()) {
+                                       dispatcher.shutdownNow();
+                               }
+                               
                                // free the network resources
                                network.unregisterTask(this);
 
+                               // free memory resources
                                if (invokable != null) {
                                        memoryManager.releaseAll(invokable);
                                }
@@ -797,6 +813,7 @@ public class Task implements Runnable {
                                                Runnable canceler = new 
TaskCanceler(LOG, invokable, executingThread, taskNameWithSubtask);
                                                Thread cancelThread = new 
Thread(executingThread.getThreadGroup(), canceler,
                                                                "Canceler for " 
+ taskNameWithSubtask);
+                                               cancelThread.setDaemon(true);
                                                cancelThread.start();
                                        }
                                        return;
@@ -955,11 +972,49 @@ public class Task implements Runnable {
                        LOG.debug("Ignoring partition state notification for 
not running task.");
                }
        }
-       
+
+       /**
+        * Utility method to dispatch an asynchronous call on the invokable.
+        * 
+        * @param runnable The async call runnable.
+        * @param callName The name of the call, for logging purposes.
+        */
        private void executeAsyncCallRunnable(Runnable runnable, String 
callName) {
-               Thread thread = new Thread(runnable, callName);
-               thread.setDaemon(true);
-               thread.start();
+               // make sure the executor is initialized. lock against 
concurrent calls to this function
+               synchronized (this) {
+                       if (isCanceledOrFailed()) {
+                               return;
+                       }
+                       
+                       // get ourselves a reference on the stack that cannot 
be concurrently modified
+                       ExecutorService executor = this.asyncCallDispatcher;
+                       if (executor == null) {
+                               // first time use, initialize
+                               executor = Executors.newSingleThreadExecutor(
+                                               new 
DispatherThreadFactory(TASK_THREADS_GROUP, "Async calls on " + 
taskNameWithSubtask));
+                               this.asyncCallDispatcher = executor;
+                               
+                               // double-check for execution state, and make 
sure we clean up after ourselves
+                               // if we created the dispatcher while the task 
was concurrently canceled
+                               if (isCanceledOrFailed()) {
+                                       executor.shutdown();
+                                       asyncCallDispatcher = null;
+                                       return;
+                               }
+                       }
+
+                       LOG.debug("Invoking async call {} on task {}", 
callName, taskNameWithSubtask);
+
+                       try {
+                               executor.submit(runnable);
+                       }
+                       catch (RejectedExecutionException e) {
+                               // may be that we are concurrently canceled. if 
not, report that something is fishy
+                               if (!isCanceledOrFailed()) {
+                                       throw new RuntimeException("Async call 
was rejected, even though the task was not canceled.", e);
+                               }
+                       }
+               }
        }
 
        // 
------------------------------------------------------------------------
@@ -1051,7 +1106,7 @@ public class Task implements Runnable {
 
                                        executer.interrupt();
                                        try {
-                                               executer.join(5000);
+                                               executer.join(10000);
                                        }
                                        catch (InterruptedException e) {
                                                // we can ignore this

http://git-wip-us.apache.org/repos/asf/flink/blob/cbde2c2a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
new file mode 100644
index 0000000..618c01f
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.taskmanager;
+
+import akka.actor.ActorSystem;
+import akka.actor.Props;
+import akka.actor.UntypedActor;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.akka.AkkaUtils;
+import org.apache.flink.runtime.blob.BlobKey;
+import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
+import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.filecache.FileCache;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.io.network.NetworkEnvironment;
+import 
org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
+import org.apache.flink.runtime.jobgraph.tasks.CheckpointCommittingOperator;
+import org.apache.flink.runtime.jobgraph.tasks.CheckpointedOperator;
+import org.apache.flink.runtime.memorymanager.MemoryManager;
+
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.runtime.util.SerializedValue;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import scala.concurrent.duration.FiniteDuration;
+
+import java.util.Collections;
+import java.util.concurrent.TimeUnit;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class TaskAsyncCallTest {
+
+       private static final int NUM_CALLS = 1000;
+       
+       private static ActorSystem actorSystem;
+       
+       private static OneShotLatch awaitLatch;
+       private static OneShotLatch triggerLatch;
+
+       // 
------------------------------------------------------------------------
+       //  Init & Shutdown
+       // 
------------------------------------------------------------------------
+
+       @BeforeClass
+       public static void startActorSystem() {
+               actorSystem = AkkaUtils.createLocalActorSystem(new 
Configuration());
+       }
+
+       @AfterClass
+       public static void shutdown() {
+               actorSystem.shutdown();
+               actorSystem.awaitTermination();
+       }
+
+       @Before
+       public void createQueuesAndActors() {
+               awaitLatch = new OneShotLatch();
+               triggerLatch = new OneShotLatch();
+       }
+
+
+       // 
------------------------------------------------------------------------
+       //  Tests 
+       // 
------------------------------------------------------------------------
+       
+       @Test
+       public void testCheckpointCallsInOrder() {
+               try {
+                       Task task = createTask();
+                       task.startTaskThread();
+                       
+                       awaitLatch.await();
+                       
+                       for (int i = 1; i <= NUM_CALLS; i++) {
+                               task.triggerCheckpointBarrier(i, 156865867234L);
+                       }
+                       
+                       triggerLatch.await();
+                       
+                       assertFalse(task.isCanceledOrFailed());
+                       assertEquals(ExecutionState.RUNNING, 
task.getExecutionState());
+                       
+                       task.cancelExecution();
+                       task.getExecutingThread().join();
+               }
+               catch (Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       @Test
+       public void testMixedAsyncCallsInOrder() {
+               try {
+                       Task task = createTask();
+                       task.startTaskThread();
+
+                       awaitLatch.await();
+
+                       for (int i = 1; i <= NUM_CALLS; i++) {
+                               task.triggerCheckpointBarrier(i, 156865867234L);
+                               task.confirmCheckpoint(i, null);
+                       }
+
+                       triggerLatch.await();
+
+                       assertFalse(task.isCanceledOrFailed());
+                       assertEquals(ExecutionState.RUNNING, 
task.getExecutionState());
+
+                       task.cancelExecution();
+                       task.getExecutingThread().join();
+               }
+               catch (Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+       
+       private static Task createTask() {
+               
+               LibraryCacheManager libCache = mock(LibraryCacheManager.class);
+               
when(libCache.getClassLoader(any(JobID.class))).thenReturn(ClassLoader.getSystemClassLoader());
+               
+               ResultPartitionManager partitionManager = 
mock(ResultPartitionManager.class);
+               ResultPartitionConsumableNotifier consumableNotifier = 
mock(ResultPartitionConsumableNotifier.class);
+               NetworkEnvironment networkEnvironment = 
mock(NetworkEnvironment.class);
+               
when(networkEnvironment.getPartitionManager()).thenReturn(partitionManager);
+               
when(networkEnvironment.getPartitionConsumableNotifier()).thenReturn(consumableNotifier);
+               
when(networkEnvironment.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC);
+
+               TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(
+                               new JobID(), new JobVertexID(), new 
ExecutionAttemptID(),
+                               "Test Task", 0, 1,
+                               new Configuration(), new Configuration(),
+                               CheckpointsInOrderInvokable.class.getName(),
+                               
Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
+                               
Collections.<InputGateDeploymentDescriptor>emptyList(),
+                               Collections.<BlobKey>emptyList(),
+                               0);
+
+               return new Task(tdd,
+                               mock(MemoryManager.class),
+                               mock(IOManager.class),
+                               networkEnvironment,
+                               mock(BroadcastVariableManager.class),
+                               
actorSystem.actorOf(Props.create(BlackHoleActor.class)),
+                               
actorSystem.actorOf(Props.create(BlackHoleActor.class)),
+                               new FiniteDuration(60, TimeUnit.SECONDS),
+                               libCache,
+                               mock(FileCache.class));
+       }
+       
+       public static class CheckpointsInOrderInvokable extends 
AbstractInvokable
+                       implements CheckpointedOperator, 
CheckpointCommittingOperator {
+
+               private volatile long lastCheckpointId = 0;
+               
+               private volatile Exception error;
+               
+               @Override
+               public void registerInputOutput() {}
+
+               @Override
+               public void invoke() throws Exception {
+                       awaitLatch.trigger();
+                       
+                       // wait forever (until canceled)
+                       synchronized (this) {
+                               while (error == null && lastCheckpointId < 
NUM_CALLS) {
+                                       wait();
+                               }
+                       }
+                       
+                       triggerLatch.trigger();
+                       if (error != null) {
+                               throw error;
+                       }
+               }
+
+               @Override
+               public void triggerCheckpoint(long checkpointId, long 
timestamp) throws Exception {
+                       lastCheckpointId++;
+                       if (checkpointId == lastCheckpointId) {
+                               if (lastCheckpointId == NUM_CALLS) {
+                                       triggerLatch.trigger();
+                               }
+                       }
+                       else if (this.error == null) {
+                               this.error = new Exception("calls out of 
order");
+                               synchronized (this) {
+                                       notifyAll();
+                               }
+                       }
+               }
+
+               @Override
+               public void confirmCheckpoint(long checkpointId, 
SerializedValue<StateHandle<?>> state) throws Exception {
+                       if (checkpointId != lastCheckpointId && this.error == 
null) {
+                               this.error = new Exception("calls out of 
order");
+                               synchronized (this) {
+                                       notifyAll();
+                               }
+                       }
+               }
+       }
+       
+       public static class BlackHoleActor extends UntypedActor {
+
+               @Override
+               public void onReceive(Object message) {}
+       }
+}

Reply via email to