This is an automated email from the ASF dual-hosted git repository.

sewen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new faf394d  [FLINK-15014][state-processor-api] Refactor 
KeyedStateInputFormat to support multiple types of user functions
faf394d is described below

commit faf394d8a03b658d3a867d306ed8bef3da1ed478
Author: Seth Wiesman <sjwies...@gmail.com>
AuthorDate: Wed Nov 6 10:26:42 2019 -0600

    [FLINK-15014][state-processor-api] Refactor KeyedStateInputFormat to 
support multiple types of user functions
    
    This closes #10382
---
 .../apache/flink/state/api/ExistingSavepoint.java  |   7 +-
 .../state/api/input/KeyedStateInputFormat.java     | 166 +++-----------------
 .../state/api/input/MultiStateKeyIterator.java     |  10 +-
 .../input/operator/KeyedStateReaderOperator.java   | 167 +++++++++++++++++++++
 .../api/input/operator/StateReaderOperator.java    | 132 ++++++++++++++++
 .../MemoryStateBackendReaderKeyedStateITCase.java  |  32 ++++
 .../RocksDBStateBackendReaderKeyedStateITCase.java |  33 ++++
 .../state/api/SavepointReaderKeyedStateITCase.java | 149 ++++--------------
 .../state/api/input/KeyedStateInputFormatTest.java |  19 +--
 .../flink/state/api/utils/SavepointTestBase.java   |  91 +++++++++++
 .../flink/state/api/utils/WaitingSource.java       | 123 +++++++++++++++
 11 files changed, 644 insertions(+), 285 deletions(-)

diff --git 
a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/ExistingSavepoint.java
 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/ExistingSavepoint.java
index f336ce1..6c91660 100644
--- 
a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/ExistingSavepoint.java
+++ 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/ExistingSavepoint.java
@@ -32,11 +32,13 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.runtime.checkpoint.OperatorState;
 import org.apache.flink.runtime.state.StateBackend;
+import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.state.api.functions.KeyedStateReaderFunction;
 import org.apache.flink.state.api.input.BroadcastStateInputFormat;
 import org.apache.flink.state.api.input.KeyedStateInputFormat;
 import org.apache.flink.state.api.input.ListStateInputFormat;
 import org.apache.flink.state.api.input.UnionStateInputFormat;
+import org.apache.flink.state.api.input.operator.KeyedStateReaderOperator;
 import org.apache.flink.state.api.runtime.metadata.SavepointMetadata;
 import org.apache.flink.util.Preconditions;
 
@@ -275,11 +277,10 @@ public class ExistingSavepoint extends 
WritableSavepoint<ExistingSavepoint> {
                TypeInformation<OUT> outTypeInfo) throws IOException {
 
                OperatorState operatorState = metadata.getOperatorState(uid);
-               KeyedStateInputFormat<K, OUT> inputFormat = new 
KeyedStateInputFormat<>(
+               KeyedStateInputFormat<K, VoidNamespace, OUT> inputFormat = new 
KeyedStateInputFormat<>(
                        operatorState,
                        stateBackend,
-                       keyTypeInfo,
-                       function);
+                       new KeyedStateReaderOperator<>(function, keyTypeInfo));
 
                return env.createInput(inputFormat, outTypeInfo);
        }
diff --git 
a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/KeyedStateInputFormat.java
 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/KeyedStateInputFormat.java
index 454b74c..daf6278 100644
--- 
a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/KeyedStateInputFormat.java
+++ 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/KeyedStateInputFormat.java
@@ -19,17 +19,10 @@
 package org.apache.flink.state.api.input;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.api.common.functions.util.FunctionUtils;
 import org.apache.flink.api.common.io.DefaultInputSplitAssigner;
 import org.apache.flink.api.common.io.RichInputFormat;
 import org.apache.flink.api.common.io.statistics.BaseStatistics;
-import org.apache.flink.api.common.state.ListState;
-import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.common.typeinfo.Types;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.core.io.InputSplitAssigner;
@@ -41,34 +34,25 @@ import 
org.apache.flink.runtime.state.DefaultKeyedStateStore;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.StateBackend;
-import org.apache.flink.runtime.state.VoidNamespace;
-import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.state.api.functions.KeyedStateReaderFunction;
+import org.apache.flink.state.api.input.operator.StateReaderOperator;
 import org.apache.flink.state.api.input.splits.KeyGroupRangeInputSplit;
 import org.apache.flink.state.api.runtime.NeverFireProcessingTimeService;
 import org.apache.flink.state.api.runtime.SavepointEnvironment;
 import org.apache.flink.state.api.runtime.SavepointRuntimeContext;
-import org.apache.flink.state.api.runtime.VoidTriggerable;
 import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
-import org.apache.flink.streaming.api.operators.InternalTimerService;
-import org.apache.flink.streaming.api.operators.KeyContext;
 import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
 import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
 import org.apache.flink.streaming.api.operators.StreamTaskStateInitializerImpl;
-import org.apache.flink.streaming.api.operators.TimerSerializer;
 import org.apache.flink.util.CollectionUtil;
 import org.apache.flink.util.Preconditions;
 
 import javax.annotation.Nonnull;
 
 import java.io.IOException;
-import java.util.Collections;
 import java.util.Comparator;
 import java.util.Iterator;
 import java.util.List;
-import java.util.Set;
-import java.util.stream.Collectors;
-import java.util.stream.StreamSupport;
 
 /**
  * Input format for reading partitioned state.
@@ -77,54 +61,39 @@ import java.util.stream.StreamSupport;
  * @param <OUT> The type of the output of the {@link KeyedStateReaderFunction}.
  */
 @Internal
-public class KeyedStateInputFormat<K, OUT> extends RichInputFormat<OUT, 
KeyGroupRangeInputSplit> implements KeyContext {
+public class KeyedStateInputFormat<K, N, OUT> extends RichInputFormat<OUT, 
KeyGroupRangeInputSplit> {
 
        private static final long serialVersionUID = 8230460226049597182L;
 
-       private static final String USER_TIMERS_NAME = "user-timers";
-
        private final OperatorState operatorState;
 
        private final StateBackend stateBackend;
 
-       private final TypeInformation<K> keyType;
-
-       private final KeyedStateReaderFunction<K, OUT> userFunction;
-
-       private transient TypeSerializer<K> keySerializer;
+       private final StateReaderOperator<?, K, N, OUT> operator;
 
        private transient CloseableRegistry registry;
 
        private transient BufferingCollector<OUT> out;
 
-       private transient Iterator<K> keys;
-
-       private transient AbstractKeyedStateBackend<K> keyedStateBackend;
-
-       private transient Context ctx;
+       private transient Iterator<Tuple2<K, N>> keysAndNamespaces;
 
        /**
         * Creates an input format for reading partitioned state from an 
operator in a savepoint.
         *
         * @param operatorState The state to be queried.
         * @param stateBackend  The state backed used to snapshot the operator.
-        * @param keyType       The type information describing the key type.
-        * @param userFunction  The {@link KeyedStateReaderFunction} called for 
each key in the operator.
         */
        public KeyedStateInputFormat(
                OperatorState operatorState,
                StateBackend stateBackend,
-               TypeInformation<K> keyType,
-               KeyedStateReaderFunction<K, OUT> userFunction) {
+               StateReaderOperator<?, K, N, OUT> operator) {
                Preconditions.checkNotNull(operatorState, "The operator state 
cannot be null");
                Preconditions.checkNotNull(stateBackend, "The state backend 
cannot be null");
-               Preconditions.checkNotNull(keyType, "The key type information 
cannot be null");
-               Preconditions.checkNotNull(userFunction, "The userfunction 
cannot be null");
+               Preconditions.checkNotNull(operator, "The operator cannot be 
null");
 
                this.operatorState = operatorState;
                this.stateBackend = stateBackend;
-               this.keyType = keyType;
-               this.userFunction = userFunction;
+               this.operator = operator;
        }
 
        @Override
@@ -160,7 +129,6 @@ public class KeyedStateInputFormat<K, OUT> extends 
RichInputFormat<OUT, KeyGroup
        @Override
        public void openInputFormat() {
                out = new BufferingCollector<>();
-               keySerializer = 
keyType.createSerializer(getRuntimeContext().getExecutionConfig());
        }
 
        @Override
@@ -176,43 +144,21 @@ public class KeyedStateInputFormat<K, OUT> extends 
RichInputFormat<OUT, KeyGroup
 
                final StreamOperatorStateContext context = 
getStreamOperatorStateContext(environment);
 
-               keyedStateBackend = (AbstractKeyedStateBackend<K>) 
context.keyedStateBackend();
+               AbstractKeyedStateBackend<K> keyedStateBackend = 
(AbstractKeyedStateBackend<K>) context.keyedStateBackend();
 
                final DefaultKeyedStateStore keyedStateStore = new 
DefaultKeyedStateStore(keyedStateBackend, 
getRuntimeContext().getExecutionConfig());
                SavepointRuntimeContext ctx = new 
SavepointRuntimeContext(getRuntimeContext(), keyedStateStore);
-               FunctionUtils.setFunctionRuntimeContext(userFunction, ctx);
-
-               keys = getKeyIterator(ctx);
 
-               final InternalTimerService<VoidNamespace> timerService = 
restoreTimerService(context);
+               InternalTimeServiceManager<K> timeServiceManager = 
(InternalTimeServiceManager<K>) context.internalTimerServiceManager();
                try {
-                       this.ctx = new Context(keyedStateBackend, timerService);
+                       
operator.setup(getRuntimeContext().getExecutionConfig(), keyedStateBackend, 
timeServiceManager, ctx);
+                       operator.open();
+                       keysAndNamespaces = operator.getKeysAndNamespaces(ctx);
                } catch (Exception e) {
                        throw new IOException("Failed to restore timer state", 
e);
                }
        }
 
-       @SuppressWarnings("unchecked")
-       private InternalTimerService<VoidNamespace> 
restoreTimerService(StreamOperatorStateContext context) {
-               InternalTimeServiceManager<K> timeServiceManager = 
(InternalTimeServiceManager<K>) context.internalTimerServiceManager();
-               TimerSerializer<K, VoidNamespace> timerSerializer = new 
TimerSerializer<>(keySerializer, VoidNamespaceSerializer.INSTANCE);
-               return 
timeServiceManager.getInternalTimerService(USER_TIMERS_NAME, timerSerializer, 
VoidTriggerable.instance());
-       }
-
-       @SuppressWarnings("unchecked")
-       private Iterator<K> getKeyIterator(SavepointRuntimeContext ctx) throws 
IOException {
-               final List<StateDescriptor<?, ?>> stateDescriptors;
-               try  {
-                       FunctionUtils.openFunction(userFunction, new 
Configuration());
-                       ctx.disableStateRegistration();
-                       stateDescriptors = ctx.getStateDescriptors();
-               } catch (Exception e) {
-                       throw new IOException("Failed to open user defined 
function", e);
-               }
-
-               return new MultiStateKeyIterator<>(stateDescriptors, 
keyedStateBackend);
-       }
-
        private StreamOperatorStateContext 
getStreamOperatorStateContext(Environment environment) throws IOException {
                StreamTaskStateInitializer initializer = new 
StreamTaskStateInitializerImpl(
                        environment,
@@ -223,8 +169,8 @@ public class KeyedStateInputFormat<K, OUT> extends 
RichInputFormat<OUT, KeyGroup
                                operatorState.getOperatorID(),
                                operatorState.getOperatorID().toString(),
                                new NeverFireProcessingTimeService(),
-                               this,
-                               keySerializer,
+                               operator,
+                               
operator.getKeyType().createSerializer(environment.getExecutionConfig()),
                                registry,
                                getRuntimeContext().getMetricGroup());
                } catch (Exception e) {
@@ -239,7 +185,7 @@ public class KeyedStateInputFormat<K, OUT> extends 
RichInputFormat<OUT, KeyGroup
 
        @Override
        public boolean reachedEnd() {
-               return !out.hasNext() && !keys.hasNext();
+               return !out.hasNext() && !keysAndNamespaces.hasNext();
        }
 
        @Override
@@ -248,31 +194,20 @@ public class KeyedStateInputFormat<K, OUT> extends 
RichInputFormat<OUT, KeyGroup
                        return out.next();
                }
 
-               final K key = keys.next();
-               setCurrentKey(key);
+               final Tuple2<K, N> keyAndNamespace = keysAndNamespaces.next();
+               operator.setCurrentKey(keyAndNamespace.f0);
 
                try {
-                       userFunction.readKey(key, ctx, out);
+                       operator.processElement(keyAndNamespace.f0, 
keyAndNamespace.f1, out);
                } catch (Exception e) {
                        throw new IOException("User defined function 
KeyedStateReaderFunction#readKey threw an exception", e);
                }
 
-               keys.remove();
+               keysAndNamespaces.remove();
 
                return out.next();
        }
 
-       @Override
-       @SuppressWarnings("unchecked")
-       public void setCurrentKey(Object key) {
-               keyedStateBackend.setCurrentKey((K) key);
-       }
-
-       @Override
-       public Object getCurrentKey() {
-               return keyedStateBackend.getCurrentKey();
-       }
-
        private static KeyGroupRangeInputSplit createKeyGroupRangeInputSplit(
                OperatorState operatorState,
                int maxParallelism,
@@ -294,65 +229,4 @@ public class KeyedStateInputFormat<K, OUT> extends 
RichInputFormat<OUT, KeyGroup
                
keyGroups.sort(Comparator.comparing(KeyGroupRange::getStartKeyGroup));
                return keyGroups;
        }
-
-       private static class Context<K> implements 
KeyedStateReaderFunction.Context {
-
-               private static final String EVENT_TIMER_STATE = 
"event-time-timers";
-
-               private static final String PROC_TIMER_STATE = 
"proc-time-timers";
-
-               ListState<Long> eventTimers;
-
-               ListState<Long> procTimers;
-
-               private Context(AbstractKeyedStateBackend<K> keyedStateBackend, 
InternalTimerService<VoidNamespace> timerService) throws Exception {
-                       eventTimers = keyedStateBackend.getPartitionedState(
-                               USER_TIMERS_NAME,
-                               StringSerializer.INSTANCE,
-                               new ListStateDescriptor<>(EVENT_TIMER_STATE, 
Types.LONG)
-                       );
-
-                       timerService.forEachEventTimeTimer((namespace, timer) 
-> {
-                               if (namespace.equals(VoidNamespace.INSTANCE)) {
-                                       eventTimers.add(timer);
-                               }
-                       });
-
-                       procTimers = keyedStateBackend.getPartitionedState(
-                               USER_TIMERS_NAME,
-                               StringSerializer.INSTANCE,
-                               new ListStateDescriptor<>(PROC_TIMER_STATE, 
Types.LONG)
-                       );
-
-                       timerService.forEachProcessingTimeTimer((namespace, 
timer) -> {
-                               if (namespace.equals(VoidNamespace.INSTANCE)) {
-                                       procTimers.add(timer);
-                               }
-                       });
-               }
-
-               @Override
-               public Set<Long> registeredEventTimeTimers() throws Exception {
-                       Iterable<Long> timers = eventTimers.get();
-                       if (timers == null) {
-                               return Collections.emptySet();
-                       }
-
-                       return StreamSupport
-                               .stream(timers.spliterator(), false)
-                               .collect(Collectors.toSet());
-               }
-
-               @Override
-               public Set<Long> registeredProcessingTimeTimers() throws 
Exception {
-                       Iterable<Long> timers = procTimers.get();
-                       if (timers == null) {
-                               return Collections.emptySet();
-                       }
-
-                       return StreamSupport
-                               .stream(timers.spliterator(), false)
-                               .collect(Collectors.toSet());
-               }
-       }
 }
diff --git 
a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/MultiStateKeyIterator.java
 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/MultiStateKeyIterator.java
index bed0032..4af28f4 100644
--- 
a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/MultiStateKeyIterator.java
+++ 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/MultiStateKeyIterator.java
@@ -18,9 +18,10 @@
 
 package org.apache.flink.state.api.input;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.util.Preconditions;
@@ -35,16 +36,17 @@ import java.util.List;
  *
  * @param <K> Type of the key by which state is keyed.
  */
-final class MultiStateKeyIterator<K> implements Iterator<K> {
+@Internal
+public final class MultiStateKeyIterator<K> implements Iterator<K> {
        private final List<? extends StateDescriptor<?, ?>> descriptors;
 
-       private final AbstractKeyedStateBackend<K> backend;
+       private final KeyedStateBackend<K> backend;
 
        private final Iterator<K> internal;
 
        private K currentKey;
 
-       MultiStateKeyIterator(List<? extends StateDescriptor<?, ?>> 
descriptors, AbstractKeyedStateBackend<K> backend) {
+       public MultiStateKeyIterator(List<? extends StateDescriptor<?, ?>> 
descriptors, KeyedStateBackend<K> backend) {
                this.descriptors = Preconditions.checkNotNull(descriptors);
 
                this.backend = Preconditions.checkNotNull(backend);
diff --git 
a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/KeyedStateReaderOperator.java
 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/KeyedStateReaderOperator.java
new file mode 100644
index 0000000..1c99a5b
--- /dev/null
+++ 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/KeyedStateReaderOperator.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.api.input.operator;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.StateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.state.api.functions.KeyedStateReaderFunction;
+import org.apache.flink.state.api.input.MultiStateKeyIterator;
+import org.apache.flink.state.api.runtime.SavepointRuntimeContext;
+import org.apache.flink.streaming.api.operators.InternalTimerService;
+import org.apache.flink.util.Collector;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+
+/**
+ * A {@link StateReaderOperator} for executing a {@link 
KeyedStateReaderFunction}.
+ *
+ * @param <KEY> The key type read from the state backend.
+ * @param <OUT> The output type of the function.
+ */
+@Internal
+public class KeyedStateReaderOperator<KEY, OUT>
+       extends StateReaderOperator<KeyedStateReaderFunction<KEY, OUT>, KEY, 
VoidNamespace, OUT> {
+
+       private static final String USER_TIMERS_NAME = "user-timers";
+
+       private transient Context<KEY> context;
+
+       public KeyedStateReaderOperator(KeyedStateReaderFunction<KEY, OUT> 
function, TypeInformation<KEY> keyType) {
+               super(function, keyType, VoidNamespaceSerializer.INSTANCE);
+       }
+
+       @Override
+       public void open() throws Exception {
+               super.open();
+
+               InternalTimerService<VoidNamespace> timerService = 
getInternalTimerService(USER_TIMERS_NAME);
+               context = new Context<>(getKeyedStateBackend(), timerService);
+       }
+
+       @Override
+       public void processElement(KEY key, VoidNamespace namespace, 
Collector<OUT> out) throws Exception {
+               function.readKey(key, context, out);
+       }
+
+       @Override
+       public Iterator<Tuple2<KEY, VoidNamespace>> 
getKeysAndNamespaces(SavepointRuntimeContext ctx) throws Exception {
+               ctx.disableStateRegistration();
+               List<StateDescriptor<?, ?>> stateDescriptors = 
ctx.getStateDescriptors();
+               Iterator<KEY> keys = new 
MultiStateKeyIterator<>(stateDescriptors, getKeyedStateBackend());
+               return new NamespaceDecorator<>(keys);
+       }
+
+       private static class Context<K> implements 
KeyedStateReaderFunction.Context {
+
+               private static final String EVENT_TIMER_STATE = 
"event-time-timers";
+
+               private static final String PROC_TIMER_STATE = 
"proc-time-timers";
+
+               ListState<Long> eventTimers;
+
+               ListState<Long> procTimers;
+
+               private Context(KeyedStateBackend<K> keyedStateBackend, 
InternalTimerService<VoidNamespace> timerService) throws Exception {
+                       eventTimers = keyedStateBackend.getPartitionedState(
+                               USER_TIMERS_NAME,
+                               StringSerializer.INSTANCE,
+                               new ListStateDescriptor<>(EVENT_TIMER_STATE, 
Types.LONG));
+
+                       timerService.forEachEventTimeTimer((namespace, timer) 
-> {
+                               if (namespace.equals(VoidNamespace.INSTANCE)) {
+                                       eventTimers.add(timer);
+                               }
+                       });
+
+                       procTimers = keyedStateBackend.getPartitionedState(
+                               USER_TIMERS_NAME,
+                               StringSerializer.INSTANCE,
+                               new ListStateDescriptor<>(PROC_TIMER_STATE, 
Types.LONG));
+
+                       timerService.forEachProcessingTimeTimer((namespace, 
timer) -> {
+                               if (namespace.equals(VoidNamespace.INSTANCE)) {
+                                       procTimers.add(timer);
+                               }
+                       });
+               }
+
+               @Override
+               public Set<Long> registeredEventTimeTimers() throws Exception {
+                       Iterable<Long> timers = eventTimers.get();
+                       if (timers == null) {
+                               return Collections.emptySet();
+                       }
+
+                       return StreamSupport
+                               .stream(timers.spliterator(), false)
+                               .collect(Collectors.toSet());
+               }
+
+               @Override
+               public Set<Long> registeredProcessingTimeTimers() throws 
Exception {
+                       Iterable<Long> timers = procTimers.get();
+                       if (timers == null) {
+                               return Collections.emptySet();
+                       }
+
+                       return StreamSupport
+                               .stream(timers.spliterator(), false)
+                               .collect(Collectors.toSet());
+               }
+       }
+
+       private static class NamespaceDecorator<KEY> implements 
Iterator<Tuple2<KEY, VoidNamespace>> {
+
+               private final Iterator<KEY> keys;
+
+               private NamespaceDecorator(Iterator<KEY> keys) {
+                       this.keys = keys;
+               }
+
+               @Override
+               public boolean hasNext() {
+                       return keys.hasNext();
+               }
+
+               @Override
+               public Tuple2<KEY, VoidNamespace> next() {
+                       KEY key = keys.next();
+                       return Tuple2.of(key, VoidNamespace.INSTANCE);
+               }
+
+               @Override
+               public void remove() {
+                       keys.remove();
+               }
+       }
+}
diff --git 
a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/StateReaderOperator.java
 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/StateReaderOperator.java
new file mode 100644
index 0000000..eabc2f8
--- /dev/null
+++ 
b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/StateReaderOperator.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.api.input.operator;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.functions.Function;
+import org.apache.flink.api.common.functions.util.FunctionUtils;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.state.api.runtime.SavepointRuntimeContext;
+import org.apache.flink.state.api.runtime.VoidTriggerable;
+import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
+import org.apache.flink.streaming.api.operators.InternalTimerService;
+import org.apache.flink.streaming.api.operators.KeyContext;
+import org.apache.flink.streaming.api.operators.TimerSerializer;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * Base class for executing functions that read keyed state.
+ *
+ * @param <F> The type of the user function.
+ * @param <KEY> The key type.
+ * @param <N> The namespace type.
+ * @param <OUT> The output type.
+ */
+@Internal
+public abstract class StateReaderOperator<F extends Function, KEY, N, OUT> 
implements KeyContext, Serializable {
+
+       private static final long serialVersionUID = 1L;
+
+       protected final F function;
+
+       private final TypeInformation<KEY> keyType;
+
+       protected final TypeSerializer<N> namespaceSerializer;
+
+       private transient ExecutionConfig executionConfig;
+
+       private transient KeyedStateBackend<KEY> keyedStateBackend;
+
+       private transient TypeSerializer<KEY> keySerializer;
+
+       private transient InternalTimeServiceManager<KEY> timerServiceManager;
+
+       protected StateReaderOperator(F function, TypeInformation<KEY> keyType, 
TypeSerializer<N> namespaceSerializer) {
+               Preconditions.checkNotNull(function, "The user function must 
not be null");
+               Preconditions.checkNotNull(keyType, "The key type must not be 
null");
+               Preconditions.checkNotNull(namespaceSerializer, "The namespace 
serializer must not be null");
+
+               this.function = function;
+               this.keyType = keyType;
+               this.namespaceSerializer = namespaceSerializer;
+       }
+
+       public abstract void processElement(KEY key, N namespace, 
Collector<OUT> out) throws Exception;
+
+       public abstract Iterator<Tuple2<KEY, N>> 
getKeysAndNamespaces(SavepointRuntimeContext ctx) throws Exception;
+
+       public final void setup(
+               ExecutionConfig executionConfig,
+               KeyedStateBackend<KEY> keyKeyedStateBackend,
+               InternalTimeServiceManager<KEY> timerServiceManager,
+               SavepointRuntimeContext ctx) {
+
+               this.executionConfig = executionConfig;
+               this.keyedStateBackend = keyKeyedStateBackend;
+               this.timerServiceManager = timerServiceManager;
+               this.keySerializer = keyType.createSerializer(executionConfig);
+
+               FunctionUtils.setFunctionRuntimeContext(function, ctx);
+       }
+
+       protected final InternalTimerService<N> getInternalTimerService(String 
name) {
+               TimerSerializer<KEY, N> timerSerializer = new 
TimerSerializer<>(keySerializer, namespaceSerializer);
+               return timerServiceManager.getInternalTimerService(name, 
timerSerializer, VoidTriggerable.instance());
+       }
+
+       public void open() throws Exception {
+               FunctionUtils.openFunction(function, new Configuration());
+       }
+
+       public void close() throws Exception {
+               FunctionUtils.closeFunction(function);
+       }
+
+       @Override
+       @SuppressWarnings("unchecked")
+       public final void setCurrentKey(Object key) {
+               keyedStateBackend.setCurrentKey((KEY) key);
+       }
+
+       @Override
+       public final Object getCurrentKey() {
+               return keyedStateBackend.getCurrentKey();
+       }
+
+       public final KeyedStateBackend<KEY> getKeyedStateBackend() {
+               return keyedStateBackend;
+       }
+
+       public final TypeInformation<KEY> getKeyType() {
+               return keyType;
+       }
+
+       public final ExecutionConfig getExecutionConfig() {
+               return this.executionConfig;
+       }
+}
diff --git 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/MemoryStateBackendReaderKeyedStateITCase.java
 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/MemoryStateBackendReaderKeyedStateITCase.java
new file mode 100644
index 0000000..61d446a
--- /dev/null
+++ 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/MemoryStateBackendReaderKeyedStateITCase.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.api;
+
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+
+/**
+ * IT Case for reading keyed state from a memory state backend.
+ */
+public class MemoryStateBackendReaderKeyedStateITCase extends 
SavepointReaderKeyedStateITCase<MemoryStateBackend> {
+
+       @Override
+       protected MemoryStateBackend getStateBackend() {
+               return new MemoryStateBackend();
+       }
+}
diff --git 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/RocksDBStateBackendReaderKeyedStateITCase.java
 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/RocksDBStateBackendReaderKeyedStateITCase.java
new file mode 100644
index 0000000..df9fb85
--- /dev/null
+++ 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/RocksDBStateBackendReaderKeyedStateITCase.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.api;
+
+import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
+import org.apache.flink.runtime.state.StateBackend;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+
+/**
+ * IT Case for reading state from a RocksDB keyed state backend.
+ */
+public class RocksDBStateBackendReaderKeyedStateITCase extends 
SavepointReaderKeyedStateITCase<RocksDBStateBackend> {
+       @Override
+       protected RocksDBStateBackend getStateBackend() {
+               return new RocksDBStateBackend((StateBackend) new 
MemoryStateBackend());
+       }
+}
diff --git 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/SavepointReaderKeyedStateITCase.java
 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/SavepointReaderKeyedStateITCase.java
index 1ce2376..06e864f 100644
--- 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/SavepointReaderKeyedStateITCase.java
+++ 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/SavepointReaderKeyedStateITCase.java
@@ -18,171 +18,74 @@
 
 package org.apache.flink.state.api;
 
-import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.JobSubmissionResult;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.api.common.time.Deadline;
 import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.api.java.ExecutionEnvironment;
-import org.apache.flink.client.ClientUtils;
-import org.apache.flink.client.program.ClusterClient;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
-import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.state.StateBackend;
-import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.state.api.functions.KeyedStateReaderFunction;
+import org.apache.flink.state.api.utils.SavepointTestBase;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
-import org.apache.flink.streaming.api.functions.source.SourceFunction;
-import org.apache.flink.test.util.AbstractTestBase;
-import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Collector;
 
 import org.junit.Assert;
 import org.junit.Test;
 
-import java.time.Duration;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
 import java.util.Set;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.TimeUnit;
 
 /**
  * IT case for reading state.
  */
-public class SavepointReaderKeyedStateITCase extends AbstractTestBase {
+public abstract class SavepointReaderKeyedStateITCase<B extends StateBackend> 
extends SavepointTestBase {
        private static final String uid = "stateful-operator";
 
        private static ValueStateDescriptor<Integer> valueState = new 
ValueStateDescriptor<>("value", Types.INT);
 
-       @Test
-       public void testKeyedInputFormat() throws Exception {
-               runKeyedState(new MemoryStateBackend());
-               // Reset the cluster so we can change the
-               // state backend in the StreamEnvironment.
-               // If we don't do this the tests will fail.
-               miniClusterResource.after();
-               miniClusterResource.before();
-               runKeyedState(new RocksDBStateBackend((StateBackend) new 
MemoryStateBackend()));
-       }
-
-       private void runKeyedState(StateBackend backend) throws Exception {
-               StreamExecutionEnvironment streamEnv = 
StreamExecutionEnvironment.getExecutionEnvironment();
-               streamEnv.setStateBackend(backend);
-               streamEnv.setParallelism(4);
+       private static final List<Pojo> elements = Arrays.asList(
+               Pojo.of(1, 1),
+               Pojo.of(2, 2),
+               Pojo.of(3, 3));
 
-               streamEnv
-                       .addSource(new SavepointSource())
-                       .rebalance()
-                       .keyBy(id -> id.key)
-                       .process(new KeyedStatefulOperator())
-                       .uid(uid)
-                       .addSink(new DiscardingSink<>());
+       protected abstract B getStateBackend();
 
-               JobGraph jobGraph = streamEnv.getStreamGraph().getJobGraph();
-
-               String path = takeSavepoint(jobGraph);
+       @Test
+       public void testUserKeyedStateReader() throws Exception {
+               String savepointPath = takeSavepoint(elements, source -> {
+                       StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+                       env.setStateBackend(getStateBackend());
+                       env.setParallelism(4);
+
+                       env
+                               .addSource(source)
+                               .rebalance()
+                               .keyBy(id -> id.key)
+                               .process(new KeyedStatefulOperator())
+                               .uid(uid)
+                               .addSink(new DiscardingSink<>());
+
+                       return env;
+               });
 
                ExecutionEnvironment batchEnv = 
ExecutionEnvironment.getExecutionEnvironment();
-               ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, 
backend);
+               ExistingSavepoint savepoint = Savepoint.load(batchEnv, 
savepointPath, getStateBackend());
 
                List<Pojo> results = savepoint
                        .readKeyedState(uid, new Reader())
                        .collect();
 
-               Set<Pojo> expected = SavepointSource.getElements();
+               Set<Pojo> expected = new HashSet<>(elements);
 
                Assert.assertEquals("Unexpected results from keyed state", 
expected, new HashSet<>(results));
        }
 
-       private String takeSavepoint(JobGraph jobGraph) throws Exception {
-               SavepointSource.initializeForTest();
-
-               ClusterClient<?> client = 
miniClusterResource.getClusterClient();
-
-               JobID jobId = jobGraph.getJobID();
-
-               Deadline deadline = Deadline.fromNow(Duration.ofMinutes(5));
-
-               String dirPath = getTempDirPath(new AbstractID().toHexString());
-
-               try {
-                       JobSubmissionResult result = 
ClientUtils.submitJob(client, jobGraph);
-
-                       boolean finished = false;
-                       while (deadline.hasTimeLeft()) {
-                               if (SavepointSource.isFinished()) {
-                                       finished = true;
-
-                                       break;
-                               }
-                       }
-
-                       if (!finished) {
-                               Assert.fail("Failed to initialize state within 
deadline");
-                       }
-
-                       CompletableFuture<String> path = 
client.triggerSavepoint(result.getJobID(), dirPath);
-                       return path.get(deadline.timeLeft().toMillis(), 
TimeUnit.MILLISECONDS);
-               } finally {
-                       client.cancel(jobId).get();
-               }
-       }
-
-       private static class SavepointSource implements SourceFunction<Pojo> {
-               private static volatile boolean finished;
-
-               private volatile boolean running = true;
-
-               private static final Pojo[] elements = {
-                       Pojo.of(1, 1),
-                       Pojo.of(2, 2),
-                       Pojo.of(3, 3)};
-
-               @Override
-               public void run(SourceContext<Pojo> ctx) {
-                       synchronized (ctx.getCheckpointLock()) {
-                               for (Pojo element : elements) {
-                                       ctx.collect(element);
-                               }
-
-                               finished = true;
-                       }
-
-                       while (running) {
-                               try {
-                                       Thread.sleep(100);
-                               } catch (InterruptedException e) {
-                                       // ignore
-                               }
-                       }
-               }
-
-               @Override
-               public void cancel() {
-                       running = false;
-               }
-
-               private static void initializeForTest() {
-                       finished = false;
-               }
-
-               private static boolean isFinished() {
-                       return finished;
-               }
-
-               private static Set<Pojo> getElements() {
-                       return new HashSet<>(Arrays.asList(elements));
-               }
-       }
-
        private static class KeyedStatefulOperator extends 
KeyedProcessFunction<Integer, Pojo, Void> {
 
                private transient ValueState<Integer> state;
diff --git 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java
 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java
index e036bc8..d0b55b6 100644
--- 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java
+++ 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java
@@ -27,8 +27,10 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.checkpoint.OperatorState;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.state.api.functions.KeyedStateReaderFunction;
+import org.apache.flink.state.api.input.operator.KeyedStateReaderOperator;
 import org.apache.flink.state.api.input.splits.KeyGroupRangeInputSplit;
 import org.apache.flink.state.api.runtime.OperatorIDGenerator;
 import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
@@ -65,7 +67,7 @@ public class KeyedStateInputFormatTest {
                OperatorState operatorState = new OperatorState(operatorID, 1, 
128);
                operatorState.putState(0, state);
 
-               KeyedStateInputFormat<?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new 
ReaderFunction());
+               KeyedStateInputFormat<?, ?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new 
KeyedStateReaderOperator<>(new ReaderFunction(), Types.INT));
                KeyGroupRangeInputSplit[] splits = format.createInputSplits(4);
                Assert.assertEquals("Failed to properly partition operator 
state into input splits", 4, splits.length);
        }
@@ -78,7 +80,7 @@ public class KeyedStateInputFormatTest {
                OperatorState operatorState = new OperatorState(operatorID, 1, 
128);
                operatorState.putState(0, state);
 
-               KeyedStateInputFormat<?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new 
ReaderFunction());
+               KeyedStateInputFormat<?, ?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new 
KeyedStateReaderOperator<>(new ReaderFunction(), Types.INT));
                KeyGroupRangeInputSplit[] splits = 
format.createInputSplits(129);
                Assert.assertEquals("Failed to properly partition operator 
state into input splits", 128, splits.length);
        }
@@ -91,7 +93,7 @@ public class KeyedStateInputFormatTest {
                OperatorState operatorState = new OperatorState(operatorID, 1, 
128);
                operatorState.putState(0, state);
 
-               KeyedStateInputFormat<?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new 
ReaderFunction());
+               KeyedStateInputFormat<?, ?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new 
KeyedStateReaderOperator<>(new ReaderFunction(), Types.INT));
                KeyGroupRangeInputSplit split = format.createInputSplits(1)[0];
 
                KeyedStateReaderFunction<Integer, Integer> userFunction = new 
ReaderFunction();
@@ -109,7 +111,7 @@ public class KeyedStateInputFormatTest {
                OperatorState operatorState = new OperatorState(operatorID, 1, 
128);
                operatorState.putState(0, state);
 
-               KeyedStateInputFormat<?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new 
ReaderFunction());
+               KeyedStateInputFormat<?, ?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new 
KeyedStateReaderOperator<>(new ReaderFunction(), Types.INT));
                KeyGroupRangeInputSplit split = format.createInputSplits(1)[0];
 
                KeyedStateReaderFunction<Integer, Integer> userFunction = new 
DoubleReaderFunction();
@@ -127,7 +129,7 @@ public class KeyedStateInputFormatTest {
                OperatorState operatorState = new OperatorState(operatorID, 1, 
128);
                operatorState.putState(0, state);
 
-               KeyedStateInputFormat<?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new 
ReaderFunction());
+               KeyedStateInputFormat<?, ?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new 
KeyedStateReaderOperator<>(new ReaderFunction(), Types.INT));
                KeyGroupRangeInputSplit split = format.createInputSplits(1)[0];
 
                KeyedStateReaderFunction<Integer, Integer> userFunction = new 
InvalidReaderFunction();
@@ -145,7 +147,7 @@ public class KeyedStateInputFormatTest {
                OperatorState operatorState = new OperatorState(operatorID, 1, 
128);
                operatorState.putState(0, state);
 
-               KeyedStateInputFormat<?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new 
TimerReaderFunction());
+               KeyedStateInputFormat<?, ?, ?> format = new 
KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new 
KeyedStateReaderOperator<>(new TimerReaderFunction(), Types.INT));
                KeyGroupRangeInputSplit split = format.createInputSplits(1)[0];
 
                KeyedStateReaderFunction<Integer, Integer> userFunction = new 
TimerReaderFunction();
@@ -157,11 +159,10 @@ public class KeyedStateInputFormatTest {
 
        @Nonnull
        private List<Integer> readInputSplit(KeyGroupRangeInputSplit split, 
KeyedStateReaderFunction<Integer, Integer> userFunction) throws IOException {
-               KeyedStateInputFormat<Integer, Integer> format = new 
KeyedStateInputFormat<>(
+               KeyedStateInputFormat<Integer, VoidNamespace, Integer> format = 
new KeyedStateInputFormat<>(
                        new OperatorState(OperatorIDGenerator.fromUid("uid"), 
1, 4),
                        new MemoryStateBackend(),
-                       Types.INT,
-                       userFunction);
+                       new KeyedStateReaderOperator<>(userFunction, 
Types.INT));
 
                List<Integer> data = new ArrayList<>();
 
diff --git 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/SavepointTestBase.java
 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/SavepointTestBase.java
new file mode 100644
index 0000000..0d25682
--- /dev/null
+++ 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/SavepointTestBase.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.api.utils;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.JobSubmissionResult;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.client.ClientUtils;
+import org.apache.flink.client.program.ClusterClient;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.FromElementsFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.util.AbstractID;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+
+/**
+ * A test base that includes utilities for taking a savepoint.
+ */
+public abstract class SavepointTestBase extends AbstractTestBase {
+
+       public <T> String takeSavepoint(Collection<T> data, 
Function<SourceFunction<T>, StreamExecutionEnvironment> jobGraphFactory) throws 
Exception {
+
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               env.getConfig().disableClosureCleaner();
+
+               WaitingSource<T> waitingSource = createSource(data);
+
+               JobGraph jobGraph = 
jobGraphFactory.apply(waitingSource).getStreamGraph().getJobGraph();
+               JobID jobId = jobGraph.getJobID();
+
+               ClusterClient<?> client = 
miniClusterResource.getClusterClient();
+
+               try {
+                       JobSubmissionResult result = 
ClientUtils.submitJob(client, jobGraph);
+
+                       return CompletableFuture
+                               .runAsync(waitingSource::awaitSource)
+                               .thenCompose(ignore -> triggerSavepoint(client, 
result.getJobID()))
+                               .get(5, TimeUnit.MINUTES);
+               } catch (Exception e) {
+                       throw new RuntimeException("Failed to take savepoint", 
e);
+               } finally {
+                       client.cancel(jobId);
+               }
+       }
+
+       private <T> WaitingSource<T> createSource(Collection<T> data) throws 
Exception {
+               T first = data.iterator().next();
+               if (first == null) {
+                       throw new IllegalArgumentException("Collection must not 
contain null elements");
+               }
+
+               TypeInformation<T> typeInfo = TypeExtractor.getForObject(first);
+               SourceFunction<T> inner = new 
FromElementsFunction<>(typeInfo.createSerializer(new ExecutionConfig()), data);
+               return new WaitingSource<>(inner, typeInfo);
+       }
+
+       private CompletableFuture<String> triggerSavepoint(ClusterClient<?> 
client, JobID jobID) throws RuntimeException {
+               try {
+                       String dirPath = getTempDirPath(new 
AbstractID().toHexString());
+                       return client.triggerSavepoint(jobID, dirPath);
+               } catch (IOException e) {
+                       throw new RuntimeException(e);
+               }
+       }
+}
diff --git 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/WaitingSource.java
 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/WaitingSource.java
new file mode 100644
index 0000000..768cb33
--- /dev/null
+++ 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/WaitingSource.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.api.utils;
+
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+
+/**
+ * A wrapper class that allows threads to block until the inner source 
completes.
+ * It's run method does not return until explicitly canceled so external 
processes can
+ * perform operations such as taking savepoints.
+ *
+ * @param <T> The output type of the inner source.
+ */
+@SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
+public class WaitingSource<T> extends RichSourceFunction<T> implements 
ResultTypeQueryable<T> {
+
+       private static final Map<String, OneShotLatch> guards = new HashMap<>();
+
+       private final SourceFunction<T> source;
+
+       private final TypeInformation<T> returnType;
+
+       private final String guardId;
+
+       private volatile boolean running;
+
+       public WaitingSource(SourceFunction<T> source, TypeInformation<T> 
returnType) {
+               this.source = source;
+               this.returnType = returnType;
+               this.guardId = UUID.randomUUID().toString();
+
+               guards.put(guardId, new OneShotLatch());
+               this.running = true;
+       }
+
+       @Override
+       public void setRuntimeContext(RuntimeContext t) {
+               if (source instanceof RichSourceFunction) {
+                       ((RichSourceFunction<T>) source).setRuntimeContext(t);
+               }
+       }
+
+       @Override
+       public void open(Configuration parameters) throws Exception {
+               if (source instanceof RichSourceFunction) {
+                       ((RichSourceFunction<T>) source).open(parameters);
+               }
+       }
+
+       @Override
+       public void close() throws Exception {
+               if (source instanceof RichSourceFunction) {
+                       ((RichSourceFunction<T>) source).close();
+               }
+       }
+
+       @Override
+       public void run(SourceContext<T> ctx) throws Exception {
+               OneShotLatch latch = guards.get(guardId);
+               try {
+                       source.run(ctx);
+               } finally {
+                       latch.trigger();
+               }
+
+               while (running) {
+                       try {
+                               Thread.sleep(100);
+                       } catch (InterruptedException e) {
+                               // ignore
+                       }
+               }
+       }
+
+       @Override
+       public void cancel() {
+               source.cancel();
+               running = false;
+       }
+
+       /**
+        * This method blocks until the inner source has completed.
+        *
+        */
+       public void awaitSource() throws RuntimeException {
+               try {
+                       guards.get(guardId).await();
+               } catch (InterruptedException e) {
+                       throw new RuntimeException("Failed to initialize 
source");
+               }
+       }
+
+       @Override
+       public TypeInformation<T> getProducedType() {
+               return returnType;
+       }
+}

Reply via email to