shpark closed pull request #81: [NEMO-151] Add OutputWriters for additional 
tagged outputs
URL: https://github.com/apache/incubator-nemo/pull/81
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/common/src/main/java/edu/snu/nemo/common/ContextImpl.java 
b/common/src/main/java/edu/snu/nemo/common/ContextImpl.java
index d1774d3b6..bfe8c0655 100644
--- a/common/src/main/java/edu/snu/nemo/common/ContextImpl.java
+++ b/common/src/main/java/edu/snu/nemo/common/ContextImpl.java
@@ -25,7 +25,7 @@
  */
 public final class ContextImpl implements Transform.Context {
   private final Map sideInputs;
-  private final Map additionalTagOutputs;
+  private final Map<String, String> additionalTagOutputs;
   private String data;
 
   /**
@@ -44,7 +44,7 @@ public Map getSideInputs() {
   }
 
   @Override
-  public Map getAdditionalTagOutputs() {
+  public Map<String, String> getAdditionalTagOutputs() {
     return this.additionalTagOutputs;
   }
 
diff --git 
a/common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/AdditionalTagOutputProperty.java
 
b/common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/AdditionalOutputTagProperty.java
similarity index 54%
rename from 
common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/AdditionalTagOutputProperty.java
rename to 
common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/AdditionalOutputTagProperty.java
index dd99100d1..db1b3717e 100644
--- 
a/common/src/main/java/edu/snu/nemo/common/ir/vertex/executionproperty/AdditionalTagOutputProperty.java
+++ 
b/common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/AdditionalOutputTagProperty.java
@@ -13,30 +13,29 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package edu.snu.nemo.common.ir.vertex.executionproperty;
+package edu.snu.nemo.common.ir.edge.executionproperty;
 
-import edu.snu.nemo.common.ir.executionproperty.VertexExecutionProperty;
-
-import java.util.HashMap;
+import edu.snu.nemo.common.ir.executionproperty.EdgeExecutionProperty;
 
 /**
- * AdditionalOutput Execution Property for vertex that outputs additional 
outputs.
+ * Additional Output Tag Execution Property for edge that contains tag for 
additional outputs.
  */
-public final class AdditionalTagOutputProperty extends 
VertexExecutionProperty<HashMap<String, String>> {
+public final class AdditionalOutputTagProperty extends 
EdgeExecutionProperty<String> {
+
   /**
    * Constructor.
-   * @param value map of tag to IRVertex id.
+   * @param value tag id of additional input.
    */
-  private AdditionalTagOutputProperty(final HashMap<String, String> value) {
+  private AdditionalOutputTagProperty(final String value) {
     super(value);
   }
 
   /**
    * Static method exposing constructor.
-   * @param value map of tag to IRVertex id.
+   * @param value tag id of additional input.
    * @return the newly created execution property.
    */
-  public static AdditionalTagOutputProperty of(final HashMap<String, String> 
value) {
-    return new AdditionalTagOutputProperty(value);
+  public static AdditionalOutputTagProperty of(final String value) {
+    return new AdditionalOutputTagProperty(value);
   }
 }
diff --git 
a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java 
b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
index d41f9ea68..871d08b68 100644
--- 
a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
+++ 
b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/Transform.java
@@ -61,7 +61,7 @@ default Object getTag() {
      * @return sideInputs.
      */
     Map getSideInputs();
-    Map getAdditionalTagOutputs();
+    Map<String, String> getAdditionalTagOutputs();
 
     /**
      * Put serialized data to send to the executor.
diff --git 
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
 
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
index ecf6c79ff..ebb06146f 100644
--- 
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
+++ 
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/NemoPipelineVisitor.java
@@ -16,10 +16,11 @@
 package edu.snu.nemo.compiler.frontend.beam;
 
 import edu.snu.nemo.common.Pair;
+import 
edu.snu.nemo.common.ir.edge.executionproperty.AdditionalOutputTagProperty;
 import 
edu.snu.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
 import edu.snu.nemo.common.ir.edge.executionproperty.DecoderProperty;
 import edu.snu.nemo.common.ir.edge.executionproperty.EncoderProperty;
-import 
edu.snu.nemo.common.ir.vertex.executionproperty.AdditionalTagOutputProperty;
+import edu.snu.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
 import edu.snu.nemo.common.ir.vertex.transform.Transform;
 import edu.snu.nemo.compiler.frontend.beam.coder.BeamDecoderFactory;
 import edu.snu.nemo.compiler.frontend.beam.coder.BeamEncoderFactory;
@@ -29,7 +30,6 @@
 import edu.snu.nemo.common.ir.vertex.IRVertex;
 import edu.snu.nemo.common.ir.vertex.LoopVertex;
 import edu.snu.nemo.common.ir.vertex.OperatorVertex;
-import edu.snu.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
 
 import edu.snu.nemo.compiler.frontend.beam.transform.*;
 
@@ -110,29 +110,18 @@ public void visitPrimitiveTransform(final 
TransformHierarchy.Node beamNode) {
 
     beamNode.getInputs().values().stream().filter(pValueToVertex::containsKey)
         .forEach(pValue -> {
+          final boolean isAdditionalOutput = pValueToTag.containsKey(pValue);
           final IRVertex src = pValueToVertex.get(pValue);
-          final IREdge edge = new IREdge(getEdgeCommunicationPattern(src, 
irVertex), src, irVertex);
+          final IREdge edge = new IREdge(getEdgeCommunicationPattern(src, 
irVertex), src, irVertex, false);
           final Pair<BeamEncoderFactory, BeamDecoderFactory> coderPair = 
pValueToCoder.get(pValue);
           edge.setProperty(EncoderProperty.of(coderPair.left()));
           edge.setProperty(DecoderProperty.of(coderPair.right()));
           edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
-          this.builder.connectVertices(edge);
-        });
-
-    // This exclusively updates execution property of vertices with additional 
tagged outputs.
-    beamNode.getInputs().values().stream().filter(pValueToTag::containsKey)
-        .forEach(pValue -> {
-          final IRVertex src = pValueToVertex.get(pValue);
-          final TupleTag tag = pValueToTag.get(pValue);
-          final HashMap<String, String> tagToVertex = new HashMap<>();
-          tagToVertex.put(tag.getId(), irVertex.getId());
-          if 
(!src.getPropertyValue(AdditionalTagOutputProperty.class).isPresent()) {
-            src.setProperty(AdditionalTagOutputProperty.of(tagToVertex));
-          } else {
-            final HashMap<String, String> prev = 
src.getPropertyValue(AdditionalTagOutputProperty.class).get();
-            prev.putAll(tagToVertex);
-            src.setProperty(AdditionalTagOutputProperty.of(prev));
+          // Apply AdditionalOutputTatProperty to edges that corresponds to 
additional outputs.
+          if (isAdditionalOutput) {
+            
edge.setProperty(AdditionalOutputTagProperty.of(pValueToTag.get(pValue).getId()));
           }
+          this.builder.connectVertices(edge);
         });
   }
 
@@ -235,8 +224,7 @@ private static void connectSideInputs(final 
DAGBuilder<IRVertex, IREdge> builder
     sideInputs.stream().filter(pValueToVertex::containsKey)
         .forEach(pValue -> {
           final IRVertex src = pValueToVertex.get(pValue);
-          final IREdge edge = new IREdge(getEdgeCommunicationPattern(src, 
irVertex),
-              src, irVertex, true);
+          final IREdge edge = new IREdge(getEdgeCommunicationPattern(src, 
irVertex), src, irVertex, true);
           final Pair<BeamEncoderFactory, BeamDecoderFactory> coder = 
pValueToCoder.get(pValue);
           edge.setProperty(EncoderProperty.of(coder.left()));
           edge.setProperty(DecoderProperty.of(coder.right()));
diff --git 
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleRelayReshapingPass.java
 
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleRelayReshapingPass.java
index dcbece7be..0e476ffa1 100644
--- 
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleRelayReshapingPass.java
+++ 
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/LargeShuffleRelayReshapingPass.java
@@ -60,6 +60,7 @@ public LargeShuffleRelayReshapingPass() {
             // before the vertex receiving shuffled data.
             final OperatorVertex iFileMergerVertex = new OperatorVertex(new 
RelayTransform());
             
iFileMergerVertex.getExecutionProperties().put(SkipSerDesProperty.of());
+
             builder.addVertex(iFileMergerVertex);
             final IREdge newEdgeToMerger = new 
IREdge(CommunicationPatternProperty.Value.Shuffle,
                 edge.getSrc(), iFileMergerVertex, edge.isSideInput());
diff --git 
a/examples/beam/src/main/java/edu/snu/nemo/examples/beam/PartitionWordsByLength.java
 
b/examples/beam/src/main/java/edu/snu/nemo/examples/beam/PartitionWordsByLength.java
index 5613e3ca6..e97b33b46 100644
--- 
a/examples/beam/src/main/java/edu/snu/nemo/examples/beam/PartitionWordsByLength.java
+++ 
b/examples/beam/src/main/java/edu/snu/nemo/examples/beam/PartitionWordsByLength.java
@@ -20,10 +20,7 @@
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.FlatMapElements;
-import org.apache.beam.sdk.transforms.MapElements;
-import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.*;
 import org.apache.beam.sdk.values.*;
 
 import java.util.Arrays;
@@ -52,11 +49,13 @@ public static void main(final String[] args) {
 
     // {} here is required for preserving type information.
     // Please see https://stackoverflow.com/a/48431397 for details.
-    final TupleTag<String> shortWordsTag = new TupleTag<String>() {
+    final TupleTag<KV<Integer, String>> shortWordsTag = new 
TupleTag<KV<Integer, String>>("short") {
     };
-    final TupleTag<Integer> longWordsTag = new TupleTag<Integer>() {
+    final TupleTag<KV<Integer, String>> longWordsTag = new 
TupleTag<KV<Integer, String>>("long") {
     };
-    final TupleTag<String> veryLongWordsTag = new TupleTag<String>() {
+    final TupleTag<String> veryLongWordsTag = new TupleTag<String>("very 
long") {
+    };
+    final TupleTag<String> veryVeryLongWordsTag = new TupleTag<String>("very 
very long") {
     };
 
     final Pipeline p = Pipeline.create(options);
@@ -70,27 +69,48 @@ public static void main(final String[] args) {
           @ProcessElement
           public void processElement(final ProcessContext c) {
             String word = c.element();
-            if (word.length() < 5) {
-              c.output(shortWordsTag, word);
-            } else if (word.length() < 8) {
-              c.output(longWordsTag, word.length());
+            if (word.length() < 6) {
+              c.output(shortWordsTag, KV.of(word.length(), word));
+            } else if (word.length() < 11) {
+              c.output(longWordsTag, KV.of(word.length(), word));
+            } else if (word.length() > 12) {
+              c.output(veryVeryLongWordsTag, word);
             } else {
-              c.output(veryLongWordsTag, word);
+              c.output(word);
             }
           }
         }).withOutputTags(veryLongWordsTag, TupleTagList
-            .of(longWordsTag)
-            .and(shortWordsTag)));
+            .of(shortWordsTag).and(longWordsTag).and(veryVeryLongWordsTag)));
 
-    PCollection<String> shortWords = results.get(shortWordsTag);
-    PCollection<String> longWordLengths = results
-        .get(longWordsTag)
-        .apply(MapElements.into(TypeDescriptors.strings()).via(i -> 
Integer.toString(i)));
+    PCollection<String> shortWords = results.get(shortWordsTag)
+        .apply(GroupByKey.create())
+        .apply(MapElements.via(new FormatLines()));
+    PCollection<String> longWords = results.get(longWordsTag)
+        .apply(GroupByKey.create())
+        .apply(MapElements.via(new FormatLines()));
     PCollection<String> veryLongWords = results.get(veryLongWordsTag);
+    PCollection<String> veryveryLongWords = results.get(veryVeryLongWordsTag);
 
     GenericSourceSink.write(shortWords, outputFilePath + "_short");
-    GenericSourceSink.write(longWordLengths, outputFilePath + "_long");
+    GenericSourceSink.write(longWords, outputFilePath + "_long");
     GenericSourceSink.write(veryLongWords, outputFilePath + "_very_long");
+    GenericSourceSink.write(veryveryLongWords, outputFilePath + 
"_very_very_long");
     p.run();
   }
+
+  /**
+   * Formats a key-value pair to a string.
+   */
+  static class FormatLines extends SimpleFunction<KV<Integer, 
Iterable<String>>, String> {
+    @Override
+    public String apply(final KV<Integer, Iterable<String>> input) {
+      final int length = input.getKey();
+      final StringBuilder sb = new StringBuilder();
+      for (final String word : input.getValue()) {
+        sb.append(length).append(": ").append(word).append('\n');
+      }
+
+      return sb.toString();
+    }
+  }
 }
diff --git 
a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/PartitionWordsByLengthITCase.java
 
b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/PartitionWordsByLengthITCase.java
index 96c83a74a..607d2dea8 100644
--- 
a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/PartitionWordsByLengthITCase.java
+++ 
b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/PartitionWordsByLengthITCase.java
@@ -19,6 +19,7 @@
 import edu.snu.nemo.common.test.ArgBuilder;
 import edu.snu.nemo.common.test.ExampleTestUtil;
 import edu.snu.nemo.examples.beam.policy.DefaultPolicyParallelismFive;
+import edu.snu.nemo.examples.beam.policy.LargeShufflePolicyParallelismFive;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -46,9 +47,8 @@
   @Before
   public void setUp() throws Exception {
     builder = new ArgBuilder()
-      .addResourceJson(executorResourceFileName)
-      .addUserMain(PartitionWordsByLength.class.getCanonicalName())
-      .addUserArgs(inputFilePath, outputFilePath);
+        .addUserMain(PartitionWordsByLength.class.getCanonicalName())
+        .addUserArgs(inputFilePath, outputFilePath);
   }
 
   @After
@@ -65,8 +65,18 @@ public void tearDown() throws Exception {
   @Test (timeout = TIMEOUT)
   public void test() throws Exception {
     JobLauncher.main(builder
-      .addJobId(PartitionWordsByLength.class.getSimpleName())
-      
.addOptimizationPolicy(DefaultPolicyParallelismFive.class.getCanonicalName())
-      .build());
+        .addResourceJson(executorResourceFileName)
+        .addJobId(PartitionWordsByLengthITCase.class.getSimpleName())
+        
.addOptimizationPolicy(DefaultPolicyParallelismFive.class.getCanonicalName())
+        .build());
+  }
+
+  @Test (timeout = TIMEOUT)
+  public void testSailfish() throws Exception {
+    JobLauncher.main(builder
+        .addResourceJson(executorResourceFileName)
+        .addJobId(PartitionWordsByLengthITCase.class.getSimpleName() + 
"_sailfish")
+        
.addOptimizationPolicy(LargeShufflePolicyParallelismFive.class.getCanonicalName())
+        .build());
   }
 }
diff --git a/examples/resources/expected_output_tag_long 
b/examples/resources/expected_output_tag_long
index 91dea2c76..78c88bfda 100644
--- a/examples/resources/expected_output_tag_long
+++ b/examples/resources/expected_output_tag_long
@@ -1,2 +1,11 @@
-6
-6
+6: foobar
+6: barbaz
+
+7: abcdefg
+
+8: fooipsum
+
+9: foobarbaz
+
+10: ipsumlorem
+
diff --git a/examples/resources/expected_output_tag_short 
b/examples/resources/expected_output_tag_short
index 72594ed96..d9d497b49 100644
--- a/examples/resources/expected_output_tag_short
+++ b/examples/resources/expected_output_tag_short
@@ -1,3 +1,15 @@
-foo
-bar
-qux
+1: a
+
+2: to
+
+3: foo
+3: bar
+3: qux
+
+4: that
+4: this
+
+5: ipsum
+5: dolor
+5: loren
+
diff --git a/examples/resources/expected_output_tag_very_long 
b/examples/resources/expected_output_tag_very_long
index 22a28156a..f562b43e6 100644
--- a/examples/resources/expected_output_tag_very_long
+++ b/examples/resources/expected_output_tag_very_long
@@ -1,2 +1,2 @@
-foobarbaz
-ipsumlorem
+foobarbazqux
+bazquxfoobar
diff --git a/examples/resources/expected_output_tag_very_very_long 
b/examples/resources/expected_output_tag_very_very_long
new file mode 100644
index 000000000..5fa149acf
--- /dev/null
+++ b/examples/resources/expected_output_tag_very_very_long
@@ -0,0 +1,2 @@
+bazquxfoobarfoobaz
+bazquxfoobarfoobar
diff --git a/examples/resources/test_input_tag 
b/examples/resources/test_input_tag
index 0cd417beb..64934b8f6 100644
--- a/examples/resources/test_input_tag
+++ b/examples/resources/test_input_tag
@@ -1,7 +1,20 @@
+a
+to
 foo
 bar
+that
+ipsum
+dolor
 foobar
 barbaz
+abcdefg
+fooipsum
 foobarbaz
+bazquxfoobarfoobaz
+loren
+this
+foobarbazqux
+bazquxfoobar
+bazquxfoobarfoobar
+qux
 ipsumlorem
-qux
\ No newline at end of file
diff --git 
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
 
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
index fa95926f8..364a9383b 100644
--- 
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
+++ 
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
@@ -25,7 +25,6 @@
 import 
edu.snu.nemo.common.ir.vertex.executionproperty.DynamicOptimizationProperty;
 import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
 import edu.snu.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;
-import 
edu.snu.nemo.common.ir.vertex.executionproperty.AdditionalTagOutputProperty;
 import edu.snu.nemo.conf.JobConf;
 import edu.snu.nemo.common.dag.DAG;
 import edu.snu.nemo.common.dag.DAGBuilder;
@@ -59,7 +58,6 @@ private PhysicalPlanGenerator(final StagePartitioner 
stagePartitioner,
     this.dagDirectory = dagDirectory;
     this.stagePartitioner = stagePartitioner;
     stagePartitioner.addIgnoredPropertyKey(DynamicOptimizationProperty.class);
-    stagePartitioner.addIgnoredPropertyKey(AdditionalTagOutputProperty.class);
   }
 
   /**
diff --git 
a/runtime/common/src/test/java/edu/snu/nemo/runtime/common/plan/StagePartitionerTest.java
 
b/runtime/common/src/test/java/edu/snu/nemo/runtime/common/plan/StagePartitionerTest.java
index 48a47d0e7..7939cbd6d 100644
--- 
a/runtime/common/src/test/java/edu/snu/nemo/runtime/common/plan/StagePartitionerTest.java
+++ 
b/runtime/common/src/test/java/edu/snu/nemo/runtime/common/plan/StagePartitionerTest.java
@@ -22,7 +22,6 @@
 import edu.snu.nemo.common.ir.vertex.IRVertex;
 import edu.snu.nemo.common.ir.vertex.OperatorVertex;
 import edu.snu.nemo.common.ir.vertex.executionproperty.*;
-import 
edu.snu.nemo.common.ir.vertex.executionproperty.ResourcePriorityProperty;
 import org.apache.reef.tang.Tang;
 import org.apache.reef.tang.exceptions.InjectionException;
 import org.junit.Before;
@@ -47,7 +46,6 @@
   public void setup() throws InjectionException {
     stagePartitioner = 
Tang.Factory.getTang().newInjector().getInstance(StagePartitioner.class);
     stagePartitioner.addIgnoredPropertyKey(DynamicOptimizationProperty.class);
-    stagePartitioner.addIgnoredPropertyKey(AdditionalTagOutputProperty.class);
   }
 
   /**
diff --git 
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
 
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
index 75207302d..809d74787 100644
--- 
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
+++ 
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
@@ -20,8 +20,8 @@
 import edu.snu.nemo.common.Pair;
 import edu.snu.nemo.common.dag.DAG;
 import edu.snu.nemo.common.ir.Readable;
+import 
edu.snu.nemo.common.ir.edge.executionproperty.AdditionalOutputTagProperty;
 import edu.snu.nemo.common.ir.vertex.*;
-import 
edu.snu.nemo.common.ir.vertex.executionproperty.AdditionalTagOutputProperty;
 import edu.snu.nemo.common.ir.vertex.transform.Transform;
 import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
 import edu.snu.nemo.runtime.common.comm.ControlMessage;
@@ -154,21 +154,26 @@ public TaskExecutor(final Task task,
           .map(RuntimeEdge::isSideInput)
           .collect(Collectors.toList());
 
-      final Map<String, String> additionalOutputMap = irVertex
-          .getPropertyValue(AdditionalTagOutputProperty.class).orElse(new 
HashMap<>());
+      final Map<String, String> additionalOutputMap =
+          getAdditionalOutputMap(irVertex, task.getTaskOutgoingEdges(), 
irVertexDag);
       final List<Boolean> isToAdditionalTagOutputs = children.stream()
           .map(harness -> harness.getIRVertex().getId())
           .map(additionalOutputMap::containsValue)
           .collect(Collectors.toList());
 
       // Handle writes
-      final List<OutputWriter> childrenTaskWriters = getChildrenTaskWriters(
-          taskIndex, irVertex, task.getTaskOutgoingEdges(), 
dataTransferFactory); // Children-task write
+      // Main output children task writes
+      final List<OutputWriter> mainChildrenTaskWriters = 
getMainChildrenTaskWriters(
+          taskIndex, irVertex, task.getTaskOutgoingEdges(), 
dataTransferFactory, additionalOutputMap);
+      // Additional output children task writes
+      final Map<String, OutputWriter> additionalChildrenTaskWriters = 
getAdditionalChildrenTaskWriters(
+          taskIndex, irVertex, task.getTaskOutgoingEdges(), 
dataTransferFactory, additionalOutputMap);
       final List<String> additionalOutputVertices = new 
ArrayList<>(additionalOutputMap.values());
       final OutputCollectorImpl oci = new 
OutputCollectorImpl(additionalOutputVertices);
+      // intra-vertex writes
       final VertexHarness vertexHarness = new VertexHarness(irVertex, oci, 
children,
-          isToSideInputs, isToAdditionalTagOutputs,
-          childrenTaskWriters, new ContextImpl(sideInputMap, 
additionalOutputMap)); // Intra-vertex write
+          isToSideInputs, isToAdditionalTagOutputs, mainChildrenTaskWriters, 
additionalChildrenTaskWriters,
+          new ContextImpl(sideInputMap, additionalOutputMap));
       prepareTransform(vertexHarness);
       vertexIdToHarness.put(irVertex.getId(), vertexHarness);
 
@@ -222,10 +227,11 @@ private void processElementRecursively(final 
VertexHarness vertexHarness, final
     }
 
     // Recursively process all of the additional output elements.
-    vertexHarness.getAdditionalTagOutputChildren().keySet().forEach(tag -> {
-      while (!outputCollector.isEmpty(tag)) {
-        final Object element = outputCollector.remove(tag);
-        handleAdditionalOutputElement(vertexHarness, element, tag); // 
Recursion
+    
vertexHarness.getContext().getAdditionalTagOutputs().values().forEach(value -> {
+      final String dstVertexId = (String) value;
+      while (!outputCollector.isEmpty(dstVertexId)) {
+        final Object element = outputCollector.remove(dstVertexId);
+        handleAdditionalOutputElement(vertexHarness, element, dstVertexId); // 
Recursion
       }
     });
   }
@@ -317,28 +323,49 @@ private void doExecute() {
 
   private void finalizeVertex(final VertexHarness vertexHarness) {
     closeTransform(vertexHarness);
-    while (!vertexHarness.getOutputCollector().isEmpty()) {
-      final Object element = vertexHarness.getOutputCollector().remove();
+    final OutputCollectorImpl outputCollector = 
vertexHarness.getOutputCollector();
+
+    // handle main outputs
+    while (!outputCollector.isEmpty()) {
+      final Object element = outputCollector.remove();
       handleMainOutputElement(vertexHarness, element);
     }
+
+    // handle additional tagged outputs
+    vertexHarness.getAdditionalTagOutputChildren().keySet().forEach(tag -> {
+      while (!outputCollector.isEmpty(tag)) {
+        final Object element = outputCollector.remove(tag);
+        handleAdditionalOutputElement(vertexHarness, element, tag);
+      }
+    });
     finalizeOutputWriters(vertexHarness);
   }
 
   private void handleMainOutputElement(final VertexHarness harness, final 
Object element) {
-    harness.getWritersToChildrenTasks().forEach(outputWriter -> {
+    // writes to children tasks
+    harness.getWritersToMainChildrenTasks().forEach(outputWriter -> {
       outputWriter.write(element);
     });
+    // writes to side input children tasks
     if (harness.getSideInputChildren().size() > 0) {
       sideInputMap.put(((OperatorVertex) 
harness.getIRVertex()).getTransform().getTag(), element);
     }
+    // process elements in the next vertices within a task
     harness.getNonSideInputChildren().forEach(child -> 
processElementRecursively(child, element));
   }
 
   private void handleAdditionalOutputElement(final VertexHarness harness, 
final Object element, final String tag) {
-    // Inter-task writes are currently not supported.
+    // writes to additional children tasks
+    harness.getWritersToAdditionalChildrenTasks().entrySet().stream()
+        .filter(kv -> kv.getKey().equals(tag))
+        .forEach(kv -> {
+          kv.getValue().write(element);
+        });
+    // writes to side input children tasks
     if (harness.getSideInputChildren().size() > 0) {
       sideInputMap.put(((OperatorVertex) 
harness.getIRVertex()).getTransform().getTag(), element);
     }
+    // process elements in the next vertices within a task
     harness.getAdditionalTagOutputChildren().entrySet().stream()
         .filter(kv -> kv.getKey().equals(tag))
         .forEach(kv -> processElementRecursively(kv.getValue(), element));
@@ -354,7 +381,7 @@ private boolean handleDataFetchers(final List<DataFetcher> 
fetchers) {
       // For this looping of available fetchers.
       int finishedFetcherIndex = NONE_FINISHED;
       for (int i = 0; i < availableFetchers.size(); i++) {
-        final DataFetcher dataFetcher = fetchers.get(i);
+        final DataFetcher dataFetcher = availableFetchers.get(i);
         final Object element;
         try {
           element = dataFetcher.fetchDataElement();
@@ -393,6 +420,31 @@ private boolean handleDataFetchers(final List<DataFetcher> 
fetchers) {
 
   ////////////////////////////////////////////// Helper methods for setting up 
initial data structures
 
+  private Map<String, String> getAdditionalOutputMap(final IRVertex irVertex,
+                                                     final List<StageEdge> 
outEdgesToChildrenTasks,
+                                                     final DAG<IRVertex, 
RuntimeEdge<IRVertex>> irVertexDag) {
+    final Map<String, String> additionalOutputMap = new HashMap<>();
+
+    // Add all intra-task additional tags to additional output map.
+    irVertexDag.getOutgoingEdgesOf(irVertex.getId())
+        .stream()
+        .filter(edge -> 
edge.getPropertyValue(AdditionalOutputTagProperty.class).isPresent())
+        .map(edge ->
+            
Pair.of(edge.getPropertyValue(AdditionalOutputTagProperty.class).get(), 
edge.getDst().getId()))
+        .forEach(pair -> additionalOutputMap.put(pair.left(), pair.right()));
+
+    // Add all inter-task additional tags to additional output map.
+    outEdgesToChildrenTasks
+        .stream()
+        .filter(edge -> edge.getSrcIRVertex().getId().equals(irVertex.getId()))
+        .filter(edge -> 
edge.getPropertyValue(AdditionalOutputTagProperty.class).isPresent())
+        .map(edge ->
+            
Pair.of(edge.getPropertyValue(AdditionalOutputTagProperty.class).get(), 
edge.getDstIRVertex().getId()))
+        .forEach(pair -> additionalOutputMap.put(pair.left(), pair.right()));
+
+    return additionalOutputMap;
+  }
+
   private Optional<Readable> getSourceVertexReader(final IRVertex irVertex,
                                                    final Map<String, Readable> 
irVertexIdToReadable) {
     if (irVertex instanceof SourceVertex) {
@@ -418,18 +470,58 @@ private boolean handleDataFetchers(final 
List<DataFetcher> fetchers) {
         .collect(Collectors.toList());
   }
 
-  private List<OutputWriter> getChildrenTaskWriters(final int taskIndex,
-                                                    final IRVertex irVertex,
-                                                    final List<StageEdge> 
outEdgesToChildrenTasks,
-                                                    final DataTransferFactory 
dataTransferFactory) {
+  /**
+   * Return inter-task OutputWriters, for single output or output associated 
with main tag.
+   * @param taskIndex               current task index
+   * @param irVertex                source irVertex
+   * @param outEdgesToChildrenTasks outgoing edges to child tasks
+   * @param dataTransferFactory     dataTransferFactory
+   * @param taggedOutputs           tag to vertex id map
+   * @return OutputWriters for main children tasks
+   */
+  private List<OutputWriter> getMainChildrenTaskWriters(final int taskIndex,
+                                                        final IRVertex 
irVertex,
+                                                        final List<StageEdge> 
outEdgesToChildrenTasks,
+                                                        final 
DataTransferFactory dataTransferFactory,
+                                                        final Map<String, 
String> taggedOutputs) {
     return outEdgesToChildrenTasks
         .stream()
         .filter(outEdge -> 
outEdge.getSrcIRVertex().getId().equals(irVertex.getId()))
+        .filter(outEdge -> 
!taggedOutputs.containsValue(outEdge.getDstIRVertex().getId()))
         .map(outEdgeForThisVertex -> dataTransferFactory
             .createWriter(irVertex, taskIndex, 
outEdgeForThisVertex.getDstIRVertex(), outEdgeForThisVertex))
         .collect(Collectors.toList());
   }
 
+  /**
+   * Return inter-task OutputWriters associated with additional output tags.
+   * @param taskIndex               current task index
+   * @param irVertex                source irVertex
+   * @param outEdgesToChildrenTasks outgoing edges to child tasks
+   * @param dataTransferFactory     dataTransferFactory
+   * @param taggedOutputs           tag to vertex id map
+   * @return additional children vertex id to OutputWriters map.
+   */
+  private Map<String, OutputWriter> getAdditionalChildrenTaskWriters(final int 
taskIndex,
+                                                                     final 
IRVertex irVertex,
+                                                                     final 
List<StageEdge> outEdgesToChildrenTasks,
+                                                                     final 
DataTransferFactory dataTransferFactory,
+                                                                     final 
Map<String, String> taggedOutputs) {
+    final Map<String, OutputWriter> additionalChildrenTaskWriters = new 
HashMap<>();
+
+    outEdgesToChildrenTasks
+        .stream()
+        .filter(outEdge -> 
outEdge.getSrcIRVertex().getId().equals(irVertex.getId()))
+        .filter(outEdge -> 
taggedOutputs.containsValue(outEdge.getDstIRVertex().getId()))
+        .forEach(outEdgeForThisVertex -> {
+          
additionalChildrenTaskWriters.put(outEdgeForThisVertex.getDstIRVertex().getId(),
+              dataTransferFactory.createWriter(irVertex, taskIndex, 
outEdgeForThisVertex.getDstIRVertex(),
+                  outEdgeForThisVertex));
+        });
+
+    return additionalChildrenTaskWriters;
+  }
+
   private List<VertexHarness> getChildrenHarnesses(final IRVertex irVertex,
                                                    final DAG<IRVertex, 
RuntimeEdge<IRVertex>> irVertexDag,
                                                    final Map<String, 
VertexHarness> vertexIdToHarness) {
@@ -486,7 +578,15 @@ private void setIRVertexPutOnHold(final 
MetricCollectionBarrierVertex irVertex)
   private void finalizeOutputWriters(final VertexHarness vertexHarness) {
     final List<Long> writtenBytesList = new ArrayList<>();
 
-    vertexHarness.getWritersToChildrenTasks().forEach(outputWriter -> {
+    // finalize OutputWriters for main children
+    vertexHarness.getWritersToMainChildrenTasks().forEach(outputWriter -> {
+      outputWriter.close();
+      final Optional<Long> writtenBytes = outputWriter.getWrittenBytes();
+      writtenBytes.ifPresent(writtenBytesList::add);
+    });
+
+    // finalize OutputWriters for additional tagged children
+    
vertexHarness.getWritersToAdditionalChildrenTasks().values().forEach(outputWriter
 -> {
       outputWriter.close();
       final Optional<Long> writtenBytes = outputWriter.getWrittenBytes();
       writtenBytes.ifPresent(writtenBytesList::add);
diff --git 
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
 
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
index c5f9a7850..c79b53020 100644
--- 
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
+++ 
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
@@ -38,14 +38,16 @@
   private final List<VertexHarness> sideInputChildren;
   private final List<VertexHarness> nonSideInputChildren;
   private final Map<String, VertexHarness> additionalTagOutputChildren;
-  private final List<OutputWriter> writersToChildrenTasks;
+  private final List<OutputWriter> writersToMainChildrenTasks;
+  private final Map<String, OutputWriter> writersToAdditionalChildrenTasks;
 
   VertexHarness(final IRVertex irVertex,
                 final OutputCollectorImpl outputCollector,
                 final List<VertexHarness> children,
                 final List<Boolean> isSideInputs,
                 final List<Boolean> isAdditionalTagOutputs,
-                final List<OutputWriter> writersToChildrenTasks,
+                final List<OutputWriter> writersToMainChildrenTasks,
+                final Map<String, OutputWriter> 
writersToAdditionalChildrenTasks,
                 final Transform.Context context) {
     this.irVertex = irVertex;
     this.outputCollector = outputCollector;
@@ -71,7 +73,8 @@
     this.sideInputChildren = sides;
     this.nonSideInputChildren = nonSides;
     this.additionalTagOutputChildren = tagged;
-    this.writersToChildrenTasks = writersToChildrenTasks;
+    this.writersToMainChildrenTasks = writersToMainChildrenTasks;
+    this.writersToAdditionalChildrenTasks = writersToAdditionalChildrenTasks;
     this.context = context;
   }
 
@@ -111,10 +114,17 @@ OutputCollectorImpl getOutputCollector() {
   }
 
   /**
-   * @return OutputWriters of this irVertex. (empty if none exists)
+   * @return OutputWriters for main outputs of this irVertex. (empty if none 
exists)
    */
-  List<OutputWriter> getWritersToChildrenTasks() {
-    return writersToChildrenTasks;
+  List<OutputWriter> getWritersToMainChildrenTasks() {
+    return writersToMainChildrenTasks;
+  }
+
+  /**
+   * @return OutputWriters for additional tagged outputs of this irVertex. 
(empty if none exists)
+   */
+  Map<String, OutputWriter> getWritersToAdditionalChildrenTasks() {
+    return writersToAdditionalChildrenTasks;
   }
 
   /**
diff --git 
a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
 
b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
index 51e6c83e3..e2fea6cd2 100644
--- 
a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
+++ 
b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
@@ -20,16 +20,19 @@
 import edu.snu.nemo.common.dag.DAG;
 import edu.snu.nemo.common.dag.DAGBuilder;
 import edu.snu.nemo.common.ir.Readable;
+import edu.snu.nemo.common.ir.edge.IREdge;
+import 
edu.snu.nemo.common.ir.edge.executionproperty.AdditionalOutputTagProperty;
+import 
edu.snu.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
 import edu.snu.nemo.common.ir.edge.executionproperty.DataStoreProperty;
 import edu.snu.nemo.common.ir.executionproperty.VertexExecutionProperty;
 import edu.snu.nemo.common.ir.vertex.InMemorySourceVertex;
 import edu.snu.nemo.common.ir.vertex.OperatorVertex;
-import 
edu.snu.nemo.common.ir.vertex.executionproperty.AdditionalTagOutputProperty;
 import edu.snu.nemo.common.ir.vertex.transform.Transform;
 import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
 import edu.snu.nemo.common.ir.vertex.IRVertex;
 import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
 import edu.snu.nemo.runtime.common.message.PersistentConnectionToMasterMap;
+import edu.snu.nemo.runtime.common.plan.Stage;
 import edu.snu.nemo.runtime.common.plan.Task;
 import edu.snu.nemo.runtime.common.plan.StageEdge;
 import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
@@ -54,8 +57,6 @@
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyString;
@@ -67,7 +68,7 @@
  */
 @RunWith(PowerMockRunner.class)
 @PrepareForTest({InputReader.class, OutputWriter.class, 
DataTransferFactory.class,
-    TaskStateManager.class, StageEdge.class, 
PersistentConnectionToMasterMap.class})
+    TaskStateManager.class, StageEdge.class, 
PersistentConnectionToMasterMap.class, Stage.class, IREdge.class})
 public final class TaskExecutorTest {
   private static final int DATA_SIZE = 100;
   private static final ExecutionPropertyMap<VertexExecutionProperty> 
TASK_EXECUTION_PROPERTY_MAP
@@ -278,21 +279,21 @@ public void testAdditionalOutputs() throws Exception {
     final IRVertex bonusVertex1 = new OperatorVertex(new RelayTransform());
     final IRVertex bonusVertex2 = new OperatorVertex(new RelayTransform());
 
-    // Tag to vertex map. Mock tags are used.
-    HashMap<String, String> tagToVertex = new HashMap<>();
-    tagToVertex.put("bonus1", bonusVertex1.getId());
-    tagToVertex.put("bonus2", bonusVertex2.getId());
+    final RuntimeEdge<IRVertex> edge1 = createEdge(routerVertex, mainVertex, 
false, "edge-1");
+    final RuntimeEdge<IRVertex> edge2 = createEdge(routerVertex, bonusVertex1, 
false, "edge-2");
+    final RuntimeEdge<IRVertex> edge3 = createEdge(routerVertex, bonusVertex2, 
false, "edge-3");
 
-    routerVertex.setProperty(AdditionalTagOutputProperty.of(tagToVertex));
+    
edge2.getExecutionProperties().put(AdditionalOutputTagProperty.of("bonus1"));
+    
edge3.getExecutionProperties().put(AdditionalOutputTagProperty.of("bonus2"));
 
     final DAG<IRVertex, RuntimeEdge<IRVertex>> taskDag = new 
DAGBuilder<IRVertex, RuntimeEdge<IRVertex>>()
         .addVertex(routerVertex)
         .addVertex(mainVertex)
         .addVertex(bonusVertex1)
         .addVertex(bonusVertex2)
-        .connectVertices(createEdge(routerVertex, mainVertex, false, "edge-1"))
-        .connectVertices(createEdge(routerVertex, bonusVertex1, false, 
"edge-2"))
-        .connectVertices(createEdge(routerVertex, bonusVertex2, false, 
"edge-3"))
+        .connectVertices(edge1)
+        .connectVertices(edge2)
+        .connectVertices(edge3)
         .buildWithoutSourceSinkCheck();
 
     final Task task = new Task(
@@ -344,17 +345,23 @@ public void testAdditionalOutputs() throws Exception {
   }
 
   private StageEdge mockStageEdgeFrom(final IRVertex irVertex) {
-    final StageEdge edge = mock(StageEdge.class);
-    when(edge.getSrcIRVertex()).thenReturn(irVertex);
-    when(edge.getDstIRVertex()).thenReturn(new OperatorVertex(new 
RelayTransform()));
-    return edge;
+    return new StageEdge("runtime incoming edge id",
+        ExecutionPropertyMap.of(mock(IREdge.class), 
CommunicationPatternProperty.Value.OneToOne),
+        irVertex,
+        new OperatorVertex(new RelayTransform()),
+        mock(Stage.class),
+        mock(Stage.class),
+        false);
   }
 
   private StageEdge mockStageEdgeTo(final IRVertex irVertex) {
-    final StageEdge edge = mock(StageEdge.class);
-    when(edge.getSrcIRVertex()).thenReturn(new OperatorVertex(new 
RelayTransform()));
-    when(edge.getDstIRVertex()).thenReturn(irVertex);
-    return edge;
+    return new StageEdge("runtime outgoing edge id",
+        ExecutionPropertyMap.of(mock(IREdge.class), 
CommunicationPatternProperty.Value.OneToOne),
+        new OperatorVertex(new RelayTransform()),
+        irVertex,
+        mock(Stage.class),
+        mock(Stage.class),
+        false);
   }
 
   /**


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to