http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java new file mode 100644 index 0000000..3f94109 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java @@ -0,0 +1,1499 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.operators; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.api.common.distributions.DataDistribution; +import org.apache.flink.api.common.functions.GroupCombineFunction; +import org.apache.flink.api.common.functions.Function; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeComparatorFactory; +import org.apache.flink.api.common.typeutils.TypeSerializerFactory; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.io.IOReadableWritable; +import org.apache.flink.runtime.accumulators.AccumulatorRegistry; +import org.apache.flink.runtime.broadcast.BroadcastVariableMaterialization; +import org.apache.flink.runtime.execution.CancelTaskException; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.network.api.reader.MutableReader; +import org.apache.flink.runtime.io.network.api.reader.MutableRecordReader; +import org.apache.flink.runtime.io.network.api.writer.ChannelSelector; +import org.apache.flink.runtime.io.network.api.writer.RecordWriter; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate; +import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; +import org.apache.flink.runtime.memory.MemoryManager; +import org.apache.flink.runtime.operators.chaining.ChainedDriver; +import org.apache.flink.runtime.operators.chaining.ExceptionInChainedStubException; +import org.apache.flink.runtime.operators.resettable.SpillingResettableMutableObjectIterator; +import org.apache.flink.runtime.operators.shipping.OutputCollector; +import org.apache.flink.runtime.operators.shipping.OutputEmitter; +import org.apache.flink.runtime.operators.shipping.RecordOutputCollector; +import org.apache.flink.runtime.operators.shipping.RecordOutputEmitter; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.apache.flink.runtime.operators.sort.CombiningUnilateralSortMerger; +import org.apache.flink.runtime.operators.sort.UnilateralSortMerger; +import org.apache.flink.runtime.operators.util.CloseableInputProvider; +import org.apache.flink.runtime.operators.util.DistributedRuntimeUDFContext; +import org.apache.flink.runtime.operators.util.LocalStrategy; +import org.apache.flink.runtime.operators.util.ReaderIterator; +import org.apache.flink.runtime.operators.util.TaskConfig; +import org.apache.flink.runtime.plugable.DeserializationDelegate; +import org.apache.flink.runtime.plugable.SerializationDelegate; +import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo; +import org.apache.flink.types.Record; +import org.apache.flink.util.Collector; +import org.apache.flink.util.InstantiationUtil; +import org.apache.flink.util.MutableObjectIterator; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * The base class for all batch tasks. Encapsulated common behavior and implements the main life-cycle + * of the user code. + */ +public class BatchTask<S extends Function, OT> extends AbstractInvokable implements TaskContext<S, OT> { + + protected static final Logger LOG = LoggerFactory.getLogger(BatchTask.class); + + // -------------------------------------------------------------------------------------------- + + /** + * The driver that invokes the user code (the stub implementation). The central driver in this task + * (further drivers may be chained behind this driver). + */ + protected volatile Driver<S, OT> driver; + + /** + * The instantiated user code of this task's main operator (driver). May be null if the operator has no udf. + */ + protected S stub; + + /** + * The udf's runtime context. + */ + protected DistributedRuntimeUDFContext runtimeUdfContext; + + /** + * The collector that forwards the user code's results. May forward to a channel or to chained drivers within + * this task. + */ + protected Collector<OT> output; + + /** + * The output writers for the data that this task forwards to the next task. The latest driver (the central, if no chained + * drivers exist, otherwise the last chained driver) produces its output to these writers. + */ + protected List<RecordWriter<?>> eventualOutputs; + + /** + * The input readers of this task. + */ + protected MutableReader<?>[] inputReaders; + + /** + * The input readers for the configured broadcast variables for this task. + */ + protected MutableReader<?>[] broadcastInputReaders; + + /** + * The inputs reader, wrapped in an iterator. Prior to the local strategies, etc... + */ + protected MutableObjectIterator<?>[] inputIterators; + + /** + * The indices of the iterative inputs. Empty, if the task is not iterative. + */ + protected int[] iterativeInputs; + + /** + * The indices of the iterative broadcast inputs. Empty, if non of the inputs is iteratve. + */ + protected int[] iterativeBroadcastInputs; + + /** + * The local strategies that are applied on the inputs. + */ + protected volatile CloseableInputProvider<?>[] localStrategies; + + /** + * The optional temp barriers on the inputs for dead-lock breaking. Are + * optionally resettable. + */ + protected volatile TempBarrier<?>[] tempBarriers; + + /** + * The resettable inputs in the case where no temp barrier is needed. + */ + protected volatile SpillingResettableMutableObjectIterator<?>[] resettableInputs; + + /** + * The inputs to the operator. Return the readers' data after the application of the local strategy + * and the temp-table barrier. + */ + protected MutableObjectIterator<?>[] inputs; + + /** + * The serializers for the input data type. + */ + protected TypeSerializerFactory<?>[] inputSerializers; + + /** + * The serializers for the broadcast input data types. + */ + protected TypeSerializerFactory<?>[] broadcastInputSerializers; + + /** + * The comparators for the central driver. + */ + protected TypeComparator<?>[] inputComparators; + + /** + * The task configuration with the setup parameters. + */ + protected TaskConfig config; + + /** + * A list of chained drivers, if there are any. + */ + protected ArrayList<ChainedDriver<?, ?>> chainedTasks; + + /** + * Certain inputs may be excluded from resetting. For example, the initial partial solution + * in an iteration head must not be reseted (it is read through the back channel), when all + * others are reseted. + */ + private boolean[] excludeFromReset; + + /** + * Flag indicating for each input whether it is cached and can be reseted. + */ + private boolean[] inputIsCached; + + /** + * flag indicating for each input whether it must be asynchronously materialized. + */ + private boolean[] inputIsAsyncMaterialized; + + /** + * The amount of memory per input that is dedicated to the materialization. + */ + private int[] materializationMemory; + + /** + * The flag that tags the task as still running. Checked periodically to abort processing. + */ + protected volatile boolean running = true; + + /** + * The accumulator map used in the RuntimeContext. + */ + protected Map<String, Accumulator<?,?>> accumulatorMap; + + // -------------------------------------------------------------------------------------------- + // Task Interface + // -------------------------------------------------------------------------------------------- + + + /** + * Initialization method. Runs in the execution graph setup phase in the JobManager + * and as a setup method on the TaskManager. + */ + @Override + public void registerInputOutput() throws Exception { + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Start registering input and output.")); + } + + // obtain task configuration (including stub parameters) + Configuration taskConf = getTaskConfiguration(); + this.config = new TaskConfig(taskConf); + + // now get the operator class which drives the operation + final Class<? extends Driver<S, OT>> driverClass = this.config.getDriver(); + this.driver = InstantiationUtil.instantiate(driverClass, Driver.class); + + // initialize the readers. + // this does not yet trigger any stream consuming or processing. + initInputReaders(); + initBroadcastInputReaders(); + + // initialize the writers. + initOutputs(); + + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Finished registering input and output.")); + } + } + + + /** + * The main work method. + */ + @Override + public void invoke() throws Exception { + + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Start task code.")); + } + + Environment env = getEnvironment(); + + this.runtimeUdfContext = createRuntimeContext(env.getTaskName()); + + // whatever happens in this scope, make sure that the local strategies are cleaned up! + // note that the initialization of the local strategies is in the try-finally block as well, + // so that the thread that creates them catches its own errors that may happen in that process. + // this is especially important, since there may be asynchronous closes (such as through canceling). + try { + // initialize the remaining data structures on the input and trigger the local processing + // the local processing includes building the dams / caches + try { + int numInputs = driver.getNumberOfInputs(); + int numComparators = driver.getNumberOfDriverComparators(); + int numBroadcastInputs = this.config.getNumBroadcastInputs(); + + initInputsSerializersAndComparators(numInputs, numComparators); + initBroadcastInputsSerializers(numBroadcastInputs); + + // set the iterative status for inputs and broadcast inputs + { + List<Integer> iterativeInputs = new ArrayList<Integer>(); + + for (int i = 0; i < numInputs; i++) { + final int numberOfEventsUntilInterrupt = getTaskConfig().getNumberOfEventsUntilInterruptInIterativeGate(i); + + if (numberOfEventsUntilInterrupt < 0) { + throw new IllegalArgumentException(); + } + else if (numberOfEventsUntilInterrupt > 0) { + this.inputReaders[i].setIterativeReader(); + iterativeInputs.add(i); + + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Input [" + i + "] reads in supersteps with [" + + + numberOfEventsUntilInterrupt + "] event(s) till next superstep.")); + } + } + } + this.iterativeInputs = asArray(iterativeInputs); + } + + { + List<Integer> iterativeBcInputs = new ArrayList<Integer>(); + + for (int i = 0; i < numBroadcastInputs; i++) { + final int numberOfEventsUntilInterrupt = getTaskConfig().getNumberOfEventsUntilInterruptInIterativeBroadcastGate(i); + + if (numberOfEventsUntilInterrupt < 0) { + throw new IllegalArgumentException(); + } + else if (numberOfEventsUntilInterrupt > 0) { + this.broadcastInputReaders[i].setIterativeReader(); + iterativeBcInputs.add(i); + + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Broadcast input [" + i + "] reads in supersteps with [" + + + numberOfEventsUntilInterrupt + "] event(s) till next superstep.")); + } + } + } + this.iterativeBroadcastInputs = asArray(iterativeBcInputs); + } + + initLocalStrategies(numInputs); + } + catch (Exception e) { + throw new RuntimeException("Initializing the input processing failed" + + (e.getMessage() == null ? "." : ": " + e.getMessage()), e); + } + + if (!this.running) { + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Task cancelled before task code was started.")); + } + return; + } + + // pre main-function initialization + initialize(); + + // read the broadcast variables. they will be released in the finally clause + for (int i = 0; i < this.config.getNumBroadcastInputs(); i++) { + final String name = this.config.getBroadcastInputName(i); + readAndSetBroadcastInput(i, name, this.runtimeUdfContext, 1 /* superstep one for the start */); + } + + // the work goes here + run(); + } + finally { + // clean up in any case! + closeLocalStrategiesAndCaches(); + + clearReaders(inputReaders); + clearWriters(eventualOutputs); + + } + + if (this.running) { + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Finished task code.")); + } + } else { + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Task code cancelled.")); + } + } + } + + @Override + public void cancel() throws Exception { + this.running = false; + + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Cancelling task code")); + } + + try { + if (this.driver != null) { + this.driver.cancel(); + } + } finally { + closeLocalStrategiesAndCaches(); + } + } + + // -------------------------------------------------------------------------------------------- + // Main Work Methods + // -------------------------------------------------------------------------------------------- + + protected void initialize() throws Exception { + // create the operator + try { + this.driver.setup(this); + } + catch (Throwable t) { + throw new Exception("The driver setup for '" + this.getEnvironment().getTaskName() + + "' , caused an error: " + t.getMessage(), t); + } + + // instantiate the UDF + try { + final Class<? super S> userCodeFunctionType = this.driver.getStubType(); + // if the class is null, the driver has no user code + if (userCodeFunctionType != null) { + this.stub = initStub(userCodeFunctionType); + } + } catch (Exception e) { + throw new RuntimeException("Initializing the UDF" + + (e.getMessage() == null ? "." : ": " + e.getMessage()), e); + } + } + + protected <X> void readAndSetBroadcastInput(int inputNum, String bcVarName, DistributedRuntimeUDFContext context, int superstep) throws IOException { + + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Setting broadcast variable '" + bcVarName + "'" + + (superstep > 1 ? ", superstep " + superstep : ""))); + } + + @SuppressWarnings("unchecked") + final TypeSerializerFactory<X> serializerFactory = (TypeSerializerFactory<X>) this.broadcastInputSerializers[inputNum]; + + final MutableReader<?> reader = this.broadcastInputReaders[inputNum]; + + BroadcastVariableMaterialization<X, ?> variable = getEnvironment().getBroadcastVariableManager().materializeBroadcastVariable(bcVarName, superstep, this, reader, serializerFactory); + context.setBroadcastVariable(bcVarName, variable); + } + + protected void releaseBroadcastVariables(String bcVarName, int superstep, DistributedRuntimeUDFContext context) { + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Releasing broadcast variable '" + bcVarName + "'" + + (superstep > 1 ? ", superstep " + superstep : ""))); + } + + getEnvironment().getBroadcastVariableManager().releaseReference(bcVarName, superstep, this); + context.clearBroadcastVariable(bcVarName); + } + + + protected void run() throws Exception { + // ---------------------------- Now, the actual processing starts ------------------------ + // check for asynchronous canceling + if (!this.running) { + return; + } + + boolean stubOpen = false; + + try { + // run the data preparation + try { + this.driver.prepare(); + } + catch (Throwable t) { + // if the preparation caused an error, clean up + // errors during clean-up are swallowed, because we have already a root exception + throw new Exception("The data preparation for task '" + this.getEnvironment().getTaskName() + + "' , caused an error: " + t.getMessage(), t); + } + + // check for canceling + if (!this.running) { + return; + } + + // start all chained tasks + BatchTask.openChainedTasks(this.chainedTasks, this); + + // open stub implementation + if (this.stub != null) { + try { + Configuration stubConfig = this.config.getStubParameters(); + FunctionUtils.openFunction(this.stub, stubConfig); + stubOpen = true; + } + catch (Throwable t) { + throw new Exception("The user defined 'open()' method caused an exception: " + t.getMessage(), t); + } + } + + // run the user code + this.driver.run(); + + // close. We close here such that a regular close throwing an exception marks a task as failed. + if (this.running && this.stub != null) { + FunctionUtils.closeFunction(this.stub); + stubOpen = false; + } + + this.output.close(); + + // close all chained tasks letting them report failure + BatchTask.closeChainedTasks(this.chainedTasks, this); + } + catch (Exception ex) { + // close the input, but do not report any exceptions, since we already have another root cause + if (stubOpen) { + try { + FunctionUtils.closeFunction(this.stub); + } + catch (Throwable t) { + // do nothing + } + } + + // if resettable driver invoke teardown + if (this.driver instanceof ResettableDriver) { + final ResettableDriver<?, ?> resDriver = (ResettableDriver<?, ?>) this.driver; + try { + resDriver.teardown(); + } catch (Throwable t) { + throw new Exception("Error while shutting down an iterative operator: " + t.getMessage(), t); + } + } + + BatchTask.cancelChainedTasks(this.chainedTasks); + + ex = ExceptionInChainedStubException.exceptionUnwrap(ex); + + if (ex instanceof CancelTaskException) { + // forward canceling exception + throw ex; + } + else if (this.running) { + // throw only if task was not cancelled. in the case of canceling, exceptions are expected + BatchTask.logAndThrowException(ex, this); + } + } + finally { + this.driver.cleanup(); + } + } + + protected void closeLocalStrategiesAndCaches() { + + // make sure that all broadcast variable references held by this task are released + if (LOG.isDebugEnabled()) { + LOG.debug(formatLogString("Releasing all broadcast variables.")); + } + + getEnvironment().getBroadcastVariableManager().releaseAllReferencesFromTask(this); + if (runtimeUdfContext != null) { + runtimeUdfContext.clearAllBroadcastVariables(); + } + + // clean all local strategies and caches/pipeline breakers. + + if (this.localStrategies != null) { + for (int i = 0; i < this.localStrategies.length; i++) { + if (this.localStrategies[i] != null) { + try { + this.localStrategies[i].close(); + } catch (Throwable t) { + LOG.error("Error closing local strategy for input " + i, t); + } + } + } + } + if (this.tempBarriers != null) { + for (int i = 0; i < this.tempBarriers.length; i++) { + if (this.tempBarriers[i] != null) { + try { + this.tempBarriers[i].close(); + } catch (Throwable t) { + LOG.error("Error closing temp barrier for input " + i, t); + } + } + } + } + if (this.resettableInputs != null) { + for (int i = 0; i < this.resettableInputs.length; i++) { + if (this.resettableInputs[i] != null) { + try { + this.resettableInputs[i].close(); + } catch (Throwable t) { + LOG.error("Error closing cache for input " + i, t); + } + } + } + } + } + + // -------------------------------------------------------------------------------------------- + // Task Setup and Teardown + // -------------------------------------------------------------------------------------------- + + /** + * @return the last output collector in the collector chain + */ + @SuppressWarnings("unchecked") + protected Collector<OT> getLastOutputCollector() { + int numChained = this.chainedTasks.size(); + return (numChained == 0) ? output : (Collector<OT>) chainedTasks.get(numChained - 1).getOutputCollector(); + } + + /** + * Sets the last output {@link Collector} of the collector chain of this {@link BatchTask}. + * <p> + * In case of chained tasks, the output collector of the last {@link ChainedDriver} is set. Otherwise it is the + * single collector of the {@link BatchTask}. + * + * @param newOutputCollector new output collector to set as last collector + */ + protected void setLastOutputCollector(Collector<OT> newOutputCollector) { + int numChained = this.chainedTasks.size(); + + if (numChained == 0) { + output = newOutputCollector; + return; + } + + chainedTasks.get(numChained - 1).setOutputCollector(newOutputCollector); + } + + public TaskConfig getLastTasksConfig() { + int numChained = this.chainedTasks.size(); + return (numChained == 0) ? config : chainedTasks.get(numChained - 1).getTaskConfig(); + } + + protected S initStub(Class<? super S> stubSuperClass) throws Exception { + try { + ClassLoader userCodeClassLoader = getUserCodeClassLoader(); + S stub = config.<S>getStubWrapper(userCodeClassLoader).getUserCodeObject(stubSuperClass, userCodeClassLoader); + // check if the class is a subclass, if the check is required + if (stubSuperClass != null && !stubSuperClass.isAssignableFrom(stub.getClass())) { + throw new RuntimeException("The class '" + stub.getClass().getName() + "' is not a subclass of '" + + stubSuperClass.getName() + "' as is required."); + } + FunctionUtils.setFunctionRuntimeContext(stub, this.runtimeUdfContext); + return stub; + } + catch (ClassCastException ccex) { + throw new Exception("The stub class is not a proper subclass of " + stubSuperClass.getName(), ccex); + } + } + + /** + * Creates the record readers for the number of inputs as defined by {@link #getNumTaskInputs()}. + * + * This method requires that the task configuration, the driver, and the user-code class loader are set. + */ + protected void initInputReaders() throws Exception { + final int numInputs = getNumTaskInputs(); + final MutableReader<?>[] inputReaders = new MutableReader<?>[numInputs]; + + int currentReaderOffset = 0; + + AccumulatorRegistry registry = getEnvironment().getAccumulatorRegistry(); + AccumulatorRegistry.Reporter reporter = registry.getReadWriteReporter(); + + for (int i = 0; i < numInputs; i++) { + // ---------------- create the input readers --------------------- + // in case where a logical input unions multiple physical inputs, create a union reader + final int groupSize = this.config.getGroupSize(i); + + if (groupSize == 1) { + // non-union case + inputReaders[i] = new MutableRecordReader<IOReadableWritable>(getEnvironment().getInputGate(currentReaderOffset)); + } else if (groupSize > 1){ + // union case + InputGate[] readers = new InputGate[groupSize]; + for (int j = 0; j < groupSize; ++j) { + readers[j] = getEnvironment().getInputGate(currentReaderOffset + j); + } + inputReaders[i] = new MutableRecordReader<IOReadableWritable>(new UnionInputGate(readers)); + } else { + throw new Exception("Illegal input group size in task configuration: " + groupSize); + } + + inputReaders[i].setReporter(reporter); + + currentReaderOffset += groupSize; + } + this.inputReaders = inputReaders; + + // final sanity check + if (currentReaderOffset != this.config.getNumInputs()) { + throw new Exception("Illegal configuration: Number of input gates and group sizes are not consistent."); + } + } + + /** + * Creates the record readers for the extra broadcast inputs as configured by {@link TaskConfig#getNumBroadcastInputs()}. + * + * This method requires that the task configuration, the driver, and the user-code class loader are set. + */ + protected void initBroadcastInputReaders() throws Exception { + final int numBroadcastInputs = this.config.getNumBroadcastInputs(); + final MutableReader<?>[] broadcastInputReaders = new MutableReader<?>[numBroadcastInputs]; + + int currentReaderOffset = config.getNumInputs(); + + for (int i = 0; i < this.config.getNumBroadcastInputs(); i++) { + // ---------------- create the input readers --------------------- + // in case where a logical input unions multiple physical inputs, create a union reader + final int groupSize = this.config.getBroadcastGroupSize(i); + if (groupSize == 1) { + // non-union case + broadcastInputReaders[i] = new MutableRecordReader<IOReadableWritable>(getEnvironment().getInputGate(currentReaderOffset)); + } else if (groupSize > 1){ + // union case + InputGate[] readers = new InputGate[groupSize]; + for (int j = 0; j < groupSize; ++j) { + readers[j] = getEnvironment().getInputGate(currentReaderOffset + j); + } + broadcastInputReaders[i] = new MutableRecordReader<IOReadableWritable>(new UnionInputGate(readers)); + } else { + throw new Exception("Illegal input group size in task configuration: " + groupSize); + } + + currentReaderOffset += groupSize; + } + this.broadcastInputReaders = broadcastInputReaders; + } + + /** + * Creates all the serializers and comparators. + */ + protected void initInputsSerializersAndComparators(int numInputs, int numComparators) throws Exception { + this.inputSerializers = new TypeSerializerFactory<?>[numInputs]; + this.inputComparators = numComparators > 0 ? new TypeComparator<?>[numComparators] : null; + this.inputIterators = new MutableObjectIterator<?>[numInputs]; + + ClassLoader userCodeClassLoader = getUserCodeClassLoader(); + + for (int i = 0; i < numInputs; i++) { + + final TypeSerializerFactory<?> serializerFactory = this.config.getInputSerializer(i, userCodeClassLoader); + this.inputSerializers[i] = serializerFactory; + + this.inputIterators[i] = createInputIterator(this.inputReaders[i], this.inputSerializers[i]); + } + + // ---------------- create the driver's comparators --------------------- + for (int i = 0; i < numComparators; i++) { + + if (this.inputComparators != null) { + final TypeComparatorFactory<?> comparatorFactory = this.config.getDriverComparator(i, userCodeClassLoader); + this.inputComparators[i] = comparatorFactory.createComparator(); + } + } + } + + /** + * Creates all the serializers and iterators for the broadcast inputs. + */ + protected void initBroadcastInputsSerializers(int numBroadcastInputs) throws Exception { + this.broadcastInputSerializers = new TypeSerializerFactory<?>[numBroadcastInputs]; + + ClassLoader userCodeClassLoader = getUserCodeClassLoader(); + + for (int i = 0; i < numBroadcastInputs; i++) { + // ---------------- create the serializer first --------------------- + final TypeSerializerFactory<?> serializerFactory = this.config.getBroadcastInputSerializer(i, userCodeClassLoader); + this.broadcastInputSerializers[i] = serializerFactory; + } + } + + /** + * + * NOTE: This method must be invoked after the invocation of {@code #initInputReaders()} and + * {@code #initInputSerializersAndComparators(int)}! + * + * @param numInputs + */ + protected void initLocalStrategies(int numInputs) throws Exception { + + final MemoryManager memMan = getMemoryManager(); + final IOManager ioMan = getIOManager(); + + this.localStrategies = new CloseableInputProvider<?>[numInputs]; + this.inputs = new MutableObjectIterator<?>[numInputs]; + this.excludeFromReset = new boolean[numInputs]; + this.inputIsCached = new boolean[numInputs]; + this.inputIsAsyncMaterialized = new boolean[numInputs]; + this.materializationMemory = new int[numInputs]; + + // set up the local strategies first, such that the can work before any temp barrier is created + for (int i = 0; i < numInputs; i++) { + initInputLocalStrategy(i); + } + + // we do another loop over the inputs, because we want to instantiate all + // sorters, etc before requesting the first input (as this call may block) + + // we have two types of materialized inputs, and both are replayable (can act as a cache) + // The first variant materializes in a different thread and hence + // acts as a pipeline breaker. this one should only be there, if a pipeline breaker is needed. + // the second variant spills to the side and will not read unless the result is also consumed + // in a pipelined fashion. + this.resettableInputs = new SpillingResettableMutableObjectIterator<?>[numInputs]; + this.tempBarriers = new TempBarrier<?>[numInputs]; + + for (int i = 0; i < numInputs; i++) { + final int memoryPages; + final boolean async = this.config.isInputAsynchronouslyMaterialized(i); + final boolean cached = this.config.isInputCached(i); + + this.inputIsAsyncMaterialized[i] = async; + this.inputIsCached[i] = cached; + + if (async || cached) { + memoryPages = memMan.computeNumberOfPages(this.config.getRelativeInputMaterializationMemory(i)); + if (memoryPages <= 0) { + throw new Exception("Input marked as materialized/cached, but no memory for materialization provided."); + } + this.materializationMemory[i] = memoryPages; + } else { + memoryPages = 0; + } + + if (async) { + @SuppressWarnings({ "unchecked", "rawtypes" }) + TempBarrier<?> barrier = new TempBarrier(this, getInput(i), this.inputSerializers[i], memMan, ioMan, memoryPages); + barrier.startReading(); + this.tempBarriers[i] = barrier; + this.inputs[i] = null; + } else if (cached) { + @SuppressWarnings({ "unchecked", "rawtypes" }) + SpillingResettableMutableObjectIterator<?> iter = new SpillingResettableMutableObjectIterator( + getInput(i), this.inputSerializers[i].getSerializer(), getMemoryManager(), getIOManager(), memoryPages, this); + this.resettableInputs[i] = iter; + this.inputs[i] = iter; + } + } + } + + protected void resetAllInputs() throws Exception { + + // first we need to make sure that caches consume remaining data + // NOTE: we need to do this before closing the local strategies + for (int i = 0; i < this.inputs.length; i++) { + + if (this.inputIsCached[i] && this.resettableInputs[i] != null) { + this.resettableInputs[i].consumeAndCacheRemainingData(); + } + } + + // close all local-strategies. they will either get re-initialized, or we have + // read them now and their data is cached + for (int i = 0; i < this.localStrategies.length; i++) { + if (this.localStrategies[i] != null) { + this.localStrategies[i].close(); + this.localStrategies[i] = null; + } + } + + final MemoryManager memMan = getMemoryManager(); + final IOManager ioMan = getIOManager(); + + // reset the caches, or re-run the input local strategy + for (int i = 0; i < this.inputs.length; i++) { + if (this.excludeFromReset[i]) { + if (this.tempBarriers[i] != null) { + this.tempBarriers[i].close(); + this.tempBarriers[i] = null; + } else if (this.resettableInputs[i] != null) { + this.resettableInputs[i].close(); + this.resettableInputs[i] = null; + } + } else { + // make sure the input is not available directly, but are lazily fetched again + this.inputs[i] = null; + + if (this.inputIsCached[i]) { + if (this.tempBarriers[i] != null) { + this.inputs[i] = this.tempBarriers[i].getIterator(); + } else if (this.resettableInputs[i] != null) { + this.resettableInputs[i].reset(); + this.inputs[i] = this.resettableInputs[i]; + } else { + throw new RuntimeException("Found a resettable input, but no temp barrier and no resettable iterator."); + } + } else { + // close the async barrier if there is one + if (this.tempBarriers[i] != null) { + this.tempBarriers[i].close(); + } + + // recreate the local strategy + initInputLocalStrategy(i); + + if (this.inputIsAsyncMaterialized[i]) { + final int pages = this.materializationMemory[i]; + @SuppressWarnings({ "unchecked", "rawtypes" }) + TempBarrier<?> barrier = new TempBarrier(this, getInput(i), this.inputSerializers[i], memMan, ioMan, pages); + barrier.startReading(); + this.tempBarriers[i] = barrier; + this.inputs[i] = null; + } + } + } + } + } + + protected void excludeFromReset(int inputNum) { + this.excludeFromReset[inputNum] = true; + } + + private void initInputLocalStrategy(int inputNum) throws Exception { + // check if there is already a strategy + if (this.localStrategies[inputNum] != null) { + throw new IllegalStateException(); + } + + // now set up the local strategy + final LocalStrategy localStrategy = this.config.getInputLocalStrategy(inputNum); + if (localStrategy != null) { + switch (localStrategy) { + case NONE: + // the input is as it is + this.inputs[inputNum] = this.inputIterators[inputNum]; + break; + case SORT: + @SuppressWarnings({ "rawtypes", "unchecked" }) + UnilateralSortMerger<?> sorter = new UnilateralSortMerger(getMemoryManager(), getIOManager(), + this.inputIterators[inputNum], this, this.inputSerializers[inputNum], getLocalStrategyComparator(inputNum), + this.config.getRelativeMemoryInput(inputNum), this.config.getFilehandlesInput(inputNum), + this.config.getSpillingThresholdInput(inputNum), this.getExecutionConfig().isObjectReuseEnabled()); + // set the input to null such that it will be lazily fetched from the input strategy + this.inputs[inputNum] = null; + this.localStrategies[inputNum] = sorter; + break; + case COMBININGSORT: + // sanity check this special case! + // this still breaks a bit of the abstraction! + // we should have nested configurations for the local strategies to solve that + if (inputNum != 0) { + throw new IllegalStateException("Performing combining sort outside a (group)reduce task!"); + } + + // instantiate ourselves a combiner. we should not use the stub, because the sort and the + // subsequent (group)reduce would otherwise share it multi-threaded + final Class<S> userCodeFunctionType = this.driver.getStubType(); + if (userCodeFunctionType == null) { + throw new IllegalStateException("Performing combining sort outside a reduce task!"); + } + final S localStub; + try { + localStub = initStub(userCodeFunctionType); + } catch (Exception e) { + throw new RuntimeException("Initializing the user code and the configuration failed" + + (e.getMessage() == null ? "." : ": " + e.getMessage()), e); + } + + if (!(localStub instanceof GroupCombineFunction)) { + throw new IllegalStateException("Performing combining sort outside a reduce task!"); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + CombiningUnilateralSortMerger<?> cSorter = new CombiningUnilateralSortMerger( + (GroupCombineFunction) localStub, getMemoryManager(), getIOManager(), this.inputIterators[inputNum], + this, this.inputSerializers[inputNum], getLocalStrategyComparator(inputNum), + this.config.getRelativeMemoryInput(inputNum), this.config.getFilehandlesInput(inputNum), + this.config.getSpillingThresholdInput(inputNum), this.getExecutionConfig().isObjectReuseEnabled()); + cSorter.setUdfConfiguration(this.config.getStubParameters()); + + // set the input to null such that it will be lazily fetched from the input strategy + this.inputs[inputNum] = null; + this.localStrategies[inputNum] = cSorter; + break; + default: + throw new Exception("Unrecognized local strategy provided: " + localStrategy.name()); + } + } else { + // no local strategy in the config + this.inputs[inputNum] = this.inputIterators[inputNum]; + } + } + + private <T> TypeComparator<T> getLocalStrategyComparator(int inputNum) throws Exception { + TypeComparatorFactory<T> compFact = this.config.getInputComparator(inputNum, getUserCodeClassLoader()); + if (compFact == null) { + throw new Exception("Missing comparator factory for local strategy on input " + inputNum); + } + return compFact.createComparator(); + } + + protected MutableObjectIterator<?> createInputIterator(MutableReader<?> inputReader, TypeSerializerFactory<?> serializerFactory) { + @SuppressWarnings("unchecked") + MutableReader<DeserializationDelegate<?>> reader = (MutableReader<DeserializationDelegate<?>>) inputReader; + @SuppressWarnings({ "unchecked", "rawtypes" }) + final MutableObjectIterator<?> iter = new ReaderIterator(reader, serializerFactory.getSerializer()); + return iter; + } + + protected int getNumTaskInputs() { + return this.driver.getNumberOfInputs(); + } + + /** + * Creates a writer for each output. Creates an OutputCollector which forwards its input to all writers. + * The output collector applies the configured shipping strategies for each writer. + */ + protected void initOutputs() throws Exception { + this.chainedTasks = new ArrayList<ChainedDriver<?, ?>>(); + this.eventualOutputs = new ArrayList<RecordWriter<?>>(); + + ClassLoader userCodeClassLoader = getUserCodeClassLoader(); + + AccumulatorRegistry accumulatorRegistry = getEnvironment().getAccumulatorRegistry(); + AccumulatorRegistry.Reporter reporter = accumulatorRegistry.getReadWriteReporter(); + + this.accumulatorMap = accumulatorRegistry.getUserMap(); + + this.output = initOutputs(this, userCodeClassLoader, this.config, this.chainedTasks, this.eventualOutputs, + this.getExecutionConfig(), reporter, this.accumulatorMap); + } + + public DistributedRuntimeUDFContext createRuntimeContext(String taskName) { + Environment env = getEnvironment(); + + return new DistributedRuntimeUDFContext(taskName, env.getNumberOfSubtasks(), + env.getIndexInSubtaskGroup(), getUserCodeClassLoader(), getExecutionConfig(), + env.getDistributedCacheEntries(), this.accumulatorMap); + } + + // -------------------------------------------------------------------------------------------- + // Task Context Signature + // ------------------------------------------------------------------------------------------- + + @Override + public TaskConfig getTaskConfig() { + return this.config; + } + + @Override + public TaskManagerRuntimeInfo getTaskManagerInfo() { + return getEnvironment().getTaskManagerInfo(); + } + + @Override + public MemoryManager getMemoryManager() { + return getEnvironment().getMemoryManager(); + } + + @Override + public IOManager getIOManager() { + return getEnvironment().getIOManager(); + } + + @Override + public S getStub() { + return this.stub; + } + + @Override + public Collector<OT> getOutputCollector() { + return this.output; + } + + @Override + public AbstractInvokable getOwningNepheleTask() { + return this; + } + + @Override + public String formatLogString(String message) { + return constructLogString(message, getEnvironment().getTaskName(), this); + } + + @Override + public <X> MutableObjectIterator<X> getInput(int index) { + if (index < 0 || index > this.driver.getNumberOfInputs()) { + throw new IndexOutOfBoundsException(); + } + + // check for lazy assignment from input strategies + if (this.inputs[index] != null) { + @SuppressWarnings("unchecked") + MutableObjectIterator<X> in = (MutableObjectIterator<X>) this.inputs[index]; + return in; + } else { + final MutableObjectIterator<X> in; + try { + if (this.tempBarriers[index] != null) { + @SuppressWarnings("unchecked") + MutableObjectIterator<X> iter = (MutableObjectIterator<X>) this.tempBarriers[index].getIterator(); + in = iter; + } else if (this.localStrategies[index] != null) { + @SuppressWarnings("unchecked") + MutableObjectIterator<X> iter = (MutableObjectIterator<X>) this.localStrategies[index].getIterator(); + in = iter; + } else { + throw new RuntimeException("Bug: null input iterator, null temp barrier, and null local strategy."); + } + this.inputs[index] = in; + return in; + } catch (InterruptedException iex) { + throw new RuntimeException("Interrupted while waiting for input " + index + " to become available."); + } catch (IOException ioex) { + throw new RuntimeException("An I/O Exception occurred while obtaining input " + index + "."); + } + } + } + + + @Override + public <X> TypeSerializerFactory<X> getInputSerializer(int index) { + if (index < 0 || index >= this.driver.getNumberOfInputs()) { + throw new IndexOutOfBoundsException(); + } + + @SuppressWarnings("unchecked") + final TypeSerializerFactory<X> serializerFactory = (TypeSerializerFactory<X>) this.inputSerializers[index]; + return serializerFactory; + } + + + @Override + public <X> TypeComparator<X> getDriverComparator(int index) { + if (this.inputComparators == null) { + throw new IllegalStateException("Comparators have not been created!"); + } + else if (index < 0 || index >= this.driver.getNumberOfDriverComparators()) { + throw new IndexOutOfBoundsException(); + } + + @SuppressWarnings("unchecked") + final TypeComparator<X> comparator = (TypeComparator<X>) this.inputComparators[index]; + return comparator; + } + + // ============================================================================================ + // Static Utilities + // + // Utilities are consolidated here to ensure a uniform way of running, + // logging, exception handling, and error messages. + // ============================================================================================ + + // -------------------------------------------------------------------------------------------- + // Logging + // -------------------------------------------------------------------------------------------- + /** + * Utility function that composes a string for logging purposes. The string includes the given message, + * the given name of the task and the index in its subtask group as well as the number of instances + * that exist in its subtask group. + * + * @param message The main message for the log. + * @param taskName The name of the task. + * @param parent The nephele task that contains the code producing the message. + * + * @return The string for logging. + */ + public static String constructLogString(String message, String taskName, AbstractInvokable parent) { + return message + ": " + taskName + " (" + (parent.getEnvironment().getIndexInSubtaskGroup() + 1) + + '/' + parent.getEnvironment().getNumberOfSubtasks() + ')'; + } + + /** + * Prints an error message and throws the given exception. If the exception is of the type + * {@link ExceptionInChainedStubException} then the chain of contained exceptions is followed + * until an exception of a different type is found. + * + * @param ex The exception to be thrown. + * @param parent The parent task, whose information is included in the log message. + * @throws Exception Always thrown. + */ + public static void logAndThrowException(Exception ex, AbstractInvokable parent) throws Exception { + String taskName; + if (ex instanceof ExceptionInChainedStubException) { + do { + ExceptionInChainedStubException cex = (ExceptionInChainedStubException) ex; + taskName = cex.getTaskName(); + ex = cex.getWrappedException(); + } while (ex instanceof ExceptionInChainedStubException); + } else { + taskName = parent.getEnvironment().getTaskName(); + } + + if (LOG.isErrorEnabled()) { + LOG.error(constructLogString("Error in task code", taskName, parent), ex); + } + + throw ex; + } + + // -------------------------------------------------------------------------------------------- + // Result Shipping and Chained Tasks + // -------------------------------------------------------------------------------------------- + + /** + * Creates the {@link Collector} for the given task, as described by the given configuration. The + * output collector contains the writers that forward the data to the different tasks that the given task + * is connected to. Each writer applies a the partitioning as described in the configuration. + * + * @param task The task that the output collector is created for. + * @param config The configuration describing the output shipping strategies. + * @param cl The classloader used to load user defined types. + * @param eventualOutputs The output writers that this task forwards to the next task for each output. + * @param outputOffset The offset to start to get the writers for the outputs + * @param numOutputs The number of outputs described in the configuration. + * + * @return The OutputCollector that data produced in this task is submitted to. + */ + public static <T> Collector<T> getOutputCollector(AbstractInvokable task, TaskConfig config, ClassLoader cl, + List<RecordWriter<?>> eventualOutputs, int outputOffset, int numOutputs, AccumulatorRegistry.Reporter reporter) throws Exception + { + if (numOutputs == 0) { + return null; + } + + // get the factory for the serializer + final TypeSerializerFactory<T> serializerFactory = config.getOutputSerializer(cl); + + // special case the Record + if (serializerFactory.getDataType().equals(Record.class)) { + final List<RecordWriter<Record>> writers = new ArrayList<RecordWriter<Record>>(numOutputs); + + // create a writer for each output + for (int i = 0; i < numOutputs; i++) { + // create the OutputEmitter from output ship strategy + final ShipStrategyType strategy = config.getOutputShipStrategy(i); + final TypeComparatorFactory<?> compFact = config.getOutputComparator(i, cl); + final RecordOutputEmitter oe; + if (compFact == null) { + oe = new RecordOutputEmitter(strategy); + } else { + @SuppressWarnings("unchecked") + TypeComparator<Record> comparator = (TypeComparator<Record>) compFact.createComparator(); + if (!comparator.supportsCompareAgainstReference()) { + throw new Exception("Incompatibe serializer-/comparator factories."); + } + final DataDistribution distribution = config.getOutputDataDistribution(i, cl); + final Partitioner<?> partitioner = config.getOutputPartitioner(i, cl); + + oe = new RecordOutputEmitter(strategy, comparator, partitioner, distribution); + } + + // setup accumulator counters + final RecordWriter<Record> recordWriter = new RecordWriter<Record>(task.getEnvironment().getWriter(outputOffset + i), oe); + recordWriter.setReporter(reporter); + + writers.add(recordWriter); + } + if (eventualOutputs != null) { + eventualOutputs.addAll(writers); + } + + @SuppressWarnings("unchecked") + final Collector<T> outColl = (Collector<T>) new RecordOutputCollector(writers); + return outColl; + } + else { + // generic case + final List<RecordWriter<SerializationDelegate<T>>> writers = new ArrayList<RecordWriter<SerializationDelegate<T>>>(numOutputs); + + // create a writer for each output + for (int i = 0; i < numOutputs; i++) + { + // create the OutputEmitter from output ship strategy + final ShipStrategyType strategy = config.getOutputShipStrategy(i); + final TypeComparatorFactory<T> compFactory = config.getOutputComparator(i, cl); + + final ChannelSelector<SerializationDelegate<T>> oe; + if (compFactory == null) { + oe = new OutputEmitter<T>(strategy); + } + else { + final DataDistribution dataDist = config.getOutputDataDistribution(i, cl); + final Partitioner<?> partitioner = config.getOutputPartitioner(i, cl); + + final TypeComparator<T> comparator = compFactory.createComparator(); + oe = new OutputEmitter<T>(strategy, comparator, partitioner, dataDist); + } + + final RecordWriter<SerializationDelegate<T>> recordWriter = + new RecordWriter<SerializationDelegate<T>>(task.getEnvironment().getWriter(outputOffset + i), oe); + + // setup live accumulator counters + recordWriter.setReporter(reporter); + + writers.add(recordWriter); + } + if (eventualOutputs != null) { + eventualOutputs.addAll(writers); + } + return new OutputCollector<T>(writers, serializerFactory.getSerializer()); + } + } + + /** + * Creates a writer for each output. Creates an OutputCollector which forwards its input to all writers. + * The output collector applies the configured shipping strategy. + */ + @SuppressWarnings("unchecked") + public static <T> Collector<T> initOutputs(AbstractInvokable nepheleTask, ClassLoader cl, TaskConfig config, + List<ChainedDriver<?, ?>> chainedTasksTarget, + List<RecordWriter<?>> eventualOutputs, + ExecutionConfig executionConfig, + AccumulatorRegistry.Reporter reporter, + Map<String, Accumulator<?,?>> accumulatorMap) + throws Exception + { + final int numOutputs = config.getNumOutputs(); + + // check whether we got any chained tasks + final int numChained = config.getNumberOfChainedStubs(); + if (numChained > 0) { + // got chained stubs. that means that this one may only have a single forward connection + if (numOutputs != 1 || config.getOutputShipStrategy(0) != ShipStrategyType.FORWARD) { + throw new RuntimeException("Plan Generation Bug: Found a chained stub that is not connected via an only forward connection."); + } + + // instantiate each task + @SuppressWarnings("rawtypes") + Collector previous = null; + for (int i = numChained - 1; i >= 0; --i) + { + // get the task first + final ChainedDriver<?, ?> ct; + try { + Class<? extends ChainedDriver<?, ?>> ctc = config.getChainedTask(i); + ct = ctc.newInstance(); + } + catch (Exception ex) { + throw new RuntimeException("Could not instantiate chained task driver.", ex); + } + + // get the configuration for the task + final TaskConfig chainedStubConf = config.getChainedStubConfig(i); + final String taskName = config.getChainedTaskName(i); + + if (i == numChained - 1) { + // last in chain, instantiate the output collector for this task + previous = getOutputCollector(nepheleTask, chainedStubConf, cl, eventualOutputs, 0, chainedStubConf.getNumOutputs(), reporter); + } + + ct.setup(chainedStubConf, taskName, previous, nepheleTask, cl, executionConfig, accumulatorMap); + chainedTasksTarget.add(0, ct); + + previous = ct; + } + // the collector of the first in the chain is the collector for the nephele task + return (Collector<T>) previous; + } + // else + + // instantiate the output collector the default way from this configuration + return getOutputCollector(nepheleTask , config, cl, eventualOutputs, 0, numOutputs, reporter); + } + + // -------------------------------------------------------------------------------------------- + // User Code LifeCycle + // -------------------------------------------------------------------------------------------- + + /** + * Opens the given stub using its {@link org.apache.flink.api.common.functions.RichFunction#open(Configuration)} method. If the open call produces + * an exception, a new exception with a standard error message is created, using the encountered exception + * as its cause. + * + * @param stub The user code instance to be opened. + * @param parameters The parameters supplied to the user code. + * + * @throws Exception Thrown, if the user code's open method produces an exception. + */ + public static void openUserCode(Function stub, Configuration parameters) throws Exception { + try { + FunctionUtils.openFunction(stub, parameters); + } catch (Throwable t) { + throw new Exception("The user defined 'open(Configuration)' method in " + stub.getClass().toString() + " caused an exception: " + t.getMessage(), t); + } + } + + /** + * Closes the given stub using its {@link org.apache.flink.api.common.functions.RichFunction#close()} method. If the close call produces + * an exception, a new exception with a standard error message is created, using the encountered exception + * as its cause. + * + * @param stub The user code instance to be closed. + * + * @throws Exception Thrown, if the user code's close method produces an exception. + */ + public static void closeUserCode(Function stub) throws Exception { + try { + FunctionUtils.closeFunction(stub); + } catch (Throwable t) { + throw new Exception("The user defined 'close()' method caused an exception: " + t.getMessage(), t); + } + } + + // -------------------------------------------------------------------------------------------- + // Chained Task LifeCycle + // -------------------------------------------------------------------------------------------- + + /** + * Opens all chained tasks, in the order as they are stored in the array. The opening process + * creates a standardized log info message. + * + * @param tasks The tasks to be opened. + * @param parent The parent task, used to obtain parameters to include in the log message. + * @throws Exception Thrown, if the opening encounters an exception. + */ + public static void openChainedTasks(List<ChainedDriver<?, ?>> tasks, AbstractInvokable parent) throws Exception { + // start all chained tasks + for (int i = 0; i < tasks.size(); i++) { + final ChainedDriver<?, ?> task = tasks.get(i); + if (LOG.isDebugEnabled()) { + LOG.debug(constructLogString("Start task code", task.getTaskName(), parent)); + } + task.openTask(); + } + } + + /** + * Closes all chained tasks, in the order as they are stored in the array. The closing process + * creates a standardized log info message. + * + * @param tasks The tasks to be closed. + * @param parent The parent task, used to obtain parameters to include in the log message. + * @throws Exception Thrown, if the closing encounters an exception. + */ + public static void closeChainedTasks(List<ChainedDriver<?, ?>> tasks, AbstractInvokable parent) throws Exception { + for (int i = 0; i < tasks.size(); i++) { + final ChainedDriver<?, ?> task = tasks.get(i); + task.closeTask(); + + if (LOG.isDebugEnabled()) { + LOG.debug(constructLogString("Finished task code", task.getTaskName(), parent)); + } + } + } + + /** + * Cancels all tasks via their {@link ChainedDriver#cancelTask()} method. Any occurring exception + * and error is suppressed, such that the canceling method of every task is invoked in all cases. + * + * @param tasks The tasks to be canceled. + */ + public static void cancelChainedTasks(List<ChainedDriver<?, ?>> tasks) { + for (int i = 0; i < tasks.size(); i++) { + try { + tasks.get(i).cancelTask(); + } catch (Throwable t) { + // do nothing + } + } + } + + // -------------------------------------------------------------------------------------------- + // Miscellaneous Utilities + // -------------------------------------------------------------------------------------------- + + /** + * Instantiates a user code class from is definition in the task configuration. + * The class is instantiated without arguments using the null-ary constructor. Instantiation + * will fail if this constructor does not exist or is not public. + * + * @param <T> The generic type of the user code class. + * @param config The task configuration containing the class description. + * @param cl The class loader to be used to load the class. + * @param superClass The super class that the user code class extends or implements, for type checking. + * + * @return An instance of the user code class. + */ + public static <T> T instantiateUserCode(TaskConfig config, ClassLoader cl, Class<? super T> superClass) { + try { + T stub = config.<T>getStubWrapper(cl).getUserCodeObject(superClass, cl); + // check if the class is a subclass, if the check is required + if (superClass != null && !superClass.isAssignableFrom(stub.getClass())) { + throw new RuntimeException("The class '" + stub.getClass().getName() + "' is not a subclass of '" + + superClass.getName() + "' as is required."); + } + return stub; + } + catch (ClassCastException ccex) { + throw new RuntimeException("The UDF class is not a proper subclass of " + superClass.getName(), ccex); + } + } + + private static int[] asArray(List<Integer> list) { + int[] a = new int[list.size()]; + + int i = 0; + for (int val : list) { + a[i++] = val; + } + return a; + } + + public static void clearWriters(List<RecordWriter<?>> writers) { + for (RecordWriter<?> writer : writers) { + writer.clearBuffers(); + } + } + + public static void clearReaders(MutableReader<?>[] readers) { + for (MutableReader<?> reader : readers) { + reader.clearBuffers(); + } + } +}
http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupDriver.java index dfc4e91..665ab0e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupDriver.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupDriver.java @@ -43,12 +43,12 @@ import org.apache.flink.util.MutableObjectIterator; * * @see org.apache.flink.api.common.functions.CoGroupFunction */ -public class CoGroupDriver<IT1, IT2, OT> implements PactDriver<CoGroupFunction<IT1, IT2, OT>, OT> { +public class CoGroupDriver<IT1, IT2, OT> implements Driver<CoGroupFunction<IT1, IT2, OT>, OT> { private static final Logger LOG = LoggerFactory.getLogger(CoGroupDriver.class); - private PactTaskContext<CoGroupFunction<IT1, IT2, OT>, OT> taskContext; + private TaskContext<CoGroupFunction<IT1, IT2, OT>, OT> taskContext; private CoGroupTaskIterator<IT1, IT2> coGroupIterator; // the iterator that does the actual cogroup @@ -60,7 +60,7 @@ public class CoGroupDriver<IT1, IT2, OT> implements PactDriver<CoGroupFunction<I @Override - public void setup(PactTaskContext<CoGroupFunction<IT1, IT2, OT>, OT> context) { + public void setup(TaskContext<CoGroupFunction<IT1, IT2, OT>, OT> context) { this.taskContext = context; this.running = true; } http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupRawDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupRawDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupRawDriver.java index 7abad5d..0487904 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupRawDriver.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupRawDriver.java @@ -26,17 +26,17 @@ import org.apache.flink.util.MutableObjectIterator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class CoGroupRawDriver<IT1, IT2, OT> implements PactDriver<CoGroupFunction<IT1, IT2, OT>, OT> { +public class CoGroupRawDriver<IT1, IT2, OT> implements Driver<CoGroupFunction<IT1, IT2, OT>, OT> { private static final Logger LOG = LoggerFactory.getLogger(CoGroupRawDriver.class); - private PactTaskContext<CoGroupFunction<IT1, IT2, OT>, OT> taskContext; + private TaskContext<CoGroupFunction<IT1, IT2, OT>, OT> taskContext; private SimpleIterable<IT1> coGroupIterator1; private SimpleIterable<IT2> coGroupIterator2; @Override - public void setup(PactTaskContext<CoGroupFunction<IT1, IT2, OT>, OT> context) { + public void setup(TaskContext<CoGroupFunction<IT1, IT2, OT>, OT> context) { this.taskContext = context; } http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupWithSolutionSetFirstDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupWithSolutionSetFirstDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupWithSolutionSetFirstDriver.java index 97d6e51..07fb352 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupWithSolutionSetFirstDriver.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupWithSolutionSetFirstDriver.java @@ -29,7 +29,7 @@ import org.apache.flink.api.common.typeutils.TypePairComparator; import org.apache.flink.api.common.typeutils.TypePairComparatorFactory; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.runtime.iterative.concurrent.SolutionSetBroker; -import org.apache.flink.runtime.iterative.task.AbstractIterativePactTask; +import org.apache.flink.runtime.iterative.task.AbstractIterativeTask; import org.apache.flink.runtime.operators.hash.CompactingHashTable; import org.apache.flink.runtime.operators.util.TaskConfig; import org.apache.flink.runtime.util.NonReusingKeyGroupedIterator; @@ -37,9 +37,9 @@ import org.apache.flink.runtime.util.ReusingKeyGroupedIterator; import org.apache.flink.runtime.util.SingleElementIterator; import org.apache.flink.util.Collector; -public class CoGroupWithSolutionSetFirstDriver<IT1, IT2, OT> implements ResettablePactDriver<CoGroupFunction<IT1, IT2, OT>, OT> { +public class CoGroupWithSolutionSetFirstDriver<IT1, IT2, OT> implements ResettableDriver<CoGroupFunction<IT1, IT2, OT>, OT> { - private PactTaskContext<CoGroupFunction<IT1, IT2, OT>, OT> taskContext; + private TaskContext<CoGroupFunction<IT1, IT2, OT>, OT> taskContext; private CompactingHashTable<IT1> hashTable; @@ -63,7 +63,7 @@ public class CoGroupWithSolutionSetFirstDriver<IT1, IT2, OT> implements Resettab // -------------------------------------------------------------------------------------------- @Override - public void setup(PactTaskContext<CoGroupFunction<IT1, IT2, OT>, OT> context) { + public void setup(TaskContext<CoGroupFunction<IT1, IT2, OT>, OT> context) { this.taskContext = context; this.running = true; } @@ -106,8 +106,8 @@ public class CoGroupWithSolutionSetFirstDriver<IT1, IT2, OT> implements Resettab final TypeComparator<IT1> solutionSetComparator; // grab a handle to the hash table from the iteration broker - if (taskContext instanceof AbstractIterativePactTask) { - AbstractIterativePactTask<?, ?> iterativeTaskContext = (AbstractIterativePactTask<?, ?>) taskContext; + if (taskContext instanceof AbstractIterativeTask) { + AbstractIterativeTask<?, ?> iterativeTaskContext = (AbstractIterativeTask<?, ?>) taskContext; String identifier = iterativeTaskContext.brokerKey(); Object table = SolutionSetBroker.instance().get(identifier); http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupWithSolutionSetSecondDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupWithSolutionSetSecondDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupWithSolutionSetSecondDriver.java index 9e8a81c..eec6597 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupWithSolutionSetSecondDriver.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CoGroupWithSolutionSetSecondDriver.java @@ -27,7 +27,7 @@ import org.apache.flink.api.common.typeutils.TypePairComparator; import org.apache.flink.api.common.typeutils.TypePairComparatorFactory; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.runtime.iterative.concurrent.SolutionSetBroker; -import org.apache.flink.runtime.iterative.task.AbstractIterativePactTask; +import org.apache.flink.runtime.iterative.task.AbstractIterativeTask; import org.apache.flink.runtime.operators.hash.CompactingHashTable; import org.apache.flink.runtime.operators.util.TaskConfig; import org.apache.flink.runtime.util.EmptyIterator; @@ -36,9 +36,9 @@ import org.apache.flink.runtime.util.ReusingKeyGroupedIterator; import org.apache.flink.runtime.util.SingleElementIterator; import org.apache.flink.util.Collector; -public class CoGroupWithSolutionSetSecondDriver<IT1, IT2, OT> implements ResettablePactDriver<CoGroupFunction<IT1, IT2, OT>, OT> { +public class CoGroupWithSolutionSetSecondDriver<IT1, IT2, OT> implements ResettableDriver<CoGroupFunction<IT1, IT2, OT>, OT> { - private PactTaskContext<CoGroupFunction<IT1, IT2, OT>, OT> taskContext; + private TaskContext<CoGroupFunction<IT1, IT2, OT>, OT> taskContext; private CompactingHashTable<IT2> hashTable; @@ -61,7 +61,7 @@ public class CoGroupWithSolutionSetSecondDriver<IT1, IT2, OT> implements Resetta // -------------------------------------------------------------------------------------------- @Override - public void setup(PactTaskContext<CoGroupFunction<IT1, IT2, OT>, OT> context) { + public void setup(TaskContext<CoGroupFunction<IT1, IT2, OT>, OT> context) { this.taskContext = context; this.running = true; } @@ -104,8 +104,8 @@ public class CoGroupWithSolutionSetSecondDriver<IT1, IT2, OT> implements Resetta final TypeComparator<IT2> solutionSetComparator; // grab a handle to the hash table from the iteration broker - if (taskContext instanceof AbstractIterativePactTask) { - AbstractIterativePactTask<?, ?> iterativeTaskContext = (AbstractIterativePactTask<?, ?>) taskContext; + if (taskContext instanceof AbstractIterativeTask) { + AbstractIterativeTask<?, ?> iterativeTaskContext = (AbstractIterativeTask<?, ?>) taskContext; String identifier = iterativeTaskContext.brokerKey(); Object table = SolutionSetBroker.instance().get(identifier); http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CollectorMapDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CollectorMapDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CollectorMapDriver.java index 7e6376f..60bc798 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CollectorMapDriver.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CollectorMapDriver.java @@ -40,19 +40,19 @@ import org.slf4j.LoggerFactory; * @param <OT> The mapper's output data type. */ @SuppressWarnings("deprecation") -public class CollectorMapDriver<IT, OT> implements PactDriver<GenericCollectorMap<IT, OT>, OT> { +public class CollectorMapDriver<IT, OT> implements Driver<GenericCollectorMap<IT, OT>, OT> { private static final Logger LOG = LoggerFactory.getLogger(CollectorMapDriver.class); - private PactTaskContext<GenericCollectorMap<IT, OT>, OT> taskContext; + private TaskContext<GenericCollectorMap<IT, OT>, OT> taskContext; private volatile boolean running; private boolean objectReuseEnabled = false; @Override - public void setup(PactTaskContext<GenericCollectorMap<IT, OT>, OT> context) { + public void setup(TaskContext<GenericCollectorMap<IT, OT>, OT> context) { this.taskContext = context; this.running = true; } http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CrossDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CrossDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CrossDriver.java index 011c980..c9d84b1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CrossDriver.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/CrossDriver.java @@ -41,12 +41,12 @@ import org.apache.flink.util.MutableObjectIterator; * * @see org.apache.flink.api.common.functions.CrossFunction */ -public class CrossDriver<T1, T2, OT> implements PactDriver<CrossFunction<T1, T2, OT>, OT> { +public class CrossDriver<T1, T2, OT> implements Driver<CrossFunction<T1, T2, OT>, OT> { private static final Logger LOG = LoggerFactory.getLogger(CrossDriver.class); - private PactTaskContext<CrossFunction<T1, T2, OT>, OT> taskContext; + private TaskContext<CrossFunction<T1, T2, OT>, OT> taskContext; private MemoryManager memManager; @@ -70,7 +70,7 @@ public class CrossDriver<T1, T2, OT> implements PactDriver<CrossFunction<T1, T2, @Override - public void setup(PactTaskContext<CrossFunction<T1, T2, OT>, OT> context) { + public void setup(TaskContext<CrossFunction<T1, T2, OT>, OT> context) { this.taskContext = context; this.running = true; } http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java index 1002bae..d20bb89 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSinkTask.java @@ -84,9 +84,8 @@ public class DataSinkTask<IT> extends AbstractInvokable { @Override public void registerInputOutput() { - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Start registering input and output")); - } + + LOG.debug(getLogString("Start registering input and output")); // initialize OutputFormat initOutputFormat(); @@ -99,24 +98,18 @@ public class DataSinkTask<IT> extends AbstractInvokable { (e.getMessage() == null ? "." : ": " + e.getMessage()), e); } - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Finished registering input and output")); - } + LOG.debug(getLogString("Finished registering input and output")); } @Override public void invoke() throws Exception { - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Starting data sink operator")); - } + LOG.debug(getLogString("Starting data sink operator")); if(RichOutputFormat.class.isAssignableFrom(this.format.getClass())){ ((RichOutputFormat) this.format).setRuntimeContext(createRuntimeContext()); - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Rich Sink detected. Initializing runtime context.")); - } + LOG.debug(getLogString("Rich Sink detected. Initializing runtime context.")); } ExecutionConfig executionConfig; @@ -131,9 +124,7 @@ public class DataSinkTask<IT> extends AbstractInvokable { LOG.warn("The execution config returned by the configuration was null"); executionConfig = new ExecutionConfig(); } - } catch (IOException e) { - throw new RuntimeException("Could not load ExecutionConfig from Job Configuration: " + e); - } catch (ClassNotFoundException e) { + } catch (IOException | ClassNotFoundException e) { throw new RuntimeException("Could not load ExecutionConfig from Job Configuration: " + e); } boolean objectReuseEnabled = executionConfig.isObjectReuseEnabled(); @@ -190,9 +181,7 @@ public class DataSinkTask<IT> extends AbstractInvokable { return; } - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Starting to produce output")); - } + LOG.debug(getLogString("Starting to produce output")); // open format.open(this.getEnvironment().getIndexInSubtaskGroup(), this.getEnvironment().getNumberOfSubtasks()); @@ -268,18 +257,14 @@ public class DataSinkTask<IT> extends AbstractInvokable { } } - RegularPactTask.clearReaders(new MutableReader[]{inputReader}); + BatchTask.clearReaders(new MutableReader[]{inputReader}); } if (!this.taskCanceled) { - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Finished data sink operator")); - } + LOG.debug(getLogString("Finished data sink operator")); } else { - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Data sink operator cancelled")); - } + LOG.debug(getLogString("Data sink operator cancelled")); } } @@ -304,9 +289,7 @@ public class DataSinkTask<IT> extends AbstractInvokable { } } - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Cancelling data sink operator")); - } + LOG.debug(getLogString("Cancelling data sink operator")); } /** @@ -402,7 +385,7 @@ public class DataSinkTask<IT> extends AbstractInvokable { * @return The string ready for logging. */ private String getLogString(String message) { - return RegularPactTask.constructLogString(message, this.getEnvironment().getTaskName(), this); + return BatchTask.constructLogString(message, this.getEnvironment().getTaskName(), this); } public DistributedRuntimeUDFContext createRuntimeContext() { http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSourceTask.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSourceTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSourceTask.java index 5806e77..801e1a1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSourceTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DataSourceTask.java @@ -80,9 +80,7 @@ public class DataSourceTask<OT> extends AbstractInvokable { public void registerInputOutput() { initInputFormat(); - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Start registering input and output")); - } + LOG.debug(getLogString("Start registering input and output")); try { initOutputs(getUserCodeClassLoader()); @@ -91,24 +89,18 @@ public class DataSourceTask<OT> extends AbstractInvokable { ex.getMessage(), ex); } - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Finished registering input and output")); - } + LOG.debug(getLogString("Finished registering input and output")); } @Override public void invoke() throws Exception { - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Starting data source operator")); - } + LOG.debug(getLogString("Starting data source operator")); if(RichInputFormat.class.isAssignableFrom(this.format.getClass())){ ((RichInputFormat) this.format).setRuntimeContext(createRuntimeContext()); - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Rich Source detected. Initializing runtime context.")); - } + LOG.debug(getLogString("Rich Source detected. Initializing runtime context.")); } ExecutionConfig executionConfig; @@ -123,23 +115,19 @@ public class DataSourceTask<OT> extends AbstractInvokable { LOG.warn("ExecutionConfig from job configuration is null. Creating empty config"); executionConfig = new ExecutionConfig(); } - } catch (IOException e) { - throw new RuntimeException("Could not load ExecutionConfig from Job Configuration: ", e); - } catch (ClassNotFoundException e) { + } catch (IOException | ClassNotFoundException e) { throw new RuntimeException("Could not load ExecutionConfig from Job Configuration: ", e); } boolean objectReuseEnabled = executionConfig.isObjectReuseEnabled(); - if (LOG.isDebugEnabled()) { - LOG.debug("DataSourceTask object reuse: " + (objectReuseEnabled ? "ENABLED" : "DISABLED") + "."); - } + LOG.debug("DataSourceTask object reuse: " + (objectReuseEnabled ? "ENABLED" : "DISABLED") + "."); final TypeSerializer<OT> serializer = this.serializerFactory.getSerializer(); try { // start all chained tasks - RegularPactTask.openChainedTasks(this.chainedTasks, this); + BatchTask.openChainedTasks(this.chainedTasks, this); // get input splits to read final Iterator<InputSplit> splitIterator = getInputSplits(); @@ -150,18 +138,14 @@ public class DataSourceTask<OT> extends AbstractInvokable { // get start and end final InputSplit split = splitIterator.next(); - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Opening input split " + split.toString())); - } + LOG.debug(getLogString("Opening input split " + split.toString())); final InputFormat<OT, InputSplit> format = this.format; // open input format format.open(split); - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Starting to read input from split " + split.toString())); - } + LOG.debug(getLogString("Starting to read input from split " + split.toString())); try { final Collector<OT> output = this.output; @@ -201,7 +185,7 @@ public class DataSourceTask<OT> extends AbstractInvokable { this.output.close(); // close all chained tasks letting them report failure - RegularPactTask.closeChainedTasks(this.chainedTasks, this); + BatchTask.closeChainedTasks(this.chainedTasks, this); } catch (Exception ex) { @@ -210,7 +194,7 @@ public class DataSourceTask<OT> extends AbstractInvokable { this.format.close(); } catch (Throwable ignored) {} - RegularPactTask.cancelChainedTasks(this.chainedTasks); + BatchTask.cancelChainedTasks(this.chainedTasks); ex = ExceptionInChainedStubException.exceptionUnwrap(ex); @@ -220,30 +204,24 @@ public class DataSourceTask<OT> extends AbstractInvokable { } else if (!this.taskCanceled) { // drop exception, if the task was canceled - RegularPactTask.logAndThrowException(ex, this); + BatchTask.logAndThrowException(ex, this); } } finally { - RegularPactTask.clearWriters(eventualOutputs); + BatchTask.clearWriters(eventualOutputs); } if (!this.taskCanceled) { - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Finished data source operator")); - } + LOG.debug(getLogString("Finished data source operator")); } else { - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Data source operator cancelled")); - } + LOG.debug(getLogString("Data source operator cancelled")); } } @Override public void cancel() throws Exception { this.taskCanceled = true; - if (LOG.isDebugEnabled()) { - LOG.debug(getLogString("Cancelling data source operator")); - } + LOG.debug(getLogString("Cancelling data source operator")); } /** @@ -303,7 +281,7 @@ public class DataSourceTask<OT> extends AbstractInvokable { final AccumulatorRegistry accumulatorRegistry = getEnvironment().getAccumulatorRegistry(); final AccumulatorRegistry.Reporter reporter = accumulatorRegistry.getReadWriteReporter(); - this.output = RegularPactTask.initOutputs(this, cl, this.config, this.chainedTasks, this.eventualOutputs, + this.output = BatchTask.initOutputs(this, cl, this.config, this.chainedTasks, this.eventualOutputs, getExecutionConfig(), reporter, getEnvironment().getAccumulatorRegistry().getUserMap()); } @@ -331,7 +309,7 @@ public class DataSourceTask<OT> extends AbstractInvokable { * @return The string ready for logging. */ private String getLogString(String message, String taskName) { - return RegularPactTask.constructLogString(message, taskName, this); + return BatchTask.constructLogString(message, taskName, this); } private Iterator<InputSplit> getInputSplits() { http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/Driver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/Driver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/Driver.java new file mode 100644 index 0000000..8e8795e --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/Driver.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.operators; + +import org.apache.flink.api.common.functions.Function; + +/** + * The interface to be implemented by all drivers that run alone (or as the primary driver) in a task. + * A driver implements the actual code to perform a batch operation, like <i>map()</i>, + * <i>reduce()</i>, <i>join()</i>, or <i>coGroup()</i>. + * + * @see TaskContext + * + * @param <S> The type of stub driven by this driver. + * @param <OT> The data type of the records produced by this driver. + */ +public interface Driver<S extends Function, OT> { + + void setup(TaskContext<S, OT> context); + + /** + * Gets the number of inputs that the task has. + * + * @return The number of inputs. + */ + int getNumberOfInputs(); + + /** + * Gets the number of comparators required for this driver. + * + * @return The number of comparators required for this driver. + */ + int getNumberOfDriverComparators(); + + /** + * Gets the class of the stub type that is run by this task. For example, a <tt>MapTask</tt> should return + * <code>MapFunction.class</code>. + * + * @return The class of the stub type run by the task. + */ + Class<S> getStubType(); + + /** + * This method is called before the user code is opened. An exception thrown by this method + * signals failure of the task. + * + * @throws Exception Exceptions may be forwarded and signal task failure. + */ + void prepare() throws Exception; + + /** + * The main operation method of the task. It should call the user code with the data subsets until + * the input is depleted. + * + * @throws Exception Any exception thrown by this method signals task failure. Because exceptions in the user + * code typically signal situations where this instance in unable to proceed, exceptions + * from the user code should be forwarded. + */ + void run() throws Exception; + + /** + * This method is invoked in any case (clean termination and exception) at the end of the tasks operation. + * + * @throws Exception Exceptions may be forwarded. + */ + void cleanup() throws Exception; + + /** + * This method is invoked when the driver must aborted in mid processing. It is invoked asynchronously by a different thread. + * + * @throws Exception Exceptions may be forwarded. + */ + void cancel() throws Exception; +} http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DriverStrategy.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DriverStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DriverStrategy.java index 0aceae2..bc7bee5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DriverStrategy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DriverStrategy.java @@ -115,7 +115,7 @@ public enum DriverStrategy { // -------------------------------------------------------------------------------------------- - private final Class<? extends PactDriver<?, ?>> driverClass; + private final Class<? extends Driver<?, ?>> driverClass; private final Class<? extends ChainedDriver<?, ?>> pushChainDriver; @@ -129,11 +129,11 @@ public enum DriverStrategy { @SuppressWarnings("unchecked") private DriverStrategy( - @SuppressWarnings("rawtypes") Class<? extends PactDriver> driverClass, + @SuppressWarnings("rawtypes") Class<? extends Driver> driverClass, @SuppressWarnings("rawtypes") Class<? extends ChainedDriver> pushChainDriverClass, DamBehavior dam, int numComparator) { - this.driverClass = (Class<? extends PactDriver<?, ?>>) driverClass; + this.driverClass = (Class<? extends Driver<?, ?>>) driverClass; this.pushChainDriver = (Class<? extends ChainedDriver<?, ?>>) pushChainDriverClass; this.numInputs = 1; this.dam1 = dam; @@ -143,11 +143,11 @@ public enum DriverStrategy { @SuppressWarnings("unchecked") private DriverStrategy( - @SuppressWarnings("rawtypes") Class<? extends PactDriver> driverClass, + @SuppressWarnings("rawtypes") Class<? extends Driver> driverClass, @SuppressWarnings("rawtypes") Class<? extends ChainedDriver> pushChainDriverClass, DamBehavior firstDam, DamBehavior secondDam, int numComparator) { - this.driverClass = (Class<? extends PactDriver<?, ?>>) driverClass; + this.driverClass = (Class<? extends Driver<?, ?>>) driverClass; this.pushChainDriver = (Class<? extends ChainedDriver<?, ?>>) pushChainDriverClass; this.numInputs = 2; this.dam1 = firstDam; @@ -157,7 +157,7 @@ public enum DriverStrategy { // -------------------------------------------------------------------------------------------- - public Class<? extends PactDriver<?, ?>> getDriverClass() { + public Class<? extends Driver<?, ?>> getDriverClass() { return this.driverClass; } http://git-wip-us.apache.org/repos/asf/flink/blob/b08669ab/flink-runtime/src/main/java/org/apache/flink/runtime/operators/FlatMapDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/FlatMapDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/FlatMapDriver.java index fd245bd..c29923b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/FlatMapDriver.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/FlatMapDriver.java @@ -39,18 +39,18 @@ import org.slf4j.LoggerFactory; * @param <IT> The mapper's input data type. * @param <OT> The mapper's output data type. */ -public class FlatMapDriver<IT, OT> implements PactDriver<FlatMapFunction<IT, OT>, OT> { +public class FlatMapDriver<IT, OT> implements Driver<FlatMapFunction<IT, OT>, OT> { private static final Logger LOG = LoggerFactory.getLogger(FlatMapDriver.class); - private PactTaskContext<FlatMapFunction<IT, OT>, OT> taskContext; + private TaskContext<FlatMapFunction<IT, OT>, OT> taskContext; private volatile boolean running; private boolean objectReuseEnabled = false; @Override - public void setup(PactTaskContext<FlatMapFunction<IT, OT>, OT> context) { + public void setup(TaskContext<FlatMapFunction<IT, OT>, OT> context) { this.taskContext = context; this.running = true; }