seojangho closed pull request #31: [NEMO-97] Refactor TaskExecutor and fix a sideinput bug URL: https://github.com/apache/incubator-nemo/pull/31
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/common/src/main/java/edu/snu/nemo/common/ContextImpl.java b/common/src/main/java/edu/snu/nemo/common/ContextImpl.java index 2e85afad..bb5aa21b 100644 --- a/common/src/main/java/edu/snu/nemo/common/ContextImpl.java +++ b/common/src/main/java/edu/snu/nemo/common/ContextImpl.java @@ -23,18 +23,18 @@ * Transform Context Implementation. */ public final class ContextImpl implements Transform.Context { - private final Map<Transform, Object> sideInputs; + private final Map sideInputs; /** * Constructor of Context Implementation. * @param sideInputs side inputs. */ - public ContextImpl(final Map<Transform, Object> sideInputs) { + public ContextImpl(final Map sideInputs) { this.sideInputs = sideInputs; } @Override - public Map<Transform, Object> getSideInputs() { + public Map getSideInputs() { return this.sideInputs; } } diff --git a/common/src/main/java/edu/snu/nemo/common/ir/Readable.java b/common/src/main/java/edu/snu/nemo/common/ir/Readable.java index 9bca623a..8f856a46 100644 --- a/common/src/main/java/edu/snu/nemo/common/ir/Readable.java +++ b/common/src/main/java/edu/snu/nemo/common/ir/Readable.java @@ -15,6 +15,7 @@ */ package edu.snu.nemo.common.ir; +import java.io.IOException; import java.io.Serializable; import java.util.List; @@ -27,9 +28,9 @@ * Method to read data from the source. * * @return an {@link Iterable} of the data read by the readable. - * @throws Exception exception while reading data. + * @throws IOException exception while reading data. */ - Iterable<O> read() throws Exception; + Iterable<O> read() throws IOException; /** * Returns the list of locations where this readable resides. diff --git a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java index 0f7b9d51..95fa539f 100644 --- a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java +++ b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java @@ -45,6 +45,13 @@ */ void close(); + /** + * @return tag + */ + default Object getTag() { + return null; + } + /** * Context of the transform. */ @@ -52,6 +59,6 @@ /** * @return sideInputs. */ - Map<Transform, Object> getSideInputs(); + Map getSideInputs(); } } diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/source/BeamBoundedSourceVertex.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/source/BeamBoundedSourceVertex.java index 1143a1a8..e309d3f6 100644 --- a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/source/BeamBoundedSourceVertex.java +++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/source/BeamBoundedSourceVertex.java @@ -17,6 +17,7 @@ import edu.snu.nemo.common.ir.Readable; +import java.io.IOException; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Arrays; @@ -96,7 +97,7 @@ public String propertiesToJSON() { } @Override - public Iterable<T> read() throws Exception { + public Iterable<T> read() throws IOException { final ArrayList<T> elements = new ArrayList<>(); try (BoundedSource.BoundedReader<T> reader = boundedSource.createReader(null)) { for (boolean available = reader.start(); available; available = reader.advance()) { diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java index b342595b..dbfa004f 100644 --- a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java +++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/CreateViewTransform.java @@ -60,6 +60,7 @@ public void onData(final I element) { * get the Tag of the Transform. * @return the PCollectionView of the transform. */ + @Override public PCollectionView getTag() { return this.pCollectionView; } diff --git a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java index 883d401e..b5d9690e 100644 --- a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java +++ b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/transform/DoTransform.java @@ -33,7 +33,6 @@ import org.joda.time.Instant; import java.io.IOException; -import java.util.HashMap; import java.util.Map; /** @@ -46,7 +45,6 @@ private final DoFn doFn; private final ObjectMapper mapper; private final String serializedOptions; - private Map<PCollectionView, Object> sideInputs; private OutputCollector<O> outputCollector; private StartBundleContext startBundleContext; private FinishBundleContext finishBundleContext; @@ -72,11 +70,9 @@ public DoTransform(final DoFn doFn, final PipelineOptions options) { @Override public void prepare(final Context context, final OutputCollector<O> oc) { this.outputCollector = oc; - this.sideInputs = new HashMap<>(); - context.getSideInputs().forEach((k, v) -> this.sideInputs.put(((CreateViewTransform) k).getTag(), v)); this.startBundleContext = new StartBundleContext(doFn, serializedOptions); this.finishBundleContext = new FinishBundleContext(doFn, outputCollector, serializedOptions); - this.processContext = new ProcessContext(doFn, outputCollector, sideInputs, serializedOptions); + this.processContext = new ProcessContext(doFn, outputCollector, context.getSideInputs(), serializedOptions); this.invoker = DoFnInvokers.invokerFor(doFn); invoker.invokeSetup(); invoker.invokeStartBundle(startBundleContext); @@ -195,7 +191,7 @@ public void output(final O output, final Instant instant, final BoundedWindow bo implements DoFnInvoker.ArgumentProvider<I, O> { private I input; private final OutputCollector<O> outputCollector; - private final Map<PCollectionView, Object> sideInputs; + private final Map sideInputs; private final ObjectMapper mapper; private final PipelineOptions options; @@ -209,7 +205,7 @@ public void output(final O output, final Instant instant, final BoundedWindow bo */ ProcessContext(final DoFn<I, O> fn, final OutputCollector<O> outputCollector, - final Map<PCollectionView, Object> sideInputs, + final Map sideInputs, final String serializedOptions) { fn.super(); this.outputCollector = outputCollector; diff --git a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkDatasetBoundedSourceVertex.java b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkDatasetBoundedSourceVertex.java index 0746be5b..3b058071 100644 --- a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkDatasetBoundedSourceVertex.java +++ b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkDatasetBoundedSourceVertex.java @@ -23,6 +23,8 @@ import org.apache.spark.rdd.RDD; import scala.collection.JavaConverters; +import javax.naming.OperationNotSupportedException; +import java.io.IOException; import java.util.*; /** @@ -105,12 +107,18 @@ private SparkDatasetBoundedSourceReadable(final Partition partition, } @Override - public Iterable<T> read() throws Exception { + public Iterable<T> read() throws IOException { // for setting up the same environment in the executors. final SparkSession spark = SparkSession.builder() .config(sessionInitialConf) .getOrCreate(); - final Dataset<T> dataset = SparkSession.initializeDataset(spark, commands); + final Dataset<T> dataset; + + try { + dataset = SparkSession.initializeDataset(spark, commands); + } catch (final OperationNotSupportedException e) { + throw new IllegalStateException(e); + } // Spark does lazy evaluation: it doesn't load the full dataset, but only the partition it is asked for. final RDD<T> rdd = dataset.sparkRDD(); diff --git a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkTextFileBoundedSourceVertex.java b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkTextFileBoundedSourceVertex.java index 5fab7944..9b2fd380 100644 --- a/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkTextFileBoundedSourceVertex.java +++ b/compiler/frontend/spark/src/main/java/edu/snu/nemo/compiler/frontend/spark/source/SparkTextFileBoundedSourceVertex.java @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD; import scala.collection.JavaConverters; +import java.io.IOException; import java.util.*; /** @@ -109,7 +110,7 @@ private SparkTextFileBoundedSourceReadable(final Partition partition, } @Override - public Iterable<String> read() throws Exception { + public Iterable<String> read() throws IOException { // for setting up the same environment in the executors. final SparkContext sparkContext = SparkContext.getOrCreate(sparkConf); diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageContext.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageContext.java index 0a71d714..ea478b7c 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageContext.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageContext.java @@ -49,7 +49,6 @@ public String getSenderId() { @Override @SuppressWarnings("squid:S2095") public <U> void reply(final U replyMessage) { - LOG.debug("[REPLY]: {}", replyMessage); final Connection connection = connectionFactory.newConnection(idFactory.getNewInstance(senderId)); try { connection.open(); diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java index 5dfcd6c0..296a19e5 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageEnvironment.java @@ -36,15 +36,11 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Future; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Message environment for NCS. */ public final class NcsMessageEnvironment implements MessageEnvironment { - private static final Logger LOG = LoggerFactory.getLogger(NcsMessageEnvironment.class.getName()); - private static final String NCS_CONN_FACTORY_ID = "NCS_CONN_FACTORY_ID"; private final NetworkConnectionService networkConnectionService; @@ -124,7 +120,6 @@ public void close() throws Exception { public void onNext(final Message<ControlMessage.Message> messages) { final ControlMessage.Message controlMessage = extractSingleMessage(messages); - LOG.debug("[RECEIVED]: msg={}", controlMessage); final MessageType messageType = getMsgType(controlMessage); switch (messageType) { case Send: diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageSender.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageSender.java index 5d1c61a5..517fe549 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageSender.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/message/ncs/NcsMessageSender.java @@ -42,15 +42,11 @@ @Override public void send(final ControlMessage.Message message) { - LOG.debug("[SEND]: msg.id={}, msg.listenerId={}", - message.getId(), message.getListenerId()); connection.write(message); } @Override public CompletableFuture<ControlMessage.Message> request(final ControlMessage.Message message) { - LOG.debug("[REQUEST]: msg.id={}, msg.listenerId={}", - message.getId(), message.getListenerId()); final CompletableFuture<ControlMessage.Message> future = replyFutureMap.beforeRequest(message.getId()); connection.write(message); return future; diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java index 5cd8b165..c5a1c3d3 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/Task.java @@ -16,7 +16,6 @@ package edu.snu.nemo.runtime.common.plan; import edu.snu.nemo.common.ir.Readable; -import edu.snu.nemo.runtime.common.RuntimeIdGenerator; import java.io.Serializable; import java.util.List; @@ -28,7 +27,6 @@ public final class Task implements Serializable { private final String jobId; private final String taskId; - private final int taskIdx; private final List<StageEdge> taskIncomingEdges; private final List<StageEdge> taskOutgoingEdges; private final int attemptIdx; @@ -58,7 +56,6 @@ public Task(final String jobId, final Map<String, Readable> irVertexIdToReadable) { this.jobId = jobId; this.taskId = taskId; - this.taskIdx = RuntimeIdGenerator.getIndexFromTaskId(taskId); this.attemptIdx = attemptIdx; this.containerType = containerType; this.serializedIRDag = serializedIRDag; @@ -88,13 +85,6 @@ public String getTaskId() { return taskId; } - /** - * @return the idx of the task. - */ - public int getTaskIdx() { - return taskIdx; - } - /** * @return the incoming edges of the task. */ diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java index 287ca010..8b2925d5 100644 --- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java +++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/Executor.java @@ -31,6 +31,7 @@ import edu.snu.nemo.runtime.common.plan.Task; import edu.snu.nemo.runtime.executor.data.SerializerManager; import edu.snu.nemo.runtime.executor.datatransfer.DataTransferFactory; +import edu.snu.nemo.runtime.executor.task.TaskExecutor; import org.apache.commons.lang3.SerializationUtils; import org.apache.reef.tang.annotations.Parameter; diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskExecutor.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskExecutor.java deleted file mode 100644 index 4a9b1eeb..00000000 --- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskExecutor.java +++ /dev/null @@ -1,757 +0,0 @@ -/* - * Copyright (C) 2017 Seoul National University - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package edu.snu.nemo.runtime.executor; - -import edu.snu.nemo.common.ContextImpl; -import edu.snu.nemo.common.Pair; -import edu.snu.nemo.common.dag.DAG; -import edu.snu.nemo.common.exception.BlockFetchException; -import edu.snu.nemo.common.exception.BlockWriteException; -import edu.snu.nemo.common.ir.Readable; -import edu.snu.nemo.common.ir.vertex.*; -import edu.snu.nemo.common.ir.vertex.transform.Transform; -import edu.snu.nemo.runtime.common.plan.Task; -import edu.snu.nemo.runtime.common.plan.StageEdge; -import edu.snu.nemo.runtime.common.plan.RuntimeEdge; -import edu.snu.nemo.runtime.common.state.TaskState; -import edu.snu.nemo.runtime.executor.data.DataUtil; -import edu.snu.nemo.runtime.executor.datatransfer.*; - -import java.util.*; -import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Executes a task. - */ -public final class TaskExecutor { - // Static variables - private static final Logger LOG = LoggerFactory.getLogger(TaskExecutor.class.getName()); - private static final String ITERATORID_PREFIX = "ITERATOR_"; - private static final AtomicInteger ITERATORID_GENERATOR = new AtomicInteger(0); - - // From Task - private final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag; - private final String taskId; - private final int taskIdx; - private final TaskStateManager taskStateManager; - private final List<StageEdge> stageIncomingEdges; - private final List<StageEdge> stageOutgoingEdges; - private Map<String, Readable> irVertexIdToReadable; - - // Other parameters - private final DataTransferFactory channelFactory; - private final MetricCollector metricCollector; - - // Data structures - private final Map<InputReader, List<IRVertexDataHandler>> inputReaderToDataHandlersMap; - private final Map<String, Iterator> idToSrcIteratorMap; - private final Map<String, List<IRVertexDataHandler>> srcIteratorIdToDataHandlersMap; - private final Map<String, List<IRVertexDataHandler>> iteratorIdToDataHandlersMap; - private final LinkedBlockingQueue<Pair<String, DataUtil.IteratorWithNumBytes>> partitionQueue; - private List<IRVertexDataHandler> irVertexDataHandlers; - private Map<OutputCollectorImpl, List<IRVertexDataHandler>> outputToChildrenDataHandlersMap; - private final Set<String> finishedVertexIds; - - // For metrics - private long serBlockSize; - private long encodedBlockSize; - - // Misc - private boolean isExecuted; - private String irVertexIdPutOnHold; - private int numPartitions; - - - /** - * Constructor. - * @param task Task with information needed during execution. - * @param irVertexDag A DAG of vertices. - * @param taskStateManager State manager for this Task. - * @param channelFactory For reading from/writing to data to other Stages. - * @param metricMessageSender For sending metric with execution stats to Master. - */ - public TaskExecutor(final Task task, - final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag, - final TaskStateManager taskStateManager, - final DataTransferFactory channelFactory, - final MetricMessageSender metricMessageSender) { - // Information from the Task. - this.irVertexDag = irVertexDag; - this.taskId = task.getTaskId(); - this.taskIdx = task.getTaskIdx(); - this.stageIncomingEdges = task.getTaskIncomingEdges(); - this.stageOutgoingEdges = task.getTaskOutgoingEdges(); - this.irVertexIdToReadable = task.getIrVertexIdToReadable(); - - // Other parameters. - this.taskStateManager = taskStateManager; - this.channelFactory = channelFactory; - this.metricCollector = new MetricCollector(metricMessageSender); - - // Initialize data structures. - this.inputReaderToDataHandlersMap = new ConcurrentHashMap<>(); - this.idToSrcIteratorMap = new HashMap<>(); - this.srcIteratorIdToDataHandlersMap = new HashMap<>(); - this.iteratorIdToDataHandlersMap = new ConcurrentHashMap<>(); - this.partitionQueue = new LinkedBlockingQueue<>(); - this.outputToChildrenDataHandlersMap = new HashMap<>(); - this.irVertexDataHandlers = new ArrayList<>(); - this.finishedVertexIds = new HashSet<>(); - - // Metrics - this.serBlockSize = 0; - this.encodedBlockSize = 0; - - // Misc - this.isExecuted = false; - this.irVertexIdPutOnHold = null; - this.numPartitions = 0; - - initialize(); - } - - /** - * Initializes this Task before execution. - * 1) Create and connect reader/writers for both inter-Task data and intra-Task data. - * 2) Prepares Transforms if needed. - */ - private void initialize() { - // Initialize data handlers for each IRVertex. - irVertexDag.topologicalDo(irVertex -> irVertexDataHandlers.add(new IRVertexDataHandler(irVertex))); - - // Initialize data transfer. - // Construct a pointer-based DAG of irVertexDataHandlers that are used for data transfer. - // 'Pointer-based' means that it isn't Map/List-based in getting the data structure or parent/children - // to avoid element-wise extra overhead of calculating hash values(HashMap) or iterating Lists. - irVertexDag.topologicalDo(irVertex -> { - final Set<StageEdge> inEdgesFromOtherStages = getInEdgesFromOtherStages(irVertex); - final Set<StageEdge> outEdgesToOtherStages = getOutEdgesToOtherStages(irVertex); - final IRVertexDataHandler dataHandler = getIRVertexDataHandler(irVertex); - - // Set data handlers of children irVertices. - // This forms a pointer-based DAG of irVertexDataHandlers. - final List<IRVertexDataHandler> childrenDataHandlers = new ArrayList<>(); - irVertexDag.getChildren(irVertex.getId()).forEach(child -> - childrenDataHandlers.add(getIRVertexDataHandler(child))); - dataHandler.setChildrenDataHandler(childrenDataHandlers); - - // Add InputReaders for inter-stage data transfer - inEdgesFromOtherStages.forEach(stageEdge -> { - final InputReader inputReader = channelFactory.createReader( - taskIdx, stageEdge.getSrcVertex(), stageEdge); - - // For InputReaders that have side input, collect them separately. - if (inputReader.isSideInputReader()) { - dataHandler.addSideInputFromOtherStages(inputReader); - } else { - inputReaderToDataHandlersMap.putIfAbsent(inputReader, new ArrayList<>()); - inputReaderToDataHandlersMap.get(inputReader).add(dataHandler); - } - }); - - // Add OutputWriters for inter-stage data transfer - outEdgesToOtherStages.forEach(stageEdge -> { - final OutputWriter outputWriter = channelFactory.createWriter( - irVertex, taskIdx, stageEdge.getDstVertex(), stageEdge); - dataHandler.addOutputWriter(outputWriter); - }); - - // Add InputPipes for intra-stage data transfer - addInputFromThisStage(irVertex, dataHandler); - - // Add OutputPipe for intra-stage data transfer - setOutputCollector(irVertex, dataHandler); - }); - - // Prepare Transforms if needed. - irVertexDag.topologicalDo(irVertex -> { - if (irVertex instanceof OperatorVertex) { - final Transform transform = ((OperatorVertex) irVertex).getTransform(); - final Map<Transform, Object> sideInputMap = new HashMap<>(); - final IRVertexDataHandler dataHandler = getIRVertexDataHandler(irVertex); - // Check and collect side inputs. - if (!dataHandler.getSideInputFromOtherStages().isEmpty()) { - sideInputFromOtherStages(irVertex, sideInputMap); - } - if (!dataHandler.getSideInputFromThisStage().isEmpty()) { - sideInputFromThisStage(irVertex, sideInputMap); - } - - final Transform.Context transformContext = new ContextImpl(sideInputMap); - final OutputCollectorImpl outputCollector = dataHandler.getOutputCollector(); - transform.prepare(transformContext, outputCollector); - } - }); - } - - /** - * Collect all inter-stage incoming edges of this vertex. - * - * @param irVertex the IRVertex whose inter-stage incoming edges to be collected. - * @return the collected incoming edges. - */ - private Set<StageEdge> getInEdgesFromOtherStages(final IRVertex irVertex) { - return stageIncomingEdges.stream().filter( - stageInEdge -> stageInEdge.getDstVertex().getId().equals(irVertex.getId())) - .collect(Collectors.toSet()); - } - - /** - * Collect all inter-stage outgoing edges of this vertex. - * - * @param irVertex the IRVertex whose inter-stage outgoing edges to be collected. - * @return the collected outgoing edges. - */ - private Set<StageEdge> getOutEdgesToOtherStages(final IRVertex irVertex) { - return stageOutgoingEdges.stream().filter( - stageInEdge -> stageInEdge.getSrcVertex().getId().equals(irVertex.getId())) - .collect(Collectors.toSet()); - } - - /** - * Add input OutputCollectors to each {@link IRVertex}. - * Input OutputCollector denotes all the OutputCollectors of intra-Stage dependencies. - * - * @param irVertex the IRVertex to add input OutputCollectors to. - */ - private void addInputFromThisStage(final IRVertex irVertex, final IRVertexDataHandler dataHandler) { - List<IRVertex> parentVertices = irVertexDag.getParents(irVertex.getId()); - if (parentVertices != null) { - parentVertices.forEach(parent -> { - final OutputCollectorImpl parentOutputCollector = getIRVertexDataHandler(parent).getOutputCollector(); - if (parentOutputCollector.hasSideInputFor(irVertex.getId())) { - dataHandler.addSideInputFromThisStage(parentOutputCollector); - } else { - dataHandler.addInputFromThisStages(parentOutputCollector); - } - }); - } - } - - /** - * Add outputCollectors to each {@link IRVertex}. - * @param irVertex the IRVertex to add output outputCollectors to. - */ - private void setOutputCollector(final IRVertex irVertex, final IRVertexDataHandler dataHandler) { - final OutputCollectorImpl outputCollector = new OutputCollectorImpl(); - irVertexDag.getOutgoingEdgesOf(irVertex).forEach(outEdge -> { - if (outEdge.isSideInput()) { - outputCollector.setSideInputRuntimeEdge(outEdge); - outputCollector.setAsSideInputFor(irVertex.getId()); - } - }); - - dataHandler.setOutputCollector(outputCollector); - } - - /** - * Check that this irVertex has OutputWriter for inter-stage data. - * - * @param irVertex the irVertex to check whether it has OutputWriters. - * @return true if the irVertex has OutputWriters. - */ - private boolean hasOutputWriter(final IRVertex irVertex) { - return !getIRVertexDataHandler(irVertex).getOutputWriters().isEmpty(); - } - - private void setIRVertexPutOnHold(final MetricCollectionBarrierVertex irVertex) { - irVertexIdPutOnHold = irVertex.getId(); - } - - /** - * Finalize the output write of this Task. - * As element-wise output write is done and the block is in memory, - * flush the block into the designated data store and commit it. - * - * @param irVertex the IRVertex with OutputWriter to flush and commit output block. - */ - private void writeAndCloseOutputWriters(final IRVertex irVertex) { - final List<Long> writtenBytesList = new ArrayList<>(); - final Map<String, Object> metric = new HashMap<>(); - metricCollector.beginMeasurement(irVertex.getId(), metric); - final long writeStartTime = System.currentTimeMillis(); - - getIRVertexDataHandler(irVertex).getOutputWriters().forEach(outputWriter -> { - outputWriter.close(); - final Optional<Long> writtenBytes = outputWriter.getWrittenBytes(); - writtenBytes.ifPresent(writtenBytesList::add); - }); - - final long writeEndTime = System.currentTimeMillis(); - metric.put("OutputWriteTime(ms)", writeEndTime - writeStartTime); - putWrittenBytesMetric(writtenBytesList, metric); - metricCollector.endMeasurement(irVertex.getId(), metric); - } - - /** - * Get input iterator from BoundedSource and bind it with id. - */ - private void prepareInputFromSource() { - irVertexDag.topologicalDo(irVertex -> { - if (irVertex instanceof SourceVertex) { - try { - final String iteratorId = generateIteratorId(); - final Readable readable = irVertexIdToReadable.get(irVertex.getId()); - if (readable == null) { - throw new RuntimeException(irVertex.toString()); - } - final Iterator iterator = readable.read().iterator(); - idToSrcIteratorMap.putIfAbsent(iteratorId, iterator); - srcIteratorIdToDataHandlersMap.putIfAbsent(iteratorId, new ArrayList<>()); - srcIteratorIdToDataHandlersMap.get(iteratorId).add(getIRVertexDataHandler(irVertex)); - } catch (final BlockFetchException ex) { - taskStateManager.onTaskStateChanged(TaskState.State.FAILED_RECOVERABLE, - Optional.empty(), Optional.of(TaskState.RecoverableFailureCause.INPUT_READ_FAILURE)); - LOG.error("{} Execution Failed (Recoverable: input read failure)! Exception: {}", - taskId, ex.toString()); - } catch (final Exception e) { - taskStateManager.onTaskStateChanged(TaskState.State.FAILED_UNRECOVERABLE, - Optional.empty(), Optional.empty()); - LOG.error("{} Execution Failed! Exception: {}", taskId, e.toString()); - throw new RuntimeException(e); - } - } - }); - } - - /** - * Get input iterator from other stages received in the form of CompletableFuture - * and bind it with id. - */ - private void prepareInputFromOtherStages() { - inputReaderToDataHandlersMap.forEach((inputReader, dataHandlers) -> { - final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = inputReader.read(); - numPartitions += futures.size(); - - // Add consumers which will push iterator when the futures are complete. - futures.forEach(compFuture -> compFuture.whenComplete((iterator, exception) -> { - if (exception != null) { - throw new BlockFetchException(exception); - } - - final String iteratorId = generateIteratorId(); - if (iteratorIdToDataHandlersMap.containsKey(iteratorId)) { - throw new RuntimeException("iteratorIdToDataHandlersMap already contains " + iteratorId); - } else { - iteratorIdToDataHandlersMap.computeIfAbsent(iteratorId, absentIteratorId -> dataHandlers); - try { - partitionQueue.put(Pair.of(iteratorId, iterator)); - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - throw new BlockFetchException(e); - } - } - })); - }); - } - - /** - * Check whether all vertices in this Task are finished. - * - * @return true if all vertices are finished. - */ - private boolean finishedAllVertices() { - // Total number of Tasks - int vertexNum = irVertexDataHandlers.size(); - int finishedVertexNum = finishedVertexIds.size(); - return finishedVertexNum == vertexNum; - } - - /** - * Initialize the very first map of OutputCollector-children irVertex DAG. - * In each map entry, the OutputCollector contains input data to be propagated through - * the children irVertex DAG. - */ - private void initializeOutputToChildrenDataHandlersMap() { - srcIteratorIdToDataHandlersMap.values().forEach(dataHandlers -> - dataHandlers.forEach(dataHandler -> { - outputToChildrenDataHandlersMap.putIfAbsent(dataHandler.getOutputCollector(), dataHandler.getChildren()); - })); - iteratorIdToDataHandlersMap.values().forEach(dataHandlers -> - dataHandlers.forEach(dataHandler -> { - outputToChildrenDataHandlersMap.putIfAbsent(dataHandler.getOutputCollector(), dataHandler.getChildren()); - })); - } - - /** - * Update the map of OutputCollector-children irVertex DAG. - */ - private void updateOutputToChildrenDataHandlersMap() { - Map<OutputCollectorImpl, List<IRVertexDataHandler>> currentMap = outputToChildrenDataHandlersMap; - Map<OutputCollectorImpl, List<IRVertexDataHandler>> updatedMap = new HashMap<>(); - - currentMap.values().forEach(dataHandlers -> - dataHandlers.forEach(dataHandler -> { - updatedMap.putIfAbsent(dataHandler.getOutputCollector(), dataHandler.getChildren()); - }) - ); - - outputToChildrenDataHandlersMap = updatedMap; - } - - /** - * Update the map of OutputCollector-children irVertex DAG. - * - * @param irVertex the IRVertex with the transform to close. - */ - private void closeTransform(final IRVertex irVertex) { - if (irVertex instanceof OperatorVertex) { - Transform transform = ((OperatorVertex) irVertex).getTransform(); - transform.close(); - } - } - - /** - * As a preprocessing of side input data, get inter stage side input - * and form a map of source transform-side input. - * - * @param irVertex the IRVertex which receives side input from other stages. - * @param sideInputMap the map of source transform-side input to build. - */ - private void sideInputFromOtherStages(final IRVertex irVertex, final Map<Transform, Object> sideInputMap) { - getIRVertexDataHandler(irVertex).getSideInputFromOtherStages().forEach(sideInputReader -> { - try { - final DataUtil.IteratorWithNumBytes sideInputIterator = sideInputReader.read().get(0).get(); - final Object sideInput = getSideInput(sideInputIterator); - final RuntimeEdge inEdge = sideInputReader.getRuntimeEdge(); - final Transform srcTransform; - if (inEdge instanceof StageEdge) { - srcTransform = ((OperatorVertex) ((StageEdge) inEdge).getSrcVertex()).getTransform(); - } else { - srcTransform = ((OperatorVertex) inEdge.getSrc()).getTransform(); - } - sideInputMap.put(srcTransform, sideInput); - - // Collect metrics on block size if possible. - try { - serBlockSize += sideInputIterator.getNumSerializedBytes(); - } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) { - serBlockSize = -1; - } - try { - encodedBlockSize += sideInputIterator.getNumEncodedBytes(); - } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) { - encodedBlockSize = -1; - } - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - throw new BlockFetchException(e); - } catch (final ExecutionException e1) { - throw new RuntimeException("Failed while reading side input from other stages " + e1); - } - }); - } - - /** - * As a preprocessing of side input data, get intra stage side input - * and form a map of source transform-side input. - * Assumption: intra stage side input denotes a data element initially received - * via side input reader from other stages. - * - * @param irVertex the IRVertex which receives the data element marked as side input. - * @param sideInputMap the map of source transform-side input to build. - */ - private void sideInputFromThisStage(final IRVertex irVertex, final Map<Transform, Object> sideInputMap) { - getIRVertexDataHandler(irVertex).getSideInputFromThisStage().forEach(input -> { - // because sideInput is only 1 element in the outputCollector - Object sideInput = input.remove(); - final RuntimeEdge inEdge = input.getSideInputRuntimeEdge(); - final Transform srcTransform; - if (inEdge instanceof StageEdge) { - srcTransform = ((OperatorVertex) ((StageEdge) inEdge).getSrcVertex()).getTransform(); - } else { - srcTransform = ((OperatorVertex) inEdge.getSrc()).getTransform(); - } - sideInputMap.put(srcTransform, sideInput); - }); - } - - /** - * Executes the task. - */ - public void execute() { - final Map<String, Object> metric = new HashMap<>(); - metricCollector.beginMeasurement(taskId, metric); - long boundedSrcReadStartTime = 0; - long boundedSrcReadEndTime = 0; - long inputReadStartTime = 0; - long inputReadEndTime = 0; - if (isExecuted) { - throw new RuntimeException("Task {" + taskId + "} execution called again!"); - } - isExecuted = true; - taskStateManager.onTaskStateChanged(TaskState.State.EXECUTING, Optional.empty(), Optional.empty()); - LOG.info("{} Executing!", taskId); - - // Prepare input data from bounded source. - boundedSrcReadStartTime = System.currentTimeMillis(); - prepareInputFromSource(); - boundedSrcReadEndTime = System.currentTimeMillis(); - metric.put("BoundedSourceReadTime(ms)", boundedSrcReadEndTime - boundedSrcReadStartTime); - - // Prepare input data from other stages. - inputReadStartTime = System.currentTimeMillis(); - prepareInputFromOtherStages(); - - // Execute the IRVertex DAG. - try { - srcIteratorIdToDataHandlersMap.forEach((srcIteratorId, dataHandlers) -> { - Iterator iterator = idToSrcIteratorMap.get(srcIteratorId); - iterator.forEachRemaining(element -> { - for (final IRVertexDataHandler dataHandler : dataHandlers) { - runTask(dataHandler, element); - } - }); - }); - - // Process data from other stages. - for (int currPartition = 0; currPartition < numPartitions; currPartition++) { - Pair<String, DataUtil.IteratorWithNumBytes> idToIteratorPair = partitionQueue.take(); - final String iteratorId = idToIteratorPair.left(); - final DataUtil.IteratorWithNumBytes iterator = idToIteratorPair.right(); - List<IRVertexDataHandler> dataHandlers = iteratorIdToDataHandlersMap.get(iteratorId); - iterator.forEachRemaining(element -> { - for (final IRVertexDataHandler dataHandler : dataHandlers) { - runTask(dataHandler, element); - } - }); - - // Collect metrics on block size if possible. - try { - serBlockSize += iterator.getNumSerializedBytes(); - } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) { - serBlockSize = -1; - } catch (final IllegalStateException e) { - LOG.error("Failed to get the number of bytes of serialized data - the data is not ready yet ", e); - } - try { - encodedBlockSize += iterator.getNumEncodedBytes(); - } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) { - encodedBlockSize = -1; - } catch (final IllegalStateException e) { - LOG.error("Failed to get the number of bytes of encoded data - the data is not ready yet ", e); - } - } - inputReadEndTime = System.currentTimeMillis(); - metric.put("InputReadTime(ms)", inputReadEndTime - inputReadStartTime); - - // Process intra-Task data. - // Intra-Task data comes from outputCollectors of this Task's vertices. - initializeOutputToChildrenDataHandlersMap(); - while (!finishedAllVertices()) { - outputToChildrenDataHandlersMap.forEach((outputCollector, childrenDataHandlers) -> { - // Get the vertex that has this outputCollector as its output outputCollector - final IRVertex outputProducer = irVertexDataHandlers.stream() - .filter(dataHandler -> dataHandler.getOutputCollector() == outputCollector) - .findFirst().get().getIRVertex(); - - // Before consuming the output of outputProducer as input, - // close transform if it is OperatorTransform. - closeTransform(outputProducer); - - // Set outputProducer as finished. - finishedVertexIds.add(outputProducer.getId()); - - while (!outputCollector.isEmpty()) { - final Object element = outputCollector.remove(); - - // Pass outputProducer's output to its children tasks recursively. - if (!childrenDataHandlers.isEmpty()) { - for (final IRVertexDataHandler childDataHandler : childrenDataHandlers) { - runTask(childDataHandler, element); - } - } - - // Write element-wise to OutputWriters if any and close the OutputWriters. - if (hasOutputWriter(outputProducer)) { - // If outputCollector isn't empty(if closeTransform produced some output), - // write them element-wise to OutputWriters. - List<OutputWriter> outputWritersOfTask = - getIRVertexDataHandler(outputProducer).getOutputWriters(); - outputWritersOfTask.forEach(outputWriter -> outputWriter.write(element)); - } - } - - if (hasOutputWriter(outputProducer)) { - writeAndCloseOutputWriters(outputProducer); - } - }); - updateOutputToChildrenDataHandlersMap(); - } - } catch (final BlockWriteException ex2) { - taskStateManager.onTaskStateChanged(TaskState.State.FAILED_RECOVERABLE, - Optional.empty(), Optional.of(TaskState.RecoverableFailureCause.OUTPUT_WRITE_FAILURE)); - LOG.error("{} Execution Failed (Recoverable: output write failure)! Exception: {}", - taskId, ex2.toString()); - } catch (final Exception e) { - taskStateManager.onTaskStateChanged(TaskState.State.FAILED_UNRECOVERABLE, - Optional.empty(), Optional.empty()); - LOG.error("{} Execution Failed! Exception: {}", - taskId, e.toString()); - throw new RuntimeException(e); - } - - // Put Task-unit metrics. - final boolean available = serBlockSize >= 0; - putReadBytesMetric(available, serBlockSize, encodedBlockSize, metric); - metricCollector.endMeasurement(taskId, metric); - if (irVertexIdPutOnHold == null) { - taskStateManager.onTaskStateChanged(TaskState.State.COMPLETE, Optional.empty(), Optional.empty()); - } else { - taskStateManager.onTaskStateChanged(TaskState.State.ON_HOLD, - Optional.of(irVertexIdPutOnHold), - Optional.empty()); - } - LOG.info("{} Complete!", taskId); - } - - /** - * Recursively executes a vertex with the input data element. - * - * @param dataHandler IRVertexDataHandler of a vertex to execute. - * @param dataElement input data element to process. - */ - private void runTask(final IRVertexDataHandler dataHandler, final Object dataElement) { - final IRVertex irVertex = dataHandler.getIRVertex(); - final OutputCollectorImpl outputCollector = dataHandler.getOutputCollector(); - - // Process element-wise depending on the vertex type - if (irVertex instanceof SourceVertex) { - if (dataElement == null) { // null used for Beam VoidCoders - final List<Object> nullForVoidCoder = Collections.singletonList(dataElement); - outputCollector.emit(nullForVoidCoder); - } else { - outputCollector.emit(dataElement); - } - } else if (irVertex instanceof OperatorVertex) { - final Transform transform = ((OperatorVertex) irVertex).getTransform(); - transform.onData(dataElement); - } else if (irVertex instanceof MetricCollectionBarrierVertex) { - if (dataElement == null) { // null used for Beam VoidCoders - final List<Object> nullForVoidCoder = Collections.singletonList(dataElement); - outputCollector.emit(nullForVoidCoder); - } else { - outputCollector.emit(dataElement); - } - setIRVertexPutOnHold((MetricCollectionBarrierVertex) irVertex); - } else { - throw new UnsupportedOperationException("This type of IRVertex is not supported"); - } - - // For the produced output - while (!outputCollector.isEmpty()) { - final Object element = outputCollector.remove(); - - // Pass output to its children recursively. - List<IRVertexDataHandler> childrenDataHandlers = dataHandler.getChildren(); - if (!childrenDataHandlers.isEmpty()) { - for (final IRVertexDataHandler childDataHandler : childrenDataHandlers) { - runTask(childDataHandler, element); - } - } - - // Write element-wise to OutputWriters if any - if (hasOutputWriter(irVertex)) { - List<OutputWriter> outputWritersOfTask = dataHandler.getOutputWriters(); - outputWritersOfTask.forEach(outputWriter -> outputWriter.write(element)); - } - } - } - - /** - * Generate a unique iterator id. - * - * @return the iterator id. - */ - private String generateIteratorId() { - return ITERATORID_PREFIX + ITERATORID_GENERATOR.getAndIncrement(); - } - - private IRVertexDataHandler getIRVertexDataHandler(final IRVertex irVertex) { - return irVertexDataHandlers.stream() - .filter(dataHandler -> dataHandler.getIRVertex() == irVertex) - .findFirst().get(); - } - - /** - * Puts read bytes metric if the input data size is known. - * - * @param serializedBytes size in serialized (encoded and optionally post-processed (e.g. compressed)) form - * @param encodedBytes size in encoded form - * @param metricMap the metric map to put written bytes metric. - */ - private static void putReadBytesMetric(final boolean available, - final long serializedBytes, - final long encodedBytes, - final Map<String, Object> metricMap) { - if (available) { - if (serializedBytes != encodedBytes) { - metricMap.put("ReadBytes(raw)", serializedBytes); - } - metricMap.put("ReadBytes", encodedBytes); - } - } - - /** - * Puts written bytes metric if the output data size is known. - * - * @param writtenBytesList the list of written bytes. - * @param metricMap the metric map to put written bytes metric. - */ - private static void putWrittenBytesMetric(final List<Long> writtenBytesList, - final Map<String, Object> metricMap) { - if (!writtenBytesList.isEmpty()) { - long totalWrittenBytes = 0; - for (final Long writtenBytes : writtenBytesList) { - totalWrittenBytes += writtenBytes; - } - metricMap.put("WrittenBytes", totalWrittenBytes); - } - } - - /** - * Get sideInput from data from {@link InputReader}. - * - * @param iterator data from {@link InputReader#read()} - * @return The corresponding sideInput - */ - private static Object getSideInput(final DataUtil.IteratorWithNumBytes iterator) { - final List copy = new ArrayList(); - iterator.forEachRemaining(copy::add); - if (copy.size() == 1) { - return copy.get(0); - } else { - if (copy.get(0) instanceof Iterable) { - final List collect = new ArrayList(); - copy.forEach(element -> ((Iterable) element).iterator().forEachRemaining(collect::add)); - return collect; - } else if (copy.get(0) instanceof Map) { - final Map collect = new HashMap(); - copy.forEach(element -> { - final Set keySet = ((Map) element).keySet(); - keySet.forEach(key -> collect.put(key, ((Map) element).get(key))); - }); - return collect; - } else { - return copy; - } - } - } -} diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/IRVertexDataHandler.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/IRVertexDataHandler.java deleted file mode 100644 index 84b7a8e6..00000000 --- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/IRVertexDataHandler.java +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (C) 2018 Seoul National University - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package edu.snu.nemo.runtime.executor.datatransfer; - -import edu.snu.nemo.common.ir.vertex.IRVertex; - -import java.util.ArrayList; -import java.util.List; - -/** - * Per-Task data handler. - * This is a wrapper class that handles data transfer of a Task. - * As Task input is processed element-wise, Task output element percolates down - * through the DAG of children TaskDataHandlers. - */ -public final class IRVertexDataHandler { - private final IRVertex irVertex; - private List<IRVertexDataHandler> children; - private final List<OutputCollectorImpl> inputFromThisStage; - private final List<InputReader> sideInputFromOtherStages; - private final List<OutputCollectorImpl> sideInputFromThisStage; - private OutputCollectorImpl outputCollector; - private final List<OutputWriter> outputWriters; - - /** - * IRVertexDataHandler Constructor. - * - * @param irVertex Task of this IRVertexDataHandler. - */ - public IRVertexDataHandler(final IRVertex irVertex) { - this.irVertex = irVertex; - this.children = new ArrayList<>(); - this.inputFromThisStage = new ArrayList<>(); - this.sideInputFromOtherStages = new ArrayList<>(); - this.sideInputFromThisStage = new ArrayList<>(); - this.outputCollector = null; - this.outputWriters = new ArrayList<>(); - } - - /** - * Get the irVertex that owns this IRVertexDataHandler. - * - * @return irVertex of this IRVertexDataHandler. - */ - public IRVertex getIRVertex() { - return irVertex; - } - - /** - * Get a DAG of children tasks' TaskDataHandlers. - * - * @return DAG of children tasks' TaskDataHandlers. - */ - public List<IRVertexDataHandler> getChildren() { - return children; - } - - /** - * Get side input from other Task. - * - * @return InputReader that has side input. - */ - public List<InputReader> getSideInputFromOtherStages() { - return sideInputFromOtherStages; - } - - /** - * Get intra-Task side input from parent tasks. - * Just like normal intra-Task inputs, intra-Task side inputs are - * collected in parent tasks' OutputCollectors. - * - * @return OutputCollectors of all parent tasks which are marked as having side input. - */ - public List<OutputCollectorImpl> getSideInputFromThisStage() { - return sideInputFromThisStage; - } - - /** - * Get OutputCollector of this irVertex. - * - * @return OutputCollector of this irVertex. - */ - public OutputCollectorImpl getOutputCollector() { - return outputCollector; - } - - /** - * Get OutputWriters of this irVertex. - * - * @return OutputWriters of this irVertex. - */ - public List<OutputWriter> getOutputWriters() { - return outputWriters; - } - - /** - * Set a DAG of children tasks' DataHandlers. - * - * @param childrenDataHandler list of children TaskDataHandlers. - */ - public void setChildrenDataHandler(final List<IRVertexDataHandler> childrenDataHandler) { - children = childrenDataHandler; - } - - /** - * Add OutputCollector of a parent irVertex that will provide intra-stage input. - * - * @param input OutputCollector of a parent irVertex. - */ - public void addInputFromThisStages(final OutputCollectorImpl input) { - inputFromThisStage.add(input); - } - - /** - * Add InputReader that will provide inter-stage side input. - * - * @param sideInputReader InputReader that will provide inter-stage side input. - */ - public void addSideInputFromOtherStages(final InputReader sideInputReader) { - sideInputFromOtherStages.add(sideInputReader); - } - - /** - * Add OutputCollector of a parent irVertex that will provide intra-stage side input. - * - * @param ocAsSideInput OutputCollector of a parent irVertex with side input. - */ - public void addSideInputFromThisStage(final OutputCollectorImpl ocAsSideInput) { - sideInputFromThisStage.add(ocAsSideInput); - } - - /** - * Set OutputCollector of this irVertex. - * - * @param oc OutputCollector of this irVertex. - */ - public void setOutputCollector(final OutputCollectorImpl oc) { - outputCollector = oc; - } - - /** - * Add OutputWriter of this irVertex. - * - * @param outputWriter OutputWriter of this irVertex. - */ - public void addOutputWriter(final OutputWriter outputWriter) { - outputWriters.add(outputWriter); - } -} diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java index a01e58c4..176acdd2 100644 --- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java +++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/InputReader.java @@ -154,8 +154,8 @@ private String getBlockId(final int taskIdx) { return RuntimeIdGenerator.generateBlockId(duplicateEdgeId, taskIdx); } - public String getSrcIrVertexId() { - return srcVertex.getId(); + public IRVertex getSrcIrVertex() { + return srcVertex; } public boolean isSideInputReader() { diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java index 0588769a..16697d6e 100644 --- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java +++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java @@ -16,11 +16,9 @@ package edu.snu.nemo.runtime.executor.datatransfer; import edu.snu.nemo.common.ir.OutputCollector; -import edu.snu.nemo.runtime.common.plan.RuntimeEdge; import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.List; +import java.util.Queue; /** * OutputCollector implementation. @@ -28,17 +26,13 @@ * @param <O> output type. */ public final class OutputCollectorImpl<O> implements OutputCollector<O> { - private final ArrayDeque<O> outputQueue; - private RuntimeEdge sideInputRuntimeEdge; - private List<String> sideInputReceivers; + private final Queue<O> outputQueue; /** * Constructor of a new OutputCollectorImpl. */ public OutputCollectorImpl() { - this.outputQueue = new ArrayDeque<>(); - this.sideInputRuntimeEdge = null; - this.sideInputReceivers = new ArrayList<>(); + this.outputQueue = new ArrayDeque<>(1); } @Override @@ -69,50 +63,4 @@ public O remove() { public boolean isEmpty() { return outputQueue.isEmpty(); } - - /** - * Return the size of this OutputCollector. - * - * @return the total number of elements in this OutputCollector. - */ - public int size() { - return outputQueue.size(); - } - - /** - * Mark this edge as side input so that TaskExecutor can retrieve - * source transform using it. - * - * @param edge the RuntimeEdge to mark as side input. - */ - public void setSideInputRuntimeEdge(final RuntimeEdge edge) { - sideInputRuntimeEdge = edge; - } - - /** - * Get the RuntimeEdge marked as side input. - * - * @return the RuntimeEdge marked as side input. - */ - public RuntimeEdge getSideInputRuntimeEdge() { - return sideInputRuntimeEdge; - } - - /** - * Set this OutputCollector as having side input for the given child task. - * - * @param physicalTaskId the id of child task whose side input will be put into this OutputCollector. - */ - public void setAsSideInputFor(final String physicalTaskId) { - sideInputReceivers.add(physicalTaskId); - } - - /** - * Check if this OutputCollector has side input for the given child task. - * - * @return true if it contains side input for child task of the given id. - */ - public boolean hasSideInputFor(final String physicalTaskId) { - return sideInputReceivers.contains(physicalTaskId); - } } diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/DataFetcher.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/DataFetcher.java new file mode 100644 index 00000000..3dbc6890 --- /dev/null +++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/DataFetcher.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.runtime.executor.task; + +import edu.snu.nemo.common.ir.vertex.IRVertex; + +import java.io.IOException; +import java.util.Map; + +/** + * An abstraction for fetching data from task-external sources. + */ +abstract class DataFetcher { + private final IRVertex dataSource; + private final VertexHarness child; + private final Map<String, Object> metricMap; + private final boolean isToSideInput; + private final boolean isFromSideInput; + + DataFetcher(final IRVertex dataSource, + final VertexHarness child, + final Map<String, Object> metricMap, + final boolean isFromSideInput, + final boolean isToSideInput) { + this.dataSource = dataSource; + this.child = child; + this.metricMap = metricMap; + this.isToSideInput = isToSideInput; + this.isFromSideInput = isFromSideInput; + } + + /** + * Can block until the next data element becomes available. + * + * @return null if there's no more data element. + * @throws IOException while fetching data + */ + abstract Object fetchDataElement() throws IOException; + + protected Map<String, Object> getMetricMap() { + return metricMap; + } + + VertexHarness getChild() { + return child; + } + + public IRVertex getDataSource() { + return dataSource; + } + + boolean isFromSideInput() { + return isFromSideInput; + } + + boolean isToSideInput() { + return isToSideInput; + } +} diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcher.java new file mode 100644 index 00000000..2abb3b7f --- /dev/null +++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/ParentTaskDataFetcher.java @@ -0,0 +1,142 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.runtime.executor.task; + +import edu.snu.nemo.common.exception.BlockFetchException; +import edu.snu.nemo.common.ir.vertex.IRVertex; +import edu.snu.nemo.runtime.executor.data.DataUtil; +import edu.snu.nemo.runtime.executor.datatransfer.InputReader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.NotThreadSafe; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.LinkedBlockingQueue; + +/** + * Fetches data from parent tasks. + */ +@NotThreadSafe +class ParentTaskDataFetcher extends DataFetcher { + private static final Logger LOG = LoggerFactory.getLogger(ParentTaskDataFetcher.class); + + private final InputReader readersForParentTask; + private final LinkedBlockingQueue<DataUtil.IteratorWithNumBytes> dataQueue; + + // Non-finals (lazy fetching) + private boolean hasFetchStarted; + private int expectedNumOfIterators; + private DataUtil.IteratorWithNumBytes currentIterator; + private int currentIteratorIndex; + private boolean noElementAtAll = true; + + ParentTaskDataFetcher(final IRVertex dataSource, + final InputReader readerForParentTask, + final VertexHarness child, + final Map<String, Object> metricMap, + final boolean isFromSideInput, + final boolean isToSideInput) { + super(dataSource, child, metricMap, isFromSideInput, isToSideInput); + this.readersForParentTask = readerForParentTask; + this.hasFetchStarted = false; + this.dataQueue = new LinkedBlockingQueue<>(); + } + + private void handleMetric(final DataUtil.IteratorWithNumBytes iterator) { + long serBytes = 0; + long encodedBytes = 0; + try { + serBytes += iterator.getNumSerializedBytes(); + } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) { + serBytes = -1; + } catch (final IllegalStateException e) { + LOG.error("Failed to get the number of bytes of serialized data - the data is not ready yet ", e); + } + try { + encodedBytes += iterator.getNumEncodedBytes(); + } catch (final DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) { + encodedBytes = -1; + } catch (final IllegalStateException e) { + LOG.error("Failed to get the number of bytes of encoded data - the data is not ready yet ", e); + } + if (serBytes != encodedBytes) { + getMetricMap().put("ReadBytes(raw)", serBytes); + } + getMetricMap().put("ReadBytes", encodedBytes); + } + + /** + * Blocking call. + */ + private void fetchInBackground() { + final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = readersForParentTask.read(); + this.expectedNumOfIterators = futures.size(); + + futures.forEach(compFuture -> compFuture.whenComplete((iterator, exception) -> { + if (exception != null) { + throw new BlockFetchException(exception); + } + + try { + dataQueue.put(iterator); // can block here + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new BlockFetchException(e); + } + })); + } + + @Override + Object fetchDataElement() throws IOException { + try { + if (!hasFetchStarted) { + fetchInBackground(); + hasFetchStarted = true; + this.currentIterator = dataQueue.take(); + this.currentIteratorIndex = 1; + } + + if (this.currentIterator.hasNext()) { + noElementAtAll = false; + return this.currentIterator.next(); + } else { + // This iterator is done, proceed to the next iterator + if (currentIteratorIndex == expectedNumOfIterators) { + // No more iterator left + if (noElementAtAll) { + // This shouldn't normally happen, except for cases such as when Beam's VoidCoder is used. + noElementAtAll = false; + return Void.TYPE; + } else { + // This whole fetcher's done + return null; + } + } else { + handleMetric(currentIterator); + // Try the next iterator + this.currentIteratorIndex += 1; + this.currentIterator = dataQueue.take(); + return fetchDataElement(); + } + } + } catch (InterruptedException exception) { + throw new IOException(exception); + } + } +} diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/SourceVertexDataFetcher.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/SourceVertexDataFetcher.java new file mode 100644 index 00000000..998df63e --- /dev/null +++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/SourceVertexDataFetcher.java @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.runtime.executor.task; + +import edu.snu.nemo.common.ir.Readable; +import edu.snu.nemo.common.ir.vertex.IRVertex; + +import java.io.IOException; +import java.util.Iterator; +import java.util.Map; + +/** + * Fetches data from a data source. + */ +class SourceVertexDataFetcher extends DataFetcher { + private final Readable readable; + + // Non-finals (lazy fetching) + private Iterator iterator; + + SourceVertexDataFetcher(final IRVertex dataSource, + final Readable readable, + final VertexHarness child, + final Map<String, Object> metricMap, + final boolean isFromSideInput, + final boolean isToSideInput) { + super(dataSource, child, metricMap, isFromSideInput, isToSideInput); + this.readable = readable; + } + + @Override + Object fetchDataElement() throws IOException { + if (iterator == null) { + final long start = System.currentTimeMillis(); + iterator = this.readable.read().iterator(); + getMetricMap().put("BoundedSourceReadTime(ms)", System.currentTimeMillis() - start); + } + + if (iterator.hasNext()) { + return iterator.next(); + } else { + return null; + } + } +} diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java new file mode 100644 index 00000000..27931a03 --- /dev/null +++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java @@ -0,0 +1,450 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.runtime.executor.task; + +import com.google.common.collect.Lists; +import edu.snu.nemo.common.ContextImpl; +import edu.snu.nemo.common.Pair; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.ir.Readable; +import edu.snu.nemo.common.ir.vertex.*; +import edu.snu.nemo.common.ir.vertex.transform.Transform; +import edu.snu.nemo.runtime.common.RuntimeIdGenerator; +import edu.snu.nemo.runtime.common.plan.Task; +import edu.snu.nemo.runtime.common.plan.StageEdge; +import edu.snu.nemo.runtime.common.plan.RuntimeEdge; +import edu.snu.nemo.runtime.common.state.TaskState; +import edu.snu.nemo.runtime.executor.MetricCollector; +import edu.snu.nemo.runtime.executor.MetricMessageSender; +import edu.snu.nemo.runtime.executor.TaskStateManager; +import edu.snu.nemo.runtime.executor.datatransfer.*; + +import java.io.IOException; +import java.util.*; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.NotThreadSafe; + +/** + * Executes a task. + * Should be accessed by a single thread. + */ +@NotThreadSafe +public final class TaskExecutor { + private static final Logger LOG = LoggerFactory.getLogger(TaskExecutor.class.getName()); + private static final int NONE_FINISHED = -1; + + // Essential information + private boolean isExecuted; + private final String taskId; + private final TaskStateManager taskStateManager; + private final List<DataFetcher> dataFetchers; + private final List<VertexHarness> sortedHarnesses; + private final Map sideInputMap; + + // Metrics information + private final Map<String, Object> metricMap; + private final MetricCollector metricCollector; + + // Dynamic optimization + private String idOfVertexPutOnHold; + + /** + * Constructor. + * @param task Task with information needed during execution. + * @param irVertexDag A DAG of vertices. + * @param taskStateManager State manager for this Task. + * @param dataTransferFactory For reading from/writing to data to other tasks. + * @param metricMessageSender For sending metric with execution stats to Master. + */ + public TaskExecutor(final Task task, + final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag, + final TaskStateManager taskStateManager, + final DataTransferFactory dataTransferFactory, + final MetricMessageSender metricMessageSender) { + // Essential information + this.isExecuted = false; + this.taskId = task.getTaskId(); + this.taskStateManager = taskStateManager; + + // Metrics information + this.metricMap = new HashMap<>(); + this.metricCollector = new MetricCollector(metricMessageSender); + + // Dynamic optimization + // Assigning null is very bad, but we are keeping this for now + this.idOfVertexPutOnHold = null; + + // Prepare data structures + this.sideInputMap = new HashMap(); + final Pair<List<DataFetcher>, List<VertexHarness>> pair = prepare(task, irVertexDag, dataTransferFactory); + this.dataFetchers = pair.left(); + this.sortedHarnesses = pair.right(); + } + + /** + * Converts the DAG of vertices into pointer-based DAG of vertex harnesses. + * This conversion is necessary for constructing concrete data channels for each vertex's inputs and outputs. + * + * - Source vertex read: Explicitly handled (SourceVertexDataFetcher) + * - Sink vertex write: Implicitly handled within the vertex + * + * - Parent-task read: Explicitly handled (ParentTaskDataFetcher) + * - Children-task write: Explicitly handled (VertexHarness) + * + * - Intra-task read: Implicitly handled when performing Intra-task writes + * - Intra-task write: Explicitly handled (VertexHarness) + + * For element-wise data processing, we traverse vertex harnesses from the roots to the leaves for each element. + * This means that overheads associated with jumping from one harness to the other should be minimal. + * For example, we should never perform an expensive hash operation to traverse the harnesses. + * + * @param task task. + * @param irVertexDag dag. + * @return fetchers and harnesses. + */ + private Pair<List<DataFetcher>, List<VertexHarness>> prepare(final Task task, + final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag, + final DataTransferFactory dataTransferFactory) { + final int taskIndex = RuntimeIdGenerator.getIndexFromTaskId(task.getTaskId()); + + // Traverse in a reverse-topological order to ensure that each visited vertex's children vertices exist. + final List<IRVertex> reverseTopologicallySorted = Lists.reverse(irVertexDag.getTopologicalSort()); + + // Create a harness for each vertex + final List<DataFetcher> dataFetcherList = new ArrayList<>(); + final Map<String, VertexHarness> vertexIdToHarness = new HashMap<>(); + reverseTopologicallySorted.forEach(irVertex -> { + final List<VertexHarness> children = getChildrenHarnesses(irVertex, irVertexDag, vertexIdToHarness); + final Optional<Readable> sourceReader = getSourceVertexReader(irVertex, task.getIrVertexIdToReadable()); + if (sourceReader.isPresent() != irVertex instanceof SourceVertex) { + throw new IllegalStateException(irVertex.toString()); + } + + final List<Boolean> isToSideInputs = children.stream() + .map(VertexHarness::getIRVertex) + .map(childVertex -> irVertexDag.getEdgeBetween(irVertex.getId(), childVertex.getId())) + .map(RuntimeEdge::isSideInput) + .collect(Collectors.toList()); + + // Handle writes + final List<OutputWriter> childrenTaskWriters = getChildrenTaskWriters( + taskIndex, irVertex, task.getTaskOutgoingEdges(), dataTransferFactory); // Children-task write + final VertexHarness vertexHarness = new VertexHarness(irVertex, new OutputCollectorImpl(), children, + isToSideInputs, childrenTaskWriters, new ContextImpl(sideInputMap)); // Intra-vertex write + prepareTransform(vertexHarness); + vertexIdToHarness.put(irVertex.getId(), vertexHarness); + + // Handle reads + final boolean isToSideInput = isToSideInputs.stream().anyMatch(bool -> bool); + if (irVertex instanceof SourceVertex) { + dataFetcherList.add(new SourceVertexDataFetcher(irVertex, sourceReader.get(), vertexHarness, metricMap, + false, isToSideInput)); // Source vertex read + } + final List<InputReader> parentTaskReaders = + getParentTaskReaders(taskIndex, irVertex, task.getTaskIncomingEdges(), dataTransferFactory); + parentTaskReaders.forEach(parentTaskReader -> { + final boolean isFromSideInput = parentTaskReader.isSideInputReader(); + dataFetcherList.add(new ParentTaskDataFetcher(parentTaskReader.getSrcIrVertex(), parentTaskReader, + vertexHarness, metricMap, isFromSideInput, isToSideInput)); // Parent-task read + }); + }); + + final List<VertexHarness> sortedHarnessList = irVertexDag.getTopologicalSort() + .stream() + .map(vertex -> vertexIdToHarness.get(vertex.getId())) + .collect(Collectors.toList()); + + return Pair.of(dataFetcherList, sortedHarnessList); + } + + /** + * Recursively process a data element down the DAG dependency. + * @param vertexHarness VertexHarness of a vertex to execute. + * @param dataElement input data element to process. + */ + private void processElementRecursively(final VertexHarness vertexHarness, final Object dataElement) { + final IRVertex irVertex = vertexHarness.getIRVertex(); + final OutputCollectorImpl outputCollector = vertexHarness.getOutputCollector(); + if (irVertex instanceof SourceVertex) { + outputCollector.emit(dataElement); + } else if (irVertex instanceof OperatorVertex) { + final Transform transform = ((OperatorVertex) irVertex).getTransform(); + transform.onData(dataElement); + } else if (irVertex instanceof MetricCollectionBarrierVertex) { + outputCollector.emit(dataElement); + setIRVertexPutOnHold((MetricCollectionBarrierVertex) irVertex); + } else { + throw new UnsupportedOperationException("This type of IRVertex is not supported"); + } + + // Given a single input element, a vertex can produce many output elements. + // Here, we recursively process all of the output elements. + while (!outputCollector.isEmpty()) { + final Object element = outputCollector.remove(); + handleOutputElement(vertexHarness, element); // Recursion + } + } + + /** + * Execute a task, while handling unrecoverable errors and exceptions. + */ + public void execute() { + try { + doExecute(); + } catch (Throwable throwable) { + // ANY uncaught throwable is reported to the master + taskStateManager.onTaskStateChanged(TaskState.State.FAILED_UNRECOVERABLE, Optional.empty(), Optional.empty()); + throwable.printStackTrace(); + } + } + + /** + * The task is executed in the following two phases. + * - Phase 1: Consume task-external side-input data + * - Phase 2: Consume task-external input data + * - Phase 3: Finalize task-internal states and data elements + */ + private void doExecute() { + // Housekeeping stuff + if (isExecuted) { + throw new RuntimeException("Task {" + taskId + "} execution called again"); + } + LOG.info("{} started", taskId); + taskStateManager.onTaskStateChanged(TaskState.State.EXECUTING, Optional.empty(), Optional.empty()); + metricCollector.beginMeasurement(taskId, metricMap); + + // Phase 1: Consume task-external side-input related data. + final Map<Boolean, List<DataFetcher>> sideInputRelated = dataFetchers.stream() + .collect(Collectors.partitioningBy(fetcher -> fetcher.isFromSideInput() || fetcher.isToSideInput())); + if (!handleDataFetchers(sideInputRelated.get(true))) { + return; + } + final Set<VertexHarness> finalizeLater = sideInputRelated.get(false).stream() + .map(DataFetcher::getChild) + .flatMap(vertex -> getAllReachables(vertex).stream()) + .collect(Collectors.toSet()); + for (final VertexHarness vertexHarness : sortedHarnesses) { + if (!finalizeLater.contains(vertexHarness)) { + finalizeVertex(vertexHarness); // finalize early to materialize intra-task side inputs. + } + } + + // Phase 2: Consume task-external input data. + if (!handleDataFetchers(sideInputRelated.get(false))) { + return; + } + + // Phase 3: Finalize task-internal states and elements + for (final VertexHarness vertexHarness : sortedHarnesses) { + if (finalizeLater.contains(vertexHarness)) { + finalizeVertex(vertexHarness); + } + } + + // Miscellaneous: Metrics, DynOpt, etc + metricCollector.endMeasurement(taskId, metricMap); + if (idOfVertexPutOnHold == null) { + taskStateManager.onTaskStateChanged(TaskState.State.COMPLETE, Optional.empty(), Optional.empty()); + LOG.info("{} completed", taskId); + } else { + taskStateManager.onTaskStateChanged(TaskState.State.ON_HOLD, + Optional.of(idOfVertexPutOnHold), + Optional.empty()); + LOG.info("{} on hold", taskId); + } + } + + private List<VertexHarness> getAllReachables(final VertexHarness src) { + final List<VertexHarness> result = new ArrayList<>(); + result.add(src); + result.addAll(src.getNonSideInputChildren().stream() + .flatMap(child -> getAllReachables(child).stream()).collect(Collectors.toList())); + result.addAll(src.getSideInputChildren().stream() + .flatMap(child -> getAllReachables(child).stream()).collect(Collectors.toList())); + return result; + } + + private void finalizeVertex(final VertexHarness vertexHarness) { + closeTransform(vertexHarness); + while (!vertexHarness.getOutputCollector().isEmpty()) { + final Object element = vertexHarness.getOutputCollector().remove(); + handleOutputElement(vertexHarness, element); + } + finalizeOutputWriters(vertexHarness); + } + + private void handleOutputElement(final VertexHarness vertexHarness, final Object element) { + vertexHarness.getWritersToChildrenTasks().forEach(outputWriter -> outputWriter.write(element)); + if (vertexHarness.getSideInputChildren().size() > 0) { + sideInputMap.put(((OperatorVertex) vertexHarness.getIRVertex()).getTransform().getTag(), element); + } + vertexHarness.getNonSideInputChildren().forEach(child -> processElementRecursively(child, element)); + } + + /** + * @param fetchers to handle. + * @return false if IOException. + */ + private boolean handleDataFetchers(final List<DataFetcher> fetchers) { + final List<DataFetcher> availableFetchers = new ArrayList<>(fetchers); + int finishedFetcherIndex = NONE_FINISHED; + while (!availableFetchers.isEmpty()) { // empty means we've consumed all task-external input data + for (int i = 0; i < availableFetchers.size(); i++) { + final DataFetcher dataFetcher = fetchers.get(i); + final Object element; + try { + element = dataFetcher.fetchDataElement(); + } catch (IOException e) { + taskStateManager.onTaskStateChanged(TaskState.State.FAILED_RECOVERABLE, + Optional.empty(), Optional.of(TaskState.RecoverableFailureCause.INPUT_READ_FAILURE)); + LOG.error("{} Execution Failed (Recoverable: input read failure)! Exception: {}", taskId, e.toString()); + return false; + } + + if (element == null) { + finishedFetcherIndex = i; + break; + } else { + if (dataFetcher.isFromSideInput()) { + sideInputMap.put(((OperatorVertex) dataFetcher.getDataSource()).getTransform().getTag(), element); + } else { + processElementRecursively(dataFetcher.getChild(), element); + } + } + } + + // Remove the finished fetcher from the list + if (finishedFetcherIndex != NONE_FINISHED) { + availableFetchers.remove(finishedFetcherIndex); + } + } + return true; + } + + ////////////////////////////////////////////// Helper methods for setting up initial data structures + + private Optional<Readable> getSourceVertexReader(final IRVertex irVertex, + final Map<String, Readable> irVertexIdToReadable) { + if (irVertex instanceof SourceVertex) { + final Readable readable = irVertexIdToReadable.get(irVertex.getId()); + if (readable == null) { + throw new IllegalStateException(irVertex.toString()); + } + return Optional.of(readable); + } else { + return Optional.empty(); + } + } + + private List<InputReader> getParentTaskReaders(final int taskIndex, + final IRVertex irVertex, + final List<StageEdge> inEdgesFromParentTasks, + final DataTransferFactory dataTransferFactory) { + return inEdgesFromParentTasks + .stream() + .filter(inEdge -> inEdge.getDstVertex().getId().equals(irVertex.getId())) + .map(inEdgeForThisVertex -> dataTransferFactory + .createReader(taskIndex, inEdgeForThisVertex.getSrcVertex(), inEdgeForThisVertex)) + .collect(Collectors.toList()); + } + + private List<OutputWriter> getChildrenTaskWriters(final int taskIndex, + final IRVertex irVertex, + final List<StageEdge> outEdgesToChildrenTasks, + final DataTransferFactory dataTransferFactory) { + return outEdgesToChildrenTasks + .stream() + .filter(outEdge -> outEdge.getSrcVertex().getId().equals(irVertex.getId())) + .map(outEdgeForThisVertex -> dataTransferFactory + .createWriter(irVertex, taskIndex, outEdgeForThisVertex.getDstVertex(), outEdgeForThisVertex)) + .collect(Collectors.toList()); + } + + private List<VertexHarness> getChildrenHarnesses(final IRVertex irVertex, + final DAG<IRVertex, RuntimeEdge<IRVertex>> irVertexDag, + final Map<String, VertexHarness> vertexIdToHarness) { + final List<VertexHarness> childrenHandlers = irVertexDag.getChildren(irVertex.getId()) + .stream() + .map(IRVertex::getId) + .map(vertexIdToHarness::get) + .collect(Collectors.toList()); + if (childrenHandlers.stream().anyMatch(harness -> harness == null)) { + // Sanity check: there shouldn't be a null harness. + throw new IllegalStateException(childrenHandlers.toString()); + } + return childrenHandlers; + } + + ////////////////////////////////////////////// Transform-specific helper methods + + private void prepareTransform(final VertexHarness vertexHarness) { + final IRVertex irVertex = vertexHarness.getIRVertex(); + if (irVertex instanceof OperatorVertex) { + final Transform transform = ((OperatorVertex) irVertex).getTransform(); + transform.prepare(vertexHarness.getContext(), vertexHarness.getOutputCollector()); + } + } + + private void closeTransform(final VertexHarness vertexHarness) { + final IRVertex irVertex = vertexHarness.getIRVertex(); + if (irVertex instanceof OperatorVertex) { + Transform transform = ((OperatorVertex) irVertex).getTransform(); + transform.close(); + } + } + + ////////////////////////////////////////////// Misc + + private void setIRVertexPutOnHold(final MetricCollectionBarrierVertex irVertex) { + idOfVertexPutOnHold = irVertex.getId(); + } + + /** + * Finalize the output write of this vertex. + * As element-wise output write is done and the block is in memory, + * flush the block into the designated data store and commit it. + * @param vertexHarness harness. + */ + private void finalizeOutputWriters(final VertexHarness vertexHarness) { + final List<Long> writtenBytesList = new ArrayList<>(); + final Map<String, Object> metric = new HashMap<>(); + final IRVertex irVertex = vertexHarness.getIRVertex(); + + metricCollector.beginMeasurement(irVertex.getId(), metric); + final long writeStartTime = System.currentTimeMillis(); + + vertexHarness.getWritersToChildrenTasks().forEach(outputWriter -> { + outputWriter.close(); + final Optional<Long> writtenBytes = outputWriter.getWrittenBytes(); + writtenBytes.ifPresent(writtenBytesList::add); + }); + + final long writeEndTime = System.currentTimeMillis(); + metric.put("OutputWriteTime(ms)", writeEndTime - writeStartTime); + if (!writtenBytesList.isEmpty()) { + long totalWrittenBytes = 0; + for (final Long writtenBytes : writtenBytesList) { + totalWrittenBytes += writtenBytes; + } + metricMap.put("WrittenBytes", totalWrittenBytes); + } + metricCollector.endMeasurement(irVertex.getId(), metric); + } + +} diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java new file mode 100644 index 00000000..2d915c44 --- /dev/null +++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java @@ -0,0 +1,108 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.runtime.executor.task; + +import edu.snu.nemo.common.ir.vertex.IRVertex; +import edu.snu.nemo.common.ir.vertex.transform.Transform; +import edu.snu.nemo.runtime.executor.datatransfer.OutputCollectorImpl; +import edu.snu.nemo.runtime.executor.datatransfer.OutputWriter; + +import java.util.ArrayList; +import java.util.List; + +/** + * Captures the relationship between a non-source IRVertex's outputCollector, and children vertices. + */ +final class VertexHarness { + // IRVertex and transform-specific information + private final IRVertex irVertex; + private final OutputCollectorImpl outputCollector; + private final Transform.Context context; + + // These lists can be empty + private final List<VertexHarness> sideInputChildren; + private final List<VertexHarness> nonSideInputChildren; + private final List<OutputWriter> writersToChildrenTasks; + + VertexHarness(final IRVertex irVertex, + final OutputCollectorImpl outputCollector, + final List<VertexHarness> children, + final List<Boolean> isSideInputs, + final List<OutputWriter> writersToChildrenTasks, + final Transform.Context context) { + this.irVertex = irVertex; + this.outputCollector = outputCollector; + if (children.size() != isSideInputs.size()) { + throw new IllegalStateException(irVertex.toString()); + } + final List<VertexHarness> sides = new ArrayList<>(); + final List<VertexHarness> nonSides = new ArrayList<>(); + for (int i = 0; i < children.size(); i++) { + final VertexHarness child = children.get(i); + if (isSideInputs.get(0)) { + sides.add(child); + } else { + nonSides.add(child); + } + } + this.sideInputChildren = sides; + this.nonSideInputChildren = nonSides; + this.writersToChildrenTasks = writersToChildrenTasks; + this.context = context; + } + + /** + * @return irVertex of this VertexHarness. + */ + IRVertex getIRVertex() { + return irVertex; + } + + /** + * @return OutputCollector of this irVertex. + */ + OutputCollectorImpl getOutputCollector() { + return outputCollector; + } + + /** + * @return list of non-sideinput children. (empty if none exists) + */ + List<VertexHarness> getNonSideInputChildren() { + return nonSideInputChildren; + } + + /** + * @return list of sideinput children. (empty if none exists) + */ + List<VertexHarness> getSideInputChildren() { + return sideInputChildren; + } + + /** + * @return OutputWriters of this irVertex. (empty if none exists) + */ + List<OutputWriter> getWritersToChildrenTasks() { + return writersToChildrenTasks; + } + + /** + * @return context. + */ + Transform.Context getContext() { + return context; + } +} diff --git a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java new file mode 100644 index 00000000..f2c0082e --- /dev/null +++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java @@ -0,0 +1,426 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.nemo.runtime.executor.task; + +import edu.snu.nemo.common.Pair; +import edu.snu.nemo.common.ir.OutputCollector; +import edu.snu.nemo.common.coder.Coder; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.dag.DAGBuilder; +import edu.snu.nemo.common.ir.Readable; +import edu.snu.nemo.common.ir.vertex.InMemorySourceVertex; +import edu.snu.nemo.common.ir.vertex.OperatorVertex; +import edu.snu.nemo.common.ir.vertex.transform.Transform; +import edu.snu.nemo.common.ir.edge.executionproperty.DataStoreProperty; +import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap; +import edu.snu.nemo.common.ir.vertex.IRVertex; +import edu.snu.nemo.runtime.common.RuntimeIdGenerator; +import edu.snu.nemo.runtime.common.plan.Task; +import edu.snu.nemo.runtime.common.plan.StageEdge; +import edu.snu.nemo.runtime.common.plan.RuntimeEdge; +import edu.snu.nemo.runtime.executor.MetricMessageSender; +import edu.snu.nemo.runtime.executor.TaskStateManager; +import edu.snu.nemo.runtime.executor.data.DataUtil; +import edu.snu.nemo.runtime.executor.datatransfer.DataTransferFactory; +import edu.snu.nemo.runtime.executor.datatransfer.InputReader; +import edu.snu.nemo.runtime.executor.datatransfer.OutputWriter; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.io.IOException; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.*; + +/** + * Tests {@link TaskExecutor}. + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest({InputReader.class, OutputWriter.class, DataTransferFactory.class, + TaskStateManager.class, StageEdge.class}) +public final class TaskExecutorTest { + private static final int DATA_SIZE = 100; + private static final String CONTAINER_TYPE = "CONTAINER_TYPE"; + private static final int SOURCE_PARALLELISM = 5; + private List<Integer> elements; + private Map<String, List> vertexIdToOutputData; + private DataTransferFactory dataTransferFactory; + private TaskStateManager taskStateManager; + private MetricMessageSender metricMessageSender; + private AtomicInteger stageId; + + private String generateTaskId() { + return RuntimeIdGenerator.generateTaskId(0, + RuntimeIdGenerator.generateStageId(stageId.getAndIncrement())); + } + + @Before + public void setUp() throws Exception { + elements = getRangedNumList(0, DATA_SIZE); + stageId = new AtomicInteger(1); + + // Mock a TaskStateManager. It accumulates the state change into a list. + taskStateManager = mock(TaskStateManager.class); + + // Mock a DataTransferFactory. + vertexIdToOutputData = new HashMap<>(); + dataTransferFactory = mock(DataTransferFactory.class); + when(dataTransferFactory.createReader(anyInt(), any(), any())).then(new ParentTaskReaderAnswer()); + when(dataTransferFactory.createWriter(any(), anyInt(), any(), any())).then(new ChildTaskWriterAnswer()); + + // Mock a MetricMessageSender. + metricMessageSender = mock(MetricMessageSender.class); + doNothing().when(metricMessageSender).send(anyString(), anyString()); + doNothing().when(metricMessageSender).close(); + } + + private boolean checkEqualElements(final List<Integer> left, final List<Integer> right) { + Collections.sort(left); + Collections.sort(right); + return left.equals(right); + } + + /** + * Test source vertex data fetching. + */ + @Test(timeout=5000) + public void testSourceVertexDataFetching() throws Exception { + final IRVertex sourceIRVertex = new InMemorySourceVertex<>(elements); + + final Readable readable = new Readable() { + @Override + public Iterable read() throws IOException { + return elements; + } + @Override + public List<String> getLocations() { + throw new UnsupportedOperationException(); + } + }; + final Map<String, Readable> vertexIdToReadable = new HashMap<>(); + vertexIdToReadable.put(sourceIRVertex.getId(), readable); + + final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = + new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>() + .addVertex(sourceIRVertex) + .buildWithoutSourceSinkCheck(); + + final Task task = + new Task( + "testSourceVertexDataFetching", + generateTaskId(), + 0, + CONTAINER_TYPE, + new byte[0], + Collections.emptyList(), + Collections.singletonList(mockStageEdgeFrom(sourceIRVertex)), + vertexIdToReadable); + + // Execute the task. + final TaskExecutor taskExecutor = new TaskExecutor( + task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender); + taskExecutor.execute(); + + // Check the output. + assertTrue(checkEqualElements(elements, vertexIdToOutputData.get(sourceIRVertex.getId()))); + } + + /** + * Test parent task data fetching. + */ + @Test(timeout=5000) + public void testParentTaskDataFetching() throws Exception { + final IRVertex vertex = new OperatorVertex(new RelayTransform()); + + final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>() + .addVertex(vertex) + .buildWithoutSourceSinkCheck(); + + final Task task = new Task( + "testSourceVertexDataFetching", + generateTaskId(), + 0, + CONTAINER_TYPE, + new byte[0], + Collections.singletonList(mockStageEdgeTo(vertex)), + Collections.singletonList(mockStageEdgeFrom(vertex)), + Collections.emptyMap()); + + // Execute the task. + final TaskExecutor taskExecutor = new TaskExecutor( + task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender); + taskExecutor.execute(); + + // Check the output. + assertTrue(checkEqualElements(elements, vertexIdToOutputData.get(vertex.getId()))); + } + + /** + * The DAG of the task to test will looks like: + * parent task -> task (vertex 1 -> task 2) -> child task + * + * The output data from task 1 will be split according to source parallelism through {@link ParentTaskReaderAnswer}. + * Because of this, task 1 will process multiple partitions and emit data in multiple times also. + * On the other hand, task 2 will receive the output data once and produce a single output. + */ + @Test(timeout=5000) + public void testTwoOperators() throws Exception { + final IRVertex operatorIRVertex1 = new OperatorVertex(new RelayTransform()); + final IRVertex operatorIRVertex2 = new OperatorVertex(new RelayTransform()); + + final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>() + .addVertex(operatorIRVertex1) + .addVertex(operatorIRVertex2) + .connectVertices(createEdge(operatorIRVertex1, operatorIRVertex2, false)) + .buildWithoutSourceSinkCheck(); + + final Task task = new Task( + "testSourceVertexDataFetching", + generateTaskId(), + 0, + CONTAINER_TYPE, + new byte[0], + Collections.singletonList(mockStageEdgeTo(operatorIRVertex1)), + Collections.singletonList(mockStageEdgeFrom(operatorIRVertex2)), + Collections.emptyMap()); + + // Execute the task. + final TaskExecutor taskExecutor = new TaskExecutor( + task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender); + taskExecutor.execute(); + + // Check the output. + assertTrue(checkEqualElements(elements, vertexIdToOutputData.get(operatorIRVertex2.getId()))); + } + + @Test(timeout=5000) + public void testTwoOperatorsWithSideInput() throws Exception { + final Object tag = new Object(); + final Transform singleListTransform = new CreateSingleListTransform(); + final IRVertex operatorIRVertex1 = new OperatorVertex(singleListTransform); + final IRVertex operatorIRVertex2 = new OperatorVertex(new SideInputPairTransform(singleListTransform.getTag())); + + final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>() + .addVertex(operatorIRVertex1) + .addVertex(operatorIRVertex2) + .connectVertices(createEdge(operatorIRVertex1, operatorIRVertex2, true)) + .buildWithoutSourceSinkCheck(); + + final Task task = new Task( + "testSourceVertexDataFetching", + generateTaskId(), + 0, + CONTAINER_TYPE, + new byte[0], + Arrays.asList(mockStageEdgeTo(operatorIRVertex1), mockStageEdgeTo(operatorIRVertex2)), + Collections.singletonList(mockStageEdgeFrom(operatorIRVertex2)), + Collections.emptyMap()); + + // Execute the task. + final TaskExecutor taskExecutor = new TaskExecutor( + task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender); + taskExecutor.execute(); + + // Check the output. + final List<Pair<List<Integer>, Integer>> pairs = vertexIdToOutputData.get(operatorIRVertex2.getId()); + final List<Integer> values = pairs.stream().map(Pair::right).collect(Collectors.toList()); + assertTrue(checkEqualElements(elements, values)); + assertTrue(pairs.stream().map(Pair::left).allMatch(sideInput -> checkEqualElements(sideInput, values))); + } + + private RuntimeEdge<IRVertex> createEdge(final IRVertex src, + final IRVertex dst, + final boolean isSideInput) { + final String runtimeIREdgeId = "Runtime edge between operator tasks"; + final Coder coder = Coder.DUMMY_CODER; + ExecutionPropertyMap edgeProperties = new ExecutionPropertyMap(runtimeIREdgeId); + edgeProperties.put(DataStoreProperty.of(DataStoreProperty.Value.MemoryStore)); + return new RuntimeEdge<>(runtimeIREdgeId, edgeProperties, src, dst, coder, isSideInput); + + } + + private StageEdge mockStageEdgeFrom(final IRVertex irVertex) { + final StageEdge edge = mock(StageEdge.class); + when(edge.getSrcVertex()).thenReturn(irVertex); + when(edge.getDstVertex()).thenReturn(new OperatorVertex(new RelayTransform())); + return edge; + } + + private StageEdge mockStageEdgeTo(final IRVertex irVertex) { + final StageEdge edge = mock(StageEdge.class); + when(edge.getSrcVertex()).thenReturn(new OperatorVertex(new RelayTransform())); + when(edge.getDstVertex()).thenReturn(irVertex); + return edge; + } + + /** + * Represents the answer return an inter-stage {@link InputReader}, + * which will have multiple iterable according to the source parallelism. + */ + private class ParentTaskReaderAnswer implements Answer<InputReader> { + @Override + public InputReader answer(final InvocationOnMock invocationOnMock) throws Throwable { + final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> inputFutures = new ArrayList<>(SOURCE_PARALLELISM); + final int elementsPerSource = DATA_SIZE / SOURCE_PARALLELISM; + for (int i = 0; i < SOURCE_PARALLELISM; i++) { + inputFutures.add(CompletableFuture.completedFuture( + DataUtil.IteratorWithNumBytes.of(elements.subList(i * elementsPerSource, (i + 1) * elementsPerSource) + .iterator()))); + } + final InputReader inputReader = mock(InputReader.class); + when(inputReader.read()).thenReturn(inputFutures); + when(inputReader.isSideInputReader()).thenReturn(false); + when(inputReader.getSourceParallelism()).thenReturn(SOURCE_PARALLELISM); + return inputReader; + } + } + + /** + * Represents the answer return a {@link OutputWriter}, + * which will stores the data to the map between task id and output data. + */ + private class ChildTaskWriterAnswer implements Answer<OutputWriter> { + @Override + public OutputWriter answer(final InvocationOnMock invocationOnMock) throws Throwable { + final Object[] args = invocationOnMock.getArguments(); + final IRVertex vertex = (IRVertex) args[0]; + final OutputWriter outputWriter = mock(OutputWriter.class); + doAnswer(new Answer() { + @Override + public Object answer(final InvocationOnMock invocationOnMock) throws Throwable { + final Object[] args = invocationOnMock.getArguments(); + final Object dataToWrite = args[0]; + vertexIdToOutputData.computeIfAbsent(vertex.getId(), emptyTaskId -> new ArrayList<>()); + vertexIdToOutputData.get(vertex.getId()).add(dataToWrite); + return null; + } + }).when(outputWriter).write(any()); + return outputWriter; + } + } + + /** + * Simple identity function for testing. + * @param <T> input/output type. + */ + private class RelayTransform<T> implements Transform<T, T> { + private OutputCollector<T> outputCollector; + + @Override + public void prepare(final Context context, final OutputCollector<T> outputCollector) { + this.outputCollector = outputCollector; + } + + @Override + public void onData(final Object element) { + outputCollector.emit((T) element); + } + + @Override + public void close() { + // Do nothing. + } + } + + /** + * Creates a view. + * @param <T> input type. + */ + private class CreateSingleListTransform<T> implements Transform<T, List<T>> { + private List<T> list; + private OutputCollector<List<T>> outputCollector; + private final Object tag = new Object(); + + @Override + public void prepare(final Context context, final OutputCollector<List<T>> outputCollector) { + this.list = new ArrayList<>(); + this.outputCollector = outputCollector; + } + + @Override + public void onData(final Object element) { + list.add((T) element); + } + + @Override + public void close() { + outputCollector.emit(list); + } + + @Override + public Object getTag() { + return tag; + } + } + + /** + * Pairs data element with a side input. + * @param <T> input/output type. + */ + private class SideInputPairTransform<T> implements Transform<T, T> { + private final Object sideInputTag; + private Context context; + private OutputCollector<T> outputCollector; + + public SideInputPairTransform(final Object sideInputTag) { + this.sideInputTag = sideInputTag; + } + + @Override + public void prepare(final Context context, final OutputCollector<T> outputCollector) { + this.context = context; + this.outputCollector = outputCollector; + } + + @Override + public void onData(final Object element) { + final Object sideInput = context.getSideInputs().get(sideInputTag); + outputCollector.emit((T) Pair.of(sideInput, element)); + } + + @Override + public void close() { + // Do nothing. + } + } + + /** + * Gets a list of integer pair elements in range. + * @param start value of the range (inclusive). + * @param end value of the range (exclusive). + * @return the list of elements. + */ + private List<Integer> getRangedNumList(final int start, final int end) { + final List<Integer> numList = new ArrayList<>(end - start); + IntStream.range(start, end).forEach(number -> numList.add(number)); + return numList; + } +} diff --git a/tests/src/test/java/edu/snu/nemo/tests/runtime/executor/TaskExecutorTest.java b/tests/src/test/java/edu/snu/nemo/tests/runtime/executor/TaskExecutorTest.java deleted file mode 100644 index ffbfef68..00000000 --- a/tests/src/test/java/edu/snu/nemo/tests/runtime/executor/TaskExecutorTest.java +++ /dev/null @@ -1,280 +0,0 @@ -/* - * Copyright (C) 2017 Seoul National University - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package edu.snu.nemo.tests.runtime.executor; - -import edu.snu.nemo.common.ir.OutputCollector; -import edu.snu.nemo.common.coder.Coder; -import edu.snu.nemo.common.dag.DAG; -import edu.snu.nemo.common.dag.DAGBuilder; -import edu.snu.nemo.common.ir.Readable; -import edu.snu.nemo.common.ir.vertex.OperatorVertex; -import edu.snu.nemo.common.ir.vertex.transform.Transform; -import edu.snu.nemo.common.ir.edge.executionproperty.DataStoreProperty; -import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap; -import edu.snu.nemo.common.ir.vertex.IRVertex; -import edu.snu.nemo.compiler.optimizer.examples.EmptyComponents; -import edu.snu.nemo.runtime.common.RuntimeIdGenerator; -import edu.snu.nemo.runtime.common.plan.Task; -import edu.snu.nemo.runtime.common.plan.StageEdge; -import edu.snu.nemo.runtime.common.plan.RuntimeEdge; -import edu.snu.nemo.runtime.executor.MetricMessageSender; -import edu.snu.nemo.runtime.executor.TaskExecutor; -import edu.snu.nemo.runtime.executor.TaskStateManager; -import edu.snu.nemo.runtime.executor.data.DataUtil; -import edu.snu.nemo.runtime.executor.datatransfer.DataTransferFactory; -import edu.snu.nemo.runtime.executor.datatransfer.InputReader; -import edu.snu.nemo.runtime.executor.datatransfer.OutputWriter; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; - -import java.util.*; -import java.util.concurrent.CompletableFuture; - -import static edu.snu.nemo.tests.runtime.RuntimeTestUtil.getRangedNumList; -import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.*; - -/** - * Tests {@link TaskExecutor}. - */ -@RunWith(PowerMockRunner.class) -@PrepareForTest({InputReader.class, OutputWriter.class, DataTransferFactory.class, - TaskStateManager.class, StageEdge.class}) -public final class TaskExecutorTest { - private static final int DATA_SIZE = 100; - private static final String CONTAINER_TYPE = "CONTAINER_TYPE"; - private static final int SOURCE_PARALLELISM = 5; - private List elements; - private Map<String, List<Object>> vertexIdToOutputData; - private DataTransferFactory dataTransferFactory; - private TaskStateManager taskStateManager; - private MetricMessageSender metricMessageSender; - - @Before - public void setUp() throws Exception { - elements = getRangedNumList(0, DATA_SIZE); - - // Mock a TaskStateManager. It accumulates the state change into a list. - taskStateManager = mock(TaskStateManager.class); - - // Mock a DataTransferFactory. - vertexIdToOutputData = new HashMap<>(); - dataTransferFactory = mock(DataTransferFactory.class); - when(dataTransferFactory.createReader(anyInt(), any(), any())).then(new InterStageReaderAnswer()); - when(dataTransferFactory.createWriter(any(), anyInt(), any(), any())).then(new WriterAnswer()); - - // Mock a MetricMessageSender. - metricMessageSender = mock(MetricMessageSender.class); - doNothing().when(metricMessageSender).send(anyString(), anyString()); - doNothing().when(metricMessageSender).close(); - } - - /** - * Test the {@link edu.snu.nemo.common.ir.vertex.SourceVertex} processing in {@link TaskExecutor}. - */ - @Test(timeout=5000) - public void testSourceVertex() throws Exception { - final IRVertex sourceIRVertex = new EmptyComponents.EmptySourceVertex("empty"); - final String stageId = RuntimeIdGenerator.generateStageId(0); - - final Readable readable = new Readable() { - @Override - public Iterable read() throws Exception { - return elements; - } - @Override - public List<String> getLocations() { - throw new UnsupportedOperationException(); - } - }; - final Map<String, Readable> vertexIdToReadable = new HashMap<>(); - vertexIdToReadable.put(sourceIRVertex.getId(), readable); - - final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = - new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>().addVertex(sourceIRVertex).buildWithoutSourceSinkCheck(); - final StageEdge stageOutEdge = mock(StageEdge.class); - when(stageOutEdge.getSrcVertex()).thenReturn(sourceIRVertex); - final String taskId = RuntimeIdGenerator.generateTaskId(0, stageId); - final Task task = - new Task( - "testSourceVertex", - taskId, - 0, - CONTAINER_TYPE, - new byte[0], - Collections.emptyList(), - Collections.singletonList(stageOutEdge), - vertexIdToReadable); - - // Execute the task. - final TaskExecutor taskExecutor = new TaskExecutor( - task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender); - taskExecutor.execute(); - - // Check the output. - assertEquals(100, vertexIdToOutputData.get(sourceIRVertex.getId()).size()); - assertEquals(elements.get(0), vertexIdToOutputData.get(sourceIRVertex.getId()).get(0)); - } - - /** - * Test the {@link edu.snu.nemo.common.ir.vertex.OperatorVertex} processing in {@link TaskExecutor}. - * - * The DAG of the task to test will looks like: - * operator task 1 -> operator task 2 - * - * The output data from upstream stage will be split - * according to source parallelism through {@link InterStageReaderAnswer}. - * Because of this, the operator task 1 will process multiple partitions and emit data in multiple times also. - * On the other hand, operator task 2 will receive the output data once and produce a single output. - */ - @Test(timeout=5000) - public void testOperatorVertex() throws Exception { - final IRVertex operatorIRVertex1 = new OperatorVertex(new SimpleTransform()); - final IRVertex operatorIRVertex2 = new OperatorVertex(new SimpleTransform()); - final String runtimeIREdgeId = "Runtime edge between operator tasks"; - - final String stageId = RuntimeIdGenerator.generateStageId(1); - - final Coder coder = Coder.DUMMY_CODER; - ExecutionPropertyMap edgeProperties = new ExecutionPropertyMap(runtimeIREdgeId); - edgeProperties.put(DataStoreProperty.of(DataStoreProperty.Value.MemoryStore)); - final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>() - .addVertex(operatorIRVertex1) - .addVertex(operatorIRVertex2) - .connectVertices(new RuntimeEdge<IRVertex>( - runtimeIREdgeId, edgeProperties, operatorIRVertex1, operatorIRVertex2, coder)) - .buildWithoutSourceSinkCheck(); - final String taskId = RuntimeIdGenerator.generateTaskId(0, stageId); - final StageEdge stageInEdge = mock(StageEdge.class); - when(stageInEdge.getDstVertex()).thenReturn(operatorIRVertex1); - final StageEdge stageOutEdge = mock(StageEdge.class); - when(stageOutEdge.getSrcVertex()).thenReturn(operatorIRVertex2); - final Task task = - new Task( - "testSourceVertex", - taskId, - 0, - CONTAINER_TYPE, - new byte[0], - Collections.singletonList(stageInEdge), - Collections.singletonList(stageOutEdge), - Collections.emptyMap()); - - // Execute the task. - final TaskExecutor taskExecutor = new TaskExecutor( - task, taskDag, taskStateManager, dataTransferFactory, metricMessageSender); - taskExecutor.execute(); - - // Check the output. - assertEquals(100, vertexIdToOutputData.get(operatorIRVertex2.getId()).size()); - } - - /** - * Represents the answer return an intra-stage {@link InputReader}, - * which will have a single iterable from the upstream task. - */ - private class IntraStageReaderAnswer implements Answer<InputReader> { - @Override - public InputReader answer(final InvocationOnMock invocationOnMock) throws Throwable { - // Read the data. - final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> inputFutures = new ArrayList<>(); - inputFutures.add(CompletableFuture.completedFuture( - DataUtil.IteratorWithNumBytes.of(elements.iterator()))); - final InputReader inputReader = mock(InputReader.class); - when(inputReader.read()).thenReturn(inputFutures); - when(inputReader.isSideInputReader()).thenReturn(false); - when(inputReader.getSourceParallelism()).thenReturn(1); - return inputReader; - } - } - - /** - * Represents the answer return an inter-stage {@link InputReader}, - * which will have multiple iterable according to the source parallelism. - */ - private class InterStageReaderAnswer implements Answer<InputReader> { - @Override - public InputReader answer(final InvocationOnMock invocationOnMock) throws Throwable { - final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> inputFutures = new ArrayList<>(SOURCE_PARALLELISM); - final int elementsPerSource = DATA_SIZE / SOURCE_PARALLELISM; - for (int i = 0; i < SOURCE_PARALLELISM; i++) { - inputFutures.add(CompletableFuture.completedFuture( - DataUtil.IteratorWithNumBytes.of(elements.subList(i * elementsPerSource, (i + 1) * elementsPerSource) - .iterator()))); - } - final InputReader inputReader = mock(InputReader.class); - when(inputReader.read()).thenReturn(inputFutures); - when(inputReader.isSideInputReader()).thenReturn(false); - when(inputReader.getSourceParallelism()).thenReturn(SOURCE_PARALLELISM); - return inputReader; - } - } - - /** - * Represents the answer return a {@link OutputWriter}, - * which will stores the data to the map between task id and output data. - */ - private class WriterAnswer implements Answer<OutputWriter> { - @Override - public OutputWriter answer(final InvocationOnMock invocationOnMock) throws Throwable { - final Object[] args = invocationOnMock.getArguments(); - final IRVertex vertex = (IRVertex) args[0]; - final OutputWriter outputWriter = mock(OutputWriter.class); - doAnswer(new Answer() { - @Override - public Object answer(final InvocationOnMock invocationOnMock) throws Throwable { - final Object[] args = invocationOnMock.getArguments(); - final Object dataToWrite = args[0]; - vertexIdToOutputData.computeIfAbsent(vertex.getId(), emptyTaskId -> new ArrayList<>()); - vertexIdToOutputData.get(vertex.getId()).add(dataToWrite); - return null; - } - }).when(outputWriter).write(any()); - return outputWriter; - } - } - - /** - * Simple {@link Transform} for testing. - * @param <T> input/output type. - */ - private class SimpleTransform<T> implements Transform<T, T> { - private OutputCollector<T> outputCollector; - - @Override - public void prepare(final Context context, final OutputCollector<T> outputCollector) { - this.outputCollector = outputCollector; - } - - @Override - public void onData(final Object element) { - outputCollector.emit((T) element); - } - - @Override - public void close() { - // Do nothing. - } - } -} ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services