This is an automated email from the ASF dual-hosted git repository. echauchot pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 2cd38984a35 [Spark Dataset runner] Reduce binary size of Java serialized task related for ParDo translation (#24543) 2cd38984a35 is described below commit 2cd38984a354c76ada42cb51f13a398babaf1b76 Author: Moritz Mack <mm...@talend.com> AuthorDate: Mon Dec 19 14:13:08 2022 +0100 [Spark Dataset runner] Reduce binary size of Java serialized task related for ParDo translation (#24543) * [Spark Dataset runner] Reduce binary size of Java serialized broadcasted task related for ParDo translation (related to #23845) --- .../batch/DoFnMapPartitionsFactory.java | 204 ---------------- .../batch/DoFnPartitionIteratorFactory.java | 272 +++++++++++++++++++++ .../translation/batch/ParDoTranslatorBatch.java | 125 ++++------ 3 files changed, 323 insertions(+), 278 deletions(-) diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java deleted file mode 100644 index a53e5ca3a79..00000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java +++ /dev/null @@ -1,204 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.batch; - -import static java.util.stream.Collectors.toCollection; -import static java.util.stream.Collectors.toMap; -import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.scalaIterator; -import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists.newArrayListWithCapacity; - -import java.io.Serializable; -import java.util.ArrayDeque; -import java.util.Deque; -import java.util.List; -import java.util.Map; -import java.util.function.Supplier; -import org.apache.beam.runners.core.DoFnRunner; -import org.apache.beam.runners.core.DoFnRunners; -import org.apache.beam.runners.core.DoFnRunners.OutputManager; -import org.apache.beam.runners.core.SideInputReader; -import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator; -import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.CachedSideInputReader; -import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext; -import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; -import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun2; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFnSchemaInformation; -import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator; -import org.apache.spark.api.java.function.MapPartitionsFunction; -import org.checkerframework.checker.nullness.qual.NonNull; -import scala.collection.Iterator; - -/** - * Encapsulates a {@link DoFn} inside a Spark {@link - * org.apache.spark.api.java.function.MapPartitionsFunction}. - */ -class DoFnMapPartitionsFactory<InT, OutT> implements Serializable { - private final String stepName; - - private final DoFn<InT, OutT> doFn; - private final DoFnSchemaInformation doFnSchema; - private final Supplier<PipelineOptions> options; - - private final Coder<InT> coder; - private final WindowingStrategy<?, ?> windowingStrategy; - private final TupleTag<OutT> mainOutput; - private final List<TupleTag<?>> additionalOutputs; - private final Map<TupleTag<?>, Coder<?>> outputCoders; - - private final Map<String, PCollectionView<?>> sideInputs; - private final SideInputReader sideInputReader; - - DoFnMapPartitionsFactory( - String stepName, - DoFn<InT, OutT> doFn, - DoFnSchemaInformation doFnSchema, - Supplier<PipelineOptions> options, - PCollection<InT> input, - TupleTag<OutT> mainOutput, - Map<TupleTag<?>, PCollection<?>> outputs, - Map<String, PCollectionView<?>> sideInputs, - SideInputReader sideInputReader) { - this.stepName = stepName; - this.doFn = doFn; - this.doFnSchema = doFnSchema; - this.options = options; - this.coder = input.getCoder(); - this.windowingStrategy = input.getWindowingStrategy(); - this.mainOutput = mainOutput; - this.additionalOutputs = additionalOutputs(outputs, mainOutput); - this.outputCoders = outputCoders(outputs); - this.sideInputs = sideInputs; - this.sideInputReader = sideInputReader; - } - - /** Create the {@link MapPartitionsFunction} using the provided output function. */ - <OutputT extends @NonNull Object> Fun1<Iterator<WindowedValue<InT>>, Iterator<OutputT>> create( - Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn) { - return it -> - it.hasNext() - ? scalaIterator(new DoFnPartitionIt<>(outputFn, it)) - : (Iterator<OutputT>) Iterator.empty(); - } - - // FIXME Add support for TimerInternals.TimerData - /** - * Partition iterator that lazily processes each element from the (input) iterator on demand - * producing zero, one or more output elements as output (via an internal buffer). - * - * <p>When initializing the iterator for a partition {@code setup} followed by {@code startBundle} - * is called. - */ - private class DoFnPartitionIt<FnInT extends InT, OutputT> extends AbstractIterator<OutputT> { - private final Deque<OutputT> buffer; - private final DoFnRunner<InT, OutT> doFnRunner; - private final Iterator<WindowedValue<FnInT>> partitionIt; - - private boolean isBundleFinished; - - DoFnPartitionIt( - Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn, - Iterator<WindowedValue<FnInT>> partitionIt) { - this.buffer = new ArrayDeque<>(); - this.doFnRunner = metricsRunner(simpleRunner(outputFn, buffer)); - this.partitionIt = partitionIt; - // Before starting to iterate over the partition, invoke setup and then startBundle - DoFnInvokers.tryInvokeSetupFor(doFn, options.get()); - try { - doFnRunner.startBundle(); - } catch (RuntimeException re) { - DoFnInvokers.invokerFor(doFn).invokeTeardown(); - throw re; - } - } - - @Override - protected OutputT computeNext() { - try { - while (true) { - if (!buffer.isEmpty()) { - return buffer.remove(); - } - if (partitionIt.hasNext()) { - // grab the next element and process it. - doFnRunner.processElement((WindowedValue<InT>) partitionIt.next()); - } else { - if (!isBundleFinished) { - isBundleFinished = true; - doFnRunner.finishBundle(); - continue; // finishBundle can produce more output - } - DoFnInvokers.invokerFor(doFn).invokeTeardown(); - return endOfData(); - } - } - } catch (RuntimeException re) { - DoFnInvokers.invokerFor(doFn).invokeTeardown(); - throw re; - } - } - } - - private <OutputT> DoFnRunner<InT, OutT> simpleRunner( - Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn, Deque<OutputT> buffer) { - OutputManager outputManager = - new OutputManager() { - @Override - public <T> void output(TupleTag<T> tag, WindowedValue<T> output) { - buffer.add(outputFn.apply(tag, output)); - } - }; - return DoFnRunners.simpleRunner( - options.get(), - doFn, - CachedSideInputReader.of(sideInputReader, sideInputs.values()), - outputManager, - mainOutput, - additionalOutputs, - new NoOpStepContext(), - coder, - outputCoders, - windowingStrategy, - doFnSchema, - sideInputs); - } - - private DoFnRunner<InT, OutT> metricsRunner(DoFnRunner<InT, OutT> runner) { - return new DoFnRunnerWithMetrics<>(stepName, runner, MetricsAccumulator.getInstance()); - } - - private static List<TupleTag<?>> additionalOutputs( - Map<TupleTag<?>, PCollection<?>> outputs, TupleTag<?> mainOutput) { - return outputs.keySet().stream() - .filter(t -> !t.equals(mainOutput)) - .collect(toCollection(() -> newArrayListWithCapacity(outputs.size() - 1))); - } - - private static Map<TupleTag<?>, Coder<?>> outputCoders(Map<TupleTag<?>, PCollection<?>> outputs) { - return outputs.entrySet().stream() - .collect(toMap(Map.Entry::getKey, e -> e.getValue().getCoder())); - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java new file mode 100644 index 00000000000..c760efd229c --- /dev/null +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.scalaIterator; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import java.io.Serializable; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.beam.runners.core.DoFnRunner; +import org.apache.beam.runners.core.DoFnRunners; +import org.apache.beam.runners.core.SideInputReader; +import org.apache.beam.runners.core.construction.ParDoTranslation; +import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.CachedSideInputReader; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.ParDo.MultiOutput; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Function1; +import scala.Tuple2; +import scala.collection.Iterator; + +/** + * Abstract factory to create a {@link DoFnPartitionIt DoFn partition iterator} using a customizable + * {@link DoFnRunners.OutputManager}. + */ +abstract class DoFnPartitionIteratorFactory<InT, FnOutT, OutT extends @NonNull Object> + implements Function1<Iterator<WindowedValue<InT>>, Iterator<OutT>>, Serializable { + private final String stepName; + private final DoFn<InT, FnOutT> doFn; + private final DoFnSchemaInformation doFnSchema; + private final Supplier<PipelineOptions> options; + private final Coder<InT> coder; + private final WindowingStrategy<?, ?> windowingStrategy; + private final TupleTag<FnOutT> mainOutput; + private final List<TupleTag<?>> additionalOutputs; + private final Map<TupleTag<?>, Coder<?>> outputCoders; + private final Map<String, PCollectionView<?>> sideInputs; + private final SideInputReader sideInputReader; + + private DoFnPartitionIteratorFactory( + AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, FnOutT>> appliedPT, + Supplier<PipelineOptions> options, + PCollection<InT> input, + SideInputReader sideInputReader) { + this.stepName = appliedPT.getFullName(); + this.doFn = appliedPT.getTransform().getFn(); + this.doFnSchema = ParDoTranslation.getSchemaInformation(appliedPT); + this.options = options; + this.coder = input.getCoder(); + this.windowingStrategy = input.getWindowingStrategy(); + this.mainOutput = appliedPT.getTransform().getMainOutputTag(); + this.additionalOutputs = additionalOutputs(appliedPT.getTransform()); + this.outputCoders = outputCoders(appliedPT.getOutputs()); + this.sideInputs = appliedPT.getTransform().getSideInputs(); + this.sideInputReader = sideInputReader; + } + + /** + * {@link DoFnPartitionIteratorFactory} emitting a single output of type {@link WindowedValue} of + * {@link OutT}. + */ + static <InT, OutT> DoFnPartitionIteratorFactory<InT, ?, WindowedValue<OutT>> singleOutput( + AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, OutT>> appliedPT, + Supplier<PipelineOptions> options, + PCollection<InT> input, + SideInputReader sideInputReader) { + return new SingleOut<>(appliedPT, options, input, sideInputReader); + } + + /** + * {@link DoFnPartitionIteratorFactory} emitting multiple outputs encoded as tuple of column index + * and {@link WindowedValue} of {@link OutT}, where column index corresponds to the index of a + * {@link TupleTag#getId()} in {@code tagColIdx}. + */ + static <InT, FnOutT, OutT> + DoFnPartitionIteratorFactory<InT, ?, Tuple2<Integer, WindowedValue<OutT>>> multiOutput( + AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, FnOutT>> appliedPT, + Supplier<PipelineOptions> options, + PCollection<InT> input, + SideInputReader sideInputReader, + Map<String, Integer> tagColIdx) { + return new MultiOut<>(appliedPT, options, input, sideInputReader, tagColIdx); + } + + @Override + public Iterator<OutT> apply(Iterator<WindowedValue<InT>> it) { + return it.hasNext() + ? scalaIterator(new DoFnPartitionIt(it)) + : (Iterator<OutT>) Iterator.empty(); + } + + /** Output manager emitting outputs of type {@link OutT} to the buffer. */ + abstract DoFnRunners.OutputManager outputManager(Deque<OutT> buffer); + + /** + * {@link DoFnPartitionIteratorFactory} emitting a single output of type {@link WindowedValue} of + * {@link OutT}. + */ + private static class SingleOut<InT, OutT> + extends DoFnPartitionIteratorFactory<InT, OutT, WindowedValue<OutT>> { + private SingleOut( + AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, OutT>> appliedPT, + Supplier<PipelineOptions> options, + PCollection<InT> input, + SideInputReader sideInputReader) { + super(appliedPT, options, input, sideInputReader); + } + + @Override + DoFnRunners.OutputManager outputManager(Deque<WindowedValue<OutT>> buffer) { + return new DoFnRunners.OutputManager() { + @Override + public <T> void output(TupleTag<T> tag, WindowedValue<T> output) { + buffer.add((WindowedValue<OutT>) output); + } + }; + } + } + + /** + * {@link DoFnPartitionIteratorFactory} emitting multiple outputs encoded as tuple of column index + * and {@link WindowedValue} of {@link OutT}, where column index corresponds to the index of a + * {@link TupleTag#getId()} in {@link #tagColIdx}. + */ + private static class MultiOut<InT, FnOutT, OutT> + extends DoFnPartitionIteratorFactory<InT, FnOutT, Tuple2<Integer, WindowedValue<OutT>>> { + private final Map<String, Integer> tagColIdx; + + public MultiOut( + AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, FnOutT>> appliedPT, + Supplier<PipelineOptions> options, + PCollection<InT> input, + SideInputReader sideInputReader, + Map<String, Integer> tagColIdx) { + super(appliedPT, options, input, sideInputReader); + this.tagColIdx = tagColIdx; + } + + @Override + DoFnRunners.OutputManager outputManager(Deque<Tuple2<Integer, WindowedValue<OutT>>> buffer) { + return new DoFnRunners.OutputManager() { + @Override + public <T> void output(TupleTag<T> tag, WindowedValue<T> output) { + Integer columnIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown tag %s", tag); + buffer.add(tuple(columnIdx, (WindowedValue<OutT>) output)); + } + }; + } + } + + // FIXME Add support for TimerInternals.TimerData + /** + * Partition iterator that lazily processes each element from the (input) iterator on demand + * producing zero, one or more output elements as output (via an internal buffer). + * + * <p>When initializing the iterator for a partition {@code setup} followed by {@code startBundle} + * is called. + */ + private class DoFnPartitionIt extends AbstractIterator<OutT> { + private final Deque<OutT> buffer = new ArrayDeque<>(); + private final DoFnRunner<InT, ?> doFnRunner = metricsRunner(simpleRunner(buffer)); + private final Iterator<WindowedValue<InT>> partitionIt; + private boolean isBundleFinished; + + private DoFnPartitionIt(Iterator<WindowedValue<InT>> partitionIt) { + this.partitionIt = partitionIt; + // Before starting to iterate over the partition, invoke setup and then startBundle + DoFnInvokers.tryInvokeSetupFor(doFn, options.get()); + try { + doFnRunner.startBundle(); + } catch (RuntimeException re) { + DoFnInvokers.invokerFor(doFn).invokeTeardown(); + throw re; + } + } + + @Override + protected OutT computeNext() { + try { + while (true) { + if (!buffer.isEmpty()) { + return buffer.remove(); + } + if (partitionIt.hasNext()) { + // grab the next element and process it. + doFnRunner.processElement(partitionIt.next()); + } else { + if (!isBundleFinished) { + isBundleFinished = true; + doFnRunner.finishBundle(); + continue; // finishBundle can produce more output + } + DoFnInvokers.invokerFor(doFn).invokeTeardown(); + return endOfData(); + } + } + } catch (RuntimeException re) { + DoFnInvokers.invokerFor(doFn).invokeTeardown(); + throw re; + } + } + } + + private DoFnRunner<InT, FnOutT> simpleRunner(Deque<OutT> buffer) { + return DoFnRunners.simpleRunner( + options.get(), + (DoFn<InT, FnOutT>) doFn, + CachedSideInputReader.of(sideInputReader, sideInputs.values()), + outputManager(buffer), + mainOutput, + additionalOutputs, + new NoOpStepContext(), + coder, + outputCoders, + windowingStrategy, + doFnSchema, + sideInputs); + } + + private DoFnRunner<InT, FnOutT> metricsRunner(DoFnRunner<InT, FnOutT> runner) { + return new DoFnRunnerWithMetrics<>(stepName, runner, MetricsAccumulator.getInstance()); + } + + private static Map<TupleTag<?>, Coder<?>> outputCoders(Map<TupleTag<?>, PCollection<?>> outputs) { + Map<TupleTag<?>, Coder<?>> coders = Maps.newHashMapWithExpectedSize(outputs.size()); + for (Map.Entry<TupleTag<?>, PCollection<?>> e : outputs.entrySet()) { + coders.put(e.getKey(), e.getValue().getCoder()); + } + return coders; + } + + private static List<TupleTag<?>> additionalOutputs(MultiOutput<?, ?> transform) { + List<TupleTag<?>> tags = transform.getAdditionalOutputTags().getAll(); + return tags.isEmpty() ? Collections.emptyList() : new ArrayList<>(tags); + } +} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java index d1e069c82d0..3083ff5101b 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java @@ -17,33 +17,29 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; -import static java.util.stream.Collectors.toList; import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.oneOfEncoder; -import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; -import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import static org.apache.spark.sql.functions.col; import static org.apache.spark.storage.StorageLevel.MEMORY_ONLY; import java.io.IOException; -import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Map.Entry; -import javax.annotation.Nullable; import org.apache.beam.runners.core.DoFnRunners; import org.apache.beam.runners.core.SideInputReader; -import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.spark.SparkCommonPipelineOptions; import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SideInputValues; import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; @@ -52,18 +48,15 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.TypedColumn; import org.apache.spark.storage.StorageLevel; -import scala.Function1; import scala.Tuple2; -import scala.collection.Iterator; +import scala.collection.TraversableOnce; import scala.reflect.ClassTag; /** @@ -115,40 +108,27 @@ class ParDoTranslatorBatch<InputT, OutputT> @Override public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt) throws IOException { - String stepName = cxt.getCurrentTransform().getFullName(); - - TupleTag<OutputT> mainOutputTag = transform.getMainOutputTag(); - - DoFnSchemaInformation doFnSchema = - ParDoTranslation.getSchemaInformation(cxt.getCurrentTransform()); PCollection<InputT> input = (PCollection<InputT>) cxt.getInput(); - Map<String, PCollectionView<?>> sideInputs = transform.getSideInputs(); Map<TupleTag<?>, PCollection<?>> outputs = cxt.getOutputs(); - DoFnMapPartitionsFactory<InputT, OutputT> factory = - new DoFnMapPartitionsFactory<>( - stepName, - transform.getFn(), - doFnSchema, - cxt.getOptionsSupplier(), - input, - mainOutputTag, - outputs, - sideInputs, - createSideInputReader(sideInputs.values(), cxt)); - Dataset<WindowedValue<InputT>> inputDs = cxt.getDataset(input); + SideInputReader sideInputReader = + createSideInputReader(transform.getSideInputs().values(), cxt); + if (outputs.size() > 1) { // In case of multiple outputs / tags, map each tag to a column by index. // At the end split the result into multiple datasets selecting one column each. - Map<TupleTag<?>, Integer> tags = ImmutableMap.copyOf(zipwithIndex(outputs.keySet())); - - List<Encoder<WindowedValue<Object>>> encoders = - createEncoders(outputs, (Iterable<TupleTag<?>>) tags.keySet(), cxt); + Map<String, Integer> tagColIdx = tagsColumnIndex((Collection<TupleTag<?>>) outputs.keySet()); + List<Encoder<WindowedValue<Object>>> encoders = createEncoders(outputs, tagColIdx, cxt); - Function1<Iterator<WindowedValue<InputT>>, Iterator<Tuple2<Integer, WindowedValue<Object>>>> - doFnMapper = factory.create((tag, v) -> tuple(tags.get(tag), (WindowedValue<Object>) v)); + DoFnPartitionIteratorFactory<InputT, ?, Tuple2<Integer, WindowedValue<Object>>> doFnMapper = + DoFnPartitionIteratorFactory.multiOutput( + cxt.getCurrentTransform(), + cxt.getOptionsSupplier(), + input, + sideInputReader, + tagColIdx); // FIXME What's the strategy to unpersist Datasets / RDDs? @@ -169,18 +149,13 @@ class ParDoTranslatorBatch<InputT, OutputT> allTagsRDD.persist(); // divide into separate output datasets per tag - for (Entry<TupleTag<?>, Integer> e : tags.entrySet()) { - TupleTag<Object> key = (TupleTag<Object>) e.getKey(); - Integer id = e.getValue(); - + for (TupleTag<?> tag : outputs.keySet()) { + int colIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown tag"); RDD<WindowedValue<Object>> rddByTag = - allTagsRDD - .filter(fun1(t -> t._1.equals(id))) - .map(fun1(Tuple2::_2), WINDOWED_VALUE_CTAG); - + allTagsRDD.flatMap(selectByColumnIdx(colIdx), WINDOWED_VALUE_CTAG); cxt.putDataset( - cxt.getOutput(key), - cxt.getSparkSession().createDataset(rddByTag, encoders.get(id)), + cxt.getOutput((TupleTag) tag), + cxt.getSparkSession().createDataset(rddByTag, encoders.get(colIdx)), false); } } else { @@ -190,40 +165,51 @@ class ParDoTranslatorBatch<InputT, OutputT> allTagsDS.persist(storageLevel); // divide into separate output datasets per tag - for (Entry<TupleTag<?>, Integer> e : tags.entrySet()) { - TupleTag<Object> key = (TupleTag<Object>) e.getKey(); - Integer id = e.getValue(); - + for (TupleTag<?> tag : outputs.keySet()) { + int colIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown tag"); // Resolve specific column matching the tuple tag (by id) TypedColumn<Tuple2<Integer, WindowedValue<Object>>, WindowedValue<Object>> col = - (TypedColumn) col(id.toString()).as(encoders.get(id)); + (TypedColumn) col(Integer.toString(colIdx)).as(encoders.get(colIdx)); - cxt.putDataset(cxt.getOutput(key), allTagsDS.filter(col.isNotNull()).select(col), false); + cxt.putDataset( + cxt.getOutput((TupleTag) tag), allTagsDS.filter(col.isNotNull()).select(col), false); } } } else { - PCollection<OutputT> output = cxt.getOutput(mainOutputTag); + PCollection<OutputT> output = cxt.getOutput(transform.getMainOutputTag()); + DoFnPartitionIteratorFactory<InputT, ?, WindowedValue<OutputT>> doFnMapper = + DoFnPartitionIteratorFactory.singleOutput( + cxt.getCurrentTransform(), cxt.getOptionsSupplier(), input, sideInputReader); + Dataset<WindowedValue<OutputT>> mainDS = - inputDs.mapPartitions( - factory.create((tag, value) -> (WindowedValue<OutputT>) value), - cxt.windowedEncoder(output.getCoder())); + inputDs.mapPartitions(doFnMapper, cxt.windowedEncoder(output.getCoder())); cxt.putDataset(output, mainDS); } } - private List<Encoder<WindowedValue<Object>>> createEncoders( - Map<TupleTag<?>, PCollection<?>> outputs, Iterable<TupleTag<?>> columns, Context ctx) { - return Streams.stream(columns) - .map(tag -> ctx.windowedEncoder(getCoder(outputs.get(tag), tag))) - .collect(toList()); + static <T> Fun1<Tuple2<Integer, T>, TraversableOnce<T>> selectByColumnIdx(int idx) { + return t -> idx == t._1 ? listOf(t._2) : emptyList(); } - private Coder<Object> getCoder(@Nullable PCollection<?> pc, TupleTag<?> tag) { - if (pc == null) { - throw new NullPointerException("No PCollection for tag " + tag); + private Map<String, Integer> tagsColumnIndex(Collection<TupleTag<?>> tags) { + Map<String, Integer> index = Maps.newHashMapWithExpectedSize(tags.size()); + for (TupleTag<?> tag : tags) { + index.put(tag.getId(), index.size()); } - return (Coder<Object>) pc.getCoder(); + return index; + } + + /** List of encoders matching the order of tagIds. */ + private List<Encoder<WindowedValue<Object>>> createEncoders( + Map<TupleTag<?>, PCollection<?>> outputs, Map<String, Integer> tagIdColIdx, Context ctx) { + ArrayList<Encoder<WindowedValue<Object>>> encoders = new ArrayList<>(outputs.size()); + for (Entry<TupleTag<?>, PCollection<?>> e : outputs.entrySet()) { + Encoder<WindowedValue<Object>> enc = ctx.windowedEncoder((Coder) e.getValue().getCoder()); + int colIdx = checkStateNotNull(tagIdColIdx.get(e.getKey().getId())); + encoders.add(colIdx, enc); + } + return encoders; } private <T> SideInputReader createSideInputReader( @@ -242,13 +228,4 @@ class ParDoTranslatorBatch<InputT, OutputT> } return SparkSideInputReader.create(broadcasts); } - - private static <T> Collection<Entry<T, Integer>> zipwithIndex(Collection<T> col) { - ArrayList<Entry<T, Integer>> zipped = new ArrayList<>(col.size()); - int i = 0; - for (T t : col) { - zipped.add(new SimpleImmutableEntry<>(t, i++)); - } - return zipped; - } }