This is an automated email from the ASF dual-hosted git repository.

roman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit c5430e2e5d4eeb0aba14ce3ea8401747afe0182d
Author: Roman Khachatryan <khachatryan.ro...@gmail.com>
AuthorDate: Thu Apr 7 20:56:56 2022 +0200

    [FLINK-25511][state/changelog] Discard pre-emptively uploaded state changes 
not included into any checkpoint
---
 .../fs_state_changelog_configuration.html          |   6 +
 .../changelog/fs/FsStateChangelogOptions.java      |   7 +
 .../changelog/fs/FsStateChangelogStorage.java      |  40 +-
 .../flink/changelog/fs/FsStateChangelogWriter.java |  62 ++-
 .../flink/changelog/fs/StateChangeFsUploader.java  |  40 +-
 .../apache/flink/changelog/fs/StateChangeSet.java  |   4 +-
 .../changelog/fs/StateChangeUploadScheduler.java   |  21 +-
 .../flink/changelog/fs/StateChangeUploader.java    |  11 +-
 .../flink/changelog/fs/TaskChangelogRegistry.java  |  82 ++++
 .../changelog/fs/TaskChangelogRegistryImpl.java    |  91 +++++
 .../apache/flink/changelog/fs/UploadResult.java    |   5 +-
 .../changelog/fs/ChangelogStorageMetricsTest.java  |  49 ++-
 .../changelog/fs/FsStateChangelogStorageTest.java  |   8 +-
 .../fs/FsStateChangelogWriterSqnTest.java          |   3 +-
 .../changelog/fs/FsStateChangelogWriterTest.java   |   6 +-
 .../fs/TaskChangelogRegistryImplTest.java          |  59 +++
 .../state/changelog/StateChangelogWriter.java      |  26 +-
 .../inmemory/InMemoryStateChangelogWriter.java     |  14 +-
 .../changelog/ChangelogKeyedStateBackend.java      |  21 +-
 .../state/changelog/ChangelogTruncateHelper.java   |  98 +++++
 .../changelog/ChangelogStateBackendTestUtils.java  |  34 +-
 .../state/changelog/ChangelogStateDiscardTest.java | 452 +++++++++++++++++++++
 .../state/changelog/StateChangeLoggerTestBase.java |   6 +-
 23 files changed, 1051 insertions(+), 94 deletions(-)

diff --git 
a/docs/layouts/shortcodes/generated/fs_state_changelog_configuration.html 
b/docs/layouts/shortcodes/generated/fs_state_changelog_configuration.html
index 45e33c4e3e2..cc3b64dfcec 100644
--- a/docs/layouts/shortcodes/generated/fs_state_changelog_configuration.html
+++ b/docs/layouts/shortcodes/generated/fs_state_changelog_configuration.html
@@ -32,6 +32,12 @@
             <td>Boolean</td>
             <td>Whether to enable compression when serializing changelog.</td>
         </tr>
+        <tr>
+            <td><h5>dstl.dfs.discard.num-threads</h5></td>
+            <td style="word-wrap: break-word;">1</td>
+            <td>Integer</td>
+            <td>Number of threads to use to discard changelog (e.g. 
pre-emptively uploaded unused state).</td>
+        </tr>
         <tr>
             <td><h5>dstl.dfs.preemptive-persist-threshold</h5></td>
             <td style="word-wrap: break-word;">5 mb</td>
diff --git 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogOptions.java
 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogOptions.java
index 159e775843d..edbe722c9df 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogOptions.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogOptions.java
@@ -88,6 +88,13 @@ public class FsStateChangelogOptions {
                     .defaultValue(5)
                     .withDescription("Number of threads to use for upload.");
 
+    public static final ConfigOption<Integer> NUM_DISCARD_THREADS =
+            ConfigOptions.key("dstl.dfs.discard.num-threads")
+                    .intType()
+                    .defaultValue(1)
+                    .withDescription(
+                            "Number of threads to use to discard changelog 
(e.g. pre-emptively uploaded unused state).");
+
     public static final ConfigOption<MemorySize> IN_FLIGHT_DATA_LIMIT =
             ConfigOptions.key("dstl.dfs.upload.max-in-flight")
                     .memoryType()
diff --git 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogStorage.java
 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogStorage.java
index bfe9abc8f33..df12cad0b97 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogStorage.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogStorage.java
@@ -37,9 +37,11 @@ import java.io.IOException;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import static 
org.apache.flink.changelog.fs.FsStateChangelogOptions.NUM_DISCARD_THREADS;
 import static 
org.apache.flink.changelog.fs.FsStateChangelogOptions.PREEMPTIVE_PERSIST_THRESHOLD;
 import static 
org.apache.flink.changelog.fs.StateChangeUploadScheduler.directScheduler;
 import static 
org.apache.flink.changelog.fs.StateChangeUploadScheduler.fromConfig;
+import static 
org.apache.flink.changelog.fs.TaskChangelogRegistry.defaultChangelogRegistry;
 
 /** Filesystem-based implementation of {@link StateChangelogStorage}. */
 @Experimental
@@ -57,11 +59,22 @@ public class FsStateChangelogStorage extends 
FsStateChangelogStorageForRecovery
      */
     private final AtomicInteger logIdGenerator = new AtomicInteger(0);
 
+    private final TaskChangelogRegistry changelogRegistry;
+
     public FsStateChangelogStorage(Configuration config, 
TaskManagerJobMetricGroup metricGroup)
             throws IOException {
+        this(config, metricGroup, 
defaultChangelogRegistry(config.get(NUM_DISCARD_THREADS)));
+    }
+
+    public FsStateChangelogStorage(
+            Configuration config,
+            TaskManagerJobMetricGroup metricGroup,
+            TaskChangelogRegistry changelogRegistry)
+            throws IOException {
         this(
-                fromConfig(config, new 
ChangelogStorageMetricGroup(metricGroup)),
-                config.get(PREEMPTIVE_PERSIST_THRESHOLD).getBytes());
+                fromConfig(config, new 
ChangelogStorageMetricGroup(metricGroup), changelogRegistry),
+                config.get(PREEMPTIVE_PERSIST_THRESHOLD).getBytes(),
+                changelogRegistry);
     }
 
     @VisibleForTesting
@@ -69,7 +82,8 @@ public class FsStateChangelogStorage extends 
FsStateChangelogStorageForRecovery
             Path basePath,
             boolean compression,
             int bufferSize,
-            ChangelogStorageMetricGroup metricGroup)
+            ChangelogStorageMetricGroup metricGroup,
+            TaskChangelogRegistry changelogRegistry)
             throws IOException {
         this(
                 directScheduler(
@@ -78,15 +92,20 @@ public class FsStateChangelogStorage extends 
FsStateChangelogStorageForRecovery
                                 basePath.getFileSystem(),
                                 compression,
                                 bufferSize,
-                                metricGroup)),
-                PREEMPTIVE_PERSIST_THRESHOLD.defaultValue().getBytes());
+                                metricGroup,
+                                changelogRegistry)),
+                PREEMPTIVE_PERSIST_THRESHOLD.defaultValue().getBytes(),
+                changelogRegistry);
     }
 
     @VisibleForTesting
     public FsStateChangelogStorage(
-            StateChangeUploadScheduler uploader, long 
preEmptivePersistThresholdInBytes) {
-        this.uploader = uploader;
+            StateChangeUploadScheduler uploader,
+            long preEmptivePersistThresholdInBytes,
+            TaskChangelogRegistry changelogRegistry) {
         this.preEmptivePersistThresholdInBytes = 
preEmptivePersistThresholdInBytes;
+        this.changelogRegistry = changelogRegistry;
+        this.uploader = uploader;
     }
 
     @Override
@@ -95,7 +114,12 @@ public class FsStateChangelogStorage extends 
FsStateChangelogStorageForRecovery
         UUID logId = new UUID(0, logIdGenerator.getAndIncrement());
         LOG.info("createWriter for operator {}/{}: {}", operatorID, 
keyGroupRange, logId);
         return new FsStateChangelogWriter(
-                logId, keyGroupRange, uploader, 
preEmptivePersistThresholdInBytes, mailboxExecutor);
+                logId,
+                keyGroupRange,
+                uploader,
+                preEmptivePersistThresholdInBytes,
+                mailboxExecutor,
+                changelogRegistry);
     }
 
     @Override
diff --git 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogWriter.java
 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogWriter.java
index 2c255f2ab6c..858fa9fb6b4 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogWriter.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/FsStateChangelogWriter.java
@@ -105,6 +105,12 @@ class FsStateChangelogWriter implements 
StateChangelogWriter<ChangelogStateHandl
      */
     private SequenceNumber lowestSequenceNumber = INITIAL_SQN;
 
+    /**
+     * {@link SequenceNumber} after which changes will NOT be requested, 
inclusive. Decreased on
+     * {@link #truncateAndClose(SequenceNumber)}.
+     */
+    private SequenceNumber highestSequenceNumber = 
SequenceNumber.of(Long.MAX_VALUE);
+
     /**
      * Active changes, that all share the same {@link #activeSequenceNumber}.
      *
@@ -131,17 +137,21 @@ class FsStateChangelogWriter implements 
StateChangelogWriter<ChangelogStateHandl
 
     private final MailboxExecutor mailboxExecutor;
 
+    private final TaskChangelogRegistry changelogRegistry;
+
     FsStateChangelogWriter(
             UUID logId,
             KeyGroupRange keyGroupRange,
             StateChangeUploadScheduler uploader,
             long preEmptivePersistThresholdInBytes,
-            MailboxExecutor mailboxExecutor) {
+            MailboxExecutor mailboxExecutor,
+            TaskChangelogRegistry changelogRegistry) {
         this.logId = logId;
         this.keyGroupRange = keyGroupRange;
         this.uploader = uploader;
         this.preEmptivePersistThresholdInBytes = 
preEmptivePersistThresholdInBytes;
         this.mailboxExecutor = mailboxExecutor;
+        this.changelogRegistry = changelogRegistry;
     }
 
     @Override
@@ -244,8 +254,14 @@ class FsStateChangelogWriter implements 
StateChangelogWriter<ChangelogStateHandl
                     } else {
                         uploadCompletionListeners.removeIf(listener -> 
listener.onSuccess(results));
                         for (UploadResult result : results) {
-                            if 
(result.sequenceNumber.compareTo(lowestSequenceNumber) >= 0) {
-                                uploaded.put(result.sequenceNumber, result);
+                            SequenceNumber resultSqn = result.sequenceNumber;
+                            if (resultSqn.compareTo(lowestSequenceNumber) >= 0
+                                    && 
resultSqn.compareTo(highestSequenceNumber) < 0) {
+                                uploaded.put(resultSqn, result);
+                            } else {
+                                // uploaded already truncated, i.e. 
materialized state changes,
+                                // or closed
+                                
changelogRegistry.notUsed(result.streamStateHandle, logId);
                             }
                         }
                     }
@@ -270,7 +286,18 @@ class FsStateChangelogWriter implements 
StateChangelogWriter<ChangelogStateHandl
         checkArgument(to.compareTo(activeSequenceNumber) <= 0);
         lowestSequenceNumber = to;
         notUploaded.headMap(lowestSequenceNumber, false).clear();
-        uploaded.headMap(lowestSequenceNumber, false).clear();
+
+        Map<SequenceNumber, UploadResult> toDiscard = uploaded.headMap(to);
+        notifyStateNotUsed(toDiscard);
+        toDiscard.clear();
+    }
+
+    @Override
+    public void truncateAndClose(SequenceNumber from) {
+        LOG.debug("truncate {} tail from sqn {} (incl.)", logId, from);
+        highestSequenceNumber = from;
+        notifyStateNotUsed(uploaded.tailMap(from));
+        close();
     }
 
     private void rollover() {
@@ -288,7 +315,20 @@ class FsStateChangelogWriter implements 
StateChangelogWriter<ChangelogStateHandl
 
     @Override
     public void confirm(SequenceNumber from, SequenceNumber to) {
-        // do nothing
+        checkState(from.compareTo(to) <= 0, "Invalid confirm range: [%s,%s)", 
from, to);
+        checkState(
+                from.compareTo(activeSequenceNumber) <= 0
+                        && to.compareTo(activeSequenceNumber) <= 0,
+                "Invalid confirm range: [%s,%s), active sqn: %s",
+                from,
+                to,
+                activeSequenceNumber);
+        // it is possible that "uploaded" has already been truncated (after 
checkpoint subsumption)
+        // so do not check that "uploaded" contains the specified range
+        LOG.debug("Confirm [{}, {})", from, to);
+        uploaded.subMap(from, to).values().stream()
+                .map(UploadResult::getStreamStateHandle)
+                .forEach(changelogRegistry::stopTracking);
     }
 
     @Override
@@ -319,11 +359,6 @@ class FsStateChangelogWriter implements 
StateChangelogWriter<ChangelogStateHandl
         return activeSequenceNumber;
     }
 
-    @VisibleForTesting
-    public SequenceNumber getLowestSequenceNumber() {
-        return lowestSequenceNumber;
-    }
-
     private void ensureCanPersist(SequenceNumber from) throws IOException {
         checkNotNull(from);
         if (highestFailed != null && highestFailed.f0.compareTo(from) >= 0) {
@@ -400,4 +435,11 @@ class FsStateChangelogWriter implements 
StateChangelogWriter<ChangelogStateHandl
         tailMap.clear();
         return toUpload;
     }
+
+    private void notifyStateNotUsed(Map<SequenceNumber, UploadResult> 
notUsedState) {
+        LOG.trace("Uploaded state to discard: {}", notUsedState);
+        for (UploadResult result : notUsedState.values()) {
+            changelogRegistry.notUsed(result.streamStateHandle, logId);
+        }
+    }
 }
diff --git 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeFsUploader.java
 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeFsUploader.java
index f2755d469a4..28ecf03c78e 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeFsUploader.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeFsUploader.java
@@ -17,12 +17,14 @@
 
 package org.apache.flink.changelog.fs;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.changelog.fs.StateChangeUploadScheduler.UploadTask;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
+import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.util.clock.Clock;
@@ -40,6 +42,8 @@ import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.UUID;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
 
 import static org.apache.flink.core.fs.FileSystem.WriteMode.NO_OVERWRITE;
 
@@ -47,7 +51,7 @@ import static 
org.apache.flink.core.fs.FileSystem.WriteMode.NO_OVERWRITE;
  * A synchronous {@link StateChangeUploadScheduler} implementation that 
uploads the changes using
  * {@link FileSystem}.
  */
-class StateChangeFsUploader implements StateChangeUploader {
+public class StateChangeFsUploader implements StateChangeUploader {
     private static final Logger LOG = 
LoggerFactory.getLogger(StateChangeFsUploader.class);
 
     private final Path basePath;
@@ -57,13 +61,35 @@ class StateChangeFsUploader implements StateChangeUploader {
     private final int bufferSize;
     private final ChangelogStorageMetricGroup metrics;
     private final Clock clock;
+    private final TaskChangelogRegistry changelogRegistry;
+    private final BiFunction<Path, Long, StreamStateHandle> handleFactory;
 
+    @VisibleForTesting
     public StateChangeFsUploader(
             Path basePath,
             FileSystem fileSystem,
             boolean compression,
             int bufferSize,
-            ChangelogStorageMetricGroup metrics) {
+            ChangelogStorageMetricGroup metrics,
+            TaskChangelogRegistry changelogRegistry) {
+        this(
+                basePath,
+                fileSystem,
+                compression,
+                bufferSize,
+                metrics,
+                changelogRegistry,
+                FileStateHandle::new);
+    }
+
+    public StateChangeFsUploader(
+            Path basePath,
+            FileSystem fileSystem,
+            boolean compression,
+            int bufferSize,
+            ChangelogStorageMetricGroup metrics,
+            TaskChangelogRegistry changelogRegistry,
+            BiFunction<Path, Long, StreamStateHandle> handleFactory) {
         this.basePath = basePath;
         this.fileSystem = fileSystem;
         this.format = new StateChangeFormat();
@@ -71,6 +97,8 @@ class StateChangeFsUploader implements StateChangeUploader {
         this.bufferSize = bufferSize;
         this.metrics = metrics;
         this.clock = SystemClock.getInstance();
+        this.changelogRegistry = changelogRegistry;
+        this.handleFactory = handleFactory;
     }
 
     public UploadTasksResult upload(Collection<UploadTask> tasks) throws 
IOException {
@@ -114,7 +142,13 @@ class StateChangeFsUploader implements StateChangeUploader 
{
                 for (UploadTask task : tasks) {
                     tasksOffsets.put(task, format.write(stream, 
task.changeSets));
                 }
-                FileStateHandle handle = new FileStateHandle(path, 
stream.getPos());
+                StreamStateHandle handle = handleFactory.apply(path, 
stream.getPos());
+                changelogRegistry.startTracking(
+                        handle,
+                        tasks.stream()
+                                .flatMap(t -> t.getChangeSets().stream())
+                                .map(StateChangeSet::getLogId)
+                                .collect(Collectors.toSet()));
                 // WARN: streams have to be closed before returning the results
                 // otherwise JM may receive invalid handles
                 return new UploadTasksResult(tasksOffsets, handle);
diff --git 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeSet.java
 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeSet.java
index 5fcc08aae5b..836bfb82fa5 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeSet.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeSet.java
@@ -17,6 +17,7 @@
 
 package org.apache.flink.changelog.fs;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.runtime.state.changelog.SequenceNumber;
 import org.apache.flink.runtime.state.changelog.StateChange;
 
@@ -33,7 +34,8 @@ import static java.util.Collections.unmodifiableList;
  * that constructor arguments are not modified outside.
  */
 @ThreadSafe
-class StateChangeSet {
+@Internal
+public class StateChangeSet {
     private final UUID logId;
     private final List<StateChange> changes;
     private final SequenceNumber sequenceNumber;
diff --git 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeUploadScheduler.java
 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeUploadScheduler.java
index 8edaa6d2fba..8c82b6602ac 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeUploadScheduler.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeUploadScheduler.java
@@ -17,6 +17,7 @@
 
 package org.apache.flink.changelog.fs;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.configuration.ReadableConfig;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.io.AvailabilityProvider;
@@ -59,7 +60,8 @@ import static 
org.apache.flink.util.Preconditions.checkArgument;
  * directly calls {@link StateChangeUploader#upload(Collection)}. Other 
implementations might batch
  * the tasks for efficiency.
  */
-interface StateChangeUploadScheduler extends AutoCloseable {
+@Internal
+public interface StateChangeUploadScheduler extends AutoCloseable {
 
     /**
      * Schedule the upload and {@link UploadTask#complete(List) complete} or 
{@link
@@ -82,7 +84,10 @@ interface StateChangeUploadScheduler extends AutoCloseable {
     }
 
     static StateChangeUploadScheduler fromConfig(
-            ReadableConfig config, ChangelogStorageMetricGroup metricGroup) 
throws IOException {
+            ReadableConfig config,
+            ChangelogStorageMetricGroup metricGroup,
+            TaskChangelogRegistry changelogRegistry)
+            throws IOException {
         Path basePath = new Path(config.get(BASE_PATH));
         long bytes = config.get(UPLOAD_BUFFER_SIZE).getBytes();
         checkArgument(bytes <= Integer.MAX_VALUE);
@@ -93,7 +98,8 @@ interface StateChangeUploadScheduler extends AutoCloseable {
                         basePath.getFileSystem(),
                         config.get(COMPRESSION_ENABLED),
                         bufferSize,
-                        metricGroup);
+                        metricGroup,
+                        changelogRegistry);
         BatchingStateChangeUploadScheduler batchingStore =
                 new BatchingStateChangeUploadScheduler(
                         config.get(PERSIST_DELAY).toMillis(),
@@ -110,6 +116,7 @@ interface StateChangeUploadScheduler extends AutoCloseable {
         return () -> AvailabilityProvider.AVAILABLE;
     }
 
+    /** Upload Task for {@link StateChangeUploadScheduler}. */
     @ThreadSafe
     final class UploadTask {
         final Collection<StateChangeSet> changeSets;
@@ -150,9 +157,17 @@ interface StateChangeUploadScheduler extends AutoCloseable 
{
             return size;
         }
 
+        public Collection<StateChangeSet> getChangeSets() {
+            return changeSets;
+        }
+
         @Override
         public String toString() {
             return "changeSets=" + changeSets;
         }
+
+        public boolean isFinished() {
+            return finished.get();
+        }
     }
 }
diff --git 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeUploader.java
 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeUploader.java
index 88d4e4b4955..1f18fe25d8b 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeUploader.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/StateChangeUploader.java
@@ -17,6 +17,8 @@
 
 package org.apache.flink.changelog.fs;
 
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.changelog.fs.StateChangeUploadScheduler.UploadTask;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.Preconditions;
@@ -34,13 +36,15 @@ import static java.util.stream.Collectors.toList;
  * changelog parts. It has a single {@link #upload} method with a collection 
of {@link UploadTask}
  * argument which is meant to initiate such an upload.
  */
-interface StateChangeUploader extends AutoCloseable {
+@Internal
+public interface StateChangeUploader extends AutoCloseable {
     /**
      * Execute the upload task and return the results. It is the caller 
responsibility to {@link
      * UploadTask#complete(List) complete} the tasks.
      */
     UploadTasksResult upload(Collection<UploadTask> tasks) throws IOException;
 
+    /** Result of executing one or more {@link UploadTask upload tasks}. */
     final class UploadTasksResult {
         private final Map<UploadTask, Map<StateChangeSet, Long>> tasksOffsets;
         private final StreamStateHandle handle;
@@ -73,5 +77,10 @@ interface StateChangeUploader extends AutoCloseable {
         public void discard() throws Exception {
             handle.discardState();
         }
+
+        @VisibleForTesting
+        public StreamStateHandle getStreamStateHandle() {
+            return handle;
+        }
     }
 }
diff --git 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/TaskChangelogRegistry.java
 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/TaskChangelogRegistry.java
new file mode 100644
index 00000000000..ff089fecb82
--- /dev/null
+++ 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/TaskChangelogRegistry.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.changelog.fs;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.state.StreamStateHandle;
+
+import java.util.Set;
+import java.util.UUID;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+
+/**
+ * TM-side registry of {@link org.apache.flink.runtime.state.StateObject 
StateObjects}, each
+ * representing one or more changelog segments. Changelog segments are 
uploaded by {@link
+ * org.apache.flink.runtime.state.changelog.StateChangelogWriter 
StateChangelogWriters} of a {@link
+ * org.apache.flink.runtime.state.changelog.StateChangelogStorage 
StateChangelogStorage}.
+ *
+ * <p>Initially, when {@link #startTracking(StreamStateHandle, Set) starting 
the tracking}, the
+ * ownership of a changelog segments is not clear, and it is assumed that JM 
<strong>might</strong>
+ * be the owner. Once the backends are not using the segments, JM can not 
become an owner anymore.
+ * the state is discarded.
+ *
+ * <p>However, if at any point it becomes known that JM is the owner, tracking 
is {@link
+ * #stopTracking(StreamStateHandle) stopped} and the state will not be 
discarded.
+ *
+ * <p>It is the client responsibility to make sure that JM can not become an 
owner when calling
+ * {@link #notUsed(StreamStateHandle, UUID)}.
+ */
+@Internal
+public interface TaskChangelogRegistry {
+
+    /** Start tracking the state uploaded for the given backends. */
+    void startTracking(StreamStateHandle handle, Set<UUID> backendIDs);
+
+    /** Stop tracking the state, so that it's not tracked (some other 
component is doing that). */
+    void stopTracking(StreamStateHandle handle);
+
+    /**
+     * Mark the state as unused by the given backend, e.g. if it was 
pre-emptively uploaded and
+     * materialized. Once no backend is using the state, it is discarded 
(unless it was {@link
+     * #stopTracking(StreamStateHandle) unregistered} earlier).
+     */
+    void notUsed(StreamStateHandle handle, UUID backendId);
+
+    TaskChangelogRegistry NO_OP =
+            new TaskChangelogRegistry() {
+                @Override
+                public void startTracking(StreamStateHandle handle, Set<UUID> 
backendIDs) {}
+
+                @Override
+                public void stopTracking(StreamStateHandle handle) {}
+
+                @Override
+                public void notUsed(StreamStateHandle handle, UUID backendId) 
{}
+            };
+
+    static TaskChangelogRegistry defaultChangelogRegistry(int 
numAsyncDiscardThreads) {
+        return 
defaultChangelogRegistry(Executors.newFixedThreadPool(numAsyncDiscardThreads));
+    }
+
+    @VisibleForTesting
+    static TaskChangelogRegistry defaultChangelogRegistry(Executor executor) {
+        return new TaskChangelogRegistryImpl(executor);
+    }
+}
diff --git 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/TaskChangelogRegistryImpl.java
 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/TaskChangelogRegistryImpl.java
new file mode 100644
index 00000000000..55376148363
--- /dev/null
+++ 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/TaskChangelogRegistryImpl.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.changelog.fs;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.state.PhysicalStateHandleID;
+import org.apache.flink.runtime.state.StreamStateHandle;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.concurrent.ThreadSafe;
+
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CopyOnWriteArraySet;
+import java.util.concurrent.Executor;
+
+@Internal
+@ThreadSafe
+class TaskChangelogRegistryImpl implements TaskChangelogRegistry {
+    private static final Logger LOG = 
LoggerFactory.getLogger(TaskChangelogRegistryImpl.class);
+
+    private final Map<PhysicalStateHandleID, Set<UUID>> entries = new 
ConcurrentHashMap<>();
+    private final Executor executor;
+
+    public TaskChangelogRegistryImpl(Executor executor) {
+        this.executor = executor;
+    }
+
+    @Override
+    public void startTracking(StreamStateHandle handle, Set<UUID> backendIDs) {
+        LOG.debug(
+                "start tracking state, key: {}, state: {}",
+                handle.getStreamStateHandleID(),
+                handle);
+        entries.put(handle.getStreamStateHandleID(), new 
CopyOnWriteArraySet<>(backendIDs));
+    }
+
+    @Override
+    public void stopTracking(StreamStateHandle handle) {
+        LOG.debug(
+                "stop tracking state, key: {}, state: {}", 
handle.getStreamStateHandleID(), handle);
+        entries.remove(handle.getStreamStateHandleID());
+    }
+
+    @Override
+    public void notUsed(StreamStateHandle handle, UUID backendId) {
+        PhysicalStateHandleID key = handle.getStreamStateHandleID();
+        LOG.debug("backend {} not using state, key: {}, state: {}", backendId, 
key, handle);
+        Set<UUID> backends = entries.get(key);
+        if (backends == null) {
+            LOG.warn("backend {} was not using state, key: {}, state: {}", 
backendId, key, handle);
+            return;
+        }
+        backends.remove(backendId);
+        if (backends.isEmpty() && entries.remove(key) != null) {
+            LOG.debug("state is not used by any backend, schedule discard: 
{}/{}", key, handle);
+            scheduleDiscard(handle);
+        }
+    }
+
+    private void scheduleDiscard(StreamStateHandle handle) {
+        executor.execute(
+                () -> {
+                    try {
+                        LOG.trace("discard uploaded but unused state changes: 
{}", handle);
+                        handle.discardState();
+                    } catch (Exception e) {
+                        LOG.warn("unable to discard uploaded but unused state 
changes", e);
+                    }
+                });
+    }
+}
diff --git 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/UploadResult.java
 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/UploadResult.java
index a09d7cf7b5d..b2caa5c3bd9 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/UploadResult.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/main/java/org/apache/flink/changelog/fs/UploadResult.java
@@ -17,12 +17,15 @@
 
 package org.apache.flink.changelog.fs;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.changelog.SequenceNumber;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
-final class UploadResult {
+/** Result of uploading state changes. */
+@Internal
+public final class UploadResult {
     public final StreamStateHandle streamStateHandle;
     public final long offset;
     public final SequenceNumber sequenceNumber;
diff --git 
a/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/ChangelogStorageMetricsTest.java
 
b/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/ChangelogStorageMetricsTest.java
index 59f1ae270ae..1002b713829 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/ChangelogStorageMetricsTest.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/ChangelogStorageMetricsTest.java
@@ -64,7 +64,11 @@ public class ChangelogStorageMetricsTest {
 
         try (FsStateChangelogStorage storage =
                 new FsStateChangelogStorage(
-                        Path.fromLocalFile(tempFolder.toFile()), false, 100, 
metrics)) {
+                        Path.fromLocalFile(tempFolder.toFile()),
+                        false,
+                        100,
+                        metrics,
+                        TaskChangelogRegistry.NO_OP)) {
             FsStateChangelogWriter writer = createWriter(storage);
             int numUploads = 5;
             for (int i = 0; i < numUploads; i++) {
@@ -84,7 +88,11 @@ public class ChangelogStorageMetricsTest {
 
         try (FsStateChangelogStorage storage =
                 new FsStateChangelogStorage(
-                        Path.fromLocalFile(tempFolder.toFile()), false, 100, 
metrics)) {
+                        Path.fromLocalFile(tempFolder.toFile()),
+                        false,
+                        100,
+                        metrics,
+                        TaskChangelogRegistry.NO_OP)) {
             FsStateChangelogWriter writer = createWriter(storage);
 
             // upload single byte to infer header size
@@ -111,7 +119,12 @@ public class ChangelogStorageMetricsTest {
         ChangelogStorageMetricGroup metrics =
                 new 
ChangelogStorageMetricGroup(createUnregisteredTaskManagerJobMetricGroup());
         try (FsStateChangelogStorage storage =
-                new FsStateChangelogStorage(Path.fromLocalFile(file), false, 
100, metrics)) {
+                new FsStateChangelogStorage(
+                        Path.fromLocalFile(file),
+                        false,
+                        100,
+                        metrics,
+                        TaskChangelogRegistry.NO_OP)) {
             FsStateChangelogWriter writer = createWriter(storage);
 
             int numUploads = 5;
@@ -136,7 +149,13 @@ public class ChangelogStorageMetricsTest {
                 new 
ChangelogStorageMetricGroup(createUnregisteredTaskManagerJobMetricGroup());
         Path basePath = Path.fromLocalFile(tempFolder.toFile());
         StateChangeFsUploader uploader =
-                new StateChangeFsUploader(basePath, basePath.getFileSystem(), 
false, 100, metrics);
+                new StateChangeFsUploader(
+                        basePath,
+                        basePath.getFileSystem(),
+                        false,
+                        100,
+                        metrics,
+                        TaskChangelogRegistry.NO_OP);
         ManuallyTriggeredScheduledExecutorService scheduler =
                 new ManuallyTriggeredScheduledExecutorService();
         BatchingStateChangeUploadScheduler batcher =
@@ -153,7 +172,9 @@ public class ChangelogStorageMetricsTest {
                                 metrics.getTotalAttemptsPerUpload()),
                         metrics);
 
-        FsStateChangelogStorage storage = new FsStateChangelogStorage(batcher, 
Integer.MAX_VALUE);
+        FsStateChangelogStorage storage =
+                new FsStateChangelogStorage(
+                        batcher, Integer.MAX_VALUE, 
TaskChangelogRegistry.NO_OP);
         FsStateChangelogWriter[] writers = new 
FsStateChangelogWriter[numWriters];
         for (int i = 0; i < numWriters; i++) {
             writers[i] =
@@ -203,7 +224,9 @@ public class ChangelogStorageMetricsTest {
                                 metrics.getTotalAttemptsPerUpload()),
                         metrics);
 
-        FsStateChangelogStorage storage = new FsStateChangelogStorage(batcher, 
Integer.MAX_VALUE);
+        FsStateChangelogStorage storage =
+                new FsStateChangelogStorage(
+                        batcher, Integer.MAX_VALUE, 
TaskChangelogRegistry.NO_OP);
         FsStateChangelogWriter writer = createWriter(storage);
 
         try {
@@ -243,7 +266,9 @@ public class ChangelogStorageMetricsTest {
                                 metrics.getTotalAttemptsPerUpload()),
                         metrics);
 
-        FsStateChangelogStorage storage = new FsStateChangelogStorage(batcher, 
Integer.MAX_VALUE);
+        FsStateChangelogStorage storage =
+                new FsStateChangelogStorage(
+                        batcher, Integer.MAX_VALUE, 
TaskChangelogRegistry.NO_OP);
         FsStateChangelogWriter writer = createWriter(storage);
 
         try {
@@ -281,7 +306,13 @@ public class ChangelogStorageMetricsTest {
 
         Path path = Path.fromLocalFile(tempFolder.toFile());
         StateChangeFsUploader delegate =
-                new StateChangeFsUploader(path, path.getFileSystem(), false, 
100, metrics);
+                new StateChangeFsUploader(
+                        path,
+                        path.getFileSystem(),
+                        false,
+                        100,
+                        metrics,
+                        TaskChangelogRegistry.NO_OP);
         ManuallyTriggeredScheduledExecutorService scheduler =
                 new ManuallyTriggeredScheduledExecutorService();
         BatchingStateChangeUploadScheduler batcher =
@@ -298,7 +329,7 @@ public class ChangelogStorageMetricsTest {
                                 metrics.getTotalAttemptsPerUpload()),
                         metrics);
         try (FsStateChangelogStorage storage =
-                new FsStateChangelogStorage(batcher, Long.MAX_VALUE)) {
+                new FsStateChangelogStorage(batcher, Long.MAX_VALUE, 
TaskChangelogRegistry.NO_OP)) {
             FsStateChangelogWriter writer = createWriter(storage);
             int numUploads = 11;
             for (int i = 0; i < numUploads; i++) {
diff --git 
a/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogStorageTest.java
 
b/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogStorageTest.java
index e23945cbcfa..906806675a4 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogStorageTest.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogStorageTest.java
@@ -54,7 +54,8 @@ public class FsStateChangelogStorageTest
                 Path.fromLocalFile(temporaryFolder),
                 compression,
                 1024 * 1024 * 10,
-                createUnregisteredChangelogStorageMetricGroup());
+                createUnregisteredChangelogStorageMetricGroup(),
+                TaskChangelogRegistry.NO_OP);
     }
 
     /**
@@ -97,7 +98,10 @@ public class FsStateChangelogStorageTest
                             }
                         };
                 StateChangelogWriter<?> writer =
-                        new FsStateChangelogStorage(scheduler, 0 /* persist 
immediately */)
+                        new FsStateChangelogStorage(
+                                        scheduler,
+                                        0,
+                                        TaskChangelogRegistry.NO_OP /* persist 
immediately */)
                                 .createWriter(
                                         new OperatorID().toString(),
                                         KeyGroupRange.of(0, 0),
diff --git 
a/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogWriterSqnTest.java
 
b/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogWriterSqnTest.java
index b005572d380..926f188bb5c 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogWriterSqnTest.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogWriterSqnTest.java
@@ -79,7 +79,8 @@ public class FsStateChangelogWriterSqnTest {
                         StateChangeUploadScheduler.directScheduler(
                                 new TestingStateChangeUploader()),
                         Long.MAX_VALUE,
-                        new SyncMailboxExecutor())) {
+                        new SyncMailboxExecutor(),
+                        TaskChangelogRegistry.NO_OP)) {
             if (writerSqnTestSettings.withAppend) {
                 append(writer);
             }
diff --git 
a/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogWriterTest.java
 
b/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogWriterTest.java
index 78015e29e78..b9e6227cf1d 100644
--- 
a/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogWriterTest.java
+++ 
b/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/FsStateChangelogWriterTest.java
@@ -92,8 +92,9 @@ class FsStateChangelogWriterTest {
                     byte[] bytes = getBytes();
                     SequenceNumber sqn = append(writer, bytes);
                     writer.persist(sqn);
+                    uploader.completeUpload();
                     uploader.reset();
-                    writer.confirm(sqn, writer.lastAppendedSqnUnsafe().next());
+                    writer.confirm(sqn, writer.nextSequenceNumber());
                     writer.persist(sqn);
                     assertNoUpload(uploader, "confirmed changes shouldn't be 
re-uploaded");
                 });
@@ -223,7 +224,8 @@ class FsStateChangelogWriterTest {
                         KeyGroupRange.of(KEY_GROUP, KEY_GROUP),
                         StateChangeUploadScheduler.directScheduler(uploader),
                         appendPersistThreshold,
-                        new SyncMailboxExecutor())) {
+                        new SyncMailboxExecutor(),
+                        TaskChangelogRegistry.NO_OP)) {
             test.accept(writer, uploader);
         }
     }
diff --git 
a/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/TaskChangelogRegistryImplTest.java
 
b/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/TaskChangelogRegistryImplTest.java
new file mode 100644
index 00000000000..e5dc73ad7be
--- /dev/null
+++ 
b/flink-dstl/flink-dstl-dfs/src/test/java/org/apache/flink/changelog/fs/TaskChangelogRegistryImplTest.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.changelog.fs;
+
+import org.apache.flink.runtime.state.TestingStreamStateHandle;
+
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.UUID;
+
+import static org.apache.flink.util.concurrent.Executors.directExecutor;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/** {@link TaskChangelogRegistryImpl} test. */
+public class TaskChangelogRegistryImplTest {
+
+    @Test
+    public void testDiscardedWhenNotUsed() {
+        TaskChangelogRegistry registry = new 
TaskChangelogRegistryImpl(directExecutor());
+        TestingStreamStateHandle handle = new TestingStreamStateHandle();
+        List<UUID> backends = Arrays.asList(UUID.randomUUID(), 
UUID.randomUUID());
+        registry.startTracking(handle, new HashSet<>(backends));
+        for (UUID backend : backends) {
+            assertFalse(handle.isDisposed());
+            registry.notUsed(handle, backend);
+        }
+        assertTrue(handle.isDisposed());
+    }
+
+    @Test
+    public void testNotDiscardedIfStoppedTracking() {
+        TaskChangelogRegistry registry = new 
TaskChangelogRegistryImpl(directExecutor());
+        TestingStreamStateHandle handle = new TestingStreamStateHandle();
+        List<UUID> backends = Arrays.asList(UUID.randomUUID(), 
UUID.randomUUID());
+        registry.startTracking(handle, new HashSet<>(backends));
+        registry.stopTracking(handle);
+        backends.forEach(id -> registry.notUsed(handle, id));
+        assertFalse(handle.isDisposed());
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/StateChangelogWriter.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/StateChangelogWriter.java
index 48e8844219e..0ee4f615a31 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/StateChangelogWriter.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/StateChangelogWriter.java
@@ -18,7 +18,6 @@
 package org.apache.flink.runtime.state.changelog;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.annotation.VisibleForTesting;
 
 import java.io.IOException;
 import java.util.concurrent.CompletableFuture;
@@ -51,9 +50,17 @@ public interface StateChangelogWriter<Handle extends 
ChangelogStateHandle> exten
     CompletableFuture<Handle> persist(SequenceNumber from) throws IOException;
 
     /**
-     * Truncate in-memory view of this state changelog to free up resources. 
Called upon state
-     * materialization. Any {@link #persist(SequenceNumber) persisted} state 
changes will not be
-     * discarded; any ongoing persist calls will not be affected.
+     * Truncate this state changelog to free up the resources and collect any 
garbage. That means:
+     *
+     * <ul>
+     *   <li>Discard the written state changes - in the provided range [from; 
to)
+     *   <li>Truncate the in-memory view of this changelog - in the range [0; 
to)
+     * </ul>
+     *
+     * Called upon state materialization. Any ongoing persist calls will not 
be affected.
+     *
+     * <p>WARNING: the range [from; to) must not include any range that is 
included into any
+     * checkpoint that is not subsumed or aborted.
      *
      * @param to exclusive
      */
@@ -73,12 +80,17 @@ public interface StateChangelogWriter<Handle extends 
ChangelogStateHandle> exten
      */
     void reset(SequenceNumber from, SequenceNumber to);
 
+    /**
+     * Truncate the tail of log and close it. No new appends will be possible. 
Any appended but not
+     * persisted records will be lost.
+     *
+     * @param from {@link SequenceNumber} from which to truncate the 
changelog, inclusive
+     */
+    void truncateAndClose(SequenceNumber from);
+
     /**
      * Close this log. No new appends will be possible. Any appended but not 
persisted records will
      * be lost.
      */
     void close();
-
-    @VisibleForTesting
-    SequenceNumber getLowestSequenceNumber();
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/inmemory/InMemoryStateChangelogWriter.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/inmemory/InMemoryStateChangelogWriter.java
index 19043ded884..3bd0a969d83 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/inmemory/InMemoryStateChangelogWriter.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/inmemory/InMemoryStateChangelogWriter.java
@@ -34,7 +34,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.NavigableMap;
-import java.util.SortedMap;
 import java.util.TreeMap;
 import java.util.concurrent.CompletableFuture;
 import java.util.stream.Collectors;
@@ -104,18 +103,13 @@ class InMemoryStateChangelogWriter implements 
StateChangelogWriter<InMemoryChang
     }
 
     @Override
-    public SequenceNumber getLowestSequenceNumber() {
-        return changesByKeyGroup.values().stream()
-                .filter(map -> !map.isEmpty())
-                .map(SortedMap::firstKey)
-                .min(Comparator.naturalOrder())
-                .orElse(nextSequenceNumber());
+    public void truncate(SequenceNumber to) {
+        changesByKeyGroup.forEach((kg, changesBySqn) -> 
changesBySqn.headMap(to, false).clear());
     }
 
     @Override
-    public void truncate(SequenceNumber before) {
-        changesByKeyGroup.forEach(
-                (kg, changesBySqn) -> changesBySqn.headMap(before, 
false).clear());
+    public void truncateAndClose(SequenceNumber from) {
+        close();
     }
 
     @Override
diff --git 
a/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java
 
b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java
index 35849546c91..581034debae 100644
--- 
a/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.CheckpointListener;
+import org.apache.flink.api.common.state.InternalCheckpointListener;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
@@ -105,7 +106,8 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
 public class ChangelogKeyedStateBackend<K>
         implements CheckpointableKeyedStateBackend<K>,
                 CheckpointListener,
-                TestableKeyedStateBackend<K> {
+                TestableKeyedStateBackend<K>,
+                InternalCheckpointListener {
     private static final Logger LOG = 
LoggerFactory.getLogger(ChangelogKeyedStateBackend.class);
 
     /**
@@ -182,6 +184,8 @@ public class ChangelogKeyedStateBackend<K>
 
     private long lastConfirmedMaterializationId = -1L;
 
+    private final ChangelogTruncateHelper changelogTruncateHelper;
+
     public ChangelogKeyedStateBackend(
             AbstractKeyedStateBackend<K> keyedStateBackend,
             String subtaskName,
@@ -218,6 +222,8 @@ public class ChangelogKeyedStateBackend<K>
         this.keyValueStatesByName = new HashMap<>();
         this.changelogStateFactory = changelogStateFactory;
         this.stateChangelogWriter = stateChangelogWriter;
+        this.lastUploadedTo = stateChangelogWriter.initialSequenceNumber();
+        this.closer.register(() -> 
stateChangelogWriter.truncateAndClose(lastUploadedTo));
         this.changelogSnapshotState = completeRestore(initialState);
         this.streamFactory =
                 new CheckpointStreamFactory() {
@@ -243,6 +249,7 @@ public class ChangelogKeyedStateBackend<K>
                     }
                 };
         this.closer.register(keyedStateBackend);
+        this.changelogTruncateHelper = new 
ChangelogTruncateHelper(stateChangelogWriter);
     }
 
     // -------------------- CheckpointableKeyedStateBackend 
--------------------------------
@@ -366,6 +373,7 @@ public class ChangelogKeyedStateBackend<K>
         lastCheckpointId = checkpointId;
         lastUploadedFrom = changelogSnapshotState.lastMaterializedTo();
         lastUploadedTo = stateChangelogWriter.nextSequenceNumber();
+        changelogTruncateHelper.checkpoint(checkpointId, lastUploadedTo);
 
         LOG.info(
                 "snapshot of {} for checkpoint {}, change range: {}..{}",
@@ -664,8 +672,7 @@ public class ChangelogKeyedStateBackend<K>
     public void updateChangelogSnapshotState(
             SnapshotResult<KeyedStateHandle> materializedSnapshot,
             long materializationID,
-            SequenceNumber upTo)
-            throws Exception {
+            SequenceNumber upTo) {
 
         LOG.info(
                 "Task {} finishes materialization, updates the snapshotState 
upTo {} : {}",
@@ -678,8 +685,7 @@ public class ChangelogKeyedStateBackend<K>
                         Collections.emptyList(),
                         upTo,
                         materializationID);
-
-        stateChangelogWriter.truncate(upTo);
+        changelogTruncateHelper.materialized(upTo);
     }
 
     // TODO: this method may change after the ownership PR
@@ -694,6 +700,11 @@ public class ChangelogKeyedStateBackend<K>
         return keyedStateBackend.getDelegatedKeyedStateBackend(recursive);
     }
 
+    @Override
+    public void notifyCheckpointSubsumed(long checkpointId) throws Exception {
+        changelogTruncateHelper.checkpointSubsumed(checkpointId);
+    }
+
     public ChangelogRestoreTarget<K> getChangelogRestoreTarget() {
         return new ChangelogRestoreTarget<K>() {
             @Override
diff --git 
a/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogTruncateHelper.java
 
b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogTruncateHelper.java
new file mode 100644
index 00000000000..6a6b71868b5
--- /dev/null
+++ 
b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogTruncateHelper.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.changelog;
+
+import org.apache.flink.runtime.state.changelog.SequenceNumber;
+import org.apache.flink.runtime.state.changelog.StateChangelogWriter;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.NavigableMap;
+import java.util.TreeMap;
+
+/**
+ * A helper class to track changelog usage by {@link 
ChangelogKeyedStateBackend}. A changelog
+ * segment is not used when:
+ *
+ * <ol>
+ *   <li>All the checkpoints using it were subsumed, and
+ *   <li>Is was materialized (and therefore will not be used in future 
checkpoints)
+ * </ol>
+ *
+ * <p>At this point, {@link 
org.apache.flink.runtime.state.changelog.StateChangelogWriter} is
+ * notified and will {@link 
org.apache.flink.runtime.state.changelog.StateChangelogWriter#truncate
+ * truncate} the changelog. That implies discarding any (changelog) state that 
is not known to JM.
+ */
+class ChangelogTruncateHelper {
+    private static final Logger LOG = 
LoggerFactory.getLogger(ChangelogTruncateHelper.class);
+
+    private final StateChangelogWriter<?> stateChangelogWriter;
+    private final NavigableMap<Long, SequenceNumber> checkpointedUpTo = new 
TreeMap<>();
+
+    private SequenceNumber subsumedUpTo;
+
+    private SequenceNumber materializedUpTo;
+
+    ChangelogTruncateHelper(StateChangelogWriter<?> stateChangelogWriter) {
+        this.stateChangelogWriter = stateChangelogWriter;
+    }
+
+    /**
+     * Set the highest {@link SequenceNumber} of changelog used by the given 
checkpoint.
+     *
+     * @param lastUploadedTo exclusive
+     */
+    public void checkpoint(long checkpointId, SequenceNumber lastUploadedTo) {
+        checkpointedUpTo.put(checkpointId, lastUploadedTo);
+    }
+
+    /** Handle checkpoint subsumption, potentially {@link #truncate() 
truncating} the changelog. */
+    public void checkpointSubsumed(long checkpointId) {
+        SequenceNumber sqn = checkpointedUpTo.get(checkpointId);
+        LOG.debug("checkpoint {} subsumed, max sqn: {}", checkpointId, sqn);
+        if (sqn != null) {
+            subsumedUpTo = sqn;
+            checkpointedUpTo.headMap(checkpointId, true).clear();
+            truncate();
+        }
+    }
+
+    /**
+     * Handle changelog materialization, potentially {@link #truncate() 
truncating} the changelog.
+     *
+     * @param upTo exclusive
+     */
+    public void materialized(SequenceNumber upTo) {
+        materializedUpTo = upTo;
+        truncate();
+    }
+
+    private void truncate() {
+        if (subsumedUpTo != null && materializedUpTo != null) {
+            SequenceNumber to =
+                    subsumedUpTo.compareTo(materializedUpTo) < 0 ? 
subsumedUpTo : materializedUpTo;
+            LOG.debug(
+                    "truncate changelog to {} (subsumed up to: {}, 
materialized up to: {})",
+                    to,
+                    subsumedUpTo,
+                    materializedUpTo);
+            stateChangelogWriter.truncate(to);
+        }
+    }
+}
diff --git 
a/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogStateBackendTestUtils.java
 
b/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogStateBackendTestUtils.java
index 942a22982be..1113d3adf4d 100644
--- 
a/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogStateBackendTestUtils.java
+++ 
b/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogStateBackendTestUtils.java
@@ -28,6 +28,7 @@ import 
org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.java.typeutils.GenericTypeInfo;
 import org.apache.flink.changelog.fs.ChangelogStorageMetricGroup;
 import org.apache.flink.changelog.fs.FsStateChangelogStorage;
+import org.apache.flink.changelog.fs.TaskChangelogRegistry;
 import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
@@ -53,9 +54,6 @@ import org.apache.flink.runtime.state.StateBackendTestBase;
 import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
-import org.apache.flink.runtime.state.changelog.ChangelogStateHandle;
-import org.apache.flink.runtime.state.changelog.SequenceNumber;
-import org.apache.flink.runtime.state.changelog.StateChangelogWriter;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.runtime.testutils.statemigration.TestType;
 import org.apache.flink.util.CloseableIterator;
@@ -140,7 +138,8 @@ public class ChangelogStateBackendTestUtils {
                                 1024,
                                 new ChangelogStorageMetricGroup(
                                         UnregisteredMetricGroups
-                                                
.createUnregisteredTaskManagerJobMetricGroup())))
+                                                
.createUnregisteredTaskManagerJobMetricGroup()),
+                                TaskChangelogRegistry.NO_OP))
                 .build();
     }
 
@@ -178,7 +177,7 @@ public class ChangelogStateBackendTestUtils {
             keyedBackend.setCurrentKey(2);
             state.update(new StateBackendTestBase.TestPojo("u2", 2));
 
-            materialize(keyedBackend, periodicMaterializationManager);
+            periodicMaterializationManager.triggerMaterialization();
 
             keyedBackend.setCurrentKey(2);
             state.update(new StateBackendTestBase.TestPojo("u2", 22));
@@ -186,7 +185,7 @@ public class ChangelogStateBackendTestUtils {
             keyedBackend.setCurrentKey(3);
             state.update(new StateBackendTestBase.TestPojo("u3", 3));
 
-            materialize(keyedBackend, periodicMaterializationManager);
+            periodicMaterializationManager.triggerMaterialization();
 
             keyedBackend.setCurrentKey(4);
             state.update(new StateBackendTestBase.TestPojo("u4", 4));
@@ -238,25 +237,6 @@ public class ChangelogStateBackendTestUtils {
         }
     }
 
-    /**
-     * Explicitly trigger materialization. Materialization is expected to 
complete before returning
-     * from this method by the use of direct executor when constructing 
materializer.
-     * Automatic/periodic triggering is disabled by NOT starting the 
periodicMaterializationManager.
-     *
-     * <p>Additionally, verify changelog truncation happened upon completion.
-     */
-    private static void materialize(
-            ChangelogKeyedStateBackend<Integer> keyedBackend,
-            PeriodicMaterializationManager periodicMaterializationManager) {
-        StateChangelogWriter<? extends ChangelogStateHandle> writer =
-                keyedBackend.getChangelogWriter();
-        SequenceNumber sqn = writer.nextSequenceNumber();
-        periodicMaterializationManager.triggerMaterialization();
-        assertTrue(
-                "Materialization didn't truncate the changelog",
-                sqn.compareTo(writer.getLowestSequenceNumber()) <= 0);
-    }
-
     public static void testMaterializedRestoreForPriorityQueue(
             StateBackend stateBackend, Environment env, 
CheckpointStreamFactory streamFactory)
             throws Exception {
@@ -289,12 +269,12 @@ public class ChangelogStateBackendTestUtils {
 
             assertThat(actualList, containsInAnyOrder(elementA100, elementA10, 
elementA20));
 
-            materialize(keyedBackend, periodicMaterializationManager);
+            periodicMaterializationManager.triggerMaterialization();
 
             TestType elementB9 = new TestType("b", 9);
             assertTrue(priorityQueue.add(elementB9));
 
-            materialize(keyedBackend, periodicMaterializationManager);
+            periodicMaterializationManager.triggerMaterialization();
 
             TestType elementC9 = new TestType("c", 9);
             TestType elementC8 = new TestType("c", 8);
diff --git 
a/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogStateDiscardTest.java
 
b/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogStateDiscardTest.java
new file mode 100644
index 00000000000..01d0ec71702
--- /dev/null
+++ 
b/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogStateDiscardTest.java
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.changelog;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.changelog.fs.FsStateChangelogStorage;
+import org.apache.flink.changelog.fs.StateChangeSet;
+import org.apache.flink.changelog.fs.StateChangeUploadScheduler;
+import org.apache.flink.changelog.fs.StateChangeUploadScheduler.UploadTask;
+import org.apache.flink.changelog.fs.StateChangeUploader;
+import org.apache.flink.changelog.fs.TaskChangelogRegistry;
+import org.apache.flink.changelog.fs.UploadResult;
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.mailbox.SyncMailboxExecutor;
+import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.LocalRecoveryConfig;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TestingStreamStateHandle;
+import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
+import org.apache.flink.runtime.state.changelog.StateChangelogStorage;
+import org.apache.flink.runtime.state.changelog.StateChangelogWriter;
+import org.apache.flink.runtime.state.heap.HeapKeyedStateBackendBuilder;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
+import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
+import 
org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorageAccess;
+import org.apache.flink.runtime.state.metrics.LatencyTrackingStateConfig;
+import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
+import org.apache.flink.util.function.BiConsumerWithException;
+import org.apache.flink.util.function.TriConsumerWithException;
+
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+
+import static java.util.Collections.emptyList;
+import static java.util.Collections.singletonList;
+import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.toList;
+import static java.util.stream.Collectors.toMap;
+import static java.util.stream.Collectors.toSet;
+import static 
org.apache.flink.changelog.fs.StateChangeUploadScheduler.directScheduler;
+import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT;
+import static org.apache.flink.runtime.state.SnapshotResult.empty;
+import static org.apache.flink.util.Preconditions.checkState;
+import static org.apache.flink.util.concurrent.Executors.directExecutor;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Verifies that any unused state created by {@link ChangelogStateBackend} is 
discarded. This is
+ * achieved by testing integration between {@link ChangelogKeyedStateBackend} 
and {@link
+ * StateChangelogWriter} created by {@link FsStateChangelogStorage}.
+ */
+public class ChangelogStateDiscardTest {
+    private static final Random RANDOM = new Random();
+
+    @Test
+    public void testPreEmptiveUploadDiscardedOnMaterialization() throws 
Exception {
+        singleBackendTest(
+                (backend, writer, uploader) -> {
+                    changeAndLogRandomState(backend, uploader.results::size);
+                    checkpoint(backend, 1L);
+                    backend.notifyCheckpointSubsumed(1L);
+                    assertRetained(uploader.results); // may still be used for 
future checkpoints
+                    materialize(backend, writer);
+                    assertDiscarded(uploader.results);
+                });
+    }
+
+    @Test
+    public void testPreEmptiveUploadDiscardedOnSubsumption() throws Exception {
+        singleBackendTest(
+                (backend, writer, uploader) -> {
+                    changeAndLogRandomState(backend, uploader.results::size);
+                    materialize(backend, writer);
+                    checkpoint(backend, 1L);
+                    assertRetained(uploader.results); // may be used in 
non-subsumed checkpoints
+                    backend.notifyCheckpointSubsumed(1L);
+                    assertDiscarded(uploader.results);
+                });
+    }
+
+    @Test
+    public void testPreEmptiveUploadNotDiscardedWithoutNotification() throws 
Exception {
+        singleBackendTest(
+                (backend, writer, uploader) -> {
+                    changeAndLogRandomState(backend, uploader.results::size);
+                    checkpoint(backend, 1L);
+                    materialize(backend, writer);
+                    assertRetained(uploader.results);
+                });
+    }
+
+    @Test
+    public void 
testPreEmptiveUploadDiscardedOnMaterializationIfCompletedLater() throws 
Exception {
+        final TaskChangelogRegistry registry =
+                
TaskChangelogRegistry.defaultChangelogRegistry(directExecutor());
+        final TestingUploadScheduler scheduler = new 
TestingUploadScheduler(registry);
+        singleBackendTest(
+                new FsStateChangelogStorage(scheduler, 0L, registry),
+                (backend, writer) -> {
+                    changeAndLogRandomState(backend, scheduler.uploads::size);
+                    truncate(writer, backend);
+
+                    
checkState(scheduler.uploads.stream().noneMatch(UploadTask::isFinished));
+                    List<UploadResult> results =
+                            
scheduler.completeUploads(ChangelogStateDiscardTest::uploadResult);
+
+                    assertDiscarded(
+                            results.stream()
+                                    .map(h -> (TestingStreamStateHandle) 
h.getStreamStateHandle())
+                                    .collect(toList()));
+                });
+    }
+
+    @Test
+    public void testPreEmptiveUploadDiscardedOnClose() throws Exception {
+        final List<TestingStreamStateHandle> afterCheckpoint = new 
ArrayList<>();
+        final List<TestingStreamStateHandle> beforeCheckpoint = new 
ArrayList<>();
+        singleBackendTest(
+                (backend, writer, uploader) -> {
+                    changeAndLogRandomState(backend, uploader.results::size);
+                    uploader.drainResultsTo(beforeCheckpoint);
+                    checkpoint(backend, 1L);
+                    changeAndLogRandomState(backend, uploader.results::size);
+                    uploader.drainResultsTo(afterCheckpoint);
+                });
+        assertRetained(beforeCheckpoint);
+        assertDiscarded(afterCheckpoint);
+    }
+
+    /**
+     * Test that an upload is discarded only when it's not used by all 
backends for which it was
+     * initiated.
+     *
+     * <p>Scenario:
+     *
+     * <ol>
+     *   <li>Two backends start pre-emptive uploads, both go into the same 
file (StreamStateHandle)
+     *   <li>First backend materializes (starts and finishes)
+     *   <li>State should not be discarded
+     *   <li>Second backend materializes (starts and finishes)
+     *   <li>State should be discarded
+     * </ol>
+     */
+    @Test
+    public void testPreEmptiveUploadForMultipleBackends() throws Exception {
+        // using the same range (rescaling not involved)
+        final KeyGroupRange kgRange = KeyGroupRange.of(0, 10);
+        final JobID jobId = new JobID();
+        final ExecutionConfig cfg = new ExecutionConfig();
+
+        final TaskChangelogRegistry registry =
+                
TaskChangelogRegistry.defaultChangelogRegistry(directExecutor());
+        final TestingUploadScheduler scheduler = new 
TestingUploadScheduler(registry);
+        final StateChangelogStorage<?> storage =
+                new FsStateChangelogStorage(scheduler, 0, registry);
+        final StateChangelogWriter<?>
+                w1 = storage.createWriter("test-operator-1", kgRange, new 
SyncMailboxExecutor()),
+                w2 = storage.createWriter("test-operator-2", kgRange, new 
SyncMailboxExecutor());
+
+        try (ChangelogKeyedStateBackend<String> b1 = backend(jobId, kgRange, 
cfg, w1);
+                ChangelogKeyedStateBackend<String> b2 = backend(jobId, 
kgRange, cfg, w2)) {
+
+            changeAndLogRandomState(b1, scheduler.uploads::size);
+            changeAndLogRandomState(b2, scheduler.uploads::size);
+
+            // emulate sharing the same file
+            final TestingStreamStateHandle handle = new 
TestingStreamStateHandle();
+            scheduler.completeUploads(task -> uploadResult(task, () -> 
handle));
+
+            truncate(w1, b1);
+            assertRetained(singletonList(handle));
+
+            truncate(w1, b2);
+            assertDiscarded(singletonList(handle));
+        }
+    }
+
+    private void singleBackendTest(
+            TriConsumerWithException<
+                            ChangelogKeyedStateBackend<String>,
+                            StateChangelogWriter<?>,
+                            TestingUploader,
+                            Exception>
+                    testCase)
+            throws Exception {
+        TaskChangelogRegistry registry =
+                
TaskChangelogRegistry.defaultChangelogRegistry(directExecutor());
+        TestingUploader uploader = new TestingUploader(registry);
+        long preEmptivePersistThresholdInBytes = 0L; // flush ASAP
+        singleBackendTest(
+                new FsStateChangelogStorage(
+                        directScheduler(uploader), 
preEmptivePersistThresholdInBytes, registry),
+                (backend, writer) -> testCase.accept(backend, writer, 
uploader));
+    }
+
+    /** Provided storage will be closed. */
+    private void singleBackendTest(
+            StateChangelogStorage<?> storage,
+            BiConsumerWithException<
+                            ChangelogKeyedStateBackend<String>, 
StateChangelogWriter<?>, Exception>
+                    testCase)
+            throws Exception {
+        final JobID jobId = new JobID();
+        final KeyGroupRange kgRange = KeyGroupRange.of(0, 10);
+        final ExecutionConfig cfg = new ExecutionConfig();
+        StateChangelogWriter<?> writer =
+                storage.createWriter("test-operator", kgRange, new 
SyncMailboxExecutor());
+        try {
+            try (ChangelogKeyedStateBackend<String> backend =
+                    backend(jobId, kgRange, cfg, writer)) {
+                testCase.accept(backend, writer);
+            }
+        } finally {
+            storage.close();
+        }
+    }
+
+    private static ChangelogKeyedStateBackend<String> backend(
+            JobID jobId,
+            KeyGroupRange kgRange,
+            ExecutionConfig executionConfig,
+            StateChangelogWriter<?> writer)
+            throws IOException {
+        AbstractKeyedStateBackend<String> nestedBackend =
+                new HeapKeyedStateBackendBuilder<>(
+                                new 
KvStateRegistry().createTaskRegistry(jobId, new JobVertexID()),
+                                StringSerializer.INSTANCE,
+                                StringSerializer.class.getClassLoader(),
+                                kgRange.getNumberOfKeyGroups(),
+                                kgRange,
+                                executionConfig,
+                                TtlTimeProvider.DEFAULT,
+                                LatencyTrackingStateConfig.disabled(),
+                                emptyList(),
+                                
UncompressedStreamCompressionDecorator.INSTANCE,
+                                new LocalRecoveryConfig(null),
+                                new HeapPriorityQueueSetFactory(
+                                        kgRange, 
kgRange.getNumberOfKeyGroups(), 128),
+                                true,
+                                new CloseableRegistry())
+                        .build();
+        return new ChangelogKeyedStateBackend<>(
+                nestedBackend,
+                "test-subtask",
+                executionConfig,
+                TtlTimeProvider.DEFAULT,
+                writer,
+                emptyList(),
+                new MemoryBackendCheckpointStorageAccess(
+                        jobId, null, null, 1 /* don't expect any 
materialization */));
+    }
+
+    private static String randomString() {
+        byte[] bytes = new byte[10];
+        RANDOM.nextBytes(bytes);
+        return new String(bytes);
+    }
+
+    private static void changeAndLogRandomState(
+            ChangelogKeyedStateBackend<String> backend, Supplier<Integer> 
changelogLength)
+            throws Exception {
+        for (int numExistingResults = changelogLength.get();
+                changelogLength.get() == numExistingResults; ) {
+            changeAndRandomState(backend);
+        }
+    }
+
+    private static void 
changeAndRandomState(ChangelogKeyedStateBackend<String> backend)
+            throws Exception {
+        backend.setCurrentKey(randomString());
+        backend.getPartitionedState(
+                        "ns",
+                        StringSerializer.INSTANCE,
+                        new ValueStateDescriptor<>(randomString(), 
String.class))
+                .update(randomString());
+    }
+
+    private static List<UploadResult> uploadResult(UploadTask upload) {
+        return uploadResult(upload, TestingStreamStateHandle::new);
+    }
+
+    private static List<UploadResult> uploadResult(
+            UploadTask upload, Supplier<StreamStateHandle> handleSupplier) {
+        return upload.getChangeSets().stream()
+                .map(
+                        changes ->
+                                new UploadResult(
+                                        handleSupplier.get(),
+                                        0L, // offset
+                                        changes.getSequenceNumber(),
+                                        changes.getSize()))
+                .collect(toList());
+    }
+
+    private static void assertRetained(List<TestingStreamStateHandle> 
toRetain) {
+        assertTrue(
+                "Some state handles were discarded: \n" + toRetain,
+                
toRetain.stream().noneMatch(TestingStreamStateHandle::isDisposed));
+    }
+
+    private static void assertDiscarded(List<TestingStreamStateHandle> 
toDiscard) {
+        assertTrue(
+                "Not all state handles were discarded: \n" + toDiscard,
+                
toDiscard.stream().allMatch(TestingStreamStateHandle::isDisposed));
+    }
+
+    private static void checkpoint(ChangelogKeyedStateBackend<String> backend, 
long checkpointId)
+            throws Exception {
+        backend.snapshot(
+                checkpointId,
+                1L,
+                new MemCheckpointStreamFactory(1000),
+                CheckpointOptions.unaligned(
+                        CHECKPOINT, 
CheckpointStorageLocationReference.getDefault()));
+    }
+
+    /**
+     * An uploader that uses a {@link TestingStreamStateHandle} as a result. 
The usage of that
+     * handle is tracked with the {@link #registry}.
+     */
+    private static class TestingUploader implements StateChangeUploader {
+        private final List<TestingStreamStateHandle> results = new 
ArrayList<>();
+        private final TaskChangelogRegistry registry;
+
+        public TestingUploader(TaskChangelogRegistry registry) {
+            this.registry = registry;
+        }
+
+        @Override
+        public UploadTasksResult upload(Collection<UploadTask> tasks) throws 
IOException {
+            TestingStreamStateHandle handle = new TestingStreamStateHandle();
+            results.add(handle);
+            // todo: avoid making StateChangeSet and its internals public?
+            // todo: make the contract more explicit or extract common code
+            Map<UploadTask, Map<StateChangeSet, Long>> taskOffsets =
+                    tasks.stream().collect(toMap(identity(), 
this::mapOffsets));
+            tasks.forEach(task -> startTracking(registry, handle, task));
+            return new UploadTasksResult(taskOffsets, handle);
+        }
+
+        private Map<StateChangeSet, Long> mapOffsets(UploadTask task) {
+            return 
task.getChangeSets().stream().collect(Collectors.toMap(identity(), ign -> 0L));
+        }
+
+        @Override
+        public void close() throws Exception {}
+
+        private void drainResultsTo(List<TestingStreamStateHandle> toDiscard) {
+            toDiscard.addAll(results);
+            results.clear();
+        }
+    }
+
+    /**
+     * An upload scheduler that collects the upload tasks and allows them to 
be {@link
+     * #completeUploads(Function) completed arbitrarily}. State handles used 
for completion are
+     * tracked by {@link #registry}.
+     */
+    private static class TestingUploadScheduler implements 
StateChangeUploadScheduler {
+        private final List<UploadTask> uploads = new ArrayList<>();
+        private final TaskChangelogRegistry registry;
+
+        private TestingUploadScheduler(TaskChangelogRegistry registry) {
+            this.registry = registry;
+        }
+
+        @Override
+        public void upload(UploadTask uploadTask) throws IOException {
+            uploads.add(uploadTask);
+        }
+
+        /**
+         * Complete the accumulated tasks using the provided results and 
register resulting state
+         * with the {@link #registry}.
+         *
+         * @return upload results
+         */
+        public List<UploadResult> completeUploads(
+                Function<UploadTask, List<UploadResult>> resultsProvider) {
+            List<UploadResult> allResults = new ArrayList<>();
+            uploads.forEach(
+                    task -> {
+                        List<UploadResult> results = 
resultsProvider.apply(task);
+                        for (UploadResult result : results) {
+                            startTracking(registry, 
result.getStreamStateHandle(), task);
+                        }
+                        allResults.addAll(results);
+                        task.complete(results);
+                        checkState(task.isFinished());
+                    });
+            uploads.clear();
+            return allResults;
+        }
+
+        @Override
+        public void close() {}
+    }
+
+    private static void startTracking(
+            TaskChangelogRegistry registry,
+            StreamStateHandle streamStateHandle,
+            UploadTask upload) {
+        registry.startTracking(
+                streamStateHandle,
+                
upload.getChangeSets().stream().map(StateChangeSet::getLogId).collect(toSet()));
+    }
+
+    private static void materialize(
+            ChangelogKeyedStateBackend<String> backend, 
StateChangelogWriter<?> writer) {
+        backend.updateChangelogSnapshotState(empty(), 0L, 
writer.nextSequenceNumber());
+    }
+
+    private static void truncate(
+            StateChangelogWriter<?> writer, ChangelogKeyedStateBackend<String> 
backend)
+            throws Exception {
+        materialize(backend, writer);
+        checkpoint(backend, 1L);
+        backend.notifyCheckpointSubsumed(1L);
+    }
+}
diff --git 
a/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/StateChangeLoggerTestBase.java
 
b/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/StateChangeLoggerTestBase.java
index b23b9d547bf..70604262496 100644
--- 
a/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/StateChangeLoggerTestBase.java
+++ 
b/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/StateChangeLoggerTestBase.java
@@ -134,11 +134,9 @@ abstract class StateChangeLoggerTestBase<Namespace> {
         public void reset(SequenceNumber from, SequenceNumber to) {}
 
         @Override
-        public void close() {}
+        public void truncateAndClose(SequenceNumber from) {}
 
         @Override
-        public SequenceNumber getLowestSequenceNumber() {
-            return initialSequenceNumber();
-        }
+        public void close() {}
     }
 }

Reply via email to