http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index 59ecd15..718c0c7 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -24,13 +24,12 @@ import org.apache.flink.api.common.state.State; import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.metrics.Counter; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.fs.FSDataOutputStream; -import org.apache.flink.runtime.state.AsynchronousKvStateSnapshot; +import org.apache.flink.metrics.Counter; import org.apache.flink.metrics.MetricGroup; -import org.apache.flink.runtime.state.AbstractStateBackend; -import org.apache.flink.runtime.state.KvStateSnapshot; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.graph.StreamConfig; @@ -38,14 +37,9 @@ import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.operators.Triggerable; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; -import org.apache.flink.util.InstantiationUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.DataInputStream; -import java.io.DataOutputStream; -import java.util.HashMap; -import java.util.Set; import java.util.concurrent.ScheduledFuture; /** @@ -96,9 +90,10 @@ public abstract class AbstractStreamOperator<OUT> private transient KeySelector<?, ?> stateKeySelector1; private transient KeySelector<?, ?> stateKeySelector2; - /** The state backend that stores the state and checkpoints for this task */ - private transient AbstractStateBackend stateBackend; - protected MetricGroup metrics; + /** Backend for keyed state. This might be empty if we're not on a keyed stream. */ + private transient KeyedStateBackend<?> keyedStateBackend; + + protected transient MetricGroup metrics; // ------------------------------------------------------------------------ // Life Cycle @@ -116,16 +111,6 @@ public abstract class AbstractStreamOperator<OUT> stateKeySelector1 = config.getStatePartitioner(0, getUserCodeClassloader()); stateKeySelector2 = config.getStatePartitioner(1, getUserCodeClassloader()); - - try { - TypeSerializer<Object> keySerializer = config.getStateKeySerializer(getUserCodeClassloader()); - // if the keySerializer is null we still need to create the state backend - // for the non-partitioned state features it provides, such as the state output streams - String operatorIdentifier = getClass().getSimpleName() + "_" + config.getVertexID() + "_" + runtimeContext.getIndexOfThisSubtask(); - stateBackend = container.createStateBackend(operatorIdentifier, keySerializer); - } catch (Exception e) { - throw new RuntimeException("Could not initialize state backend. ", e); - } } public MetricGroup getMetricGroup() { @@ -141,7 +126,27 @@ public abstract class AbstractStreamOperator<OUT> * @throws Exception An exception in this method causes the operator to fail. */ @Override - public void open() throws Exception {} + public void open() throws Exception { + try { + TypeSerializer<Object> keySerializer = config.getStateKeySerializer(getUserCodeClassloader()); + // create a keyed state backend if there is keyed state, as indicated by the presence of a key serializer + if (null != keySerializer) { + ExecutionConfig execConf = container.getEnvironment().getExecutionConfig();; + + KeyGroupRange subTaskKeyGroupRange = KeyGroupRange.computeKeyGroupRangeForOperatorIndex( + container.getEnvironment().getTaskInfo().getNumberOfKeyGroups(), + container.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(), + container.getIndexInSubtaskGroup()); + + keyedStateBackend = container.createKeyedStateBackend( + keySerializer, + container.getConfiguration().getKeyGroupAssigner(getUserCodeClassloader()), + subTaskKeyGroupRange); + } + } catch (Exception e) { + throw new RuntimeException("Could not initialize keyed state backend.", e); + } + } /** * This method is called after all records have been added to the operators via the methods @@ -166,69 +171,22 @@ public abstract class AbstractStreamOperator<OUT> * that the operator has acquired. */ @Override - public void dispose() { - if (stateBackend != null) { - try { - stateBackend.close(); - stateBackend.discardState(); - } catch (Exception e) { - throw new RuntimeException("Error while closing/disposing state backend.", e); - } + public void dispose() throws Exception { + if (keyedStateBackend != null) { + keyedStateBackend.close(); } } - - // ------------------------------------------------------------------------ - // Checkpointing - // ------------------------------------------------------------------------ @Override public void snapshotState(FSDataOutputStream out, long checkpointId, - long timestamp) throws Exception { - - HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> keyedState = - stateBackend.snapshotPartitionedState(checkpointId,timestamp); - - // Materialize asynchronous snapshots, if any - if (keyedState != null) { - Set<String> keys = keyedState.keySet(); - for (String key: keys) { - if (keyedState.get(key) instanceof AsynchronousKvStateSnapshot) { - AsynchronousKvStateSnapshot<?, ?, ?, ?, ?> asyncHandle = (AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) keyedState.get(key); - keyedState.put(key, asyncHandle.materialize()); - } - } - } - - byte[] serializedSnapshot = InstantiationUtil.serializeObject(keyedState); - - DataOutputStream dos = new DataOutputStream(out); - dos.writeInt(serializedSnapshot.length); - dos.write(serializedSnapshot); - - dos.flush(); - - } + long timestamp) throws Exception {} @Override - public void restoreState(FSDataInputStream in) throws Exception { - DataInputStream dis = new DataInputStream(in); - int size = dis.readInt(); - byte[] serializedSnapshot = new byte[size]; - dis.readFully(serializedSnapshot); - - HashMap<String, KvStateSnapshot> keyedState = - InstantiationUtil.deserializeObject(serializedSnapshot, getUserCodeClassloader()); + public void restoreState(FSDataInputStream in) throws Exception {} - stateBackend.injectKeyValueStateSnapshots(keyedState); - } - @Override - public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception { - if (stateBackend != null) { - stateBackend.notifyOfCompletedCheckpoint(checkpointId); - } - } + public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {} // ------------------------------------------------------------------------ // Properties and Services @@ -265,13 +223,14 @@ public abstract class AbstractStreamOperator<OUT> return runtimeContext; } - public AbstractStateBackend getStateBackend() { - return stateBackend; + @SuppressWarnings("rawtypes, unchecked") + public <K> KeyedStateBackend<K> getStateBackend() { + return (KeyedStateBackend<K>) keyedStateBackend; } /** - * Register a timer callback. At the specified time the {@link Triggerable} will be invoked. - * This call is guaranteed to not happen concurrently with method calls on the operator. + * Register a timer callback. At the specified time the provided {@link Triggerable} will + * be invoked. This call is guaranteed to not happen concurrently with method calls on the operator. * * @param time The absolute time in milliseconds. * @param target The target to be triggered. @@ -291,7 +250,7 @@ public abstract class AbstractStreamOperator<OUT> * @throws Exception Thrown, if the state backend cannot create the key/value state. */ protected <S extends State> S getPartitionedState(StateDescriptor<S, ?> stateDescriptor) throws Exception { - return getStateBackend().getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateDescriptor); + return getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateDescriptor); } /** @@ -302,13 +261,13 @@ public abstract class AbstractStreamOperator<OUT> */ @SuppressWarnings("unchecked") protected <S extends State, N> S getPartitionedState(N namespace, TypeSerializer<N> namespaceSerializer, StateDescriptor<S, ?> stateDescriptor) throws Exception { - if (stateBackend != null) { - return stateBackend.getPartitionedState( - namespace, - namespaceSerializer, - stateDescriptor); + if (keyedStateBackend != null) { + return keyedStateBackend.getPartitionedState( + namespace, + namespaceSerializer, + stateDescriptor); } else { - throw new RuntimeException("Cannot create partitioned state. The key grouped state " + + throw new RuntimeException("Cannot create partitioned state. The keyed state " + "backend has not been set. This indicates that the operator is not " + "partitioned/keyed."); } @@ -335,15 +294,16 @@ public abstract class AbstractStreamOperator<OUT> @SuppressWarnings({"unchecked", "rawtypes"}) public void setKeyContext(Object key) { - if (stateBackend != null) { + if (keyedStateBackend != null) { try { - stateBackend.setCurrentKey(key); + // need to work around type restrictions + @SuppressWarnings("unchecked,rawtypes") + KeyedStateBackend rawBackend = (KeyedStateBackend) keyedStateBackend; + + rawBackend.setCurrentKey(key); } catch (Exception e) { throw new RuntimeException("Exception occurred while setting the current key context.", e); } - } else { - throw new RuntimeException("Could not set the current key context, because the " + - "AbstractStateBackend has not been initialized."); } }
http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java index b1bc531..6ac73e7 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java @@ -18,8 +18,6 @@ package org.apache.flink.streaming.api.operators; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.io.Serializable; import org.apache.flink.annotation.PublicEvolving; @@ -35,6 +33,7 @@ import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.util.InstantiationUtil; import static java.util.Objects.requireNonNull; @@ -49,7 +48,9 @@ import static java.util.Objects.requireNonNull; * The type of the user function */ @PublicEvolving -public abstract class AbstractUdfStreamOperator<OUT, F extends Function> extends AbstractStreamOperator<OUT> implements OutputTypeConfigurable<OUT> { +public abstract class AbstractUdfStreamOperator<OUT, F extends Function> + extends AbstractStreamOperator<OUT> + implements OutputTypeConfigurable<OUT> { private static final long serialVersionUID = 1L; @@ -100,16 +101,11 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> extends } @Override - public void dispose() { + public void dispose() throws Exception { super.dispose(); if (!functionsClosed) { functionsClosed = true; - try { - FunctionUtils.closeFunction(userFunction); - } - catch (Throwable t) { - LOG.error("Exception while closing user function while failing or canceling task", t); - } + FunctionUtils.closeFunction(userFunction); } } @@ -130,9 +126,7 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> extends udfState = chkFunction.snapshotState(checkpointId, timestamp); if (udfState != null) { out.write(1); - ObjectOutputStream os = new ObjectOutputStream(out); - os.writeObject(udfState); - os.flush(); + InstantiationUtil.serializeObject(out, udfState); } else { out.write(0); } @@ -153,8 +147,7 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function> extends int hasUdfState = in.read(); if (hasUdfState == 1) { - ObjectInputStream ois = new ObjectInputStream(in); - Serializable functionState = (Serializable) ois.readObject(); + Serializable functionState = InstantiationUtil.deserializeObject(in, getUserCodeClassloader()); if (functionState != null) { try { chkFunction.restoreState(functionState); http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java index 3411a60..f1e8160 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java @@ -20,9 +20,9 @@ package org.apache.flink.streaming.api.operators; import java.io.Serializable; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; @@ -84,7 +84,7 @@ public interface StreamOperator<OUT> extends Serializable { * This method is expected to make a thorough effort to release all resources * that the operator has acquired. */ - void dispose(); + void dispose() throws Exception; // ------------------------------------------------------------------------ // state snapshots @@ -92,8 +92,7 @@ public interface StreamOperator<OUT> extends Serializable { /** * Called to draw a state snapshot from the operator. This method snapshots the operator state - * (if the operator is stateful) and the key/value state (if it is being used and has been - * initialized). + * (if the operator is stateful). * * @param out The stream to which we have to write our state. * @param checkpointId The ID of the checkpoint. http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java index 8d074cc..35d1108 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java @@ -24,7 +24,7 @@ import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.runtime.io.disk.InputViewIterator; -import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.util.ReusingMutableToRegularIteratorWrapper; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -56,9 +56,10 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I protected static final Logger LOG = LoggerFactory.getLogger(GenericWriteAheadSink.class); private final CheckpointCommitter committer; - private transient AbstractStateBackend.CheckpointStateOutputStream out; + private transient CheckpointStreamFactory.CheckpointStateOutputStream out; protected final TypeSerializer<IN> serializer; private final String id; + private transient CheckpointStreamFactory checkpointStreamFactory; private ExactlyOnceState state = new ExactlyOnceState(); @@ -76,6 +77,8 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I committer.setOperatorSubtaskId(getRuntimeContext().getIndexOfThisSubtask()); committer.open(); cleanState(); + checkpointStreamFactory = + getContainingTask().createCheckpointStreamFactory(this); } public void close() throws Exception { @@ -184,9 +187,9 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I @Override public void processElement(StreamRecord<IN> element) throws Exception { IN value = element.getValue(); - //generate initial operator state + // generate initial operator state if (out == null) { - out = getStateBackend().createCheckpointStateOutputStream(0, 0); + out = checkpointStreamFactory.createCheckpointStateOutputStream(0, 0); } serializer.serialize(value, new DataOutputViewStreamWrapper(out)); } http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java index 2c95099..e74dd87 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/AbstractAlignedProcessingTimeWindowOperator.java @@ -125,7 +125,7 @@ public abstract class AbstractAlignedProcessingTimeWindowOperator<KEY, IN, OUT, // decide when to first compute the window and when to slide it // the values should align with the start of time (that is, the UNIX epoch, not the big bang) - final long now = System.currentTimeMillis(); + final long now = getRuntimeContext().getCurrentProcessingTime(); nextEvaluationTime = now + windowSlide - (now % windowSlide); nextSlideTime = now + paneSize - (now % paneSize); @@ -178,7 +178,7 @@ public abstract class AbstractAlignedProcessingTimeWindowOperator<KEY, IN, OUT, } @Override - public void dispose() { + public void dispose() throws Exception { super.dispose(); // acquire the lock during shutdown, to prevent trigger calls at the same time http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java index dbdd660..25ec519 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java @@ -277,7 +277,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window> } @Override - public void dispose() { + public void dispose() throws Exception { super.dispose(); timestampedCollector = null; watermarkTimers = null; http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 074257c..02579aa 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.runtime.tasks; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.api.common.state.KeyGroupAssigner; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; @@ -30,7 +31,10 @@ import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.StateBackendFactory; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.FsStateBackend; @@ -40,7 +44,6 @@ import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory; import org.apache.flink.runtime.util.event.EventListener; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.graph.StreamConfig; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.runtime.io.RecordWriterOutput; @@ -50,6 +53,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.Closeable; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -57,6 +61,9 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.RunnableFuture; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledThreadPoolExecutor; @@ -130,6 +137,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> /** The class loader used to load dynamic classes of a job */ private ClassLoader userClassLoader; + /** Our state backend. We use this to create checkpoint streams and a keyed state backend. */ + private AbstractStateBackend stateBackend; + + /** Keyed state backend for the head operator, if it is keyed. There can only ever be one. */ + private KeyedStateBackend<?> keyedStateBackend; + /** * The internal {@link TimeServiceProvider} used to define the current * processing time (default = {@code System.currentTimeMillis()}) and @@ -152,7 +165,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> private volatile AsynchronousException asyncException; /** The currently active background materialization threads */ - private final Set<Closeable> cancelables = new HashSet<Closeable>(); + private final Set<Closeable> cancelables = new HashSet<>(); /** Flag to mark the task "in operation", in which case check * needs to be initialized to true, so that early cancel() before invoke() behaves correctly */ @@ -163,6 +176,8 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> private long lastCheckpointSize = 0; + private ExecutorService asyncOperationsThreadPool; + // ------------------------------------------------------------------------ // Life cycle methods for specific implementations // ------------------------------------------------------------------------ @@ -205,9 +220,14 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // -------- Initialize --------- LOG.debug("Initializing {}", getName()); + asyncOperationsThreadPool = Executors.newCachedThreadPool(); + userClassLoader = getUserCodeClassLoader(); configuration = new StreamConfig(getTaskConfiguration()); + + stateBackend = createStateBackend(); + accumulatorMap = getEnvironment().getAccumulatorRegistry().getUserMap(); // if the clock is not already set, then assign a default TimeServiceProvider @@ -252,8 +272,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // first order of business is to give operators back their state restoreState(); lazyRestoreChainedOperatorState = null; // GC friendliness - lazyRestoreKeyGroupStates = null; // GC friendliness - + // we need to make sure that any triggers scheduled in open() cannot be // executed before all operators are opened synchronized (lock) { @@ -292,6 +311,9 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // still let the computation fail tryDisposeAllOperators(); disposed = true; + + // Don't forget to check and throw exceptions that happened in async thread one last time + checkTimerException(); } finally { // clean up everything we initialized @@ -307,16 +329,17 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> LOG.error("Could not shut down timer service", t); } } - + // stop all asynchronous checkpoint threads try { closeAllClosables(); + shutdownAsyncThreads(); } catch (Throwable t) { // catch and log the exception to not replace the original exception LOG.error("Could not shut down async checkpoint threads", t); } - + // release the output resources. this method should never fail. if (operatorChain != null) { operatorChain.releaseOutputs(); @@ -330,7 +353,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // catch and log the exception to not replace the original exception LOG.error("Error during cleanup of stream task", t); } - + // if the operators were not disposed before, do a hard dispose if (!disposed) { disposeAllOperators(); @@ -414,6 +437,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> } } + private void shutdownAsyncThreads() throws Exception { + if (!asyncOperationsThreadPool.isShutdown()) { + asyncOperationsThreadPool.shutdownNow(); + } + } + /** * Execute the operator-specific {@link StreamOperator#dispose()} method in each * of the operators in the chain of this {@link StreamTask}. </b> Disposing happens @@ -558,69 +587,6 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> cancelables.remove(lazyRestoreChainedOperatorState); } } -// if (lazyRestoreState != null || lazyRestoreKeyGroupStates != null) { -// -// LOG.info("Restoring checkpointed state to task {}", getName()); -// -// final StreamOperator<?>[] allOperators = operatorChain.getAllOperators(); -// StreamOperatorNonPartitionedState[] nonPartitionedStates; -// -// final List<Map<Integer, PartitionedStateSnapshot>> keyGroupStates = new ArrayList<Map<Integer, PartitionedStateSnapshot>>(allOperators.length); -// -// for (int i = 0; i < allOperators.length; i++) { -// keyGroupStates.add(new HashMap<Integer, PartitionedStateSnapshot>()); -// } -// -// if (lazyRestoreState != null) { -// try { -// nonPartitionedStates = lazyRestoreState.get(getUserCodeClassLoader()); -// -// // be GC friendly -// lazyRestoreState = null; -// } catch (Exception e) { -// throw new Exception("Could not restore checkpointed non-partitioned state.", e); -// } -// } else { -// nonPartitionedStates = new StreamOperatorNonPartitionedState[allOperators.length]; -// } -// -// if (lazyRestoreKeyGroupStates != null) { -// try { -// // construct key groups state for operators -// for (Map.Entry<Integer, ChainedStateHandle> lazyRestoreKeyGroupState : lazyRestoreKeyGroupStates.entrySet()) { -// int keyGroupId = lazyRestoreKeyGroupState.getKey(); -// -// Map<Integer, PartitionedStateSnapshot> chainedKeyGroupStates = lazyRestoreKeyGroupState.getValue().get(getUserCodeClassLoader()); -// -// for (Map.Entry<Integer, PartitionedStateSnapshot> chainedKeyGroupState : chainedKeyGroupStates.entrySet()) { -// int chainedIndex = chainedKeyGroupState.getKey(); -// -// Map<Integer, PartitionedStateSnapshot> keyGroupState; -// -// keyGroupState = keyGroupStates.get(chainedIndex); -// keyGroupState.put(keyGroupId, chainedKeyGroupState.getValue()); -// } -// } -// -// lazyRestoreKeyGroupStates = null; -// -// } catch (Exception e) { -// throw new Exception("Could not restore checkpointed partitioned state.", e); -// } -// } -// -// for (int i = 0; i < nonPartitionedStates.length; i++) { -// StreamOperatorNonPartitionedState nonPartitionedState = nonPartitionedStates[i]; -// StreamOperator<?> operator = allOperators[i]; -// KeyGroupsStateHandle partitionedState = new KeyGroupsStateHandle(keyGroupStates.get(i)); -// StreamOperatorState operatorState = new StreamOperatorState(partitionedState, nonPartitionedState); -// -// if (operator != null) { -// LOG.debug("Restore state of task {} in chain ({}).", i, getName()); -// operator.restoreState(operatorState, recoveryTimestamp); -// } -// } -// } } @Override @@ -658,8 +624,15 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> StreamOperator<?> operator = allOperators[i]; if (operator != null) { - AbstractStateBackend.CheckpointStateOutputStream outStream = - ((AbstractStreamOperator) operator).getStateBackend().createCheckpointStateOutputStream(checkpointId, timestamp); + CheckpointStreamFactory streamFactory = + stateBackend.createStreamFactory( + getEnvironment().getJobID(), + createOperatorIdentifier( + operator, + configuration.getVertexID())); + + CheckpointStreamFactory.CheckpointStateOutputStream outStream = + streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp); operator.snapshotState(outStream, checkpointId, timestamp); @@ -667,22 +640,37 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> } } - if (!isRunning) { - // Rethrow the cancel exception because some state backends could swallow - // exceptions and seem to exit cleanly. - throw new CancelTaskException(); + RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture = null; + + if (keyedStateBackend != null) { + CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory( + getEnvironment().getJobID(), + createOperatorIdentifier( + headOperator, + configuration.getVertexID())); + keyGroupsStateHandleFuture = keyedStateBackend.snapshot( + checkpointId, + timestamp, + streamFactory); } - ChainedStateHandle<StreamStateHandle> states = new ChainedStateHandle<>(nonPartitionedStates); - List<KeyGroupsStateHandle> keyedStates = Collections.<KeyGroupsStateHandle>emptyList(); + ChainedStateHandle<StreamStateHandle> chainedStateHandles = new ChainedStateHandle<>(nonPartitionedStates); + LOG.debug("Finished synchronous checkpoints for checkpoint {} on task {}", checkpointId, getName()); - if (states.isEmpty() && keyedStates.isEmpty()) { - getEnvironment().acknowledgeCheckpoint(checkpointId); - } else { - this.lastCheckpointSize = states.getStateSize(); - getEnvironment().acknowledgeCheckpoint(checkpointId, states, keyedStates); + AsyncCheckpointRunnable asyncCheckpointRunnable = new AsyncCheckpointRunnable( + "checkpoint-" + checkpointId + "-" + timestamp, + this, + cancelables, + chainedStateHandles, + keyGroupsStateHandleFuture, + checkpointId); + + synchronized (cancelables) { + cancelables.add(asyncCheckpointRunnable); } + + asyncOperationsThreadPool.submit(asyncCheckpointRunnable); return true; } else { return false; @@ -712,7 +700,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // State backend // ------------------------------------------------------------------------ - public AbstractStateBackend createStateBackend(String operatorIdentifier, TypeSerializer<?> keySerializer) throws Exception { + private AbstractStateBackend createStateBackend() throws Exception { AbstractStateBackend stateBackend = configuration.getStateBackend(getUserCodeClassLoader()); if (stateBackend != null) { @@ -732,7 +720,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> switch (backendName) { case "jobmanager": LOG.info("State backend is set to heap memory (checkpoint to jobmanager)"); - stateBackend = MemoryStateBackend.create(); + stateBackend = new MemoryStateBackend(); break; case "filesystem": @@ -760,10 +748,69 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> } } } - stateBackend.initializeForJob(getEnvironment(), operatorIdentifier, keySerializer); return stateBackend; } + public <K> KeyedStateBackend<K> createKeyedStateBackend( + TypeSerializer<K> keySerializer, + KeyGroupAssigner<K> keyGroupAssigner, + KeyGroupRange keyGroupRange) throws Exception { + + if (keyedStateBackend != null) { + throw new RuntimeException("The keyed state backend can only be created once."); + } + + String operatorIdentifier = createOperatorIdentifier( + headOperator, + configuration.getVertexID()); + + if (lazyRestoreKeyGroupStates != null) { + keyedStateBackend = stateBackend.restoreKeyedStateBackend( + getEnvironment(), + getEnvironment().getJobID(), + operatorIdentifier, + keySerializer, + keyGroupAssigner, + keyGroupRange, + lazyRestoreKeyGroupStates, + getEnvironment().getTaskKvStateRegistry()); + + lazyRestoreKeyGroupStates = null; // GC friendliness + } else { + keyedStateBackend = stateBackend.createKeyedStateBackend( + getEnvironment(), + getEnvironment().getJobID(), + operatorIdentifier, + keySerializer, + keyGroupAssigner, + keyGroupRange, + getEnvironment().getTaskKvStateRegistry()); + } + + return (KeyedStateBackend<K>) keyedStateBackend; + } + + /** + * This is only visible because + * {@link org.apache.flink.streaming.runtime.operators.GenericWriteAheadSink} uses the + * checkpoint stream factory to write write-ahead logs. <b>This should not be used for + * anything else.</b> + */ + public CheckpointStreamFactory createCheckpointStreamFactory(StreamOperator operator) throws IOException { + return stateBackend.createStreamFactory( + getEnvironment().getJobID(), + createOperatorIdentifier( + operator, + configuration.getVertexID())); + + } + + private String createOperatorIdentifier(StreamOperator operator, int vertexId) { + return operator.getClass().getSimpleName() + + "_" + vertexId + + "_" + getEnvironment().getTaskInfo().getIndexOfThisSubtask(); + } + /** * Registers a timer. */ @@ -852,77 +899,83 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> // ------------------------------------------------------------------------ -// private static class AsyncCheckpointThread extends Thread implements Closeable { -// -// private final StreamTask<?, ?> owner; -// -// private final Set<Closeable> cancelables; -// -// private final StreamTaskState[] states; -// -// private final long checkpointId; -// -// AsyncCheckpointThread(String name, StreamTask<?, ?> owner, Set<Closeable> cancelables, -// StreamTaskState[] states, long checkpointId) { -// super(name); -// setDaemon(true); -// -// this.owner = owner; -// this.cancelables = cancelables; -// this.states = states; -// this.checkpointId = checkpointId; -// } -// -// @Override -// public void run() { -// try { -// for (StreamTaskState state : states) { -// if (state != null) { -// if (state.getFunctionState() instanceof AsynchronousStateHandle) { -// AsynchronousStateHandle<Serializable> asyncState = (AsynchronousStateHandle<Serializable>) state.getFunctionState(); -// state.setFunctionState(asyncState.materialize()); -// } -// if (state.getOperatorState() instanceof AsynchronousStateHandle) { -// AsynchronousStateHandle<?> asyncState = (AsynchronousStateHandle<?>) state.getOperatorState(); -// state.setOperatorState(asyncState.materialize()); -// } -// if (state.getKvStates() != null) { -// Set<String> keys = state.getKvStates().keySet(); -// HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> kvStates = state.getKvStates(); -// for (String key: keys) { -// if (kvStates.get(key) instanceof AsynchronousKvStateSnapshot) { -// AsynchronousKvStateSnapshot<?, ?, ?, ?, ?> asyncHandle = (AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) kvStates.get(key); -// kvStates.put(key, asyncHandle.materialize()); -// } -// } -// } -// -// } -// } -// StreamTaskStateList allStates = new StreamTaskStateList(states); -// owner.lastCheckpointSize = allStates.getStateSize(); -// owner.getEnvironment().acknowledgeCheckpoint(checkpointId, allStates); -// -// LOG.debug("Finished asynchronous checkpoints for checkpoint {} on task {}", checkpointId, getName()); -// } -// catch (Exception e) { -// if (owner.isRunning()) { -// LOG.error("Caught exception while materializing asynchronous checkpoints.", e); -// } -// if (owner.asyncException == null) { -// owner.asyncException = new AsynchronousException(e); -// } -// } -// finally { -// synchronized (cancelables) { -// cancelables.remove(this); -// } -// } -// } -// -// @Override -// public void close() { -// interrupt(); -// } -// } + private static class AsyncCheckpointRunnable implements Runnable, Closeable { + + private final StreamTask<?, ?> owner; + + private final Set<Closeable> cancelables; + + private final ChainedStateHandle<StreamStateHandle> chainedStateHandles; + + private final RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture; + + private final long checkpointId; + + private final String name; + + AsyncCheckpointRunnable( + String name, + StreamTask<?, ?> owner, + Set<Closeable> cancelables, + ChainedStateHandle<StreamStateHandle> chainedStateHandles, + RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture, + long checkpointId) { + + this.name = name; + this.owner = owner; + this.cancelables = cancelables; + this.chainedStateHandles = chainedStateHandles; + this.keyGroupsStateHandleFuture = keyGroupsStateHandleFuture; + this.checkpointId = checkpointId; + } + + @Override + public void run() { + try { + + List<KeyGroupsStateHandle> keyedStates = Collections.emptyList(); + + if (keyGroupsStateHandleFuture != null) { + + if (!keyGroupsStateHandleFuture.isDone()) { + //TODO this currently works because we only have one RunnableFuture + keyGroupsStateHandleFuture.run(); + } + + KeyGroupsStateHandle keyGroupsStateHandle = this.keyGroupsStateHandleFuture.get(); + if (keyGroupsStateHandle != null) { + keyedStates = Arrays.asList(keyGroupsStateHandle); + } + } + + if (chainedStateHandles.isEmpty() && keyedStates.isEmpty()) { + owner.getEnvironment().acknowledgeCheckpoint(checkpointId); + } else { + owner. getEnvironment().acknowledgeCheckpoint(checkpointId, chainedStateHandles, keyedStates); + } + + LOG.debug("Finished asynchronous checkpoints for checkpoint {} on task {}", checkpointId, name); + } + catch (Exception e) { + if (owner.isRunning()) { + LOG.error("Caught exception while materializing asynchronous checkpoints.", e); + } + if (owner.asyncException == null) { + owner.asyncException = new AsynchronousException(e); + } + } + finally { + synchronized (cancelables) { + cancelables.remove(this); + } + } + } + + @Override + public void close() { + if (keyGroupsStateHandleFuture != null) { + keyGroupsStateHandleFuture.cancel(true); + } + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamGroupedFoldTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamGroupedFoldTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamGroupedFoldTest.java index f6e7e6b..68a2bb2 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamGroupedFoldTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamGroupedFoldTest.java @@ -27,6 +27,7 @@ import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.TestHarnessUtil; @@ -69,8 +70,8 @@ public class StreamGroupedFoldTest { StreamGroupedFold<Integer, String, String> operator = new StreamGroupedFold<>(new MyFolder(), "100"); operator.setOutputType(BasicTypeInfo.STRING_TYPE_INFO, new ExecutionConfig()); - OneInputStreamOperatorTestHarness<Integer, String> testHarness = new OneInputStreamOperatorTestHarness<>(operator); - testHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO); + OneInputStreamOperatorTestHarness<Integer, String> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>(operator, keySelector, BasicTypeInfo.STRING_TYPE_INFO); long initialTime = 0L; ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>(); @@ -107,10 +108,9 @@ public class StreamGroupedFoldTest { new TestOpenCloseFoldFunction(), "init"); operator.setOutputType(BasicTypeInfo.STRING_TYPE_INFO, new ExecutionConfig()); - OneInputStreamOperatorTestHarness<Integer, String> testHarness = new OneInputStreamOperatorTestHarness<>(operator); - testHarness.configureForKeyedStream(keySelector, BasicTypeInfo.INT_TYPE_INFO); - - + OneInputStreamOperatorTestHarness<Integer, String> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>(operator, keySelector, BasicTypeInfo.INT_TYPE_INFO); + long initialTime = 0L; testHarness.open(); http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamGroupedReduceTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamGroupedReduceTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamGroupedReduceTest.java index 6cb46c9..0f304a0 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamGroupedReduceTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamGroupedReduceTest.java @@ -28,6 +28,7 @@ import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.TestHarnessUtil; import org.junit.Assert; @@ -52,8 +53,8 @@ public class StreamGroupedReduceTest { StreamGroupedReduce<Integer> operator = new StreamGroupedReduce<>(new MyReducer(), IntSerializer.INSTANCE); - OneInputStreamOperatorTestHarness<Integer, Integer> testHarness = new OneInputStreamOperatorTestHarness<>(operator); - testHarness.configureForKeyedStream(keySelector, BasicTypeInfo.INT_TYPE_INFO); + OneInputStreamOperatorTestHarness<Integer, Integer> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>(operator, keySelector, BasicTypeInfo.INT_TYPE_INFO); long initialTime = 0L; ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>(); @@ -84,8 +85,8 @@ public class StreamGroupedReduceTest { StreamGroupedReduce<Integer> operator = new StreamGroupedReduce<>(new TestOpenCloseReduceFunction(), IntSerializer.INSTANCE); - OneInputStreamOperatorTestHarness<Integer, Integer> testHarness = new OneInputStreamOperatorTestHarness<>(operator); - testHarness.configureForKeyedStream(keySelector, BasicTypeInfo.INT_TYPE_INFO); + OneInputStreamOperatorTestHarness<Integer, Integer> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>(operator, keySelector, BasicTypeInfo.INT_TYPE_INFO); long initialTime = 0L; http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java index 3a88d94..d3b7ff9 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.api.operators; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.TaskInfo; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.api.common.functions.ReduceFunction; @@ -28,13 +29,21 @@ import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.core.fs.Path; import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.state.HashKeyGroupAssigner; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; -import org.apache.flink.runtime.state.memory.MemListState; +import org.apache.flink.runtime.state.heap.HeapListState; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -181,11 +190,18 @@ public class StreamingRuntimeContextTest { @Override public ListState<String> answer(InvocationOnMock invocationOnMock) throws Throwable { ListStateDescriptor<String> descr = - (ListStateDescriptor<String>) invocationOnMock.getArguments()[0]; - MemListState<String, VoidNamespace, String> listState = new MemListState<>( - StringSerializer.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr); - listState.setCurrentNamespace(VoidNamespace.INSTANCE); - return listState; + (ListStateDescriptor<String>) invocationOnMock.getArguments()[0]; + KeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend( + new DummyEnvironment("test_task", 1, 0), + new JobID(), + "test_op", + IntSerializer.INSTANCE, + new HashKeyGroupAssigner<Integer>(1), + new KeyGroupRange(0, 0), + new KvStateRegistry().createTaskRegistry(new JobID(), + new JobVertexID())); + backend.setCurrentKey(0); + return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr); } }); @@ -196,7 +212,7 @@ public class StreamingRuntimeContextTest { Environment env = mock(Environment.class); when(env.getUserClassLoader()).thenReturn(StreamingRuntimeContextTest.class.getClassLoader()); when(env.getDistributedCacheEntries()).thenReturn(Collections.<String, Future<Path>>emptyMap()); - when(env.getTaskInfo()).thenReturn(new TaskInfo("test task", 0, 1, 1)); + when(env.getTaskInfo()).thenReturn(new TaskInfo("test task", 1, 0, 1, 1)); return env; } } http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java index 3b201dc..f4ac5b2 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java @@ -333,21 +333,6 @@ public class StreamOperatorChainingTest { when(mockTask.getEnvironment()).thenReturn(env); when(mockTask.getExecutionConfig()).thenReturn(new ExecutionConfig().enableObjectReuse()); - try { - doAnswer(new Answer<AbstractStateBackend>() { - @Override - public AbstractStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable { - final String operatorIdentifier = (String) invocationOnMock.getArguments()[0]; - final TypeSerializer<?> keySerializer = (TypeSerializer<?>) invocationOnMock.getArguments()[1]; - MemoryStateBackend backend = MemoryStateBackend.create(); - backend.initializeForJob(env, operatorIdentifier, keySerializer); - return backend; - } - }).when(mockTask).createStateBackend(any(String.class), any(TypeSerializer.class)); - } catch (Exception e) { - throw new RuntimeException(e.getMessage(), e); - } - return mockTask; } http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java index 2cb1809..40a6c79 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java @@ -21,10 +21,9 @@ package org.apache.flink.streaming.runtime.operators.windowing; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.TaskInfo; import org.apache.flink.api.common.accumulators.Accumulator; -import org.apache.flink.api.common.state.KeyGroupAssigner; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.api.java.ClosureCleaner; @@ -33,9 +32,6 @@ import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; -import org.apache.flink.runtime.state.AbstractStateBackend; -import org.apache.flink.runtime.state.memory.MemoryStateBackend; -import org.apache.flink.runtime.state.HashKeyGroupAssigner; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; import org.apache.flink.streaming.api.functions.windowing.RichWindowFunction; @@ -46,11 +42,12 @@ import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.streaming.runtime.operators.Triggerable; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.streaming.runtime.tasks.TestTimeServiceProvider; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; import org.apache.flink.util.Collector; import org.junit.After; -import org.junit.Ignore; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; @@ -79,7 +76,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @SuppressWarnings({"serial", "SynchronizationOnLocalVariableOrMethodParameter"}) -@Ignore public class AccumulatingAlignedProcessingTimeWindowOperatorTest { @SuppressWarnings("unchecked") @@ -203,7 +199,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, StringSerializer.INSTANCE, StringSerializer.INSTANCE, 5000, 1000); - op.setup(mockTask, createTaskConfig(identitySelector, IntSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), mockOut); + op.setup(mockTask, new StreamConfig(new Configuration()), mockOut); op.open(); assertTrue(op.getNextSlideTime() % 1000 == 0); assertTrue(op.getNextEvaluationTime() % 1000 == 0); @@ -255,7 +251,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { IntSerializer.INSTANCE, IntSerializer.INSTANCE, windowSize, windowSize); - op.setup(mockTask, createTaskConfig(identitySelector, IntSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), out); + op.setup(mockTask, new StreamConfig(new Configuration()), out); op.open(); final int numElements = 1000; @@ -306,7 +302,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { validatingIdentityFunction, identitySelector, IntSerializer.INSTANCE, IntSerializer.INSTANCE, 150, 50); - op.setup(mockTask, createTaskConfig(identitySelector, IntSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), out); + op.setup(mockTask, new StreamConfig(new Configuration()), out); op.open(); final int numElements = 1000; @@ -367,7 +363,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { validatingIdentityFunction, identitySelector, IntSerializer.INSTANCE, IntSerializer.INSTANCE, 50, 50); - op.setup(mockTask, createTaskConfig(identitySelector, IntSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), out); + op.setup(mockTask, new StreamConfig(new Configuration()), out); op.open(); synchronized (lock) { @@ -423,7 +419,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { validatingIdentityFunction, identitySelector, IntSerializer.INSTANCE, IntSerializer.INSTANCE, 150, 50); - op.setup(mockTask, createTaskConfig(identitySelector, IntSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), out); + op.setup(mockTask, new StreamConfig(new Configuration()), out); op.open(); synchronized (lock) { @@ -457,67 +453,58 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { @Test public void checkpointRestoreWithPendingWindowTumbling() { - final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); try { final int windowSize = 200; - final CollectingOutput<Integer> out = new CollectingOutput<>(windowSize); - final Object lock = new Object(); - final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock); - // tumbling window that triggers every 50 milliseconds + // tumbling window that triggers every 200 milliseconds AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> op = new AccumulatingProcessingTimeWindowOperator<>( validatingIdentityFunction, identitySelector, IntSerializer.INSTANCE, IntSerializer.INSTANCE, windowSize, windowSize); - OneInputStreamOperatorTestHarness<Integer, Integer> testHarness = - new OneInputStreamOperatorTestHarness<>(op); + TestTimeServiceProvider timerService = new TestTimeServiceProvider(); + + OneInputStreamOperatorTestHarness<Integer, Integer> testHarness = + new OneInputStreamOperatorTestHarness<>(op, new ExecutionConfig(), timerService); testHarness.setup(); testHarness.open(); + timerService.setCurrentTime(0); + // inject some elements final int numElementsFirst = 700; final int numElements = 1000; for (int i = 0; i < numElementsFirst; i++) { - synchronized (lock) { - op.processElement(new StreamRecord<Integer>(i)); - } - Thread.sleep(1); + testHarness.processElement(new StreamRecord<>(i)); } // draw a snapshot and dispose the window - StreamStateHandle state; - List<Integer> resultAtSnapshot; - synchronized (lock) { - int beforeSnapShot = out.getElements().size(); - state = testHarness.snapshot(1L, System.currentTimeMillis()); - resultAtSnapshot = new ArrayList<>(out.getElements()); - int afterSnapShot = out.getElements().size(); - assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot); - assertTrue(afterSnapShot <= numElementsFirst); - } + System.out.println("GOT: " + testHarness.getOutput()); + int beforeSnapShot = testHarness.getOutput().size(); + StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis()); + List<Integer> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput()); + int afterSnapShot = testHarness.getOutput().size(); + assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot); + assertTrue(afterSnapShot <= numElementsFirst); // inject some random elements, which should not show up in the state for (int i = 0; i < 300; i++) { - synchronized (lock) { - op.processElement(new StreamRecord<Integer>(i + numElementsFirst)); - } - Thread.sleep(1); + testHarness.processElement(new StreamRecord<>(i + numElementsFirst)); } op.dispose(); // re-create the operator and restore the state - final CollectingOutput<Integer> out2 = new CollectingOutput<>(windowSize); op = new AccumulatingProcessingTimeWindowOperator<>( validatingIdentityFunction, identitySelector, IntSerializer.INSTANCE, IntSerializer.INSTANCE, windowSize, windowSize); testHarness = - new OneInputStreamOperatorTestHarness<>(op); + new OneInputStreamOperatorTestHarness<>(op, new ExecutionConfig(), timerService); + testHarness.setup(); testHarness.restore(state); @@ -525,18 +512,16 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { // inject some more elements for (int i = numElementsFirst; i < numElements; i++) { - synchronized (lock) { - op.processElement(new StreamRecord<Integer>(i)); - } - Thread.sleep(1); + testHarness.processElement(new StreamRecord<>(i)); } - - out2.waitForNElements(numElements - resultAtSnapshot.size(), 60_000); + timerService.setCurrentTime(400); // get and verify the result - List<Integer> finalResult = new ArrayList<>(resultAtSnapshot); - finalResult.addAll(out2.getElements()); + List<Integer> finalResult = new ArrayList<>(); + finalResult.addAll(resultAtSnapshot); + List<Integer> finalPartialResult = extractFromStreamRecords(testHarness.getOutput()); + finalResult.addAll(finalPartialResult); assertEquals(numElements, finalResult.size()); Collections.sort(finalResult); @@ -548,22 +533,16 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { e.printStackTrace(); fail(e.getMessage()); } - finally { - timerService.shutdown(); - } } @Test public void checkpointRestoreWithPendingWindowSliding() { - final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); try { final int factor = 4; final int windowSlide = 50; final int windowSize = factor * windowSlide; - - final CollectingOutput<Integer> out = new CollectingOutput<>(windowSlide); - final Object lock = new Object(); - final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock); + + TestTimeServiceProvider timerService = new TestTimeServiceProvider(); // sliding window (200 msecs) every 50 msecs AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> op = @@ -573,7 +552,9 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { windowSize, windowSlide); OneInputStreamOperatorTestHarness<Integer, Integer> testHarness = - new OneInputStreamOperatorTestHarness<>(op); + new OneInputStreamOperatorTestHarness<>(op, new ExecutionConfig(), timerService); + + timerService.setCurrentTime(0); testHarness.setup(); testHarness.open(); @@ -583,44 +564,32 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { final int numElementsFirst = 700; for (int i = 0; i < numElementsFirst; i++) { - synchronized (lock) { - op.processElement(new StreamRecord<Integer>(i)); - } - Thread.sleep(1); + testHarness.processElement(new StreamRecord<>(i)); } // draw a snapshot - StreamStateHandle state; - List<Integer> resultAtSnapshot; - synchronized (lock) { - int beforeSnapShot = out.getElements().size(); - state = testHarness.snapshot(1L, System.currentTimeMillis()); - resultAtSnapshot = new ArrayList<>(out.getElements()); - int afterSnapShot = out.getElements().size(); - assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot); - } - + List<Integer> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput()); + int beforeSnapShot = testHarness.getOutput().size(); + StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis()); + int afterSnapShot = testHarness.getOutput().size(); + assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot); + assertTrue(resultAtSnapshot.size() <= factor * numElementsFirst); // inject the remaining elements - these should not influence the snapshot for (int i = numElementsFirst; i < numElements; i++) { - synchronized (lock) { - op.processElement(new StreamRecord<Integer>(i)); - } - Thread.sleep(1); + testHarness.processElement(new StreamRecord<>(i)); } op.dispose(); // re-create the operator and restore the state - final CollectingOutput<Integer> out2 = new CollectingOutput<>(windowSlide); op = new AccumulatingProcessingTimeWindowOperator<>( validatingIdentityFunction, identitySelector, IntSerializer.INSTANCE, IntSerializer.INSTANCE, windowSize, windowSlide); - testHarness = - new OneInputStreamOperatorTestHarness<>(op); + testHarness = new OneInputStreamOperatorTestHarness<>(op, new ExecutionConfig(), timerService); testHarness.setup(); testHarness.restore(state); @@ -629,29 +598,24 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { // inject again the remaining elements for (int i = numElementsFirst; i < numElements; i++) { - synchronized (lock) { - op.processElement(new StreamRecord<Integer>(i)); - } - Thread.sleep(1); + testHarness.processElement(new StreamRecord<>(i)); } - // for a deterministic result, we need to wait until all pending triggers - // have fired and emitted their results - long deadline = System.currentTimeMillis() + 120000; - do { - Thread.sleep(20); - } - while (resultAtSnapshot.size() + out2.getElements().size() < factor * numElements - && System.currentTimeMillis() < deadline); + timerService.setCurrentTime(50); + timerService.setCurrentTime(100); + timerService.setCurrentTime(150); + timerService.setCurrentTime(200); + timerService.setCurrentTime(250); + timerService.setCurrentTime(300); + timerService.setCurrentTime(350); - synchronized (lock) { - op.close(); - } + testHarness.close(); op.dispose(); // get and verify the result List<Integer> finalResult = new ArrayList<>(resultAtSnapshot); - finalResult.addAll(out2.getElements()); + List<Integer> finalPartialResult = extractFromStreamRecords(testHarness.getOutput()); + finalResult.addAll(finalPartialResult); assertEquals(factor * numElements, finalResult.size()); Collections.sort(finalResult); @@ -663,19 +627,12 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { e.printStackTrace(); fail(e.getMessage()); } - finally { - timerService.shutdown(); - } } @Test public void testKeyValueStateInWindowFunction() { - final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); try { - final CollectingOutput<Integer> out = new CollectingOutput<>(50); - final Object lock = new Object(); - final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock); - + StatefulFunction.globalCounts.clear(); // tumbling window that triggers every 20 milliseconds @@ -684,26 +641,28 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { new StatefulFunction(), identitySelector, IntSerializer.INSTANCE, IntSerializer.INSTANCE, 50, 50); - op.setup(mockTask, createTaskConfig(identitySelector, IntSerializer.INSTANCE, new HashKeyGroupAssigner<Object>(10)), out); - op.open(); + TestTimeServiceProvider timerService = new TestTimeServiceProvider(); - synchronized (lock) { - op.processElement(new StreamRecord<Integer>(1)); - op.processElement(new StreamRecord<Integer>(2)); - } - out.waitForNElements(2, 60000); + OneInputStreamOperatorTestHarness<Integer, Integer> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>(op, new ExecutionConfig(), timerService, identitySelector, BasicTypeInfo.INT_TYPE_INFO); - synchronized (lock) { - op.processElement(new StreamRecord<Integer>(1)); - op.processElement(new StreamRecord<Integer>(2)); - op.processElement(new StreamRecord<Integer>(1)); - op.processElement(new StreamRecord<Integer>(1)); - op.processElement(new StreamRecord<Integer>(2)); - op.processElement(new StreamRecord<Integer>(2)); - } - out.waitForNElements(8, 60000); + testHarness.open(); - List<Integer> result = out.getElements(); + timerService.setCurrentTime(0); + + testHarness.processElement(new StreamRecord<>(1)); + testHarness.processElement(new StreamRecord<>(2)); + + op.processElement(new StreamRecord<>(1)); + op.processElement(new StreamRecord<>(2)); + op.processElement(new StreamRecord<>(1)); + op.processElement(new StreamRecord<>(1)); + op.processElement(new StreamRecord<>(2)); + op.processElement(new StreamRecord<>(2)); + + timerService.setCurrentTime(1000); + + List<Integer> result = extractFromStreamRecords(testHarness.getOutput()); assertEquals(8, result.size()); Collections.sort(result); @@ -712,18 +671,13 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { assertEquals(4, StatefulFunction.globalCounts.get(1).intValue()); assertEquals(4, StatefulFunction.globalCounts.get(2).intValue()); - synchronized (lock) { - op.close(); - } + testHarness.close(); op.dispose(); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } - finally { - timerService.shutdown(); - } } // ------------------------------------------------------------------------ @@ -793,27 +747,11 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { when(mockTaskManagerRuntimeInfo.getConfiguration()).thenReturn(configuration); final Environment env = mock(Environment.class); - when(env.getTaskInfo()).thenReturn(new TaskInfo("Test task name", 0, 1, 0)); + when(env.getTaskInfo()).thenReturn(new TaskInfo("Test task name", 1, 0, 1, 0)); when(env.getUserClassLoader()).thenReturn(AggregatingAlignedProcessingTimeWindowOperatorTest.class.getClassLoader()); when(env.getMetricGroup()).thenReturn(new UnregisteredTaskMetricsGroup()); when(task.getEnvironment()).thenReturn(env); - - try { - doAnswer(new Answer<AbstractStateBackend>() { - @Override - public AbstractStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable { - final String operatorIdentifier = (String) invocationOnMock.getArguments()[0]; - final TypeSerializer<?> keySerializer = (TypeSerializer<?>) invocationOnMock.getArguments()[1]; - MemoryStateBackend backend = MemoryStateBackend.create(); - backend.initializeForJob(env, operatorIdentifier, keySerializer); - return backend; - } - }).when(task).createStateBackend(any(String.class), any(TypeSerializer.class)); - } catch (Exception e) { - e.printStackTrace(); - } - return task; } @@ -846,11 +784,14 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest { return mockTask; } - private static StreamConfig createTaskConfig(KeySelector<?, ?> partitioner, TypeSerializer<?> keySerializer, KeyGroupAssigner<?> keyGroupAssigner) { - StreamConfig cfg = new StreamConfig(new Configuration()); - cfg.setStatePartitioner(0, partitioner); - cfg.setStateKeySerializer(keySerializer); - cfg.setKeyGroupAssigner(keyGroupAssigner); - return cfg; + @SuppressWarnings({"unchecked", "rawtypes"}) + private <T> List<T> extractFromStreamRecords(Iterable<Object> input) { + List<T> result = new ArrayList<>(); + for (Object in : input) { + if (in instanceof StreamRecord) { + result.add((T) ((StreamRecord) in).getValue()); + } + } + return result; } }