This is an automated email from the ASF dual-hosted git repository. sewen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new faf394d [FLINK-15014][state-processor-api] Refactor KeyedStateInputFormat to support multiple types of user functions faf394d is described below commit faf394d8a03b658d3a867d306ed8bef3da1ed478 Author: Seth Wiesman <sjwies...@gmail.com> AuthorDate: Wed Nov 6 10:26:42 2019 -0600 [FLINK-15014][state-processor-api] Refactor KeyedStateInputFormat to support multiple types of user functions This closes #10382 --- .../apache/flink/state/api/ExistingSavepoint.java | 7 +- .../state/api/input/KeyedStateInputFormat.java | 166 +++----------------- .../state/api/input/MultiStateKeyIterator.java | 10 +- .../input/operator/KeyedStateReaderOperator.java | 167 +++++++++++++++++++++ .../api/input/operator/StateReaderOperator.java | 132 ++++++++++++++++ .../MemoryStateBackendReaderKeyedStateITCase.java | 32 ++++ .../RocksDBStateBackendReaderKeyedStateITCase.java | 33 ++++ .../state/api/SavepointReaderKeyedStateITCase.java | 149 ++++-------------- .../state/api/input/KeyedStateInputFormatTest.java | 19 +-- .../flink/state/api/utils/SavepointTestBase.java | 91 +++++++++++ .../flink/state/api/utils/WaitingSource.java | 123 +++++++++++++++ 11 files changed, 644 insertions(+), 285 deletions(-) diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/ExistingSavepoint.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/ExistingSavepoint.java index f336ce1..6c91660 100644 --- a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/ExistingSavepoint.java +++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/ExistingSavepoint.java @@ -32,11 +32,13 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.runtime.checkpoint.OperatorState; import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.state.api.functions.KeyedStateReaderFunction; import org.apache.flink.state.api.input.BroadcastStateInputFormat; import org.apache.flink.state.api.input.KeyedStateInputFormat; import org.apache.flink.state.api.input.ListStateInputFormat; import org.apache.flink.state.api.input.UnionStateInputFormat; +import org.apache.flink.state.api.input.operator.KeyedStateReaderOperator; import org.apache.flink.state.api.runtime.metadata.SavepointMetadata; import org.apache.flink.util.Preconditions; @@ -275,11 +277,10 @@ public class ExistingSavepoint extends WritableSavepoint<ExistingSavepoint> { TypeInformation<OUT> outTypeInfo) throws IOException { OperatorState operatorState = metadata.getOperatorState(uid); - KeyedStateInputFormat<K, OUT> inputFormat = new KeyedStateInputFormat<>( + KeyedStateInputFormat<K, VoidNamespace, OUT> inputFormat = new KeyedStateInputFormat<>( operatorState, stateBackend, - keyTypeInfo, - function); + new KeyedStateReaderOperator<>(function, keyTypeInfo)); return env.createInput(inputFormat, outTypeInfo); } diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/KeyedStateInputFormat.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/KeyedStateInputFormat.java index 454b74c..daf6278 100644 --- a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/KeyedStateInputFormat.java +++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/KeyedStateInputFormat.java @@ -19,17 +19,10 @@ package org.apache.flink.state.api.input; import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.common.io.DefaultInputSplitAssigner; import org.apache.flink.api.common.io.RichInputFormat; import org.apache.flink.api.common.io.statistics.BaseStatistics; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.state.StateDescriptor; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.common.typeinfo.Types; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.core.io.InputSplitAssigner; @@ -41,34 +34,25 @@ import org.apache.flink.runtime.state.DefaultKeyedStateStore; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.StateBackend; -import org.apache.flink.runtime.state.VoidNamespace; -import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.state.api.functions.KeyedStateReaderFunction; +import org.apache.flink.state.api.input.operator.StateReaderOperator; import org.apache.flink.state.api.input.splits.KeyGroupRangeInputSplit; import org.apache.flink.state.api.runtime.NeverFireProcessingTimeService; import org.apache.flink.state.api.runtime.SavepointEnvironment; import org.apache.flink.state.api.runtime.SavepointRuntimeContext; -import org.apache.flink.state.api.runtime.VoidTriggerable; import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; -import org.apache.flink.streaming.api.operators.InternalTimerService; -import org.apache.flink.streaming.api.operators.KeyContext; import org.apache.flink.streaming.api.operators.StreamOperatorStateContext; import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer; import org.apache.flink.streaming.api.operators.StreamTaskStateInitializerImpl; -import org.apache.flink.streaming.api.operators.TimerSerializer; import org.apache.flink.util.CollectionUtil; import org.apache.flink.util.Preconditions; import javax.annotation.Nonnull; import java.io.IOException; -import java.util.Collections; import java.util.Comparator; import java.util.Iterator; import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.StreamSupport; /** * Input format for reading partitioned state. @@ -77,54 +61,39 @@ import java.util.stream.StreamSupport; * @param <OUT> The type of the output of the {@link KeyedStateReaderFunction}. */ @Internal -public class KeyedStateInputFormat<K, OUT> extends RichInputFormat<OUT, KeyGroupRangeInputSplit> implements KeyContext { +public class KeyedStateInputFormat<K, N, OUT> extends RichInputFormat<OUT, KeyGroupRangeInputSplit> { private static final long serialVersionUID = 8230460226049597182L; - private static final String USER_TIMERS_NAME = "user-timers"; - private final OperatorState operatorState; private final StateBackend stateBackend; - private final TypeInformation<K> keyType; - - private final KeyedStateReaderFunction<K, OUT> userFunction; - - private transient TypeSerializer<K> keySerializer; + private final StateReaderOperator<?, K, N, OUT> operator; private transient CloseableRegistry registry; private transient BufferingCollector<OUT> out; - private transient Iterator<K> keys; - - private transient AbstractKeyedStateBackend<K> keyedStateBackend; - - private transient Context ctx; + private transient Iterator<Tuple2<K, N>> keysAndNamespaces; /** * Creates an input format for reading partitioned state from an operator in a savepoint. * * @param operatorState The state to be queried. * @param stateBackend The state backed used to snapshot the operator. - * @param keyType The type information describing the key type. - * @param userFunction The {@link KeyedStateReaderFunction} called for each key in the operator. */ public KeyedStateInputFormat( OperatorState operatorState, StateBackend stateBackend, - TypeInformation<K> keyType, - KeyedStateReaderFunction<K, OUT> userFunction) { + StateReaderOperator<?, K, N, OUT> operator) { Preconditions.checkNotNull(operatorState, "The operator state cannot be null"); Preconditions.checkNotNull(stateBackend, "The state backend cannot be null"); - Preconditions.checkNotNull(keyType, "The key type information cannot be null"); - Preconditions.checkNotNull(userFunction, "The userfunction cannot be null"); + Preconditions.checkNotNull(operator, "The operator cannot be null"); this.operatorState = operatorState; this.stateBackend = stateBackend; - this.keyType = keyType; - this.userFunction = userFunction; + this.operator = operator; } @Override @@ -160,7 +129,6 @@ public class KeyedStateInputFormat<K, OUT> extends RichInputFormat<OUT, KeyGroup @Override public void openInputFormat() { out = new BufferingCollector<>(); - keySerializer = keyType.createSerializer(getRuntimeContext().getExecutionConfig()); } @Override @@ -176,43 +144,21 @@ public class KeyedStateInputFormat<K, OUT> extends RichInputFormat<OUT, KeyGroup final StreamOperatorStateContext context = getStreamOperatorStateContext(environment); - keyedStateBackend = (AbstractKeyedStateBackend<K>) context.keyedStateBackend(); + AbstractKeyedStateBackend<K> keyedStateBackend = (AbstractKeyedStateBackend<K>) context.keyedStateBackend(); final DefaultKeyedStateStore keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, getRuntimeContext().getExecutionConfig()); SavepointRuntimeContext ctx = new SavepointRuntimeContext(getRuntimeContext(), keyedStateStore); - FunctionUtils.setFunctionRuntimeContext(userFunction, ctx); - - keys = getKeyIterator(ctx); - final InternalTimerService<VoidNamespace> timerService = restoreTimerService(context); + InternalTimeServiceManager<K> timeServiceManager = (InternalTimeServiceManager<K>) context.internalTimerServiceManager(); try { - this.ctx = new Context(keyedStateBackend, timerService); + operator.setup(getRuntimeContext().getExecutionConfig(), keyedStateBackend, timeServiceManager, ctx); + operator.open(); + keysAndNamespaces = operator.getKeysAndNamespaces(ctx); } catch (Exception e) { throw new IOException("Failed to restore timer state", e); } } - @SuppressWarnings("unchecked") - private InternalTimerService<VoidNamespace> restoreTimerService(StreamOperatorStateContext context) { - InternalTimeServiceManager<K> timeServiceManager = (InternalTimeServiceManager<K>) context.internalTimerServiceManager(); - TimerSerializer<K, VoidNamespace> timerSerializer = new TimerSerializer<>(keySerializer, VoidNamespaceSerializer.INSTANCE); - return timeServiceManager.getInternalTimerService(USER_TIMERS_NAME, timerSerializer, VoidTriggerable.instance()); - } - - @SuppressWarnings("unchecked") - private Iterator<K> getKeyIterator(SavepointRuntimeContext ctx) throws IOException { - final List<StateDescriptor<?, ?>> stateDescriptors; - try { - FunctionUtils.openFunction(userFunction, new Configuration()); - ctx.disableStateRegistration(); - stateDescriptors = ctx.getStateDescriptors(); - } catch (Exception e) { - throw new IOException("Failed to open user defined function", e); - } - - return new MultiStateKeyIterator<>(stateDescriptors, keyedStateBackend); - } - private StreamOperatorStateContext getStreamOperatorStateContext(Environment environment) throws IOException { StreamTaskStateInitializer initializer = new StreamTaskStateInitializerImpl( environment, @@ -223,8 +169,8 @@ public class KeyedStateInputFormat<K, OUT> extends RichInputFormat<OUT, KeyGroup operatorState.getOperatorID(), operatorState.getOperatorID().toString(), new NeverFireProcessingTimeService(), - this, - keySerializer, + operator, + operator.getKeyType().createSerializer(environment.getExecutionConfig()), registry, getRuntimeContext().getMetricGroup()); } catch (Exception e) { @@ -239,7 +185,7 @@ public class KeyedStateInputFormat<K, OUT> extends RichInputFormat<OUT, KeyGroup @Override public boolean reachedEnd() { - return !out.hasNext() && !keys.hasNext(); + return !out.hasNext() && !keysAndNamespaces.hasNext(); } @Override @@ -248,31 +194,20 @@ public class KeyedStateInputFormat<K, OUT> extends RichInputFormat<OUT, KeyGroup return out.next(); } - final K key = keys.next(); - setCurrentKey(key); + final Tuple2<K, N> keyAndNamespace = keysAndNamespaces.next(); + operator.setCurrentKey(keyAndNamespace.f0); try { - userFunction.readKey(key, ctx, out); + operator.processElement(keyAndNamespace.f0, keyAndNamespace.f1, out); } catch (Exception e) { throw new IOException("User defined function KeyedStateReaderFunction#readKey threw an exception", e); } - keys.remove(); + keysAndNamespaces.remove(); return out.next(); } - @Override - @SuppressWarnings("unchecked") - public void setCurrentKey(Object key) { - keyedStateBackend.setCurrentKey((K) key); - } - - @Override - public Object getCurrentKey() { - return keyedStateBackend.getCurrentKey(); - } - private static KeyGroupRangeInputSplit createKeyGroupRangeInputSplit( OperatorState operatorState, int maxParallelism, @@ -294,65 +229,4 @@ public class KeyedStateInputFormat<K, OUT> extends RichInputFormat<OUT, KeyGroup keyGroups.sort(Comparator.comparing(KeyGroupRange::getStartKeyGroup)); return keyGroups; } - - private static class Context<K> implements KeyedStateReaderFunction.Context { - - private static final String EVENT_TIMER_STATE = "event-time-timers"; - - private static final String PROC_TIMER_STATE = "proc-time-timers"; - - ListState<Long> eventTimers; - - ListState<Long> procTimers; - - private Context(AbstractKeyedStateBackend<K> keyedStateBackend, InternalTimerService<VoidNamespace> timerService) throws Exception { - eventTimers = keyedStateBackend.getPartitionedState( - USER_TIMERS_NAME, - StringSerializer.INSTANCE, - new ListStateDescriptor<>(EVENT_TIMER_STATE, Types.LONG) - ); - - timerService.forEachEventTimeTimer((namespace, timer) -> { - if (namespace.equals(VoidNamespace.INSTANCE)) { - eventTimers.add(timer); - } - }); - - procTimers = keyedStateBackend.getPartitionedState( - USER_TIMERS_NAME, - StringSerializer.INSTANCE, - new ListStateDescriptor<>(PROC_TIMER_STATE, Types.LONG) - ); - - timerService.forEachProcessingTimeTimer((namespace, timer) -> { - if (namespace.equals(VoidNamespace.INSTANCE)) { - procTimers.add(timer); - } - }); - } - - @Override - public Set<Long> registeredEventTimeTimers() throws Exception { - Iterable<Long> timers = eventTimers.get(); - if (timers == null) { - return Collections.emptySet(); - } - - return StreamSupport - .stream(timers.spliterator(), false) - .collect(Collectors.toSet()); - } - - @Override - public Set<Long> registeredProcessingTimeTimers() throws Exception { - Iterable<Long> timers = procTimers.get(); - if (timers == null) { - return Collections.emptySet(); - } - - return StreamSupport - .stream(timers.spliterator(), false) - .collect(Collectors.toSet()); - } - } } diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/MultiStateKeyIterator.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/MultiStateKeyIterator.java index bed0032..4af28f4 100644 --- a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/MultiStateKeyIterator.java +++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/MultiStateKeyIterator.java @@ -18,9 +18,10 @@ package org.apache.flink.state.api.input; +import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.state.State; import org.apache.flink.api.common.state.StateDescriptor; -import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.util.Preconditions; @@ -35,16 +36,17 @@ import java.util.List; * * @param <K> Type of the key by which state is keyed. */ -final class MultiStateKeyIterator<K> implements Iterator<K> { +@Internal +public final class MultiStateKeyIterator<K> implements Iterator<K> { private final List<? extends StateDescriptor<?, ?>> descriptors; - private final AbstractKeyedStateBackend<K> backend; + private final KeyedStateBackend<K> backend; private final Iterator<K> internal; private K currentKey; - MultiStateKeyIterator(List<? extends StateDescriptor<?, ?>> descriptors, AbstractKeyedStateBackend<K> backend) { + public MultiStateKeyIterator(List<? extends StateDescriptor<?, ?>> descriptors, KeyedStateBackend<K> backend) { this.descriptors = Preconditions.checkNotNull(descriptors); this.backend = Preconditions.checkNotNull(backend); diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/KeyedStateReaderOperator.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/KeyedStateReaderOperator.java new file mode 100644 index 0000000..1c99a5b --- /dev/null +++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/KeyedStateReaderOperator.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.state.api.input.operator; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.state.api.functions.KeyedStateReaderFunction; +import org.apache.flink.state.api.input.MultiStateKeyIterator; +import org.apache.flink.state.api.runtime.SavepointRuntimeContext; +import org.apache.flink.streaming.api.operators.InternalTimerService; +import org.apache.flink.util.Collector; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +/** + * A {@link StateReaderOperator} for executing a {@link KeyedStateReaderFunction}. + * + * @param <KEY> The key type read from the state backend. + * @param <OUT> The output type of the function. + */ +@Internal +public class KeyedStateReaderOperator<KEY, OUT> + extends StateReaderOperator<KeyedStateReaderFunction<KEY, OUT>, KEY, VoidNamespace, OUT> { + + private static final String USER_TIMERS_NAME = "user-timers"; + + private transient Context<KEY> context; + + public KeyedStateReaderOperator(KeyedStateReaderFunction<KEY, OUT> function, TypeInformation<KEY> keyType) { + super(function, keyType, VoidNamespaceSerializer.INSTANCE); + } + + @Override + public void open() throws Exception { + super.open(); + + InternalTimerService<VoidNamespace> timerService = getInternalTimerService(USER_TIMERS_NAME); + context = new Context<>(getKeyedStateBackend(), timerService); + } + + @Override + public void processElement(KEY key, VoidNamespace namespace, Collector<OUT> out) throws Exception { + function.readKey(key, context, out); + } + + @Override + public Iterator<Tuple2<KEY, VoidNamespace>> getKeysAndNamespaces(SavepointRuntimeContext ctx) throws Exception { + ctx.disableStateRegistration(); + List<StateDescriptor<?, ?>> stateDescriptors = ctx.getStateDescriptors(); + Iterator<KEY> keys = new MultiStateKeyIterator<>(stateDescriptors, getKeyedStateBackend()); + return new NamespaceDecorator<>(keys); + } + + private static class Context<K> implements KeyedStateReaderFunction.Context { + + private static final String EVENT_TIMER_STATE = "event-time-timers"; + + private static final String PROC_TIMER_STATE = "proc-time-timers"; + + ListState<Long> eventTimers; + + ListState<Long> procTimers; + + private Context(KeyedStateBackend<K> keyedStateBackend, InternalTimerService<VoidNamespace> timerService) throws Exception { + eventTimers = keyedStateBackend.getPartitionedState( + USER_TIMERS_NAME, + StringSerializer.INSTANCE, + new ListStateDescriptor<>(EVENT_TIMER_STATE, Types.LONG)); + + timerService.forEachEventTimeTimer((namespace, timer) -> { + if (namespace.equals(VoidNamespace.INSTANCE)) { + eventTimers.add(timer); + } + }); + + procTimers = keyedStateBackend.getPartitionedState( + USER_TIMERS_NAME, + StringSerializer.INSTANCE, + new ListStateDescriptor<>(PROC_TIMER_STATE, Types.LONG)); + + timerService.forEachProcessingTimeTimer((namespace, timer) -> { + if (namespace.equals(VoidNamespace.INSTANCE)) { + procTimers.add(timer); + } + }); + } + + @Override + public Set<Long> registeredEventTimeTimers() throws Exception { + Iterable<Long> timers = eventTimers.get(); + if (timers == null) { + return Collections.emptySet(); + } + + return StreamSupport + .stream(timers.spliterator(), false) + .collect(Collectors.toSet()); + } + + @Override + public Set<Long> registeredProcessingTimeTimers() throws Exception { + Iterable<Long> timers = procTimers.get(); + if (timers == null) { + return Collections.emptySet(); + } + + return StreamSupport + .stream(timers.spliterator(), false) + .collect(Collectors.toSet()); + } + } + + private static class NamespaceDecorator<KEY> implements Iterator<Tuple2<KEY, VoidNamespace>> { + + private final Iterator<KEY> keys; + + private NamespaceDecorator(Iterator<KEY> keys) { + this.keys = keys; + } + + @Override + public boolean hasNext() { + return keys.hasNext(); + } + + @Override + public Tuple2<KEY, VoidNamespace> next() { + KEY key = keys.next(); + return Tuple2.of(key, VoidNamespace.INSTANCE); + } + + @Override + public void remove() { + keys.remove(); + } + } +} diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/StateReaderOperator.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/StateReaderOperator.java new file mode 100644 index 0000000..eabc2f8 --- /dev/null +++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/input/operator/StateReaderOperator.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.state.api.input.operator; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.Function; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.state.api.runtime.SavepointRuntimeContext; +import org.apache.flink.state.api.runtime.VoidTriggerable; +import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; +import org.apache.flink.streaming.api.operators.InternalTimerService; +import org.apache.flink.streaming.api.operators.KeyContext; +import org.apache.flink.streaming.api.operators.TimerSerializer; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base class for executing functions that read keyed state. + * + * @param <F> The type of the user function. + * @param <KEY> The key type. + * @param <N> The namespace type. + * @param <OUT> The output type. + */ +@Internal +public abstract class StateReaderOperator<F extends Function, KEY, N, OUT> implements KeyContext, Serializable { + + private static final long serialVersionUID = 1L; + + protected final F function; + + private final TypeInformation<KEY> keyType; + + protected final TypeSerializer<N> namespaceSerializer; + + private transient ExecutionConfig executionConfig; + + private transient KeyedStateBackend<KEY> keyedStateBackend; + + private transient TypeSerializer<KEY> keySerializer; + + private transient InternalTimeServiceManager<KEY> timerServiceManager; + + protected StateReaderOperator(F function, TypeInformation<KEY> keyType, TypeSerializer<N> namespaceSerializer) { + Preconditions.checkNotNull(function, "The user function must not be null"); + Preconditions.checkNotNull(keyType, "The key type must not be null"); + Preconditions.checkNotNull(namespaceSerializer, "The namespace serializer must not be null"); + + this.function = function; + this.keyType = keyType; + this.namespaceSerializer = namespaceSerializer; + } + + public abstract void processElement(KEY key, N namespace, Collector<OUT> out) throws Exception; + + public abstract Iterator<Tuple2<KEY, N>> getKeysAndNamespaces(SavepointRuntimeContext ctx) throws Exception; + + public final void setup( + ExecutionConfig executionConfig, + KeyedStateBackend<KEY> keyKeyedStateBackend, + InternalTimeServiceManager<KEY> timerServiceManager, + SavepointRuntimeContext ctx) { + + this.executionConfig = executionConfig; + this.keyedStateBackend = keyKeyedStateBackend; + this.timerServiceManager = timerServiceManager; + this.keySerializer = keyType.createSerializer(executionConfig); + + FunctionUtils.setFunctionRuntimeContext(function, ctx); + } + + protected final InternalTimerService<N> getInternalTimerService(String name) { + TimerSerializer<KEY, N> timerSerializer = new TimerSerializer<>(keySerializer, namespaceSerializer); + return timerServiceManager.getInternalTimerService(name, timerSerializer, VoidTriggerable.instance()); + } + + public void open() throws Exception { + FunctionUtils.openFunction(function, new Configuration()); + } + + public void close() throws Exception { + FunctionUtils.closeFunction(function); + } + + @Override + @SuppressWarnings("unchecked") + public final void setCurrentKey(Object key) { + keyedStateBackend.setCurrentKey((KEY) key); + } + + @Override + public final Object getCurrentKey() { + return keyedStateBackend.getCurrentKey(); + } + + public final KeyedStateBackend<KEY> getKeyedStateBackend() { + return keyedStateBackend; + } + + public final TypeInformation<KEY> getKeyType() { + return keyType; + } + + public final ExecutionConfig getExecutionConfig() { + return this.executionConfig; + } +} diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/MemoryStateBackendReaderKeyedStateITCase.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/MemoryStateBackendReaderKeyedStateITCase.java new file mode 100644 index 0000000..61d446a --- /dev/null +++ b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/MemoryStateBackendReaderKeyedStateITCase.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.state.api; + +import org.apache.flink.runtime.state.memory.MemoryStateBackend; + +/** + * IT Case for reading keyed state from a memory state backend. + */ +public class MemoryStateBackendReaderKeyedStateITCase extends SavepointReaderKeyedStateITCase<MemoryStateBackend> { + + @Override + protected MemoryStateBackend getStateBackend() { + return new MemoryStateBackend(); + } +} diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/RocksDBStateBackendReaderKeyedStateITCase.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/RocksDBStateBackendReaderKeyedStateITCase.java new file mode 100644 index 0000000..df9fb85 --- /dev/null +++ b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/RocksDBStateBackendReaderKeyedStateITCase.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.state.api; + +import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; +import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; + +/** + * IT Case for reading state from a RocksDB keyed state backend. + */ +public class RocksDBStateBackendReaderKeyedStateITCase extends SavepointReaderKeyedStateITCase<RocksDBStateBackend> { + @Override + protected RocksDBStateBackend getStateBackend() { + return new RocksDBStateBackend((StateBackend) new MemoryStateBackend()); + } +} diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/SavepointReaderKeyedStateITCase.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/SavepointReaderKeyedStateITCase.java index 1ce2376..06e864f 100644 --- a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/SavepointReaderKeyedStateITCase.java +++ b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/SavepointReaderKeyedStateITCase.java @@ -18,171 +18,74 @@ package org.apache.flink.state.api; -import org.apache.flink.api.common.JobID; -import org.apache.flink.api.common.JobSubmissionResult; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.time.Deadline; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.ExecutionEnvironment; -import org.apache.flink.client.ClientUtils; -import org.apache.flink.client.program.ClusterClient; import org.apache.flink.configuration.Configuration; -import org.apache.flink.contrib.streaming.state.RocksDBStateBackend; -import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.state.StateBackend; -import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.state.api.functions.KeyedStateReaderFunction; +import org.apache.flink.state.api.utils.SavepointTestBase; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.KeyedProcessFunction; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; -import org.apache.flink.streaming.api.functions.source.SourceFunction; -import org.apache.flink.test.util.AbstractTestBase; -import org.apache.flink.util.AbstractID; import org.apache.flink.util.Collector; import org.junit.Assert; import org.junit.Test; -import java.time.Duration; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; /** * IT case for reading state. */ -public class SavepointReaderKeyedStateITCase extends AbstractTestBase { +public abstract class SavepointReaderKeyedStateITCase<B extends StateBackend> extends SavepointTestBase { private static final String uid = "stateful-operator"; private static ValueStateDescriptor<Integer> valueState = new ValueStateDescriptor<>("value", Types.INT); - @Test - public void testKeyedInputFormat() throws Exception { - runKeyedState(new MemoryStateBackend()); - // Reset the cluster so we can change the - // state backend in the StreamEnvironment. - // If we don't do this the tests will fail. - miniClusterResource.after(); - miniClusterResource.before(); - runKeyedState(new RocksDBStateBackend((StateBackend) new MemoryStateBackend())); - } - - private void runKeyedState(StateBackend backend) throws Exception { - StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.getExecutionEnvironment(); - streamEnv.setStateBackend(backend); - streamEnv.setParallelism(4); + private static final List<Pojo> elements = Arrays.asList( + Pojo.of(1, 1), + Pojo.of(2, 2), + Pojo.of(3, 3)); - streamEnv - .addSource(new SavepointSource()) - .rebalance() - .keyBy(id -> id.key) - .process(new KeyedStatefulOperator()) - .uid(uid) - .addSink(new DiscardingSink<>()); + protected abstract B getStateBackend(); - JobGraph jobGraph = streamEnv.getStreamGraph().getJobGraph(); - - String path = takeSavepoint(jobGraph); + @Test + public void testUserKeyedStateReader() throws Exception { + String savepointPath = takeSavepoint(elements, source -> { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setStateBackend(getStateBackend()); + env.setParallelism(4); + + env + .addSource(source) + .rebalance() + .keyBy(id -> id.key) + .process(new KeyedStatefulOperator()) + .uid(uid) + .addSink(new DiscardingSink<>()); + + return env; + }); ExecutionEnvironment batchEnv = ExecutionEnvironment.getExecutionEnvironment(); - ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, backend); + ExistingSavepoint savepoint = Savepoint.load(batchEnv, savepointPath, getStateBackend()); List<Pojo> results = savepoint .readKeyedState(uid, new Reader()) .collect(); - Set<Pojo> expected = SavepointSource.getElements(); + Set<Pojo> expected = new HashSet<>(elements); Assert.assertEquals("Unexpected results from keyed state", expected, new HashSet<>(results)); } - private String takeSavepoint(JobGraph jobGraph) throws Exception { - SavepointSource.initializeForTest(); - - ClusterClient<?> client = miniClusterResource.getClusterClient(); - - JobID jobId = jobGraph.getJobID(); - - Deadline deadline = Deadline.fromNow(Duration.ofMinutes(5)); - - String dirPath = getTempDirPath(new AbstractID().toHexString()); - - try { - JobSubmissionResult result = ClientUtils.submitJob(client, jobGraph); - - boolean finished = false; - while (deadline.hasTimeLeft()) { - if (SavepointSource.isFinished()) { - finished = true; - - break; - } - } - - if (!finished) { - Assert.fail("Failed to initialize state within deadline"); - } - - CompletableFuture<String> path = client.triggerSavepoint(result.getJobID(), dirPath); - return path.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); - } finally { - client.cancel(jobId).get(); - } - } - - private static class SavepointSource implements SourceFunction<Pojo> { - private static volatile boolean finished; - - private volatile boolean running = true; - - private static final Pojo[] elements = { - Pojo.of(1, 1), - Pojo.of(2, 2), - Pojo.of(3, 3)}; - - @Override - public void run(SourceContext<Pojo> ctx) { - synchronized (ctx.getCheckpointLock()) { - for (Pojo element : elements) { - ctx.collect(element); - } - - finished = true; - } - - while (running) { - try { - Thread.sleep(100); - } catch (InterruptedException e) { - // ignore - } - } - } - - @Override - public void cancel() { - running = false; - } - - private static void initializeForTest() { - finished = false; - } - - private static boolean isFinished() { - return finished; - } - - private static Set<Pojo> getElements() { - return new HashSet<>(Arrays.asList(elements)); - } - } - private static class KeyedStatefulOperator extends KeyedProcessFunction<Integer, Pojo, Void> { private transient ValueState<Integer> state; diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java index e036bc8..d0b55b6 100644 --- a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java +++ b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java @@ -27,8 +27,10 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.checkpoint.OperatorState; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.state.api.functions.KeyedStateReaderFunction; +import org.apache.flink.state.api.input.operator.KeyedStateReaderOperator; import org.apache.flink.state.api.input.splits.KeyGroupRangeInputSplit; import org.apache.flink.state.api.runtime.OperatorIDGenerator; import org.apache.flink.streaming.api.functions.KeyedProcessFunction; @@ -65,7 +67,7 @@ public class KeyedStateInputFormatTest { OperatorState operatorState = new OperatorState(operatorID, 1, 128); operatorState.putState(0, state); - KeyedStateInputFormat<?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new ReaderFunction()); + KeyedStateInputFormat<?, ?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new KeyedStateReaderOperator<>(new ReaderFunction(), Types.INT)); KeyGroupRangeInputSplit[] splits = format.createInputSplits(4); Assert.assertEquals("Failed to properly partition operator state into input splits", 4, splits.length); } @@ -78,7 +80,7 @@ public class KeyedStateInputFormatTest { OperatorState operatorState = new OperatorState(operatorID, 1, 128); operatorState.putState(0, state); - KeyedStateInputFormat<?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new ReaderFunction()); + KeyedStateInputFormat<?, ?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new KeyedStateReaderOperator<>(new ReaderFunction(), Types.INT)); KeyGroupRangeInputSplit[] splits = format.createInputSplits(129); Assert.assertEquals("Failed to properly partition operator state into input splits", 128, splits.length); } @@ -91,7 +93,7 @@ public class KeyedStateInputFormatTest { OperatorState operatorState = new OperatorState(operatorID, 1, 128); operatorState.putState(0, state); - KeyedStateInputFormat<?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new ReaderFunction()); + KeyedStateInputFormat<?, ?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new KeyedStateReaderOperator<>(new ReaderFunction(), Types.INT)); KeyGroupRangeInputSplit split = format.createInputSplits(1)[0]; KeyedStateReaderFunction<Integer, Integer> userFunction = new ReaderFunction(); @@ -109,7 +111,7 @@ public class KeyedStateInputFormatTest { OperatorState operatorState = new OperatorState(operatorID, 1, 128); operatorState.putState(0, state); - KeyedStateInputFormat<?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new ReaderFunction()); + KeyedStateInputFormat<?, ?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new KeyedStateReaderOperator<>(new ReaderFunction(), Types.INT)); KeyGroupRangeInputSplit split = format.createInputSplits(1)[0]; KeyedStateReaderFunction<Integer, Integer> userFunction = new DoubleReaderFunction(); @@ -127,7 +129,7 @@ public class KeyedStateInputFormatTest { OperatorState operatorState = new OperatorState(operatorID, 1, 128); operatorState.putState(0, state); - KeyedStateInputFormat<?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new ReaderFunction()); + KeyedStateInputFormat<?, ?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new KeyedStateReaderOperator<>(new ReaderFunction(), Types.INT)); KeyGroupRangeInputSplit split = format.createInputSplits(1)[0]; KeyedStateReaderFunction<Integer, Integer> userFunction = new InvalidReaderFunction(); @@ -145,7 +147,7 @@ public class KeyedStateInputFormatTest { OperatorState operatorState = new OperatorState(operatorID, 1, 128); operatorState.putState(0, state); - KeyedStateInputFormat<?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), Types.INT, new TimerReaderFunction()); + KeyedStateInputFormat<?, ?, ?> format = new KeyedStateInputFormat<>(operatorState, new MemoryStateBackend(), new KeyedStateReaderOperator<>(new TimerReaderFunction(), Types.INT)); KeyGroupRangeInputSplit split = format.createInputSplits(1)[0]; KeyedStateReaderFunction<Integer, Integer> userFunction = new TimerReaderFunction(); @@ -157,11 +159,10 @@ public class KeyedStateInputFormatTest { @Nonnull private List<Integer> readInputSplit(KeyGroupRangeInputSplit split, KeyedStateReaderFunction<Integer, Integer> userFunction) throws IOException { - KeyedStateInputFormat<Integer, Integer> format = new KeyedStateInputFormat<>( + KeyedStateInputFormat<Integer, VoidNamespace, Integer> format = new KeyedStateInputFormat<>( new OperatorState(OperatorIDGenerator.fromUid("uid"), 1, 4), new MemoryStateBackend(), - Types.INT, - userFunction); + new KeyedStateReaderOperator<>(userFunction, Types.INT)); List<Integer> data = new ArrayList<>(); diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/SavepointTestBase.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/SavepointTestBase.java new file mode 100644 index 0000000..0d25682 --- /dev/null +++ b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/SavepointTestBase.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.state.api.utils; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.JobSubmissionResult; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.client.ClientUtils; +import org.apache.flink.client.program.ClusterClient; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.source.FromElementsFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.util.AbstractID; + +import java.io.IOException; +import java.util.Collection; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +/** + * A test base that includes utilities for taking a savepoint. + */ +public abstract class SavepointTestBase extends AbstractTestBase { + + public <T> String takeSavepoint(Collection<T> data, Function<SourceFunction<T>, StreamExecutionEnvironment> jobGraphFactory) throws Exception { + + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.getConfig().disableClosureCleaner(); + + WaitingSource<T> waitingSource = createSource(data); + + JobGraph jobGraph = jobGraphFactory.apply(waitingSource).getStreamGraph().getJobGraph(); + JobID jobId = jobGraph.getJobID(); + + ClusterClient<?> client = miniClusterResource.getClusterClient(); + + try { + JobSubmissionResult result = ClientUtils.submitJob(client, jobGraph); + + return CompletableFuture + .runAsync(waitingSource::awaitSource) + .thenCompose(ignore -> triggerSavepoint(client, result.getJobID())) + .get(5, TimeUnit.MINUTES); + } catch (Exception e) { + throw new RuntimeException("Failed to take savepoint", e); + } finally { + client.cancel(jobId); + } + } + + private <T> WaitingSource<T> createSource(Collection<T> data) throws Exception { + T first = data.iterator().next(); + if (first == null) { + throw new IllegalArgumentException("Collection must not contain null elements"); + } + + TypeInformation<T> typeInfo = TypeExtractor.getForObject(first); + SourceFunction<T> inner = new FromElementsFunction<>(typeInfo.createSerializer(new ExecutionConfig()), data); + return new WaitingSource<>(inner, typeInfo); + } + + private CompletableFuture<String> triggerSavepoint(ClusterClient<?> client, JobID jobID) throws RuntimeException { + try { + String dirPath = getTempDirPath(new AbstractID().toHexString()); + return client.triggerSavepoint(jobID, dirPath); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/WaitingSource.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/WaitingSource.java new file mode 100644 index 0000000..768cb33 --- /dev/null +++ b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/utils/WaitingSource.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.state.api.utils; + +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ResultTypeQueryable; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.streaming.api.functions.source.RichSourceFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; + +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +/** + * A wrapper class that allows threads to block until the inner source completes. + * It's run method does not return until explicitly canceled so external processes can + * perform operations such as taking savepoints. + * + * @param <T> The output type of the inner source. + */ +@SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter") +public class WaitingSource<T> extends RichSourceFunction<T> implements ResultTypeQueryable<T> { + + private static final Map<String, OneShotLatch> guards = new HashMap<>(); + + private final SourceFunction<T> source; + + private final TypeInformation<T> returnType; + + private final String guardId; + + private volatile boolean running; + + public WaitingSource(SourceFunction<T> source, TypeInformation<T> returnType) { + this.source = source; + this.returnType = returnType; + this.guardId = UUID.randomUUID().toString(); + + guards.put(guardId, new OneShotLatch()); + this.running = true; + } + + @Override + public void setRuntimeContext(RuntimeContext t) { + if (source instanceof RichSourceFunction) { + ((RichSourceFunction<T>) source).setRuntimeContext(t); + } + } + + @Override + public void open(Configuration parameters) throws Exception { + if (source instanceof RichSourceFunction) { + ((RichSourceFunction<T>) source).open(parameters); + } + } + + @Override + public void close() throws Exception { + if (source instanceof RichSourceFunction) { + ((RichSourceFunction<T>) source).close(); + } + } + + @Override + public void run(SourceContext<T> ctx) throws Exception { + OneShotLatch latch = guards.get(guardId); + try { + source.run(ctx); + } finally { + latch.trigger(); + } + + while (running) { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + // ignore + } + } + } + + @Override + public void cancel() { + source.cancel(); + running = false; + } + + /** + * This method blocks until the inner source has completed. + * + */ + public void awaitSource() throws RuntimeException { + try { + guards.get(guardId).await(); + } catch (InterruptedException e) { + throw new RuntimeException("Failed to initialize source"); + } + } + + @Override + public TypeInformation<T> getProducedType() { + return returnType; + } +}