This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 2cd38984a35 [Spark Dataset runner] Reduce binary size of Java 
serialized task related for ParDo translation (#24543)
2cd38984a35 is described below

commit 2cd38984a354c76ada42cb51f13a398babaf1b76
Author: Moritz Mack <mm...@talend.com>
AuthorDate: Mon Dec 19 14:13:08 2022 +0100

    [Spark Dataset runner] Reduce binary size of Java serialized task related 
for ParDo translation (#24543)
    
    * [Spark Dataset runner] Reduce binary size of Java serialized broadcasted 
task related for ParDo translation (related to #23845)
---
 .../batch/DoFnMapPartitionsFactory.java            | 204 ----------------
 .../batch/DoFnPartitionIteratorFactory.java        | 272 +++++++++++++++++++++
 .../translation/batch/ParDoTranslatorBatch.java    | 125 ++++------
 3 files changed, 323 insertions(+), 278 deletions(-)

diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java
deleted file mode 100644
index a53e5ca3a79..00000000000
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java
+++ /dev/null
@@ -1,204 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-
-import static java.util.stream.Collectors.toCollection;
-import static java.util.stream.Collectors.toMap;
-import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.scalaIterator;
-import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists.newArrayListWithCapacity;
-
-import java.io.Serializable;
-import java.util.ArrayDeque;
-import java.util.Deque;
-import java.util.List;
-import java.util.Map;
-import java.util.function.Supplier;
-import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.DoFnRunners;
-import org.apache.beam.runners.core.DoFnRunners.OutputManager;
-import org.apache.beam.runners.core.SideInputReader;
-import 
org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.CachedSideInputReader;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun2;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
-import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator;
-import org.apache.spark.api.java.function.MapPartitionsFunction;
-import org.checkerframework.checker.nullness.qual.NonNull;
-import scala.collection.Iterator;
-
-/**
- * Encapsulates a {@link DoFn} inside a Spark {@link
- * org.apache.spark.api.java.function.MapPartitionsFunction}.
- */
-class DoFnMapPartitionsFactory<InT, OutT> implements Serializable {
-  private final String stepName;
-
-  private final DoFn<InT, OutT> doFn;
-  private final DoFnSchemaInformation doFnSchema;
-  private final Supplier<PipelineOptions> options;
-
-  private final Coder<InT> coder;
-  private final WindowingStrategy<?, ?> windowingStrategy;
-  private final TupleTag<OutT> mainOutput;
-  private final List<TupleTag<?>> additionalOutputs;
-  private final Map<TupleTag<?>, Coder<?>> outputCoders;
-
-  private final Map<String, PCollectionView<?>> sideInputs;
-  private final SideInputReader sideInputReader;
-
-  DoFnMapPartitionsFactory(
-      String stepName,
-      DoFn<InT, OutT> doFn,
-      DoFnSchemaInformation doFnSchema,
-      Supplier<PipelineOptions> options,
-      PCollection<InT> input,
-      TupleTag<OutT> mainOutput,
-      Map<TupleTag<?>, PCollection<?>> outputs,
-      Map<String, PCollectionView<?>> sideInputs,
-      SideInputReader sideInputReader) {
-    this.stepName = stepName;
-    this.doFn = doFn;
-    this.doFnSchema = doFnSchema;
-    this.options = options;
-    this.coder = input.getCoder();
-    this.windowingStrategy = input.getWindowingStrategy();
-    this.mainOutput = mainOutput;
-    this.additionalOutputs = additionalOutputs(outputs, mainOutput);
-    this.outputCoders = outputCoders(outputs);
-    this.sideInputs = sideInputs;
-    this.sideInputReader = sideInputReader;
-  }
-
-  /** Create the {@link MapPartitionsFunction} using the provided output 
function. */
-  <OutputT extends @NonNull Object> Fun1<Iterator<WindowedValue<InT>>, 
Iterator<OutputT>> create(
-      Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn) {
-    return it ->
-        it.hasNext()
-            ? scalaIterator(new DoFnPartitionIt<>(outputFn, it))
-            : (Iterator<OutputT>) Iterator.empty();
-  }
-
-  // FIXME Add support for TimerInternals.TimerData
-  /**
-   * Partition iterator that lazily processes each element from the (input) 
iterator on demand
-   * producing zero, one or more output elements as output (via an internal 
buffer).
-   *
-   * <p>When initializing the iterator for a partition {@code setup} followed 
by {@code startBundle}
-   * is called.
-   */
-  private class DoFnPartitionIt<FnInT extends InT, OutputT> extends 
AbstractIterator<OutputT> {
-    private final Deque<OutputT> buffer;
-    private final DoFnRunner<InT, OutT> doFnRunner;
-    private final Iterator<WindowedValue<FnInT>> partitionIt;
-
-    private boolean isBundleFinished;
-
-    DoFnPartitionIt(
-        Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn,
-        Iterator<WindowedValue<FnInT>> partitionIt) {
-      this.buffer = new ArrayDeque<>();
-      this.doFnRunner = metricsRunner(simpleRunner(outputFn, buffer));
-      this.partitionIt = partitionIt;
-      // Before starting to iterate over the partition, invoke setup and then 
startBundle
-      DoFnInvokers.tryInvokeSetupFor(doFn, options.get());
-      try {
-        doFnRunner.startBundle();
-      } catch (RuntimeException re) {
-        DoFnInvokers.invokerFor(doFn).invokeTeardown();
-        throw re;
-      }
-    }
-
-    @Override
-    protected OutputT computeNext() {
-      try {
-        while (true) {
-          if (!buffer.isEmpty()) {
-            return buffer.remove();
-          }
-          if (partitionIt.hasNext()) {
-            // grab the next element and process it.
-            doFnRunner.processElement((WindowedValue<InT>) partitionIt.next());
-          } else {
-            if (!isBundleFinished) {
-              isBundleFinished = true;
-              doFnRunner.finishBundle();
-              continue; // finishBundle can produce more output
-            }
-            DoFnInvokers.invokerFor(doFn).invokeTeardown();
-            return endOfData();
-          }
-        }
-      } catch (RuntimeException re) {
-        DoFnInvokers.invokerFor(doFn).invokeTeardown();
-        throw re;
-      }
-    }
-  }
-
-  private <OutputT> DoFnRunner<InT, OutT> simpleRunner(
-      Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn, Deque<OutputT> 
buffer) {
-    OutputManager outputManager =
-        new OutputManager() {
-          @Override
-          public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
-            buffer.add(outputFn.apply(tag, output));
-          }
-        };
-    return DoFnRunners.simpleRunner(
-        options.get(),
-        doFn,
-        CachedSideInputReader.of(sideInputReader, sideInputs.values()),
-        outputManager,
-        mainOutput,
-        additionalOutputs,
-        new NoOpStepContext(),
-        coder,
-        outputCoders,
-        windowingStrategy,
-        doFnSchema,
-        sideInputs);
-  }
-
-  private DoFnRunner<InT, OutT> metricsRunner(DoFnRunner<InT, OutT> runner) {
-    return new DoFnRunnerWithMetrics<>(stepName, runner, 
MetricsAccumulator.getInstance());
-  }
-
-  private static List<TupleTag<?>> additionalOutputs(
-      Map<TupleTag<?>, PCollection<?>> outputs, TupleTag<?> mainOutput) {
-    return outputs.keySet().stream()
-        .filter(t -> !t.equals(mainOutput))
-        .collect(toCollection(() -> newArrayListWithCapacity(outputs.size() - 
1)));
-  }
-
-  private static Map<TupleTag<?>, Coder<?>> outputCoders(Map<TupleTag<?>, 
PCollection<?>> outputs) {
-    return outputs.entrySet().stream()
-        .collect(toMap(Map.Entry::getKey, e -> e.getValue().getCoder()));
-  }
-}
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java
new file mode 100644
index 00000000000..c760efd229c
--- /dev/null
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java
@@ -0,0 +1,272 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.scalaIterator;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+
+import java.io.Serializable;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.DoFnRunners;
+import org.apache.beam.runners.core.SideInputReader;
+import org.apache.beam.runners.core.construction.ParDoTranslation;
+import 
org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
+import 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.CachedSideInputReader;
+import 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
+import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import scala.Function1;
+import scala.Tuple2;
+import scala.collection.Iterator;
+
+/**
+ * Abstract factory to create a {@link DoFnPartitionIt DoFn partition 
iterator} using a customizable
+ * {@link DoFnRunners.OutputManager}.
+ */
+abstract class DoFnPartitionIteratorFactory<InT, FnOutT, OutT extends @NonNull 
Object>
+    implements Function1<Iterator<WindowedValue<InT>>, Iterator<OutT>>, 
Serializable {
+  private final String stepName;
+  private final DoFn<InT, FnOutT> doFn;
+  private final DoFnSchemaInformation doFnSchema;
+  private final Supplier<PipelineOptions> options;
+  private final Coder<InT> coder;
+  private final WindowingStrategy<?, ?> windowingStrategy;
+  private final TupleTag<FnOutT> mainOutput;
+  private final List<TupleTag<?>> additionalOutputs;
+  private final Map<TupleTag<?>, Coder<?>> outputCoders;
+  private final Map<String, PCollectionView<?>> sideInputs;
+  private final SideInputReader sideInputReader;
+
+  private DoFnPartitionIteratorFactory(
+      AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, 
FnOutT>> appliedPT,
+      Supplier<PipelineOptions> options,
+      PCollection<InT> input,
+      SideInputReader sideInputReader) {
+    this.stepName = appliedPT.getFullName();
+    this.doFn = appliedPT.getTransform().getFn();
+    this.doFnSchema = ParDoTranslation.getSchemaInformation(appliedPT);
+    this.options = options;
+    this.coder = input.getCoder();
+    this.windowingStrategy = input.getWindowingStrategy();
+    this.mainOutput = appliedPT.getTransform().getMainOutputTag();
+    this.additionalOutputs = additionalOutputs(appliedPT.getTransform());
+    this.outputCoders = outputCoders(appliedPT.getOutputs());
+    this.sideInputs = appliedPT.getTransform().getSideInputs();
+    this.sideInputReader = sideInputReader;
+  }
+
+  /**
+   * {@link DoFnPartitionIteratorFactory} emitting a single output of type 
{@link WindowedValue} of
+   * {@link OutT}.
+   */
+  static <InT, OutT> DoFnPartitionIteratorFactory<InT, ?, WindowedValue<OutT>> 
singleOutput(
+      AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, OutT>> 
appliedPT,
+      Supplier<PipelineOptions> options,
+      PCollection<InT> input,
+      SideInputReader sideInputReader) {
+    return new SingleOut<>(appliedPT, options, input, sideInputReader);
+  }
+
+  /**
+   * {@link DoFnPartitionIteratorFactory} emitting multiple outputs encoded as 
tuple of column index
+   * and {@link WindowedValue} of {@link OutT}, where column index corresponds 
to the index of a
+   * {@link TupleTag#getId()} in {@code tagColIdx}.
+   */
+  static <InT, FnOutT, OutT>
+      DoFnPartitionIteratorFactory<InT, ?, Tuple2<Integer, 
WindowedValue<OutT>>> multiOutput(
+          AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, 
FnOutT>> appliedPT,
+          Supplier<PipelineOptions> options,
+          PCollection<InT> input,
+          SideInputReader sideInputReader,
+          Map<String, Integer> tagColIdx) {
+    return new MultiOut<>(appliedPT, options, input, sideInputReader, 
tagColIdx);
+  }
+
+  @Override
+  public Iterator<OutT> apply(Iterator<WindowedValue<InT>> it) {
+    return it.hasNext()
+        ? scalaIterator(new DoFnPartitionIt(it))
+        : (Iterator<OutT>) Iterator.empty();
+  }
+
+  /** Output manager emitting outputs of type {@link OutT} to the buffer. */
+  abstract DoFnRunners.OutputManager outputManager(Deque<OutT> buffer);
+
+  /**
+   * {@link DoFnPartitionIteratorFactory} emitting a single output of type 
{@link WindowedValue} of
+   * {@link OutT}.
+   */
+  private static class SingleOut<InT, OutT>
+      extends DoFnPartitionIteratorFactory<InT, OutT, WindowedValue<OutT>> {
+    private SingleOut(
+        AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, 
OutT>> appliedPT,
+        Supplier<PipelineOptions> options,
+        PCollection<InT> input,
+        SideInputReader sideInputReader) {
+      super(appliedPT, options, input, sideInputReader);
+    }
+
+    @Override
+    DoFnRunners.OutputManager outputManager(Deque<WindowedValue<OutT>> buffer) 
{
+      return new DoFnRunners.OutputManager() {
+        @Override
+        public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
+          buffer.add((WindowedValue<OutT>) output);
+        }
+      };
+    }
+  }
+
+  /**
+   * {@link DoFnPartitionIteratorFactory} emitting multiple outputs encoded as 
tuple of column index
+   * and {@link WindowedValue} of {@link OutT}, where column index corresponds 
to the index of a
+   * {@link TupleTag#getId()} in {@link #tagColIdx}.
+   */
+  private static class MultiOut<InT, FnOutT, OutT>
+      extends DoFnPartitionIteratorFactory<InT, FnOutT, Tuple2<Integer, 
WindowedValue<OutT>>> {
+    private final Map<String, Integer> tagColIdx;
+
+    public MultiOut(
+        AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, 
FnOutT>> appliedPT,
+        Supplier<PipelineOptions> options,
+        PCollection<InT> input,
+        SideInputReader sideInputReader,
+        Map<String, Integer> tagColIdx) {
+      super(appliedPT, options, input, sideInputReader);
+      this.tagColIdx = tagColIdx;
+    }
+
+    @Override
+    DoFnRunners.OutputManager outputManager(Deque<Tuple2<Integer, 
WindowedValue<OutT>>> buffer) {
+      return new DoFnRunners.OutputManager() {
+        @Override
+        public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
+          Integer columnIdx = checkStateNotNull(tagColIdx.get(tag.getId()), 
"Unknown tag %s", tag);
+          buffer.add(tuple(columnIdx, (WindowedValue<OutT>) output));
+        }
+      };
+    }
+  }
+
+  // FIXME Add support for TimerInternals.TimerData
+  /**
+   * Partition iterator that lazily processes each element from the (input) 
iterator on demand
+   * producing zero, one or more output elements as output (via an internal 
buffer).
+   *
+   * <p>When initializing the iterator for a partition {@code setup} followed 
by {@code startBundle}
+   * is called.
+   */
+  private class DoFnPartitionIt extends AbstractIterator<OutT> {
+    private final Deque<OutT> buffer = new ArrayDeque<>();
+    private final DoFnRunner<InT, ?> doFnRunner = 
metricsRunner(simpleRunner(buffer));
+    private final Iterator<WindowedValue<InT>> partitionIt;
+    private boolean isBundleFinished;
+
+    private DoFnPartitionIt(Iterator<WindowedValue<InT>> partitionIt) {
+      this.partitionIt = partitionIt;
+      // Before starting to iterate over the partition, invoke setup and then 
startBundle
+      DoFnInvokers.tryInvokeSetupFor(doFn, options.get());
+      try {
+        doFnRunner.startBundle();
+      } catch (RuntimeException re) {
+        DoFnInvokers.invokerFor(doFn).invokeTeardown();
+        throw re;
+      }
+    }
+
+    @Override
+    protected OutT computeNext() {
+      try {
+        while (true) {
+          if (!buffer.isEmpty()) {
+            return buffer.remove();
+          }
+          if (partitionIt.hasNext()) {
+            // grab the next element and process it.
+            doFnRunner.processElement(partitionIt.next());
+          } else {
+            if (!isBundleFinished) {
+              isBundleFinished = true;
+              doFnRunner.finishBundle();
+              continue; // finishBundle can produce more output
+            }
+            DoFnInvokers.invokerFor(doFn).invokeTeardown();
+            return endOfData();
+          }
+        }
+      } catch (RuntimeException re) {
+        DoFnInvokers.invokerFor(doFn).invokeTeardown();
+        throw re;
+      }
+    }
+  }
+
+  private DoFnRunner<InT, FnOutT> simpleRunner(Deque<OutT> buffer) {
+    return DoFnRunners.simpleRunner(
+        options.get(),
+        (DoFn<InT, FnOutT>) doFn,
+        CachedSideInputReader.of(sideInputReader, sideInputs.values()),
+        outputManager(buffer),
+        mainOutput,
+        additionalOutputs,
+        new NoOpStepContext(),
+        coder,
+        outputCoders,
+        windowingStrategy,
+        doFnSchema,
+        sideInputs);
+  }
+
+  private DoFnRunner<InT, FnOutT> metricsRunner(DoFnRunner<InT, FnOutT> 
runner) {
+    return new DoFnRunnerWithMetrics<>(stepName, runner, 
MetricsAccumulator.getInstance());
+  }
+
+  private static Map<TupleTag<?>, Coder<?>> outputCoders(Map<TupleTag<?>, 
PCollection<?>> outputs) {
+    Map<TupleTag<?>, Coder<?>> coders = 
Maps.newHashMapWithExpectedSize(outputs.size());
+    for (Map.Entry<TupleTag<?>, PCollection<?>> e : outputs.entrySet()) {
+      coders.put(e.getKey(), e.getValue().getCoder());
+    }
+    return coders;
+  }
+
+  private static List<TupleTag<?>> additionalOutputs(MultiOutput<?, ?> 
transform) {
+    List<TupleTag<?>> tags = transform.getAdditionalOutputTags().getAll();
+    return tags.isEmpty() ? Collections.emptyList() : new ArrayList<>(tags);
+  }
+}
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
index d1e069c82d0..3083ff5101b 100644
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
@@ -17,33 +17,29 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 
-import static java.util.stream.Collectors.toList;
 import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.oneOfEncoder;
-import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
-import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf;
 import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
 import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
 import static org.apache.spark.sql.functions.col;
 import static org.apache.spark.storage.StorageLevel.MEMORY_ONLY;
 
 import java.io.IOException;
-import java.util.AbstractMap.SimpleImmutableEntry;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
-import javax.annotation.Nullable;
 import org.apache.beam.runners.core.DoFnRunners;
 import org.apache.beam.runners.core.SideInputReader;
-import org.apache.beam.runners.core.construction.ParDoTranslation;
 import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SideInputValues;
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader;
+import 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
@@ -52,18 +48,15 @@ import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
-import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
-import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Encoder;
 import org.apache.spark.sql.TypedColumn;
 import org.apache.spark.storage.StorageLevel;
-import scala.Function1;
 import scala.Tuple2;
-import scala.collection.Iterator;
+import scala.collection.TraversableOnce;
 import scala.reflect.ClassTag;
 
 /**
@@ -115,40 +108,27 @@ class ParDoTranslatorBatch<InputT, OutputT>
   @Override
   public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context 
cxt)
       throws IOException {
-    String stepName = cxt.getCurrentTransform().getFullName();
-
-    TupleTag<OutputT> mainOutputTag = transform.getMainOutputTag();
-
-    DoFnSchemaInformation doFnSchema =
-        ParDoTranslation.getSchemaInformation(cxt.getCurrentTransform());
 
     PCollection<InputT> input = (PCollection<InputT>) cxt.getInput();
-    Map<String, PCollectionView<?>> sideInputs = transform.getSideInputs();
     Map<TupleTag<?>, PCollection<?>> outputs = cxt.getOutputs();
 
-    DoFnMapPartitionsFactory<InputT, OutputT> factory =
-        new DoFnMapPartitionsFactory<>(
-            stepName,
-            transform.getFn(),
-            doFnSchema,
-            cxt.getOptionsSupplier(),
-            input,
-            mainOutputTag,
-            outputs,
-            sideInputs,
-            createSideInputReader(sideInputs.values(), cxt));
-
     Dataset<WindowedValue<InputT>> inputDs = cxt.getDataset(input);
+    SideInputReader sideInputReader =
+        createSideInputReader(transform.getSideInputs().values(), cxt);
+
     if (outputs.size() > 1) {
       // In case of multiple outputs / tags, map each tag to a column by index.
       // At the end split the result into multiple datasets selecting one 
column each.
-      Map<TupleTag<?>, Integer> tags = 
ImmutableMap.copyOf(zipwithIndex(outputs.keySet()));
-
-      List<Encoder<WindowedValue<Object>>> encoders =
-          createEncoders(outputs, (Iterable<TupleTag<?>>) tags.keySet(), cxt);
+      Map<String, Integer> tagColIdx = 
tagsColumnIndex((Collection<TupleTag<?>>) outputs.keySet());
+      List<Encoder<WindowedValue<Object>>> encoders = createEncoders(outputs, 
tagColIdx, cxt);
 
-      Function1<Iterator<WindowedValue<InputT>>, Iterator<Tuple2<Integer, 
WindowedValue<Object>>>>
-          doFnMapper = factory.create((tag, v) -> tuple(tags.get(tag), 
(WindowedValue<Object>) v));
+      DoFnPartitionIteratorFactory<InputT, ?, Tuple2<Integer, 
WindowedValue<Object>>> doFnMapper =
+          DoFnPartitionIteratorFactory.multiOutput(
+              cxt.getCurrentTransform(),
+              cxt.getOptionsSupplier(),
+              input,
+              sideInputReader,
+              tagColIdx);
 
       // FIXME What's the strategy to unpersist Datasets / RDDs?
 
@@ -169,18 +149,13 @@ class ParDoTranslatorBatch<InputT, OutputT>
         allTagsRDD.persist();
 
         // divide into separate output datasets per tag
-        for (Entry<TupleTag<?>, Integer> e : tags.entrySet()) {
-          TupleTag<Object> key = (TupleTag<Object>) e.getKey();
-          Integer id = e.getValue();
-
+        for (TupleTag<?> tag : outputs.keySet()) {
+          int colIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown 
tag");
           RDD<WindowedValue<Object>> rddByTag =
-              allTagsRDD
-                  .filter(fun1(t -> t._1.equals(id)))
-                  .map(fun1(Tuple2::_2), WINDOWED_VALUE_CTAG);
-
+              allTagsRDD.flatMap(selectByColumnIdx(colIdx), 
WINDOWED_VALUE_CTAG);
           cxt.putDataset(
-              cxt.getOutput(key),
-              cxt.getSparkSession().createDataset(rddByTag, encoders.get(id)),
+              cxt.getOutput((TupleTag) tag),
+              cxt.getSparkSession().createDataset(rddByTag, 
encoders.get(colIdx)),
               false);
         }
       } else {
@@ -190,40 +165,51 @@ class ParDoTranslatorBatch<InputT, OutputT>
         allTagsDS.persist(storageLevel);
 
         // divide into separate output datasets per tag
-        for (Entry<TupleTag<?>, Integer> e : tags.entrySet()) {
-          TupleTag<Object> key = (TupleTag<Object>) e.getKey();
-          Integer id = e.getValue();
-
+        for (TupleTag<?> tag : outputs.keySet()) {
+          int colIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown 
tag");
           // Resolve specific column matching the tuple tag (by id)
           TypedColumn<Tuple2<Integer, WindowedValue<Object>>, 
WindowedValue<Object>> col =
-              (TypedColumn) col(id.toString()).as(encoders.get(id));
+              (TypedColumn) 
col(Integer.toString(colIdx)).as(encoders.get(colIdx));
 
-          cxt.putDataset(cxt.getOutput(key), 
allTagsDS.filter(col.isNotNull()).select(col), false);
+          cxt.putDataset(
+              cxt.getOutput((TupleTag) tag), 
allTagsDS.filter(col.isNotNull()).select(col), false);
         }
       }
     } else {
-      PCollection<OutputT> output = cxt.getOutput(mainOutputTag);
+      PCollection<OutputT> output = 
cxt.getOutput(transform.getMainOutputTag());
+      DoFnPartitionIteratorFactory<InputT, ?, WindowedValue<OutputT>> 
doFnMapper =
+          DoFnPartitionIteratorFactory.singleOutput(
+              cxt.getCurrentTransform(), cxt.getOptionsSupplier(), input, 
sideInputReader);
+
       Dataset<WindowedValue<OutputT>> mainDS =
-          inputDs.mapPartitions(
-              factory.create((tag, value) -> (WindowedValue<OutputT>) value),
-              cxt.windowedEncoder(output.getCoder()));
+          inputDs.mapPartitions(doFnMapper, 
cxt.windowedEncoder(output.getCoder()));
 
       cxt.putDataset(output, mainDS);
     }
   }
 
-  private List<Encoder<WindowedValue<Object>>> createEncoders(
-      Map<TupleTag<?>, PCollection<?>> outputs, Iterable<TupleTag<?>> columns, 
Context ctx) {
-    return Streams.stream(columns)
-        .map(tag -> ctx.windowedEncoder(getCoder(outputs.get(tag), tag)))
-        .collect(toList());
+  static <T> Fun1<Tuple2<Integer, T>, TraversableOnce<T>> 
selectByColumnIdx(int idx) {
+    return t -> idx == t._1 ? listOf(t._2) : emptyList();
   }
 
-  private Coder<Object> getCoder(@Nullable PCollection<?> pc, TupleTag<?> tag) 
{
-    if (pc == null) {
-      throw new NullPointerException("No PCollection for tag " + tag);
+  private Map<String, Integer> tagsColumnIndex(Collection<TupleTag<?>> tags) {
+    Map<String, Integer> index = Maps.newHashMapWithExpectedSize(tags.size());
+    for (TupleTag<?> tag : tags) {
+      index.put(tag.getId(), index.size());
     }
-    return (Coder<Object>) pc.getCoder();
+    return index;
+  }
+
+  /** List of encoders matching the order of tagIds. */
+  private List<Encoder<WindowedValue<Object>>> createEncoders(
+      Map<TupleTag<?>, PCollection<?>> outputs, Map<String, Integer> 
tagIdColIdx, Context ctx) {
+    ArrayList<Encoder<WindowedValue<Object>>> encoders = new 
ArrayList<>(outputs.size());
+    for (Entry<TupleTag<?>, PCollection<?>> e : outputs.entrySet()) {
+      Encoder<WindowedValue<Object>> enc = ctx.windowedEncoder((Coder) 
e.getValue().getCoder());
+      int colIdx = checkStateNotNull(tagIdColIdx.get(e.getKey().getId()));
+      encoders.add(colIdx, enc);
+    }
+    return encoders;
   }
 
   private <T> SideInputReader createSideInputReader(
@@ -242,13 +228,4 @@ class ParDoTranslatorBatch<InputT, OutputT>
     }
     return SparkSideInputReader.create(broadcasts);
   }
-
-  private static <T> Collection<Entry<T, Integer>> zipwithIndex(Collection<T> 
col) {
-    ArrayList<Entry<T, Integer>> zipped = new ArrayList<>(col.size());
-    int i = 0;
-    for (T t : col) {
-      zipped.add(new SimpleImmutableEntry<>(t, i++));
-    }
-    return zipped;
-  }
 }

Reply via email to