This is an automated email from the ASF dual-hosted git repository.

trohrmann pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 4b39e4ad43874cb5caedc6ab16240281445134f2
Author: Till Rohrmann <trohrm...@apache.org>
AuthorDate: Tue Oct 19 17:10:34 2021 +0200

    [FLINK-25817] Let TaskLocalStateStoreImpl persist TaskStateSnapshots
    
    This commit lets the TaskLocalStateStoreImpl persist the TaskStateSnapshots 
into the
    directory of the local state checkpoint. This allows to recover the 
TaskStateSnapshots
    in case of a process crash. If the TaskStateSnapshot cannot be read then 
the whole local
    checkpointing directory will be deleted to avoid corrupted files.
---
 .../state/TaskExecutorLocalStateStoresManager.java |   6 +-
 .../runtime/state/TaskLocalStateStoreImpl.java     | 115 ++++++++++++++++++---
 .../runtime/state/TaskLocalStateStoreImplTest.java | 106 +++++++++++++++----
 3 files changed, 191 insertions(+), 36 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorLocalStateStoresManager.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorLocalStateStoresManager.java
index 987f600..80367e0 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorLocalStateStoresManager.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorLocalStateStoresManager.java
@@ -78,6 +78,11 @@ public class TaskExecutorLocalStateStoresManager {
             @Nonnull Executor discardExecutor)
             throws IOException {
 
+        LOG.debug(
+                "Start {} with local state root directories {}.",
+                getClass().getSimpleName(),
+                localStateRootDirectories);
+
         this.taskStateStoresByAllocationID = new HashMap<>();
         this.localRecoveryEnabled = localRecoveryEnabled;
         this.localStateRootDirectories = localStateRootDirectories;
@@ -193,7 +198,6 @@ public class TaskExecutorLocalStateStoresManager {
     }
 
     public void releaseLocalStateForAllocationId(@Nonnull AllocationID 
allocationID) {
-
         if (LOG.isDebugEnabled()) {
             LOG.debug("Releasing local state under allocation id {}.", 
allocationID);
         }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskLocalStateStoreImpl.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskLocalStateStoreImpl.java
index af4695f..87f4e0c 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskLocalStateStoreImpl.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskLocalStateStoreImpl.java
@@ -25,6 +25,8 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.FlinkRuntimeException;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -35,7 +37,11 @@ import javax.annotation.Nullable;
 import javax.annotation.concurrent.GuardedBy;
 
 import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
 import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
 import java.util.AbstractMap;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -60,6 +66,8 @@ public class TaskLocalStateStoreImpl implements 
OwnedTaskLocalStateStore {
     /** Dummy value to use instead of null to satisfy {@link 
ConcurrentHashMap}. */
     @VisibleForTesting static final TaskStateSnapshot NULL_DUMMY = new 
TaskStateSnapshot(0, false);
 
+    public static final String TASK_STATE_SNAPSHOT_FILENAME = 
"_task_state_snapshot";
+
     /** JobID from the owning subtask. */
     @Nonnull private final JobID jobID;
 
@@ -165,6 +173,7 @@ public class TaskLocalStateStoreImpl implements 
OwnedTaskLocalStateStore {
             } else {
                 TaskStateSnapshot previous =
                         storedTaskStateByCheckpointID.put(checkpointId, 
localState);
+                persistLocalStateMetadata(checkpointId, localState);
 
                 if (previous != null) {
                     toDiscard = new AbstractMap.SimpleEntry<>(checkpointId, 
previous);
@@ -177,6 +186,45 @@ public class TaskLocalStateStoreImpl implements 
OwnedTaskLocalStateStore {
         }
     }
 
+    /**
+     * Writes a task state snapshot file that contains the serialized content 
of the local state.
+     *
+     * @param checkpointId identifying the checkpoint
+     * @param localState task state snapshot that will be persisted
+     */
+    private void persistLocalStateMetadata(long checkpointId, 
TaskStateSnapshot localState) {
+        final File taskStateSnapshotFile = 
getTaskStateSnapshotFile(checkpointId);
+        try (ObjectOutputStream oos =
+                new ObjectOutputStream(new 
FileOutputStream(taskStateSnapshotFile))) {
+            oos.writeObject(localState);
+
+            LOG.debug(
+                    "Successfully written local task state snapshot file {} 
for checkpoint {}.",
+                    taskStateSnapshotFile,
+                    checkpointId);
+        } catch (IOException e) {
+            ExceptionUtils.rethrow(e, "Could not write the local task state 
snapshot file.");
+        }
+    }
+
+    @VisibleForTesting
+    File getTaskStateSnapshotFile(long checkpointId) {
+        final File checkpointDirectory =
+                localRecoveryConfig
+                        .getLocalStateDirectoryProvider()
+                        .orElseThrow(
+                                () -> new IllegalStateException("Local 
recovery must be enabled."))
+                        .subtaskSpecificCheckpointDirectory(checkpointId);
+
+        if (!checkpointDirectory.exists() && !checkpointDirectory.mkdirs()) {
+            throw new FlinkRuntimeException(
+                    String.format(
+                            "Could not create the checkpoint directory '%s'", 
checkpointDirectory));
+        }
+
+        return new File(checkpointDirectory, TASK_STATE_SNAPSHOT_FILENAME);
+    }
+
     @Override
     @Nullable
     public TaskStateSnapshot retrieveLocalState(long checkpointID) {
@@ -184,7 +232,7 @@ public class TaskLocalStateStoreImpl implements 
OwnedTaskLocalStateStore {
         TaskStateSnapshot snapshot;
 
         synchronized (lock) {
-            snapshot = storedTaskStateByCheckpointID.get(checkpointID);
+            snapshot = loadTaskStateSnapshot(checkpointID);
         }
 
         if (snapshot != null) {
@@ -216,6 +264,42 @@ public class TaskLocalStateStoreImpl implements 
OwnedTaskLocalStateStore {
         return (snapshot != NULL_DUMMY) ? snapshot : null;
     }
 
+    @GuardedBy("lock")
+    @Nullable
+    private TaskStateSnapshot loadTaskStateSnapshot(long checkpointID) {
+        return storedTaskStateByCheckpointID.computeIfAbsent(
+                checkpointID, this::tryLoadTaskStateSnapshotFromDisk);
+    }
+
+    @GuardedBy("lock")
+    @Nullable
+    private TaskStateSnapshot tryLoadTaskStateSnapshotFromDisk(long 
checkpointID) {
+        final File taskStateSnapshotFile = 
getTaskStateSnapshotFile(checkpointID);
+
+        if (taskStateSnapshotFile.exists()) {
+            TaskStateSnapshot taskStateSnapshot = null;
+            try (ObjectInputStream ois =
+                    new ObjectInputStream(new 
FileInputStream(taskStateSnapshotFile))) {
+                taskStateSnapshot = (TaskStateSnapshot) ois.readObject();
+
+                LOG.debug(
+                        "Loaded task state snapshot for checkpoint {} 
successfully from disk.",
+                        checkpointID);
+            } catch (IOException | ClassNotFoundException e) {
+                LOG.debug(
+                        "Could not read task state snapshot file {} for 
checkpoint {}. Deleting the corresponding local state.",
+                        taskStateSnapshotFile,
+                        checkpointID);
+
+                discardLocalStateForCheckpoint(checkpointID, Optional.empty());
+            }
+
+            return taskStateSnapshot;
+        }
+
+        return null;
+    }
+
     @Override
     @Nonnull
     public LocalRecoveryConfig getLocalRecoveryConfig() {
@@ -307,14 +391,14 @@ public class TaskLocalStateStoreImpl implements 
OwnedTaskLocalStateStore {
     private void syncDiscardLocalStateForCollection(
             Collection<Map.Entry<Long, TaskStateSnapshot>> toDiscard) {
         for (Map.Entry<Long, TaskStateSnapshot> entry : toDiscard) {
-            discardLocalStateForCheckpoint(entry.getKey(), entry.getValue());
+            discardLocalStateForCheckpoint(entry.getKey(), 
Optional.of(entry.getValue()));
         }
     }
 
     /**
      * Helper method that discards state objects with an executor and reports 
exceptions to the log.
      */
-    private void discardLocalStateForCheckpoint(long checkpointID, 
TaskStateSnapshot o) {
+    private void discardLocalStateForCheckpoint(long checkpointID, 
Optional<TaskStateSnapshot> o) {
 
         if (LOG.isTraceEnabled()) {
             LOG.trace(
@@ -333,17 +417,20 @@ public class TaskLocalStateStoreImpl implements 
OwnedTaskLocalStateStore {
                     subtaskIndex);
         }
 
-        try {
-            o.discardState();
-        } catch (Exception discardEx) {
-            LOG.warn(
-                    "Exception while discarding local task state snapshot of 
checkpoint {} in subtask ({} - {} - {}).",
-                    checkpointID,
-                    jobID,
-                    jobVertexID,
-                    subtaskIndex,
-                    discardEx);
-        }
+        o.ifPresent(
+                taskStateSnapshot -> {
+                    try {
+                        taskStateSnapshot.discardState();
+                    } catch (Exception discardEx) {
+                        LOG.warn(
+                                "Exception while discarding local task state 
snapshot of checkpoint {} in subtask ({} - {} - {}).",
+                                checkpointID,
+                                jobID,
+                                jobVertexID,
+                                subtaskIndex,
+                                discardEx);
+                    }
+                });
 
         Optional<LocalRecoveryDirectoryProvider> directoryProviderOptional =
                 localRecoveryConfig.getLocalStateDirectoryProvider();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskLocalStateStoreImplTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskLocalStateStoreImplTest.java
index b71d575..8c906b5 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskLocalStateStoreImplTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskLocalStateStoreImplTest.java
@@ -33,53 +33,67 @@ import org.junit.Before;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 
+import javax.annotation.Nonnull;
+
 import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.StandardOpenOption;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
-import java.util.SortedMap;
-import java.util.TreeMap;
+import java.util.Map;
 
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
 /** Test for the {@link TaskLocalStateStoreImpl}. */
 public class TaskLocalStateStoreImplTest extends TestLogger {
 
-    private SortedMap<Long, TaskStateSnapshot> internalSnapshotMap;
-    private Object internalLock;
     private TemporaryFolder temporaryFolder;
     private File[] allocationBaseDirs;
     private TaskLocalStateStoreImpl taskLocalStateStore;
+    private JobID jobID;
+    private AllocationID allocationID;
+    private JobVertexID jobVertexID;
+    private int subtaskIdx;
 
     @Before
     public void before() throws Exception {
-        JobID jobID = new JobID();
-        AllocationID allocationID = new AllocationID();
-        JobVertexID jobVertexID = new JobVertexID();
-        int subtaskIdx = 0;
+        jobID = new JobID();
+        allocationID = new AllocationID();
+        jobVertexID = new JobVertexID();
+        subtaskIdx = 0;
         this.temporaryFolder = new TemporaryFolder();
         this.temporaryFolder.create();
         this.allocationBaseDirs =
                 new File[] {temporaryFolder.newFolder(), 
temporaryFolder.newFolder()};
-        this.internalSnapshotMap = new TreeMap<>();
-        this.internalLock = new Object();
 
+        this.taskLocalStateStore =
+                createTaskLocalStateStoreImpl(
+                        allocationBaseDirs, jobID, allocationID, jobVertexID, 
subtaskIdx);
+    }
+
+    @Nonnull
+    private TaskLocalStateStoreImpl createTaskLocalStateStoreImpl(
+            File[] allocationBaseDirs,
+            JobID jobID,
+            AllocationID allocationID,
+            JobVertexID jobVertexID,
+            int subtaskIdx) {
         LocalRecoveryDirectoryProviderImpl directoryProvider =
                 new LocalRecoveryDirectoryProviderImpl(
                         allocationBaseDirs, jobID, jobVertexID, subtaskIdx);
 
         LocalRecoveryConfig localRecoveryConfig = new 
LocalRecoveryConfig(directoryProvider);
-
-        this.taskLocalStateStore =
-                new TaskLocalStateStoreImpl(
-                        jobID,
-                        allocationID,
-                        jobVertexID,
-                        subtaskIdx,
-                        localRecoveryConfig,
-                        Executors.directExecutor(),
-                        internalSnapshotMap,
-                        internalLock);
+        return new TaskLocalStateStoreImpl(
+                jobID,
+                allocationID,
+                jobVertexID,
+                subtaskIdx,
+                localRecoveryConfig,
+                Executors.directExecutor());
     }
 
     @After
@@ -180,6 +194,56 @@ public class TaskLocalStateStoreImplTest extends 
TestLogger {
         checkPrunedAndDiscarded(taskStateSnapshots, 0, chkCount);
     }
 
+    @Test
+    public void retrieveNullIfNoPersistedLocalState() {
+        assertThat(taskLocalStateStore.retrieveLocalState(0)).isNull();
+    }
+
+    @Test
+    public void retrievePersistedLocalStateFromDisc() {
+        final TaskStateSnapshot taskStateSnapshot = createTaskStateSnapshot();
+        final long checkpointId = 0L;
+        taskLocalStateStore.storeLocalState(checkpointId, taskStateSnapshot);
+
+        final TaskLocalStateStoreImpl newTaskLocalStateStore =
+                createTaskLocalStateStoreImpl(
+                        allocationBaseDirs, jobID, allocationID, jobVertexID, 
0);
+
+        final TaskStateSnapshot retrievedTaskStateSnapshot =
+                newTaskLocalStateStore.retrieveLocalState(checkpointId);
+
+        assertThat(retrievedTaskStateSnapshot).isEqualTo(taskStateSnapshot);
+    }
+
+    @Nonnull
+    private TaskStateSnapshot createTaskStateSnapshot() {
+        final Map<OperatorID, OperatorSubtaskState> operatorSubtaskStates = 
new HashMap<>();
+        operatorSubtaskStates.put(new OperatorID(), 
OperatorSubtaskState.builder().build());
+        operatorSubtaskStates.put(new OperatorID(), 
OperatorSubtaskState.builder().build());
+        final TaskStateSnapshot taskStateSnapshot = new 
TaskStateSnapshot(operatorSubtaskStates);
+        return taskStateSnapshot;
+    }
+
+    @Test
+    public void deletesLocalStateIfRetrievalFails() throws IOException {
+        final TaskStateSnapshot taskStateSnapshot = createTaskStateSnapshot();
+        final long checkpointId = 0L;
+        taskLocalStateStore.storeLocalState(checkpointId, taskStateSnapshot);
+
+        final File taskStateSnapshotFile =
+                taskLocalStateStore.getTaskStateSnapshotFile(checkpointId);
+
+        Files.write(
+                taskStateSnapshotFile.toPath(), new byte[] {1, 2, 3, 4}, 
StandardOpenOption.WRITE);
+
+        final TaskLocalStateStoreImpl newTaskLocalStateStore =
+                createTaskLocalStateStoreImpl(
+                        allocationBaseDirs, jobID, allocationID, jobVertexID, 
subtaskIdx);
+
+        
assertThat(newTaskLocalStateStore.retrieveLocalState(checkpointId)).isNull();
+        assertThat(taskStateSnapshotFile.getParentFile()).doesNotExist();
+    }
+
     private void checkStoredAsExpected(List<TestingTaskStateSnapshot> history, 
int start, int end) {
         for (int i = start; i < end; ++i) {
             TestingTaskStateSnapshot expected = history.get(i);

Reply via email to