johnyangk closed pull request #104: [NEMO-183] DAG-centric translation from
Beam pipeline to IR DAG
URL: https://github.com/apache/incubator-nemo/pull/104
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineRunner.java
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineRunner.java
index 3ced56545..2342dc28b 100644
---
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineRunner.java
+++
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineRunner.java
@@ -17,7 +17,8 @@
import edu.snu.nemo.client.JobLauncher;
import edu.snu.nemo.common.dag.DAG;
-import edu.snu.nemo.common.dag.DAGBuilder;
+import edu.snu.nemo.common.ir.edge.IREdge;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineRunner;
import org.apache.beam.sdk.options.PipelineOptions;
@@ -53,10 +54,11 @@ private NemoPipelineRunner(final NemoPipelineOptions
nemoPipelineOptions) {
* @return The result of the pipeline.
*/
public NemoPipelineResult run(final Pipeline pipeline) {
- final DAGBuilder builder = new DAGBuilder<>();
- final NemoPipelineVisitor nemoPipelineVisitor = new
NemoPipelineVisitor(builder, nemoPipelineOptions);
- pipeline.traverseTopologically(nemoPipelineVisitor);
- final DAG dag = builder.build();
+ final PipelineVisitor pipelineVisitor = new PipelineVisitor();
+ pipeline.traverseTopologically(pipelineVisitor);
+ final DAG<IRVertex, IREdge> dag =
PipelineTranslator.translate(pipelineVisitor.getConvertedPipeline(),
+ nemoPipelineOptions);
+
final NemoPipelineResult nemoPipelineResult = new NemoPipelineResult();
JobLauncher.launchDAG(dag);
return nemoPipelineResult;
diff --git
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
deleted file mode 100644
index 38b07a3a5..000000000
---
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
+++ /dev/null
@@ -1,304 +0,0 @@
-/*
- * Copyright (C) 2018 Seoul National University
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package edu.snu.nemo.compiler.frontend.beam;
-
-import edu.snu.nemo.common.Pair;
-import edu.snu.nemo.common.ir.edge.executionproperty.*;
-import edu.snu.nemo.common.ir.vertex.transform.Transform;
-import edu.snu.nemo.compiler.frontend.beam.coder.BeamDecoderFactory;
-import edu.snu.nemo.compiler.frontend.beam.coder.BeamEncoderFactory;
-import edu.snu.nemo.common.dag.DAGBuilder;
-import edu.snu.nemo.common.ir.edge.IREdge;
-import edu.snu.nemo.compiler.frontend.beam.source.BeamBoundedSourceVertex;
-import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.common.ir.vertex.LoopVertex;
-import edu.snu.nemo.common.ir.vertex.OperatorVertex;
-
-import edu.snu.nemo.compiler.frontend.beam.transform.*;
-
-import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.coders.*;
-import org.apache.beam.sdk.io.Read;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.runners.TransformHierarchy;
-import org.apache.beam.sdk.transforms.*;
-import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.PCollectionViews;
-import org.apache.beam.sdk.values.PValue;
-import org.apache.beam.sdk.values.TupleTag;
-
-import java.util.*;
-import java.util.stream.Collectors;
-
-/**
- * Visits every node in the beam dag to translate the BEAM program to the IR.
- */
-public final class NemoPipelineVisitor extends
Pipeline.PipelineVisitor.Defaults {
- private final DAGBuilder<IRVertex, IREdge> builder;
- private final Map<PValue, IRVertex> pValueToVertex;
- private final PipelineOptions options;
- // loopVertexStack keeps track of where the beam program is: whether it is
inside a composite transform or it is not.
- private final Stack<LoopVertex> loopVertexStack;
- private final Map<PValue, Pair<BeamEncoderFactory, BeamDecoderFactory>>
pValueToCoder;
- private final Map<IRVertex, Pair<BeamEncoderFactory, BeamDecoderFactory>>
sideInputCoder;
- private final Map<PValue, TupleTag> pValueToTag;
- private final Map<IRVertex, Set<PValue>> additionalInputs;
-
- /**
- * Constructor of the BEAM Visitor.
- *
- * @param builder DAGBuilder to build the DAG with.
- * @param options Pipeline options.
- */
- public NemoPipelineVisitor(final DAGBuilder<IRVertex, IREdge> builder, final
PipelineOptions options) {
- this.builder = builder;
- this.pValueToVertex = new HashMap<>();
- this.options = options;
- this.loopVertexStack = new Stack<>();
- this.pValueToCoder = new HashMap<>();
- this.sideInputCoder = new HashMap<>();
- this.pValueToTag = new HashMap<>();
- this.additionalInputs = new HashMap<>();
- }
-
- @Override
- public CompositeBehavior enterCompositeTransform(final
TransformHierarchy.Node beamNode) {
- if (beamNode.getTransform() instanceof LoopCompositeTransform) {
- final LoopVertex loopVertex = new LoopVertex(beamNode.getFullName());
- this.builder.addVertex(loopVertex, this.loopVertexStack);
- this.builder.removeVertex(loopVertex);
- this.loopVertexStack.push(loopVertex);
- }
- return CompositeBehavior.ENTER_TRANSFORM;
- }
-
- @Override
- public void leaveCompositeTransform(final TransformHierarchy.Node beamNode) {
- if (beamNode.getTransform() instanceof LoopCompositeTransform) {
- this.loopVertexStack.pop();
- }
- }
-
- @Override
- public void visitPrimitiveTransform(final TransformHierarchy.Node beamNode) {
-// Print if needed for development
-// LOG.info("visitp " + beamNode.getTransform());
- final IRVertex irVertex =
- convertToVertex(beamNode, builder, pValueToVertex, sideInputCoder,
pValueToTag, additionalInputs,
- options, loopVertexStack);
- beamNode.getOutputs().values().stream().filter(v -> v instanceof
PCollection).map(v -> (PCollection) v)
- .forEach(output -> pValueToCoder.put(output,
- Pair.of(new BeamEncoderFactory(output.getCoder()), new
BeamDecoderFactory(output.getCoder()))));
-
- beamNode.getOutputs().values().forEach(output ->
pValueToVertex.put(output, irVertex));
- final Set<PValue> additionalInputsForThisVertex =
additionalInputs.getOrDefault(irVertex, new HashSet<>());
- beamNode.getInputs().values().stream().filter(pValueToVertex::containsKey)
- .filter(pValue -> !additionalInputsForThisVertex.contains(pValue))
- .forEach(pValue -> {
- final boolean isAdditionalOutput = pValueToTag.containsKey(pValue);
- final IRVertex src = pValueToVertex.get(pValue);
- final IREdge edge = new IREdge(getEdgeCommunicationPattern(src,
irVertex), src, irVertex, false);
- final Pair<BeamEncoderFactory, BeamDecoderFactory> coderPair =
pValueToCoder.get(pValue);
- edge.setProperty(EncoderProperty.of(coderPair.left()));
- edge.setProperty(DecoderProperty.of(coderPair.right()));
- edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
- // Apply AdditionalOutputTatProperty to edges that corresponds to
additional outputs.
- if (isAdditionalOutput) {
-
edge.setProperty(AdditionalOutputTagProperty.of(pValueToTag.get(pValue).getId()));
- }
- this.builder.connectVertices(edge);
- });
- }
-
- /**
- * Convert Beam node to IR vertex.
- *
- * @param beamNode input beam node.
- * @param builder the DAG builder to add the vertex to.
- * @param pValueToVertex PValue to Vertex map.
- * @param sideInputCoder Side input EncoderFactory and DecoderFactory map.
- * @param pValueToTag PValue to Tag map.
- * @param additionalInputs additional inputs.
- * @param options pipeline options.
- * @param loopVertexStack Stack to get the current loop vertex that the
operator vertex will be assigned to.
- * @param <I> input type.
- * @param <O> output type.
- * @return newly created vertex.
- */
- private static <I, O> IRVertex
- convertToVertex(final TransformHierarchy.Node beamNode,
- final DAGBuilder<IRVertex, IREdge> builder,
- final Map<PValue, IRVertex> pValueToVertex,
- final Map<IRVertex, Pair<BeamEncoderFactory,
BeamDecoderFactory>> sideInputCoder,
- final Map<PValue, TupleTag> pValueToTag,
- final Map<IRVertex, Set<PValue>> additionalInputs,
- final PipelineOptions options,
- final Stack<LoopVertex> loopVertexStack) {
- final PTransform beamTransform = beamNode.getTransform();
- final IRVertex irVertex;
- if (beamTransform instanceof Read.Bounded) {
- final Read.Bounded<O> read = (Read.Bounded) beamTransform;
- irVertex = new BeamBoundedSourceVertex<>(read.getSource());
- builder.addVertex(irVertex, loopVertexStack);
- } else if (beamTransform instanceof GroupByKey) {
- irVertex = new OperatorVertex(new GroupByKeyTransform());
- builder.addVertex(irVertex, loopVertexStack);
- } else if (beamTransform instanceof View.CreatePCollectionView) {
- final View.CreatePCollectionView view = (View.CreatePCollectionView)
beamTransform;
- final CreateViewTransform transform = new
CreateViewTransform(view.getView());
- irVertex = new OperatorVertex(transform);
- pValueToVertex.put(view.getView(), irVertex);
- builder.addVertex(irVertex, loopVertexStack);
- // Coders for outgoing edges in CreateViewTransform.
- // Since outgoing PValues for CreateViewTransform is PCollectionView,
- // we cannot use PCollection::getEncoderFactory to obtain coders.
- final Coder beamInputCoder = beamNode.getInputs().values().stream()
- .filter(v -> v instanceof PCollection).map(v -> (PCollection)
v).findFirst()
- .orElseThrow(() -> new RuntimeException("No inputs provided to " +
beamNode.getFullName())).getCoder();
- beamNode.getOutputs().values().stream()
- .forEach(output ->
- sideInputCoder.put(irVertex,
getCoderPairForView(view.getView().getViewFn(), beamInputCoder)));
- } else if (beamTransform instanceof Window) {
- final Window<I> window = (Window<I>) beamTransform;
- final WindowTransform transform = new
WindowTransform(window.getWindowFn());
- irVertex = new OperatorVertex(transform);
- builder.addVertex(irVertex, loopVertexStack);
- } else if (beamTransform instanceof Window.Assign) {
- final Window.Assign<I> window = (Window.Assign<I>) beamTransform;
- final WindowTransform transform = new
WindowTransform(window.getWindowFn());
- irVertex = new OperatorVertex(transform);
- builder.addVertex(irVertex, loopVertexStack);
- } else if (beamTransform instanceof ParDo.SingleOutput) {
- final ParDo.SingleOutput<I, O> parDo = (ParDo.SingleOutput<I, O>)
beamTransform;
- final DoTransform transform = new DoTransform(parDo.getFn(), options);
- irVertex = new OperatorVertex(transform);
- additionalInputs.put(irVertex,
parDo.getAdditionalInputs().values().stream().collect(Collectors.toSet()));
- builder.addVertex(irVertex, loopVertexStack);
- connectSideInputs(builder, parDo.getSideInputs(), pValueToVertex,
sideInputCoder, irVertex);
- } else if (beamTransform instanceof ParDo.MultiOutput) {
- final ParDo.MultiOutput<I, O> parDo = (ParDo.MultiOutput<I, O>)
beamTransform;
- final DoTransform transform = new DoTransform(parDo.getFn(), options);
- irVertex = new OperatorVertex(transform);
- additionalInputs.put(irVertex,
parDo.getAdditionalInputs().values().stream().collect(Collectors.toSet()));
- if (parDo.getAdditionalOutputTags().size() > 0) {
- // Store PValue to additional tag id mapping.
- beamNode.getOutputs().entrySet().stream()
- .filter(kv -> !kv.getKey().equals(parDo.getMainOutputTag()))
- .forEach(kv -> pValueToTag.put(kv.getValue(), kv.getKey()));
- }
- builder.addVertex(irVertex, loopVertexStack);
- connectSideInputs(builder, parDo.getSideInputs(), pValueToVertex,
sideInputCoder, irVertex);
- } else if (beamTransform instanceof Flatten.PCollections) {
- irVertex = new OperatorVertex(new FlattenTransform());
- builder.addVertex(irVertex, loopVertexStack);
- } else {
- throw new UnsupportedOperationException(beamTransform.toString());
- }
- return irVertex;
- }
-
- /**
- * Connect side inputs to the vertex.
- *
- * @param builder the DAG builder to add the vertex to.
- * @param sideInputs side inputs.
- * @param pValueToVertex PValue to Vertex map.
- * @param coderMap Side input to Encoder/Decoder factory map.
- * @param irVertex wrapper for a user operation in the IR. (Where the
side input is headed to)
- */
- private static void connectSideInputs(final DAGBuilder<IRVertex, IREdge>
builder,
- final List<PCollectionView<?>>
sideInputs,
- final Map<PValue, IRVertex>
pValueToVertex,
- final Map<IRVertex,
Pair<BeamEncoderFactory, BeamDecoderFactory>> coderMap,
- final IRVertex irVertex) {
- sideInputs.stream().filter(pValueToVertex::containsKey)
- .forEach(pValue -> {
- final IRVertex src = pValueToVertex.get(pValue);
- final IREdge edge = new IREdge(getEdgeCommunicationPattern(src,
irVertex),
- src, irVertex, true);
- final Pair<BeamEncoderFactory, BeamDecoderFactory> coder =
coderMap.get(src);
- edge.setProperty(EncoderProperty.of(coder.left()));
- edge.setProperty(DecoderProperty.of(coder.right()));
- edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
- builder.connectVertices(edge);
- });
- }
-
- /**
- * Get appropriate encoder and decoder pair for {@link PCollectionView}.
- *
- * @param viewFn {@link ViewFn} from the corresponding {@link
View.CreatePCollectionView} transform
- * @param beamInputCoder Beam {@link Coder} for input value to {@link
View.CreatePCollectionView}
- * @return appropriate pair of {@link BeamEncoderFactory} and {@link
BeamDecoderFactory}
- */
- private static Pair<BeamEncoderFactory, BeamDecoderFactory>
getCoderPairForView(final ViewFn viewFn,
-
final Coder beamInputCoder) {
- final Coder beamOutputCoder;
- if (viewFn instanceof PCollectionViews.IterableViewFn) {
- beamOutputCoder = IterableCoder.of(beamInputCoder);
- } else if (viewFn instanceof PCollectionViews.ListViewFn) {
- beamOutputCoder = ListCoder.of(beamInputCoder);
- } else if (viewFn instanceof PCollectionViews.MapViewFn) {
- final KvCoder inputCoder = (KvCoder) beamInputCoder;
- beamOutputCoder = MapCoder.of(inputCoder.getKeyCoder(),
inputCoder.getValueCoder());
- } else if (viewFn instanceof PCollectionViews.MultimapViewFn) {
- final KvCoder inputCoder = (KvCoder) beamInputCoder;
- beamOutputCoder = MapCoder.of(inputCoder.getKeyCoder(),
IterableCoder.of(inputCoder.getValueCoder()));
- } else if (viewFn instanceof PCollectionViews.SingletonViewFn) {
- beamOutputCoder = beamInputCoder;
- } else {
- throw new UnsupportedOperationException("Unsupported viewFn: " +
viewFn.getClass());
- }
- return Pair.of(new BeamEncoderFactory(beamOutputCoder), new
BeamDecoderFactory(beamOutputCoder));
- }
-
- /**
- * Get the edge type for the src, dst vertex.
- *
- * @param src source vertex.
- * @param dst destination vertex.
- * @return the appropriate edge type.
- */
- private static CommunicationPatternProperty.Value
getEdgeCommunicationPattern(final IRVertex src,
-
final IRVertex dst) {
- final Class<?> constructUnionTableFn;
- try {
- constructUnionTableFn =
Class.forName("org.apache.beam.sdk.transforms.join.CoGroupByKey$ConstructUnionTableFn");
- } catch (final ClassNotFoundException e) {
- throw new RuntimeException(e);
- }
-
- final Transform srcTransform = src instanceof OperatorVertex ?
((OperatorVertex) src).getTransform() : null;
- final Transform dstTransform = dst instanceof OperatorVertex ?
((OperatorVertex) dst).getTransform() : null;
- final DoFn srcDoFn = srcTransform instanceof DoTransform ? ((DoTransform)
srcTransform).getDoFn() : null;
-
- if (srcDoFn != null && srcDoFn.getClass().equals(constructUnionTableFn)) {
- return CommunicationPatternProperty.Value.Shuffle;
- }
- if (srcTransform instanceof FlattenTransform) {
- return CommunicationPatternProperty.Value.OneToOne;
- }
- if (dstTransform instanceof GroupByKeyTransform) {
- return CommunicationPatternProperty.Value.Shuffle;
- }
- if (dstTransform instanceof CreateViewTransform) {
- return CommunicationPatternProperty.Value.BroadCast;
- }
- return CommunicationPatternProperty.Value.OneToOne;
- }
-}
diff --git
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
new file mode 100644
index 000000000..7e34ca234
--- /dev/null
+++
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
@@ -0,0 +1,544 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.compiler.frontend.beam;
+
+import edu.snu.nemo.common.dag.DAG;
+import edu.snu.nemo.common.dag.DAGBuilder;
+import edu.snu.nemo.common.ir.edge.IREdge;
+import edu.snu.nemo.common.ir.edge.executionproperty.*;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.common.ir.vertex.LoopVertex;
+import edu.snu.nemo.common.ir.vertex.OperatorVertex;
+import edu.snu.nemo.common.ir.vertex.transform.Transform;
+import edu.snu.nemo.compiler.frontend.beam.PipelineVisitor.*;
+import edu.snu.nemo.compiler.frontend.beam.coder.BeamDecoderFactory;
+import edu.snu.nemo.compiler.frontend.beam.coder.BeamEncoderFactory;
+import edu.snu.nemo.compiler.frontend.beam.source.BeamBoundedSourceVertex;
+import edu.snu.nemo.compiler.frontend.beam.transform.*;
+import org.apache.beam.sdk.coders.*;
+import org.apache.beam.sdk.io.Read;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.*;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.values.*;
+
+import java.lang.annotation.*;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Stack;
+import java.util.function.BiFunction;
+
+/**
+ * Converts DAG of Beam pipeline to Nemo IR DAG.
+ * For a {@link PrimitiveTransformVertex}, it defines mapping to the
corresponding {@link IRVertex}.
+ * For a {@link CompositeTransformVertex}, it defines how to setup and clear
{@link TranslationContext}
+ * before start translating inner Beam transform hierarchy.
+ */
+public final class PipelineTranslator
+ implements BiFunction<CompositeTransformVertex, PipelineOptions,
DAG<IRVertex, IREdge>> {
+
+ private static final PipelineTranslator INSTANCE = new PipelineTranslator();
+
+ private final Map<Class<? extends PTransform>, Method>
primitiveTransformToTranslator = new HashMap<>();
+ private final Map<Class<? extends PTransform>, Method>
compositeTransformToTranslator = new HashMap<>();
+
+ /**
+ * Static translator method.
+ * @param pipeline Top-level Beam transform hierarchy, usually given by
{@link PipelineVisitor}
+ * @param pipelineOptions {@link PipelineOptions}
+ * @return Nemo IR DAG
+ */
+ public static DAG<IRVertex, IREdge> translate(final CompositeTransformVertex
pipeline,
+ final PipelineOptions
pipelineOptions) {
+ return INSTANCE.apply(pipeline, pipelineOptions);
+ }
+
+ /**
+ * Creates the translator, while building a map between {@link PTransform}s
and the corresponding translators.
+ */
+ private PipelineTranslator() {
+ for (final Method translator : getClass().getDeclaredMethods()) {
+ final PrimitiveTransformTranslator primitive =
translator.getAnnotation(PrimitiveTransformTranslator.class);
+ final CompositeTransformTranslator composite =
translator.getAnnotation(CompositeTransformTranslator.class);
+ if (primitive != null) {
+ for (final Class<? extends PTransform> transform : primitive.value()) {
+ if (primitiveTransformToTranslator.containsKey(transform)) {
+ throw new RuntimeException(String.format("Translator for primitive
transform %s is"
+ + "already registered: %s", transform,
primitiveTransformToTranslator.get(transform)));
+ }
+ primitiveTransformToTranslator.put(transform, translator);
+ }
+ }
+ if (composite != null) {
+ for (final Class<? extends PTransform> transform : composite.value()) {
+ if (compositeTransformToTranslator.containsKey(transform)) {
+ throw new RuntimeException(String.format("Translator for composite
transform %s is"
+ + "already registered: %s", transform,
compositeTransformToTranslator.get(transform)));
+ }
+ compositeTransformToTranslator.put(transform, translator);
+ }
+ }
+ }
+ }
+
+ @PrimitiveTransformTranslator(Read.Bounded.class)
+ private static void boundedReadTranslator(final TranslationContext ctx,
+ final PrimitiveTransformVertex
transformVertex,
+ final Read.Bounded<?> transform) {
+ final IRVertex vertex = new
BeamBoundedSourceVertex<>(transform.getSource());
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().forEach(input ->
ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getOutputs().values().forEach(output ->
ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ @PrimitiveTransformTranslator(ParDo.SingleOutput.class)
+ private static void parDoSingleOutputTranslator(final TranslationContext ctx,
+ final
PrimitiveTransformVertex transformVertex,
+ final ParDo.SingleOutput<?,
?> transform) {
+ final DoTransform doTransform = new DoTransform(transform.getFn(),
ctx.pipelineOptions);
+ final IRVertex vertex = new OperatorVertex(doTransform);
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().stream()
+ .filter(input ->
!transform.getAdditionalInputs().values().contains(input))
+ .forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transform.getSideInputs().forEach(input -> ctx.addEdgeTo(vertex, input,
true));
+ transformVertex.getNode().getOutputs().values().forEach(output ->
ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ @PrimitiveTransformTranslator(ParDo.MultiOutput.class)
+ private static void parDoMultiOutputTranslator(final TranslationContext ctx,
+ final
PrimitiveTransformVertex transformVertex,
+ final ParDo.MultiOutput<?, ?>
transform) {
+ final DoTransform doTransform = new DoTransform(transform.getFn(),
ctx.pipelineOptions);
+ final IRVertex vertex = new OperatorVertex(doTransform);
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().stream()
+ .filter(input ->
!transform.getAdditionalInputs().values().contains(input))
+ .forEach(input -> ctx.addEdgeTo(vertex, input, false));
+ transform.getSideInputs().forEach(input -> ctx.addEdgeTo(vertex, input,
true));
+ transformVertex.getNode().getOutputs().entrySet().stream()
+ .filter(pValueWithTupleTag ->
pValueWithTupleTag.getKey().equals(transform.getMainOutputTag()))
+ .forEach(pValueWithTupleTag -> ctx.registerMainOutputFrom(vertex,
pValueWithTupleTag.getValue()));
+ transformVertex.getNode().getOutputs().entrySet().stream()
+ .filter(pValueWithTupleTag ->
!pValueWithTupleTag.getKey().equals(transform.getMainOutputTag()))
+ .forEach(pValueWithTupleTag ->
ctx.registerAdditionalOutputFrom(vertex, pValueWithTupleTag.getValue(),
+ pValueWithTupleTag.getKey()));
+ }
+
+ @PrimitiveTransformTranslator(GroupByKey.class)
+ private static void groupByKeyTranslator(final TranslationContext ctx,
+ final PrimitiveTransformVertex
transformVertex,
+ final GroupByKey<?, ?> transform) {
+ final IRVertex vertex = new OperatorVertex(new GroupByKeyTransform());
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().forEach(input ->
ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getOutputs().values().forEach(output ->
ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ @PrimitiveTransformTranslator({Window.class, Window.Assign.class})
+ private static void windowTranslator(final TranslationContext ctx,
+ final PrimitiveTransformVertex
transformVertex,
+ final PTransform<?, ?> transform) {
+ final WindowFn windowFn;
+ if (transform instanceof Window) {
+ windowFn = ((Window) transform).getWindowFn();
+ } else if (transform instanceof Window.Assign) {
+ windowFn = ((Window.Assign) transform).getWindowFn();
+ } else {
+ throw new UnsupportedOperationException(String.format("%s is not
supported", transform));
+ }
+ final IRVertex vertex = new OperatorVertex(new WindowTransform(windowFn));
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().forEach(input ->
ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getOutputs().values().forEach(output ->
ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ @PrimitiveTransformTranslator(View.CreatePCollectionView.class)
+ private static void createPCollectionViewTranslator(final TranslationContext
ctx,
+ final
PrimitiveTransformVertex transformVertex,
+ final
View.CreatePCollectionView<?, ?> transform) {
+ final IRVertex vertex = new OperatorVertex(new
CreateViewTransform<>(transform.getView()));
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().forEach(input ->
ctx.addEdgeTo(vertex, input, false));
+ ctx.registerMainOutputFrom(vertex, transform.getView());
+ transformVertex.getNode().getOutputs().values().forEach(output ->
ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ @PrimitiveTransformTranslator(Flatten.PCollections.class)
+ private static void flattenTranslator(final TranslationContext ctx,
+ final PrimitiveTransformVertex
transformVertex,
+ final Flatten.PCollections<?>
transform) {
+ final IRVertex vertex = new OperatorVertex(new FlattenTransform());
+ ctx.addVertex(vertex);
+ transformVertex.getNode().getInputs().values().forEach(input ->
ctx.addEdgeTo(vertex, input, false));
+ transformVertex.getNode().getOutputs().values().forEach(output ->
ctx.registerMainOutputFrom(vertex, output));
+ }
+
+ /**
+ * Default translator for CompositeTransforms. Translates inner DAG without
modifying {@link TranslationContext}.
+ *
+ * @param ctx provides translation context
+ * @param transformVertex the given CompositeTransform to translate
+ * @param transform transform which can be obtained from {@code
transformVertex}
+ */
+ @CompositeTransformTranslator(PTransform.class)
+ private static void topologicalTranslator(final TranslationContext ctx,
+ final CompositeTransformVertex
transformVertex,
+ final PTransform<?, ?> transform) {
+ transformVertex.getDAG().topologicalDo(ctx::translate);
+ }
+
+ /**
+ * Translator for Combine transform. Implements local combining before
shuffling key-value pairs.
+ *
+ * @param ctx provides translation context
+ * @param transformVertex the given CompositeTransform to translate
+ * @param transform transform which can be obtained from {@code
transformVertex}
+ */
+ @CompositeTransformTranslator({Combine.Globally.class, Combine.PerKey.class,
Combine.GroupedValues.class})
+ private static void combineTranslator(final TranslationContext ctx,
+ final CompositeTransformVertex
transformVertex,
+ final PTransform<?, ?> transform) {
+ final List<TransformVertex> topologicalOrdering =
transformVertex.getDAG().getTopologicalSort();
+ final TransformVertex first = topologicalOrdering.get(0);
+ final TransformVertex last =
topologicalOrdering.get(topologicalOrdering.size() - 1);
+
+ if (first.getNode().getTransform() instanceof GroupByKey) {
+ // Translate the given CompositeTransform under OneToOneEdge-enforced
context.
+ final TranslationContext oneToOneEdgeContext = new
TranslationContext(ctx,
+ OneToOneCommunicationPatternSelector.INSTANCE);
+ transformVertex.getDAG().topologicalDo(oneToOneEdgeContext::translate);
+
+ // Attempt to translate the CompositeTransform again.
+ // Add GroupByKey, which is the first transform in the given
CompositeTransform.
+ // Make sure it consumes the output from the last vertex in
OneToOneEdge-translated hierarchy.
+ final IRVertex groupByKey = new OperatorVertex(new
GroupByKeyTransform());
+ ctx.addVertex(groupByKey);
+ last.getNode().getOutputs().values().forEach(outputFromCombiner
+ -> ctx.addEdgeTo(groupByKey, outputFromCombiner, false));
+ first.getNode().getOutputs().values()
+ .forEach(outputFromGroupByKey ->
ctx.registerMainOutputFrom(groupByKey, outputFromGroupByKey));
+
+ // Translate the remaining vertices.
+ topologicalOrdering.stream().skip(1).forEach(ctx::translate);
+ } else {
+ transformVertex.getDAG().topologicalDo(ctx::translate);
+ }
+ }
+
+ /**
+ * Pushes the loop vertex to the stack before translating the inner DAG, and
pops it after the translation.
+ *
+ * @param ctx provides translation context
+ * @param transformVertex the given CompositeTransform to translate
+ * @param transform transform which can be obtained from {@code
transformVertex}
+ */
+ @CompositeTransformTranslator(LoopCompositeTransform.class)
+ private static void loopTranslator(final TranslationContext ctx,
+ final CompositeTransformVertex
transformVertex,
+ final LoopCompositeTransform<?, ?>
transform) {
+ final LoopVertex loopVertex = new
LoopVertex(transformVertex.getNode().getFullName());
+ ctx.builder.addVertex(loopVertex, ctx.loopVertexStack);
+ ctx.builder.removeVertex(loopVertex);
+ ctx.loopVertexStack.push(loopVertex);
+ topologicalTranslator(ctx, transformVertex, transform);
+ ctx.loopVertexStack.pop();
+ }
+
+ @Override
+ public DAG<IRVertex, IREdge> apply(final CompositeTransformVertex pipeline,
final PipelineOptions pipelineOptions) {
+ final TranslationContext ctx = new TranslationContext(pipeline,
primitiveTransformToTranslator,
+ compositeTransformToTranslator,
DefaultCommunicationPatternSelector.INSTANCE, pipelineOptions);
+ ctx.translate(pipeline);
+ return ctx.builder.build();
+ }
+
+ /**
+ * Annotates translator for PrimitiveTransform.
+ */
+ @Target(ElementType.METHOD)
+ @Retention(RetentionPolicy.RUNTIME)
+ private @interface PrimitiveTransformTranslator {
+ Class<? extends PTransform>[] value();
+ }
+
+ /**
+ * Annotates translator for CompositeTransform.
+ */
+ @Target(ElementType.METHOD)
+ @Retention(RetentionPolicy.RUNTIME)
+ private @interface CompositeTransformTranslator {
+ Class<? extends PTransform>[] value();
+ }
+
+ /**
+ * Translation context.
+ */
+ private static final class TranslationContext {
+ private final CompositeTransformVertex pipeline;
+ private final PipelineOptions pipelineOptions;
+ private final DAGBuilder<IRVertex, IREdge> builder;
+ private final Map<PValue, IRVertex> pValueToProducer;
+ private final Map<PValue, TupleTag<?>> pValueToTag;
+ private final Stack<LoopVertex> loopVertexStack;
+ private final BiFunction<IRVertex, IRVertex,
CommunicationPatternProperty.Value> communicationPatternSelector;
+
+ private final Map<Class<? extends PTransform>, Method>
primitiveTransformToTranslator;
+ private final Map<Class<? extends PTransform>, Method>
compositeTransformToTranslator;
+
+ /**
+ * @param pipeline the pipeline to translate
+ * @param primitiveTransformToTranslator provides translators for
PrimitiveTransform
+ * @param compositeTransformToTranslator provides translators for
CompositeTransform
+ * @param selector provides {@link CommunicationPatternProperty.Value} for
IR edges
+ * @param pipelineOptions {@link PipelineOptions}
+ */
+ private TranslationContext(final CompositeTransformVertex pipeline,
+ final Map<Class<? extends PTransform>, Method>
primitiveTransformToTranslator,
+ final Map<Class<? extends PTransform>, Method>
compositeTransformToTranslator,
+ final BiFunction<IRVertex, IRVertex,
CommunicationPatternProperty.Value> selector,
+ final PipelineOptions pipelineOptions) {
+ this.pipeline = pipeline;
+ this.builder = new DAGBuilder<>();
+ this.pValueToProducer = new HashMap<>();
+ this.pValueToTag = new HashMap<>();
+ this.loopVertexStack = new Stack<>();
+ this.primitiveTransformToTranslator = primitiveTransformToTranslator;
+ this.compositeTransformToTranslator = compositeTransformToTranslator;
+ this.communicationPatternSelector = selector;
+ this.pipelineOptions = pipelineOptions;
+ }
+
+ /**
+ * Copy constructor, except for setting different
CommunicationPatternProperty selector.
+ *
+ * @param ctx the original {@link TranslationContext}
+ * @param selector provides {@link CommunicationPatternProperty.Value} for
IR edges
+ */
+ private TranslationContext(final TranslationContext ctx,
+ final BiFunction<IRVertex, IRVertex,
CommunicationPatternProperty.Value> selector) {
+ this.pipeline = ctx.pipeline;
+ this.pipelineOptions = ctx.pipelineOptions;
+ this.builder = ctx.builder;
+ this.pValueToProducer = ctx.pValueToProducer;
+ this.pValueToTag = ctx.pValueToTag;
+ this.loopVertexStack = ctx.loopVertexStack;
+ this.primitiveTransformToTranslator = ctx.primitiveTransformToTranslator;
+ this.compositeTransformToTranslator = ctx.compositeTransformToTranslator;
+
+ this.communicationPatternSelector = selector;
+ }
+
+ /**
+ * Selects appropriate translator to translate the given hierarchy.
+ *
+ * @param transformVertex the Beam transform hierarchy to translate
+ */
+ private void translate(final TransformVertex transformVertex) {
+ final boolean isComposite = transformVertex instanceof
CompositeTransformVertex;
+ final PTransform<?, ?> transform =
transformVertex.getNode().getTransform();
+ if (transform == null) {
+ // root node
+ topologicalTranslator(this, (CompositeTransformVertex)
transformVertex, null);
+ return;
+ }
+
+ Class<?> clazz = transform.getClass();
+ while (true) {
+ final Method translator = (isComposite ?
compositeTransformToTranslator : primitiveTransformToTranslator)
+ .get(clazz);
+ if (translator == null) {
+ if (clazz.getSuperclass() != null) {
+ clazz = clazz.getSuperclass();
+ continue;
+ }
+ throw new UnsupportedOperationException(String.format("%s transform
%s is not supported",
+ isComposite ? "Composite" : "Primitive",
transform.getClass().getCanonicalName()));
+ } else {
+ try {
+ translator.setAccessible(true);
+ translator.invoke(null, this, transformVertex, transform);
+ break;
+ } catch (final IllegalAccessException e) {
+ throw new RuntimeException(e);
+ } catch (final InvocationTargetException | RuntimeException e) {
+ throw new RuntimeException(String.format(
+ "Translator %s have failed to translate %s", translator,
transform), e);
+ }
+ }
+ }
+ }
+
+ /**
+ * Add IR vertex to the builder.
+ *
+ * @param vertex IR vertex to add
+ */
+ private void addVertex(final IRVertex vertex) {
+ builder.addVertex(vertex, loopVertexStack);
+ }
+
+ /**
+ * Add IR edge to the builder.
+ *
+ * @param dst the destination IR vertex.
+ * @param input the {@link PValue} {@code dst} consumes
+ * @param isSideInput whether it is sideInput or not.
+ */
+ private void addEdgeTo(final IRVertex dst, final PValue input, final
boolean isSideInput) {
+ final IRVertex src = pValueToProducer.get(input);
+ if (src == null) {
+ try {
+ throw new RuntimeException(String.format("Cannot find a vertex that
emits pValue %s, "
+ + "while PTransform %s is known to produce it.", input,
pipeline.getPrimitiveProducerOf(input)));
+ } catch (final RuntimeException e) {
+ throw new RuntimeException(String.format("Cannot find a vertex that
emits pValue %s, "
+ + "and the corresponding PTransform was not found", input));
+ }
+ }
+ final CommunicationPatternProperty.Value communicationPattern =
communicationPatternSelector.apply(src, dst);
+ if (communicationPattern == null) {
+ throw new RuntimeException(String.format("%s have failed to determine
communication pattern "
+ + "for an edge from %s to %s", communicationPatternSelector, src,
dst));
+ }
+ final IREdge edge = new IREdge(communicationPattern, src, dst,
isSideInput);
+ final Coder<?> coder;
+ if (input instanceof PCollection) {
+ coder = ((PCollection) input).getCoder();
+ } else if (input instanceof PCollectionView) {
+ coder = getCoderForView((PCollectionView) input);
+ } else {
+ coder = null;
+ }
+ if (coder == null) {
+ throw new RuntimeException(String.format("While adding an edge from
%s, to %s, coder for PValue %s cannot "
+ + "be determined", src, dst, input));
+ }
+ edge.setProperty(EncoderProperty.of(new BeamEncoderFactory<>(coder)));
+ edge.setProperty(DecoderProperty.of(new BeamDecoderFactory<>(coder)));
+ if (pValueToTag.containsKey(input)) {
+
edge.setProperty(AdditionalOutputTagProperty.of(pValueToTag.get(input).getId()));
+ }
+ edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
+ builder.connectVertices(edge);
+ }
+
+ /**
+ * Registers a {@link PValue} as a main output from the specified {@link
IRVertex}.
+ *
+ * @param irVertex the IR vertex
+ * @param output the {@link PValue} {@code irVertex} emits as main output
+ */
+ private void registerMainOutputFrom(final IRVertex irVertex, final PValue
output) {
+ pValueToProducer.put(output, irVertex);
+ }
+
+ /**
+ * Registers a {@link PValue} as an additional output from the specified
{@link IRVertex}.
+ *
+ * @param irVertex the IR vertex
+ * @param output the {@link PValue} {@code irVertex} emits as additional
output
+ * @param tag the {@link TupleTag} associated with this additional output
+ */
+ private void registerAdditionalOutputFrom(final IRVertex irVertex, final
PValue output, final TupleTag<?> tag) {
+ pValueToTag.put(output, tag);
+ pValueToProducer.put(output, irVertex);
+ }
+
+ /**
+ * Get appropriate coder for {@link PCollectionView}.
+ *
+ * @param view {@link PCollectionView} from the corresponding {@link
View.CreatePCollectionView} transform
+ * @return appropriate {@link Coder} for {@link PCollectionView}
+ */
+ private Coder<?> getCoderForView(final PCollectionView view) {
+ final PrimitiveTransformVertex src =
pipeline.getPrimitiveProducerOf(view);
+ final Coder<?> baseCoder = src.getNode().getInputs().values().stream()
+ .filter(v -> v instanceof PCollection).map(v -> (PCollection)
v).findFirst()
+ .orElseThrow(() -> new RuntimeException(String.format("No incoming
PCollection to %s", src)))
+ .getCoder();
+ final ViewFn viewFn = view.getViewFn();
+ if (viewFn instanceof PCollectionViews.IterableViewFn) {
+ return IterableCoder.of(baseCoder);
+ } else if (viewFn instanceof PCollectionViews.ListViewFn) {
+ return ListCoder.of(baseCoder);
+ } else if (viewFn instanceof PCollectionViews.MapViewFn) {
+ final KvCoder<?, ?> inputCoder = (KvCoder) baseCoder;
+ return MapCoder.of(inputCoder.getKeyCoder(),
inputCoder.getValueCoder());
+ } else if (viewFn instanceof PCollectionViews.MultimapViewFn) {
+ final KvCoder<?, ?> inputCoder = (KvCoder) baseCoder;
+ return MapCoder.of(inputCoder.getKeyCoder(),
IterableCoder.of(inputCoder.getValueCoder()));
+ } else if (viewFn instanceof PCollectionViews.SingletonViewFn) {
+ return baseCoder;
+ } else {
+ throw new UnsupportedOperationException(String.format("Unsupported
viewFn %s", viewFn.getClass()));
+ }
+ }
+ }
+
+ /**
+ * Default implementation for {@link CommunicationPatternProperty.Value}
selector.
+ */
+ private static final class DefaultCommunicationPatternSelector
+ implements BiFunction<IRVertex, IRVertex,
CommunicationPatternProperty.Value> {
+
+ private static final DefaultCommunicationPatternSelector INSTANCE = new
DefaultCommunicationPatternSelector();
+
+ @Override
+ public CommunicationPatternProperty.Value apply(final IRVertex src, final
IRVertex dst) {
+ final Class<?> constructUnionTableFn;
+ try {
+ constructUnionTableFn =
Class.forName("org.apache.beam.sdk.transforms.join.CoGroupByKey$ConstructUnionTableFn");
+ } catch (final ClassNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+
+ final Transform srcTransform = src instanceof OperatorVertex ?
((OperatorVertex) src).getTransform() : null;
+ final Transform dstTransform = dst instanceof OperatorVertex ?
((OperatorVertex) dst).getTransform() : null;
+ final DoFn srcDoFn = srcTransform instanceof DoTransform ?
((DoTransform) srcTransform).getDoFn() : null;
+
+ if (srcDoFn != null && srcDoFn.getClass().equals(constructUnionTableFn))
{
+ return CommunicationPatternProperty.Value.Shuffle;
+ }
+ if (srcTransform instanceof FlattenTransform) {
+ return CommunicationPatternProperty.Value.OneToOne;
+ }
+ if (dstTransform instanceof GroupByKeyTransform) {
+ return CommunicationPatternProperty.Value.Shuffle;
+ }
+ if (dstTransform instanceof CreateViewTransform) {
+ return CommunicationPatternProperty.Value.BroadCast;
+ }
+ return CommunicationPatternProperty.Value.OneToOne;
+ }
+ }
+
+ /**
+ * A {@link CommunicationPatternProperty.Value} selector which always emits
OneToOne.
+ */
+ private static final class OneToOneCommunicationPatternSelector
+ implements BiFunction<IRVertex, IRVertex,
CommunicationPatternProperty.Value> {
+ private static final OneToOneCommunicationPatternSelector INSTANCE = new
OneToOneCommunicationPatternSelector();
+ @Override
+ public CommunicationPatternProperty.Value apply(final IRVertex src, final
IRVertex dst) {
+ return CommunicationPatternProperty.Value.OneToOne;
+ }
+ }
+}
diff --git
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineVisitor.java
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineVisitor.java
new file mode 100644
index 000000000..55be86562
--- /dev/null
+++
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineVisitor.java
@@ -0,0 +1,296 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.compiler.frontend.beam;
+
+import edu.snu.nemo.common.dag.DAG;
+import edu.snu.nemo.common.dag.DAGBuilder;
+import edu.snu.nemo.common.dag.Edge;
+import edu.snu.nemo.common.dag.Vertex;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.values.PValue;
+
+import java.util.*;
+
+/**
+ * Traverses through the given Beam pipeline to construct a DAG of Beam
Transform,
+ * while preserving hierarchy of CompositeTransforms.
+ * Hierarchy is established when a CompositeTransform is expanded to other
CompositeTransforms or PrimitiveTransforms,
+ * as the former CompositeTransform becoming 'enclosingVertex' which have the
inner transforms as embedded DAG.
+ * This DAG will be later translated by {@link PipelineTranslator} into Nemo
IR DAG.
+ */
+public final class PipelineVisitor extends Pipeline.PipelineVisitor.Defaults {
+
+ private static final String TRANSFORM = "Transform-";
+ private static final String DATAFLOW = "Dataflow-";
+
+ private final Stack<CompositeTransformVertex> compositeTransformVertexStack
= new Stack<>();
+ private CompositeTransformVertex rootVertex = null;
+ private int nextIdx = 0;
+
+ @Override
+ public void visitPrimitiveTransform(final TransformHierarchy.Node node) {
+ final PrimitiveTransformVertex vertex = new PrimitiveTransformVertex(node,
compositeTransformVertexStack.peek());
+ compositeTransformVertexStack.peek().addVertex(vertex);
+ vertex.getPValuesConsumed()
+ .forEach(pValue -> {
+ final TransformVertex dst = getDestinationOfDataFlowEdge(vertex,
pValue);
+ dst.enclosingVertex.addDataFlow(new
DataFlowEdge(dst.enclosingVertex.getProducerOf(pValue), dst));
+ });
+ }
+
+ @Override
+ public CompositeBehavior enterCompositeTransform(final
TransformHierarchy.Node node) {
+ final CompositeTransformVertex vertex;
+ if (compositeTransformVertexStack.isEmpty()) {
+ // There is always a top-level CompositeTransform that encompasses the
entire Beam pipeline.
+ vertex = new CompositeTransformVertex(node, null);
+ } else {
+ vertex = new CompositeTransformVertex(node,
compositeTransformVertexStack.peek());
+ }
+ compositeTransformVertexStack.push(vertex);
+ return CompositeBehavior.ENTER_TRANSFORM;
+ }
+
+ @Override
+ public void leaveCompositeTransform(final TransformHierarchy.Node node) {
+ final CompositeTransformVertex vertex =
compositeTransformVertexStack.pop();
+ vertex.build();
+ if (compositeTransformVertexStack.isEmpty()) {
+ // The vertex is the root.
+ if (rootVertex != null) {
+ throw new RuntimeException("The visitor already have traversed a Beam
pipeline. "
+ + "Re-using a visitor is not allowed.");
+ }
+ rootVertex = vertex;
+ } else {
+ // The CompositeTransformVertex is ready; adding it to its enclosing
vertex.
+ compositeTransformVertexStack.peek().addVertex(vertex);
+ }
+ }
+
+ /**
+ * @return A vertex representing the top-level CompositeTransform.
+ */
+ public CompositeTransformVertex getConvertedPipeline() {
+ if (rootVertex == null) {
+ throw new RuntimeException("The visitor have not fully traversed through
a Beam pipeline.");
+ }
+ return rootVertex;
+ }
+
+ /**
+ * Represents a {@link org.apache.beam.sdk.transforms.PTransform} as a
vertex in DAG.
+ */
+ public abstract class TransformVertex extends Vertex {
+ private final TransformHierarchy.Node node;
+ private final CompositeTransformVertex enclosingVertex;
+
+ /**
+ * @param node the corresponding Beam node
+ * @param enclosingVertex the vertex for the transform which inserted this
transform as its expansion,
+ * or {@code null}
+ */
+ private TransformVertex(final TransformHierarchy.Node node, final
CompositeTransformVertex enclosingVertex) {
+ super(String.format("%s%d", TRANSFORM, nextIdx++));
+ this.node = node;
+ this.enclosingVertex = enclosingVertex;
+ }
+
+ /**
+ * @return Collection of {@link PValue}s this transform emits.
+ */
+ public abstract Collection<PValue> getPValuesProduced();
+
+ /**
+ * Searches within {@code this} to find a transform that produces the
given {@link PValue}.
+ *
+ * @param pValue a {@link PValue}
+ * @return the {@link TransformVertex} whose {@link
org.apache.beam.sdk.transforms.PTransform}
+ * produces the given {@code pValue}
+ */
+ public abstract PrimitiveTransformVertex getPrimitiveProducerOf(final
PValue pValue);
+
+ /**
+ * @return the corresponding Beam node.
+ */
+ public TransformHierarchy.Node getNode() {
+ return node;
+ }
+
+ /**
+ * @return the enclosing {@link CompositeTransformVertex} if any, {@code
null} otherwise.
+ */
+ public CompositeTransformVertex getEnclosingVertex() {
+ return enclosingVertex;
+ }
+ }
+
+ /**
+ * Represents a transform hierarchy for primitive transform.
+ */
+ public final class PrimitiveTransformVertex extends TransformVertex {
+ private final List<PValue> pValuesProduced = new ArrayList<>();
+ private final List<PValue> pValuesConsumed = new ArrayList<>();
+
+ private PrimitiveTransformVertex(final TransformHierarchy.Node node,
+ final CompositeTransformVertex
enclosingVertex) {
+ super(node, enclosingVertex);
+ if (node.getTransform() instanceof View.CreatePCollectionView) {
+ pValuesProduced.add(((View.CreatePCollectionView)
node.getTransform()).getView());
+ }
+ if (node.getTransform() instanceof ParDo.SingleOutput) {
+ pValuesConsumed.addAll(((ParDo.SingleOutput)
node.getTransform()).getSideInputs());
+ }
+ if (node.getTransform() instanceof ParDo.MultiOutput) {
+ pValuesConsumed.addAll(((ParDo.MultiOutput)
node.getTransform()).getSideInputs());
+ }
+ pValuesProduced.addAll(getNode().getOutputs().values());
+ pValuesConsumed.addAll(getNode().getInputs().values());
+ }
+
+ @Override
+ public Collection<PValue> getPValuesProduced() {
+ return pValuesProduced;
+ }
+
+ @Override
+ public PrimitiveTransformVertex getPrimitiveProducerOf(final PValue
pValue) {
+ if (!getPValuesProduced().contains(pValue)) {
+ throw new RuntimeException();
+ }
+ return this;
+ }
+
+ /**
+ * @return collection of {@link PValue} this transform consumes.
+ */
+ public Collection<PValue> getPValuesConsumed() {
+ return pValuesConsumed;
+ }
+ }
+ /**
+ * Represents a transform hierarchy for composite transform.
+ */
+ public final class CompositeTransformVertex extends TransformVertex {
+ private final Map<PValue, TransformVertex> pValueToProducer = new
HashMap<>();
+ private final Collection<DataFlowEdge> dataFlowEdges = new ArrayList<>();
+ private final DAGBuilder<TransformVertex, DataFlowEdge> builder = new
DAGBuilder<>();
+ private DAG<TransformVertex, DataFlowEdge> dag = null;
+
+ private CompositeTransformVertex(final TransformHierarchy.Node node,
+ final CompositeTransformVertex
enclosingVertex) {
+ super(node, enclosingVertex);
+ }
+
+ /**
+ * Finalize this vertex and make it ready to be added to another {@link
CompositeTransformVertex}.
+ */
+ private void build() {
+ if (dag != null) {
+ throw new RuntimeException("DAG already have been built.");
+ }
+ dataFlowEdges.forEach(builder::connectVertices);
+ dag = builder.build();
+ }
+
+ /**
+ * Add a {@link TransformVertex}.
+ *
+ * @param vertex the vertex to add
+ */
+ private void addVertex(final TransformVertex vertex) {
+ vertex.getPValuesProduced().forEach(value -> pValueToProducer.put(value,
vertex));
+ builder.addVertex(vertex);
+ }
+
+ /**
+ * Add a {@link DataFlowEdge}.
+ *
+ * @param dataFlowEdge the edge to add
+ */
+ private void addDataFlow(final DataFlowEdge dataFlowEdge) {
+ dataFlowEdges.add(dataFlowEdge);
+ }
+
+ @Override
+ public Collection<PValue> getPValuesProduced() {
+ return pValueToProducer.keySet();
+ }
+
+ /**
+ * Get a direct child of this vertex which produces the given {@link
PValue}.
+ *
+ * @param pValue the {@link PValue} to search
+ * @return the direct child of this vertex which produces {@code pValue}
+ */
+ public TransformVertex getProducerOf(final PValue pValue) {
+ final TransformVertex vertex = pValueToProducer.get(pValue);
+ if (vertex == null) {
+ throw new RuntimeException();
+ }
+ return vertex;
+ }
+
+ @Override
+ public PrimitiveTransformVertex getPrimitiveProducerOf(final PValue
pValue) {
+ return getProducerOf(pValue).getPrimitiveProducerOf(pValue);
+ }
+
+ /**
+ * @return DAG of Beam hierarchy
+ */
+ public DAG<TransformVertex, DataFlowEdge> getDAG() {
+ return dag;
+ }
+ }
+
+ /**
+ * Represents data flow from a transform to another transform.
+ */
+ public final class DataFlowEdge extends Edge<TransformVertex> {
+ /**
+ * @param src source vertex
+ * @param dst destination vertex
+ */
+ private DataFlowEdge(final TransformVertex src, final TransformVertex dst)
{
+ super(String.format("%s%d", DATAFLOW, nextIdx++), src, dst);
+ }
+ }
+
+ /**
+ * @param primitiveConsumer a {@link PrimitiveTransformVertex} which
consumes {@code pValue}
+ * @param pValue the specified {@link PValue}
+ * @return the closest {@link TransformVertex} to {@code primitiveConsumer},
+ * which is equal to or encloses {@code primitiveConsumer} and can
be the destination vertex of
+ * data flow edge from the producer of {@code pValue} to {@code
primitiveConsumer}.
+ */
+ private TransformVertex getDestinationOfDataFlowEdge(final
PrimitiveTransformVertex primitiveConsumer,
+ final PValue pValue) {
+ TransformVertex current = primitiveConsumer;
+ while (true) {
+ if (current.getEnclosingVertex().getPValuesProduced().contains(pValue)) {
+ return current;
+ }
+ current = current.getEnclosingVertex();
+ if (current.getEnclosingVertex() == null) {
+ throw new RuntimeException(String.format("Cannot find producer of %s",
pValue));
+ }
+ }
+ }
+}
diff --git
a/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendALSTest.java
b/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendALSTest.java
index 67f146a55..12c82446a 100644
---
a/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendALSTest.java
+++
b/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendALSTest.java
@@ -38,22 +38,22 @@ public void testALSDAG() throws Exception {
final DAG<IRVertex, IREdge> producedDAG = CompilerTestUtil.compileALSDAG();
assertEquals(producedDAG.getTopologicalSort(),
producedDAG.getTopologicalSort());
- assertEquals(38, producedDAG.getVertices().size());
+ assertEquals(42, producedDAG.getVertices().size());
// producedDAG.getTopologicalSort().forEach(v ->
System.out.println(v.getId()));
- final IRVertex vertex4 = producedDAG.getTopologicalSort().get(6);
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex4).size());
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex4.getId()).size());
- assertEquals(4, producedDAG.getOutgoingEdgesOf(vertex4).size());
+ final IRVertex vertex11 = producedDAG.getTopologicalSort().get(5);
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex11).size());
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex11.getId()).size());
+ assertEquals(4, producedDAG.getOutgoingEdgesOf(vertex11).size());
- final IRVertex vertex13 = producedDAG.getTopologicalSort().get(11);
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex13).size());
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex13.getId()).size());
- assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex13).size());
+ final IRVertex vertex17 = producedDAG.getTopologicalSort().get(10);
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex17).size());
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex17.getId()).size());
+ assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex17).size());
- final IRVertex vertex14 = producedDAG.getTopologicalSort().get(12);
- assertEquals(2, producedDAG.getIncomingEdgesOf(vertex14).size());
- assertEquals(2, producedDAG.getIncomingEdgesOf(vertex14.getId()).size());
- assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex14).size());
+ final IRVertex vertex18 = producedDAG.getTopologicalSort().get(16);
+ assertEquals(2, producedDAG.getIncomingEdgesOf(vertex18).size());
+ assertEquals(2, producedDAG.getIncomingEdgesOf(vertex18.getId()).size());
+ assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex18).size());
}
}
diff --git
a/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendMLRTest.java
b/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendMLRTest.java
index 0cb3a2670..7f4d5910d 100644
---
a/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendMLRTest.java
+++
b/compiler/test/src/test/java/edu/snu/nemo/compiler/frontend/beam/BeamFrontendMLRTest.java
@@ -38,21 +38,21 @@ public void testMLRDAG() throws Exception {
final DAG<IRVertex, IREdge> producedDAG = CompilerTestUtil.compileMLRDAG();
assertEquals(producedDAG.getTopologicalSort(),
producedDAG.getTopologicalSort());
- assertEquals(36, producedDAG.getVertices().size());
+ assertEquals(42, producedDAG.getVertices().size());
- final IRVertex vertex3 = producedDAG.getTopologicalSort().get(0);
- assertEquals(0, producedDAG.getIncomingEdgesOf(vertex3).size());
- assertEquals(0, producedDAG.getIncomingEdgesOf(vertex3.getId()).size());
- assertEquals(3, producedDAG.getOutgoingEdgesOf(vertex3).size());
+ final IRVertex vertex1 = producedDAG.getTopologicalSort().get(5);
+ assertEquals(0, producedDAG.getIncomingEdgesOf(vertex1).size());
+ assertEquals(0, producedDAG.getIncomingEdgesOf(vertex1.getId()).size());
+ assertEquals(3, producedDAG.getOutgoingEdgesOf(vertex1).size());
- final IRVertex vertex13 = producedDAG.getTopologicalSort().get(11);
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex13).size());
- assertEquals(1, producedDAG.getIncomingEdgesOf(vertex13.getId()).size());
- assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex13).size());
+ final IRVertex vertex15 = producedDAG.getTopologicalSort().get(13);
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex15).size());
+ assertEquals(1, producedDAG.getIncomingEdgesOf(vertex15.getId()).size());
+ assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex15).size());
- final IRVertex vertex19 = producedDAG.getTopologicalSort().get(17);
- assertEquals(2, producedDAG.getIncomingEdgesOf(vertex19).size());
- assertEquals(2, producedDAG.getIncomingEdgesOf(vertex19.getId()).size());
- assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex19).size());
+ final IRVertex vertex21 = producedDAG.getTopologicalSort().get(19);
+ assertEquals(2, producedDAG.getIncomingEdgesOf(vertex21).size());
+ assertEquals(2, producedDAG.getIncomingEdgesOf(vertex21.getId()).size());
+ assertEquals(1, producedDAG.getOutgoingEdgesOf(vertex21).size());
}
}
diff --git
a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/TransientResourceCompositePassTest.java
b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/TransientResourceCompositePassTest.java
index e035241a9..97a5cd27a 100644
---
a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/TransientResourceCompositePassTest.java
+++
b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/TransientResourceCompositePassTest.java
@@ -53,37 +53,37 @@ public void testTransientResourcePass() throws Exception {
final IRVertex vertex1 = processedDAG.getTopologicalSort().get(0);
assertEquals(ResourcePriorityProperty.TRANSIENT,
vertex1.getPropertyValue(ResourcePriorityProperty.class).get());
- final IRVertex vertex5 = processedDAG.getTopologicalSort().get(1);
- assertEquals(ResourcePriorityProperty.TRANSIENT,
vertex5.getPropertyValue(ResourcePriorityProperty.class).get());
- processedDAG.getIncomingEdgesOf(vertex5).forEach(irEdge -> {
+ final IRVertex vertex2 = processedDAG.getTopologicalSort().get(11);
+ assertEquals(ResourcePriorityProperty.TRANSIENT,
vertex2.getPropertyValue(ResourcePriorityProperty.class).get());
+ processedDAG.getIncomingEdgesOf(vertex2).forEach(irEdge -> {
assertEquals(DataStoreProperty.Value.MemoryStore,
irEdge.getPropertyValue(DataStoreProperty.class).get());
assertEquals(DataFlowProperty.Value.Pull,
irEdge.getPropertyValue(DataFlowProperty.class).get());
});
- final IRVertex vertex6 = processedDAG.getTopologicalSort().get(2);
- assertEquals(ResourcePriorityProperty.RESERVED,
vertex6.getPropertyValue(ResourcePriorityProperty.class).get());
- processedDAG.getIncomingEdgesOf(vertex6).forEach(irEdge -> {
+ final IRVertex vertex5 = processedDAG.getTopologicalSort().get(14);
+ assertEquals(ResourcePriorityProperty.RESERVED,
vertex5.getPropertyValue(ResourcePriorityProperty.class).get());
+ processedDAG.getIncomingEdgesOf(vertex5).forEach(irEdge -> {
assertEquals(DataStoreProperty.Value.LocalFileStore,
irEdge.getPropertyValue(DataStoreProperty.class).get());
assertEquals(DataFlowProperty.Value.Push,
irEdge.getPropertyValue(DataFlowProperty.class).get());
});
- final IRVertex vertex4 = processedDAG.getTopologicalSort().get(6);
- assertEquals(ResourcePriorityProperty.RESERVED,
vertex4.getPropertyValue(ResourcePriorityProperty.class).get());
- processedDAG.getIncomingEdgesOf(vertex4).forEach(irEdge -> {
+ final IRVertex vertex11 = processedDAG.getTopologicalSort().get(5);
+ assertEquals(ResourcePriorityProperty.RESERVED,
vertex11.getPropertyValue(ResourcePriorityProperty.class).get());
+ processedDAG.getIncomingEdgesOf(vertex11).forEach(irEdge -> {
assertEquals(DataStoreProperty.Value.MemoryStore,
irEdge.getPropertyValue(DataStoreProperty.class).get());
assertEquals(DataFlowProperty.Value.Pull,
irEdge.getPropertyValue(DataFlowProperty.class).get());
});
- final IRVertex vertex13 = processedDAG.getTopologicalSort().get(11);
- assertEquals(ResourcePriorityProperty.TRANSIENT,
vertex13.getPropertyValue(ResourcePriorityProperty.class).get());
- processedDAG.getIncomingEdgesOf(vertex13).forEach(irEdge -> {
+ final IRVertex vertex17 = processedDAG.getTopologicalSort().get(10);
+ assertEquals(ResourcePriorityProperty.TRANSIENT,
vertex17.getPropertyValue(ResourcePriorityProperty.class).get());
+ processedDAG.getIncomingEdgesOf(vertex17).forEach(irEdge -> {
assertEquals(DataStoreProperty.Value.LocalFileStore,
irEdge.getPropertyValue(DataStoreProperty.class).get());
assertEquals(DataFlowProperty.Value.Pull,
irEdge.getPropertyValue(DataFlowProperty.class).get());
});
- final IRVertex vertex15 = processedDAG.getTopologicalSort().get(13);
- assertEquals(ResourcePriorityProperty.RESERVED,
vertex15.getPropertyValue(ResourcePriorityProperty.class).get());
- processedDAG.getIncomingEdgesOf(vertex15).forEach(irEdge -> {
+ final IRVertex vertex19 = processedDAG.getTopologicalSort().get(17);
+ assertEquals(ResourcePriorityProperty.RESERVED,
vertex19.getPropertyValue(ResourcePriorityProperty.class).get());
+ processedDAG.getIncomingEdgesOf(vertex19).forEach(irEdge -> {
assertEquals(DataStoreProperty.Value.LocalFileStore,
irEdge.getPropertyValue(DataStoreProperty.class).get());
assertEquals(DataFlowProperty.Value.Push,
irEdge.getPropertyValue(DataFlowProperty.class).get());
});
diff --git
a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPassTest.java
b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPassTest.java
index acccd73f3..7e0c42500 100644
---
a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPassTest.java
+++
b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopExtractionPassTest.java
@@ -45,6 +45,6 @@ public void setUp() throws Exception {
public void testLoopGrouping() {
final DAG<IRVertex, IREdge> processedDAG = new
LoopExtractionPass().apply(compiledDAG);
- assertEquals(9, processedDAG.getTopologicalSort().size());
+ assertEquals(13, processedDAG.getTopologicalSort().size());
}
}
diff --git
a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionALSInefficientTest.java
b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionALSInefficientTest.java
index 64d5b7451..2162463d8 100644
---
a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionALSInefficientTest.java
+++
b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionALSInefficientTest.java
@@ -45,7 +45,7 @@ public void setUp() throws Exception {
@Test
public void testForInefficientALSDAG() throws Exception {
- final long expectedNumOfVertices = groupedDAG.getVertices().size() + 3;
+ final long expectedNumOfVertices = groupedDAG.getVertices().size() + 5;
final DAG<IRVertex, IREdge> processedDAG =
LoopOptimizations.getLoopInvariantCodeMotionPass()
.apply(groupedDAG);
diff --git
a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionPassTest.java
b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionPassTest.java
index 07210ee07..5cd7928f2 100644
---
a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionPassTest.java
+++
b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LoopInvariantCodeMotionPassTest.java
@@ -59,32 +59,32 @@ public void setUp() throws Exception {
assertTrue(alsLoopOpt.isPresent());
final LoopVertex alsLoop = alsLoopOpt.get();
- final IRVertex vertex7 = groupedDAG.getTopologicalSort().get(3);
- final IRVertex vertex14 = alsLoop.getDAG().getTopologicalSort().get(4);
+ final IRVertex vertex6 = groupedDAG.getTopologicalSort().get(11);
+ final IRVertex vertex18 = alsLoop.getDAG().getTopologicalSort().get(4);
- final Set<IREdge> oldDAGIncomingEdges =
alsLoop.getDagIncomingEdges().get(vertex14);
- final List<IREdge> newDAGIncomingEdge =
groupedDAG.getIncomingEdgesOf(vertex7);
+ final Set<IREdge> oldDAGIncomingEdges =
alsLoop.getDagIncomingEdges().get(vertex18);
+ final List<IREdge> newDAGIncomingEdge =
groupedDAG.getIncomingEdgesOf(vertex6);
- alsLoop.getDagIncomingEdges().remove(vertex14);
- alsLoop.getDagIncomingEdges().putIfAbsent(vertex7, new HashSet<>());
-
newDAGIncomingEdge.forEach(alsLoop.getDagIncomingEdges().get(vertex7)::add);
+ alsLoop.getDagIncomingEdges().remove(vertex18);
+ alsLoop.getDagIncomingEdges().putIfAbsent(vertex6, new HashSet<>());
+
newDAGIncomingEdge.forEach(alsLoop.getDagIncomingEdges().get(vertex6)::add);
- alsLoop.getNonIterativeIncomingEdges().remove(vertex14);
- alsLoop.getNonIterativeIncomingEdges().putIfAbsent(vertex7, new
HashSet<>());
-
newDAGIncomingEdge.forEach(alsLoop.getNonIterativeIncomingEdges().get(vertex7)::add);
+ alsLoop.getNonIterativeIncomingEdges().remove(vertex18);
+ alsLoop.getNonIterativeIncomingEdges().putIfAbsent(vertex6, new
HashSet<>());
+
newDAGIncomingEdge.forEach(alsLoop.getNonIterativeIncomingEdges().get(vertex6)::add);
- alsLoop.getBuilder().addVertex(vertex7);
+ alsLoop.getBuilder().addVertex(vertex6);
oldDAGIncomingEdges.forEach(alsLoop.getBuilder()::connectVertices);
final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
groupedDAG.topologicalDo(v -> {
- if (!v.equals(vertex7) && !v.equals(alsLoop)) {
+ if (!v.equals(vertex6) && !v.equals(alsLoop)) {
builder.addVertex(v);
groupedDAG.getIncomingEdgesOf(v).forEach(builder::connectVertices);
} else if (v.equals(alsLoop)) {
builder.addVertex(v);
groupedDAG.getIncomingEdgesOf(v).forEach(e -> {
- if (!e.getSrc().equals(vertex7)) {
+ if (!e.getSrc().equals(vertex6)) {
builder.connectVertices(e);
} else {
final Optional<IREdge> incomingEdge =
newDAGIncomingEdge.stream().findFirst();
diff --git
a/examples/beam/src/main/java/edu/snu/nemo/examples/beam/WordCount.java
b/examples/beam/src/main/java/edu/snu/nemo/examples/beam/WordCount.java
index 88ed4aeb9..3d3c556a3 100644
--- a/examples/beam/src/main/java/edu/snu/nemo/examples/beam/WordCount.java
+++ b/examples/beam/src/main/java/edu/snu/nemo/examples/beam/WordCount.java
@@ -56,8 +56,7 @@ public static void main(final String[] args) {
return KV.of(documentId, count);
}
}))
- .apply(GroupByKey.<String, Long>create())
- .apply(Combine.<String, Long, Long>groupedValues(Sum.ofLongs()))
+ .apply(Sum.longsPerKey())
.apply(MapElements.<KV<String, Long>, String>via(new
SimpleFunction<KV<String, Long>, String>() {
@Override
public String apply(final KV<String, Long> kv) {
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services