Repository: beam Updated Branches: refs/heads/master 34b38ef95 -> 9cc8018b3
http://git-wip-us.apache.org/repos/asf/beam/blob/8766b03e/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 31307cc..ccf84b2 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -332,20 +332,58 @@ final class StreamingTransformTranslator { }; } + private static <InputT, OutputT> TransformEvaluator<ParDo.Bound<InputT, OutputT>> parDo() { + return new TransformEvaluator<ParDo.Bound<InputT, OutputT>>() { + @Override + public void evaluate(final ParDo.Bound<InputT, OutputT> transform, + final EvaluationContext context) { + final DoFn<InputT, OutputT> doFn = transform.getFn(); + rejectSplittable(doFn); + rejectStateAndTimers(doFn); + final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); + final WindowingStrategy<?, ?> windowingStrategy = + context.getInput(transform).getWindowingStrategy(); + final SparkPCollectionView pviews = context.getPViews(); + + @SuppressWarnings("unchecked") + UnboundedDataset<InputT> unboundedDataset = + ((UnboundedDataset<InputT>) context.borrowDataset(transform)); + JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream(); + + final String stepName = context.getCurrentTransform().getFullName(); + + JavaDStream<WindowedValue<OutputT>> outStream = + dStream.transform(new Function<JavaRDD<WindowedValue<InputT>>, + JavaRDD<WindowedValue<OutputT>>>() { + @Override + public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd) throws + Exception { + final JavaSparkContext jsc = new JavaSparkContext(rdd.context()); + final Accumulator<NamedAggregators> aggAccum = + SparkAggregators.getNamedAggregators(jsc); + final Accumulator<SparkMetricsContainer> metricsAccum = + MetricsAccumulator.getInstance(); + final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), + jsc, pviews); + return rdd.mapPartitions( + new DoFnFunction<>(aggAccum, metricsAccum, stepName, doFn, runtimeContext, + sideInputs, windowingStrategy)); + } + }); + + context.putDataset(transform, + new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); + } + }; + } + private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>> multiDo() { return new TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>() { - public void evaluate( - final ParDo.BoundMulti<InputT, OutputT> transform, final EvaluationContext context) { - if (transform.getSideOutputTags().size() == 0) { - evaluateSingle(transform, context); - } else { - evaluateMulti(transform, context); - } - } - - private void evaluateMulti( - final ParDo.BoundMulti<InputT, OutputT> transform, final EvaluationContext context) { + @Override + public void evaluate(final ParDo.BoundMulti<InputT, OutputT> transform, + final EvaluationContext context) { final DoFn<InputT, OutputT> doFn = transform.getFn(); rejectSplittable(doFn); rejectStateAndTimers(doFn); @@ -389,60 +427,10 @@ final class StreamingTransformTranslator { JavaDStream<WindowedValue<Object>> values = (JavaDStream<WindowedValue<Object>>) (JavaDStream<?>) TranslationUtils.dStreamValues(filtered); - context.putDataset( - e.getValue(), new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); + context.putDataset(e.getValue(), + new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); } } - - private void evaluateSingle( - final ParDo.BoundMulti<InputT, OutputT> transform, final EvaluationContext context) { - final DoFn<InputT, OutputT> doFn = transform.getFn(); - rejectSplittable(doFn); - rejectStateAndTimers(doFn); - final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); - final WindowingStrategy<?, ?> windowingStrategy = - context.getInput(transform).getWindowingStrategy(); - final SparkPCollectionView pviews = context.getPViews(); - - @SuppressWarnings("unchecked") - UnboundedDataset<InputT> unboundedDataset = - ((UnboundedDataset<InputT>) context.borrowDataset(transform)); - JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream(); - - final String stepName = context.getCurrentTransform().getFullName(); - - JavaDStream<WindowedValue<OutputT>> outStream = - dStream.transform( - new Function<JavaRDD<WindowedValue<InputT>>, JavaRDD<WindowedValue<OutputT>>>() { - @Override - public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd) - throws Exception { - final JavaSparkContext jsc = new JavaSparkContext(rdd.context()); - final Accumulator<NamedAggregators> aggAccum = - SparkAggregators.getNamedAggregators(jsc); - final Accumulator<SparkMetricsContainer> metricsAccum = - MetricsAccumulator.getInstance(); - final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> - sideInputs = - TranslationUtils.getSideInputs(transform.getSideInputs(), jsc, pviews); - return rdd.mapPartitions( - new DoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - runtimeContext, - sideInputs, - windowingStrategy)); - } - }); - - PCollection<OutputT> output = - (PCollection<OutputT>) - Iterables.getOnlyElement(context.getOutputs(transform)).getValue(); - context.putDataset( - output, new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); - } }; } @@ -487,6 +475,7 @@ final class StreamingTransformTranslator { EVALUATORS.put(Read.Unbounded.class, readUnbounded()); EVALUATORS.put(GroupByKey.class, groupByKey()); EVALUATORS.put(Combine.GroupedValues.class, combineGrouped()); + EVALUATORS.put(ParDo.Bound.class, parDo()); EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); EVALUATORS.put(ConsoleIO.Write.Unbound.class, print()); EVALUATORS.put(CreateStream.class, createFromQueue()); http://git-wip-us.apache.org/repos/asf/beam/blob/8766b03e/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java index d66633b..b181a04 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java @@ -83,7 +83,7 @@ public class TrackStreamingSourcesTest { p.apply(emptyStream).apply(ParDo.of(new PassthroughFn<>())); - p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class, 0)); + p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class, 0)); assertThat(StreamingSourceTracker.numAssertions, equalTo(1)); } @@ -111,7 +111,7 @@ public class TrackStreamingSourcesTest { PCollectionList.of(pcol1).and(pcol2).apply(Flatten.<Integer>pCollections()); flattened.apply(ParDo.of(new PassthroughFn<>())); - p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class, 0, 1)); + p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class, 0, 1)); assertThat(StreamingSourceTracker.numAssertions, equalTo(1)); } http://git-wip-us.apache.org/repos/asf/beam/blob/8766b03e/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 9225231..19c5a2d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -738,8 +738,12 @@ public class ParDo { @Override public PCollection<OutputT> expand(PCollection<? extends InputT> input) { - TupleTag<OutputT> mainOutput = new TupleTag<>(); - return input.apply(withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput); + validateWindowType(input, fn); + return PCollection.<OutputT>createPrimitiveOutputInternal( + input.getPipeline(), + input.getWindowingStrategy(), + input.isBounded()) + .setTypeDescriptor(getFn().getOutputTypeDescriptor()); } @Override