This is an automated email from the ASF dual-hosted git repository.

iemejia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c45a50f  [BEAM-6935] Spark portable runner: implement side inputs
     new 16b58ad  Merge pull request #8220: [BEAM-6935] Spark portable runner: 
implement side inputs
c45a50f is described below

commit c45a50fb092171ce4fa5f8b0758a584911d4f50d
Author: Kyle Weaver <kcwea...@google.com>
AuthorDate: Thu Mar 28 19:16:51 2019 -0700

    [BEAM-6935] Spark portable runner: implement side inputs
---
 .../functions/FlinkExecutableStageFunction.java    |  4 +-
 .../translation/BatchSideInputHandlerFactory.java} | 35 ++++++-------
 .../BatchSideInputHandlerFactoryTest.java}         | 40 +++++++--------
 .../runners/spark/translation/BoundedDataset.java  |  9 ++++
 .../SparkBatchPortablePipelineTranslator.java      | 47 +++++++++++++++--
 .../translation/SparkExecutableStageFunction.java  | 59 +++++++++++++++++++---
 .../runners/spark/SparkPortableExecutionTest.java  | 36 +++++++++----
 .../SparkExecutableStageFunctionTest.java          | 15 +++---
 8 files changed, 181 insertions(+), 64 deletions(-)

diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
index e7dafa8..c02aa65 100644
--- 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
+++ 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
@@ -54,6 +54,7 @@ import 
org.apache.beam.runners.fnexecution.control.StageBundleFactory;
 import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
 import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
 import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
+import 
org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.io.FileSystems;
@@ -167,7 +168,8 @@ public class FlinkExecutableStageFunction<InputT> extends 
AbstractRichFunction
       RuntimeContext runtimeContext) {
     final StateRequestHandler sideInputHandler;
     StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory =
-        FlinkBatchSideInputHandlerFactory.forStage(executableStage, 
runtimeContext);
+        BatchSideInputHandlerFactory.forStage(
+            executableStage, runtimeContext::getBroadcastVariable);
     try {
       sideInputHandler =
           StateRequestHandlers.forSideInputHandlerFactory(
diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java
 
b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactory.java
similarity index 87%
rename from 
runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java
rename to 
runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactory.java
index 798c32b..5460898 100644
--- 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java
+++ 
b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactory.java
@@ -15,7 +15,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.beam.runners.flink.translation.functions;
+package org.apache.beam.runners.fnexecution.translation;
 
 import static 
org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
 
@@ -43,24 +43,25 @@ import org.apache.beam.sdk.values.KV;
 import 
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMultimap;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Multimap;
-import org.apache.flink.api.common.functions.RuntimeContext;
 
-/**
- * {@link StateRequestHandler} that uses a Flink {@link RuntimeContext} to 
access Flink broadcast
- * variable that represent side inputs.
- */
-class FlinkBatchSideInputHandlerFactory implements SideInputHandlerFactory {
+/** {@link StateRequestHandler} that uses a {@link SideInputGetter} to access 
side inputs. */
+public class BatchSideInputHandlerFactory implements SideInputHandlerFactory {
 
   // Map from side input id to global PCollection id.
   private final Map<SideInputId, PCollectionNode> sideInputToCollection;
-  private final RuntimeContext runtimeContext;
+  private final SideInputGetter sideInputGetter;
+
+  /** Returns the value for the side input with the given PCollection id from 
the runner. */
+  public interface SideInputGetter {
+    <T> List<T> getSideInput(String pCollectionId);
+  }
 
   /**
    * Creates a new state handler for the given stage. Note that this requires 
a traversal of the
    * stage itself, so this should only be called once per stage rather than 
once per bundle.
    */
-  static FlinkBatchSideInputHandlerFactory forStage(
-      ExecutableStage stage, RuntimeContext runtimeContext) {
+  public static BatchSideInputHandlerFactory forStage(
+      ExecutableStage stage, SideInputGetter sideInputGetter) {
     ImmutableMap.Builder<SideInputId, PCollectionNode> sideInputBuilder = 
ImmutableMap.builder();
     for (SideInputReference sideInput : stage.getSideInputs()) {
       sideInputBuilder.put(
@@ -70,13 +71,13 @@ class FlinkBatchSideInputHandlerFactory implements 
SideInputHandlerFactory {
               .build(),
           sideInput.collection());
     }
-    return new FlinkBatchSideInputHandlerFactory(sideInputBuilder.build(), 
runtimeContext);
+    return new BatchSideInputHandlerFactory(sideInputBuilder.build(), 
sideInputGetter);
   }
 
-  private FlinkBatchSideInputHandlerFactory(
-      Map<SideInputId, PCollectionNode> sideInputToCollection, RuntimeContext 
runtimeContext) {
+  private BatchSideInputHandlerFactory(
+      Map<SideInputId, PCollectionNode> sideInputToCollection, SideInputGetter 
sideInputGetter) {
     this.sideInputToCollection = sideInputToCollection;
-    this.runtimeContext = runtimeContext;
+    this.sideInputGetter = sideInputGetter;
   }
 
   @Override
@@ -96,7 +97,7 @@ class FlinkBatchSideInputHandlerFactory implements 
SideInputHandlerFactory {
       @SuppressWarnings("unchecked") // T == V
       Coder<V> outputCoder = (Coder<V>) elementCoder;
       return forIterableSideInput(
-          runtimeContext.getBroadcastVariable(collectionNode.getId()), 
outputCoder, windowCoder);
+          sideInputGetter.getSideInput(collectionNode.getId()), outputCoder, 
windowCoder);
     } else if 
(PTransformTranslation.MULTIMAP_SIDE_INPUT.equals(accessPattern.getUrn())
         || 
Materializations.MULTIMAP_MATERIALIZATION_URN.equals(accessPattern.getUrn())) {
       // TODO: Remove non standard URN.
@@ -104,7 +105,7 @@ class FlinkBatchSideInputHandlerFactory implements 
SideInputHandlerFactory {
       @SuppressWarnings("unchecked") // T == KV<?, V>
       KvCoder<?, V> kvCoder = (KvCoder<?, V>) elementCoder;
       return forMultimapSideInput(
-          runtimeContext.getBroadcastVariable(collectionNode.getId()),
+          sideInputGetter.getSideInput(collectionNode.getId()),
           kvCoder.getKeyCoder(),
           kvCoder.getValueCoder(),
           windowCoder);
@@ -202,7 +203,7 @@ class FlinkBatchSideInputHandlerFactory implements 
SideInputHandlerFactory {
   @AutoValue
   abstract static class SideInputKey {
     static SideInputKey of(Object key, Object window) {
-      return new AutoValue_FlinkBatchSideInputHandlerFactory_SideInputKey(key, 
window);
+      return new AutoValue_BatchSideInputHandlerFactory_SideInputKey(key, 
window);
     }
 
     @Nullable
diff --git 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java
 
b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactoryTest.java
similarity index 89%
rename from 
runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java
rename to 
runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactoryTest.java
index 897289f..f664aa9 100644
--- 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java
+++ 
b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactoryTest.java
@@ -15,7 +15,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.beam.runners.flink.translation.functions;
+package org.apache.beam.runners.fnexecution.translation;
 
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.contains;
@@ -50,7 +50,6 @@ import 
org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCod
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
-import org.apache.flink.api.common.functions.RuntimeContext;
 import org.joda.time.DateTime;
 import org.joda.time.DateTimeZone;
 import org.joda.time.Instant;
@@ -63,9 +62,9 @@ import org.junit.runners.JUnit4;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
-/** Tests for {@link FlinkBatchSideInputHandlerFactory}. */
+/** Tests for {@link BatchSideInputHandlerFactory}. */
 @RunWith(JUnit4.class)
-public class FlinkBatchSideInputHandlerFactoryTest {
+public class BatchSideInputHandlerFactoryTest {
 
   private static final String TRANSFORM_ID = "transform-id";
   private static final String SIDE_INPUT_NAME = "side-input";
@@ -87,7 +86,7 @@ public class FlinkBatchSideInputHandlerFactoryTest {
 
   @Rule public ExpectedException thrown = ExpectedException.none();
 
-  @Mock private RuntimeContext context;
+  @Mock private BatchSideInputHandlerFactory.SideInputGetter context;
 
   @Before
   public void setUpMocks() {
@@ -97,8 +96,7 @@ public class FlinkBatchSideInputHandlerFactoryTest {
   @Test
   public void invalidSideInputThrowsException() {
     ExecutableStage stage = createExecutableStage(Collections.emptyList());
-    FlinkBatchSideInputHandlerFactory factory =
-        FlinkBatchSideInputHandlerFactory.forStage(stage, context);
+    BatchSideInputHandlerFactory factory = 
BatchSideInputHandlerFactory.forStage(stage, context);
     thrown.expect(instanceOf(IllegalArgumentException.class));
     factory.forSideInput(
         "transform-id",
@@ -110,8 +108,8 @@ public class FlinkBatchSideInputHandlerFactoryTest {
 
   @Test
   public void emptyResultForEmptyCollection() {
-    FlinkBatchSideInputHandlerFactory factory =
-        FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
+    BatchSideInputHandlerFactory factory =
+        BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
     SideInputHandler<Integer, GlobalWindow> handler =
         factory.forSideInput(
             TRANSFORM_ID,
@@ -127,12 +125,12 @@ public class FlinkBatchSideInputHandlerFactoryTest {
 
   @Test
   public void singleElementForCollection() {
-    when(context.getBroadcastVariable(COLLECTION_ID))
+    when(context.getSideInput(COLLECTION_ID))
         .thenReturn(
             Arrays.asList(WindowedValue.valueInGlobalWindow(KV.<Void, 
Integer>of(null, 3))));
 
-    FlinkBatchSideInputHandlerFactory factory =
-        FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
+    BatchSideInputHandlerFactory factory =
+        BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
     SideInputHandler<Integer, GlobalWindow> handler =
         factory.forSideInput(
             TRANSFORM_ID,
@@ -146,15 +144,15 @@ public class FlinkBatchSideInputHandlerFactoryTest {
 
   @Test
   public void groupsValuesByKey() {
-    when(context.getBroadcastVariable(COLLECTION_ID))
+    when(context.getSideInput(COLLECTION_ID))
         .thenReturn(
             Arrays.asList(
                 WindowedValue.valueInGlobalWindow(KV.of("foo", 2)),
                 WindowedValue.valueInGlobalWindow(KV.of("bar", 3)),
                 WindowedValue.valueInGlobalWindow(KV.of("foo", 5))));
 
-    FlinkBatchSideInputHandlerFactory factory =
-        FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
+    BatchSideInputHandlerFactory factory =
+        BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
     SideInputHandler<Integer, GlobalWindow> handler =
         factory.forSideInput(
             TRANSFORM_ID,
@@ -173,7 +171,7 @@ public class FlinkBatchSideInputHandlerFactoryTest {
     Instant instantC = new DateTime(2018, 1, 1, 1, 3, 
DateTimeZone.UTC).toInstant();
     IntervalWindow windowA = new IntervalWindow(instantA, instantB);
     IntervalWindow windowB = new IntervalWindow(instantB, instantC);
-    when(context.getBroadcastVariable(COLLECTION_ID))
+    when(context.getSideInput(COLLECTION_ID))
         .thenReturn(
             Arrays.asList(
                 WindowedValue.of(KV.of("foo", 1), instantA, windowA, 
PaneInfo.NO_FIRING),
@@ -183,8 +181,8 @@ public class FlinkBatchSideInputHandlerFactoryTest {
                 WindowedValue.of(KV.of("bar", 5), instantB, windowB, 
PaneInfo.NO_FIRING),
                 WindowedValue.of(KV.of("foo", 6), instantB, windowB, 
PaneInfo.NO_FIRING)));
 
-    FlinkBatchSideInputHandlerFactory factory =
-        FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
+    BatchSideInputHandlerFactory factory =
+        BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
     SideInputHandler<Integer, IntervalWindow> handler =
         factory.forSideInput(
             TRANSFORM_ID,
@@ -205,7 +203,7 @@ public class FlinkBatchSideInputHandlerFactoryTest {
     Instant instantC = new DateTime(2018, 1, 1, 1, 3, 
DateTimeZone.UTC).toInstant();
     IntervalWindow windowA = new IntervalWindow(instantA, instantB);
     IntervalWindow windowB = new IntervalWindow(instantB, instantC);
-    when(context.getBroadcastVariable(COLLECTION_ID))
+    when(context.getSideInput(COLLECTION_ID))
         .thenReturn(
             Arrays.asList(
                 WindowedValue.of(1, instantA, windowA, PaneInfo.NO_FIRING),
@@ -213,8 +211,8 @@ public class FlinkBatchSideInputHandlerFactoryTest {
                 WindowedValue.of(3, instantB, windowB, PaneInfo.NO_FIRING),
                 WindowedValue.of(4, instantB, windowB, PaneInfo.NO_FIRING)));
 
-    FlinkBatchSideInputHandlerFactory factory =
-        FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
+    BatchSideInputHandlerFactory factory =
+        BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context);
     SideInputHandler<Integer, IntervalWindow> handler =
         factory.forSideInput(
             TRANSFORM_ID,
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
index 1e620e7..c81c5f4 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
@@ -46,6 +46,7 @@ public class BoundedDataset<T> implements Dataset {
   private Iterable<WindowedValue<T>> windowedValues;
   private Coder<T> coder;
   private JavaRDD<WindowedValue<T>> rdd;
+  private List<byte[]> clientBytes;
 
   BoundedDataset(JavaRDD<WindowedValue<T>> rdd) {
     this.rdd = rdd;
@@ -69,6 +70,14 @@ public class BoundedDataset<T> implements Dataset {
     return rdd;
   }
 
+  List<byte[]> getBytes(WindowedValue.WindowedValueCoder<T> wvCoder) {
+    if (clientBytes == null) {
+      JavaRDDLike<byte[], ?> bytesRDD = 
rdd.map(CoderHelpers.toByteFunction(wvCoder));
+      clientBytes = bytesRDD.collect();
+    }
+    return clientBytes;
+  }
+
   Iterable<WindowedValue<T>> getValues(PCollection<T> pcollection) {
     if (windowedValues == null) {
       WindowFn<?, ?> windowFn = 
pcollection.getWindowingStrategy().getWindowFn();
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
index c65caa4..82557ae 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
@@ -22,10 +22,12 @@ import static 
org.apache.beam.runners.fnexecution.translation.PipelineTranslator
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import javax.annotation.Nullable;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
+import 
org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.runners.core.SystemReduceFn;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
@@ -54,6 +56,8 @@ import 
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.Partitioner;
 import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.broadcast.Broadcast;
+import scala.Tuple2;
 
 /** Translates a bounded portable pipeline into a Spark job. */
 public class SparkBatchPortablePipelineTranslator {
@@ -163,7 +167,7 @@ public class SparkBatchPortablePipelineTranslator {
     context.pushDataset(getOutputId(transformNode), new 
BoundedDataset<>(groupedByKeyAndWindow));
   }
 
-  private static <InputT, OutputT> void translateExecutableStage(
+  private static <InputT, OutputT, SideInputT> void translateExecutableStage(
       PTransformNode transformNode, RunnerApi.Pipeline pipeline, 
SparkTranslationContext context) {
 
     RunnerApi.ExecutableStagePayload stagePayload;
@@ -180,8 +184,22 @@ public class SparkBatchPortablePipelineTranslator {
     Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
     BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
 
-    SparkExecutableStageFunction<InputT> function =
-        new SparkExecutableStageFunction<>(stagePayload, context.jobInfo, 
outputMap);
+    ImmutableMap.Builder<String, Tuple2<Broadcast<List<byte[]>>, 
WindowedValueCoder<SideInputT>>>
+        broadcastVariablesBuilder = ImmutableMap.builder();
+    for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
+      RunnerApi.Components components = stagePayload.getComponents();
+      String collectionId =
+          components
+              .getTransformsOrThrow(sideInputId.getTransformId())
+              .getInputsOrThrow(sideInputId.getLocalName());
+      Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 =
+          broadcastSideInput(collectionId, components, context);
+      broadcastVariablesBuilder.put(collectionId, tuple2);
+    }
+
+    SparkExecutableStageFunction<InputT, SideInputT> function =
+        new SparkExecutableStageFunction<>(
+            stagePayload, context.jobInfo, outputMap, 
broadcastVariablesBuilder.build());
     JavaRDD<RawUnionValue> staged = inputRdd.mapPartitions(function);
 
     for (String outputId : outputs.values()) {
@@ -191,6 +209,29 @@ public class SparkBatchPortablePipelineTranslator {
     }
   }
 
+  /**
+   * Collect and serialize the data and then broadcast the result. *This can 
be expensive.*
+   *
+   * @return Spark broadcast variable and coder to decode its contents
+   */
+  private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<T>> 
broadcastSideInput(
+      String collectionId, RunnerApi.Components components, 
SparkTranslationContext context) {
+    PCollection collection = components.getPcollectionsOrThrow(collectionId);
+    @SuppressWarnings("unchecked")
+    BoundedDataset<T> dataset = (BoundedDataset<T>) 
context.popDataset(collectionId);
+    PCollectionNode collectionNode = PipelineNode.pCollection(collectionId, 
collection);
+    WindowedValueCoder<T> coder;
+    try {
+      coder =
+          (WindowedValueCoder<T>) 
WireCoders.instantiateRunnerWireCoder(collectionNode, components);
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+    List<byte[]> bytes = dataset.getBytes(coder);
+    Broadcast<List<byte[]>> broadcast = 
context.getSparkContext().broadcast(bytes);
+    return new Tuple2<>(broadcast, coder);
+  }
+
   @Nullable
   private static Partitioner getPartitioner(SparkTranslationContext context) {
     Long bundleSize =
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java
index 93250bc..e9ff511 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java
@@ -17,12 +17,15 @@
  */
 package org.apache.beam.runners.spark.translation;
 
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.EnumMap;
 import java.util.Iterator;
+import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.stream.Collectors;
 import 
org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
@@ -33,17 +36,23 @@ import 
org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
 import org.apache.beam.runners.fnexecution.control.DefaultJobBundleFactory;
 import org.apache.beam.runners.fnexecution.control.JobBundleFactory;
 import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
+import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
 import org.apache.beam.runners.fnexecution.control.RemoteBundle;
 import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
 import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
 import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
 import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
+import 
org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory;
+import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.transforms.join.RawUnionValue;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
 import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.broadcast.Broadcast;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import scala.Tuple2;
 
 /**
  * Spark function that passes its input through an SDK-executed {@link
@@ -54,7 +63,7 @@ import org.slf4j.LoggerFactory;
  * The resulting data set should be further processed by a {@link
  * SparkExecutableStageExtractionFunction}.
  */
-public class SparkExecutableStageFunction<InputT>
+public class SparkExecutableStageFunction<InputT, SideInputT>
     implements FlatMapFunction<Iterator<WindowedValue<InputT>>, RawUnionValue> 
{
 
   private static final Logger LOG = 
LoggerFactory.getLogger(SparkExecutableStageFunction.class);
@@ -62,21 +71,27 @@ public class SparkExecutableStageFunction<InputT>
   private final RunnerApi.ExecutableStagePayload stagePayload;
   private final Map<String, Integer> outputMap;
   private final JobBundleFactoryCreator jobBundleFactoryCreator;
+  // map from pCollection id to tuple of serialized bytes and coder to decode 
the bytes
+  private final Map<String, Tuple2<Broadcast<List<byte[]>>, 
WindowedValueCoder<SideInputT>>>
+      sideInputs;
 
   SparkExecutableStageFunction(
       RunnerApi.ExecutableStagePayload stagePayload,
       JobInfo jobInfo,
-      Map<String, Integer> outputMap) {
-    this(stagePayload, outputMap, () -> 
DefaultJobBundleFactory.create(jobInfo));
+      Map<String, Integer> outputMap,
+      Map<String, Tuple2<Broadcast<List<byte[]>>, 
WindowedValueCoder<SideInputT>>> sideInputs) {
+    this(stagePayload, outputMap, () -> 
DefaultJobBundleFactory.create(jobInfo), sideInputs);
   }
 
   SparkExecutableStageFunction(
       RunnerApi.ExecutableStagePayload stagePayload,
       Map<String, Integer> outputMap,
-      JobBundleFactoryCreator jobBundleFactoryCreator) {
+      JobBundleFactoryCreator jobBundleFactoryCreator,
+      Map<String, Tuple2<Broadcast<List<byte[]>>, 
WindowedValueCoder<SideInputT>>> sideInputs) {
     this.stagePayload = stagePayload;
     this.outputMap = outputMap;
     this.jobBundleFactoryCreator = jobBundleFactoryCreator;
+    this.sideInputs = sideInputs;
   }
 
   @Override
@@ -86,10 +101,8 @@ public class SparkExecutableStageFunction<InputT>
     try (StageBundleFactory stageBundleFactory = 
jobBundleFactory.forStage(executableStage)) {
       ConcurrentLinkedQueue<RawUnionValue> collector = new 
ConcurrentLinkedQueue<>();
       ReceiverFactory receiverFactory = new ReceiverFactory(collector, 
outputMap);
-      EnumMap<TypeCase, StateRequestHandler> handlers = new 
EnumMap<>(StateKey.TypeCase.class);
-      // TODO add state request handlers
       StateRequestHandler stateRequestHandler =
-          StateRequestHandlers.delegateBasedUponType(handlers);
+          getStateRequestHandler(executableStage, 
stageBundleFactory.getProcessBundleDescriptor());
       SparkBundleProgressHandler bundleProgressHandler = new 
SparkBundleProgressHandler();
       try (RemoteBundle bundle =
           stageBundleFactory.getBundle(
@@ -109,6 +122,38 @@ public class SparkExecutableStageFunction<InputT>
     }
   }
 
+  private StateRequestHandler getStateRequestHandler(
+      ExecutableStage executableStage,
+      ProcessBundleDescriptors.ExecutableProcessBundleDescriptor 
processBundleDescriptor) {
+    EnumMap<TypeCase, StateRequestHandler> handlerMap = new 
EnumMap<>(StateKey.TypeCase.class);
+    final StateRequestHandler sideInputHandler;
+    StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory =
+        BatchSideInputHandlerFactory.forStage(
+            executableStage,
+            new BatchSideInputHandlerFactory.SideInputGetter() {
+              @Override
+              public <T> List<T> getSideInput(String pCollectionId) {
+                Tuple2<Broadcast<List<byte[]>>, 
WindowedValueCoder<SideInputT>> tuple2 =
+                    sideInputs.get(pCollectionId);
+                Broadcast<List<byte[]>> broadcast = tuple2._1;
+                WindowedValueCoder<SideInputT> coder = tuple2._2;
+                return (List<T>)
+                    broadcast.value().stream()
+                        .map(bytes -> CoderHelpers.fromByteArray(bytes, coder))
+                        .collect(Collectors.toList());
+              }
+            });
+    try {
+      sideInputHandler =
+          StateRequestHandlers.forSideInputHandlerFactory(
+              ProcessBundleDescriptors.getSideInputs(executableStage), 
sideInputHandlerFactory);
+    } catch (IOException e) {
+      throw new RuntimeException("Failed to setup state handler", e);
+    }
+    handlerMap.put(StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputHandler);
+    return StateRequestHandlers.delegateBasedUponType(handlerMap);
+  }
+
   interface JobBundleFactoryCreator extends Serializable {
     JobBundleFactory create();
   }
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
index ad97ec0..38bdd1f 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
@@ -34,9 +34,11 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.Impulse;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.WithKeys;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
 import 
org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.ListeningExecutorService;
 import 
org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.MoreExecutors;
 import org.junit.AfterClass;
@@ -80,6 +82,20 @@ public class SparkPortableExecutionTest implements 
Serializable {
         .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED);
 
     Pipeline p = Pipeline.create(options);
+
+    final PCollectionView<Integer> view =
+        p.apply("impulse23", Impulse.create())
+            .apply(
+                "create23",
+                ParDo.of(
+                    new DoFn<byte[], Integer>() {
+                      @ProcessElement
+                      public void process(ProcessContext context) {
+                        context.output(23);
+                      }
+                    }))
+            .apply(View.asSingleton());
+
     PCollection<KV<String, Iterable<Long>>> result =
         p.apply("impulse", Impulse.create())
             .apply(
@@ -108,15 +124,17 @@ public class SparkPortableExecutionTest implements 
Serializable {
             .apply(
                 "print",
                 ParDo.of(
-                    new DoFn<KV<String, Iterable<Long>>, KV<String, Long>>() {
-                      @ProcessElement
-                      public void process(ProcessContext context) {
-                        LOG.info("Output element: {}", context.element());
-                        for (Long i : context.element().getValue()) {
-                          context.output(KV.of(context.element().getKey(), i));
-                        }
-                      }
-                    }))
+                        new DoFn<KV<String, Iterable<Long>>, KV<String, 
Long>>() {
+                          @ProcessElement
+                          public void process(ProcessContext context) {
+                            LOG.info("Side input: {}", 
context.sideInput(view));
+                            LOG.info("Output element: {}", context.element());
+                            for (Long i : context.element().getValue()) {
+                              context.output(KV.of(context.element().getKey(), 
i));
+                            }
+                          }
+                        })
+                    .withSideInputs(view))
             // Second GBK forces the output to be materialized
             .apply("gbk", GroupByKey.create());
 
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
index 8f1bdca..bba1ea4 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
@@ -89,14 +89,14 @@ public class SparkExecutableStageFunctionTest {
 
   @Test(expected = Exception.class)
   public void sdkErrorsSurfaceOnClose() throws Exception {
-    SparkExecutableStageFunction<Integer> function = 
getFunction(Collections.emptyMap());
+    SparkExecutableStageFunction<Integer, ?> function = 
getFunction(Collections.emptyMap());
     doThrow(new Exception()).when(remoteBundle).close();
     function.call(Collections.emptyIterator());
   }
 
   @Test
   public void expectedInputsAreSent() throws Exception {
-    SparkExecutableStageFunction<Integer> function = 
getFunction(Collections.emptyMap());
+    SparkExecutableStageFunction<Integer, ?> function = 
getFunction(Collections.emptyMap());
 
     RemoteBundle bundle = Mockito.mock(RemoteBundle.class);
     when(stageBundleFactory.getBundle(any(), any(), any())).thenReturn(bundle);
@@ -178,7 +178,7 @@ public class SparkExecutableStageFunctionTest {
         };
     when(jobBundleFactory.forStage(any())).thenReturn(stageBundleFactory);
 
-    SparkExecutableStageFunction<Integer> function = getFunction(outputTagMap);
+    SparkExecutableStageFunction<Integer, ?> function = 
getFunction(outputTagMap);
     Iterator<RawUnionValue> iterator = 
function.call(Collections.emptyIterator());
     Iterable<RawUnionValue> iterable = () -> iterator;
 
@@ -190,14 +190,17 @@ public class SparkExecutableStageFunctionTest {
 
   @Test
   public void testStageBundleClosed() throws Exception {
-    SparkExecutableStageFunction<Integer> function = 
getFunction(Collections.emptyMap());
+    SparkExecutableStageFunction<Integer, ?> function = 
getFunction(Collections.emptyMap());
     function.call(Collections.emptyIterator());
     verify(stageBundleFactory).getBundle(any(), any(), any());
+    verify(stageBundleFactory).getProcessBundleDescriptor();
     verify(stageBundleFactory).close();
     verifyNoMoreInteractions(stageBundleFactory);
   }
 
-  private <T> SparkExecutableStageFunction<T> getFunction(Map<String, Integer> 
outputMap) {
-    return new SparkExecutableStageFunction<>(stagePayload, outputMap, 
jobBundleFactoryCreator);
+  private <InputT, SideInputT> SparkExecutableStageFunction<InputT, 
SideInputT> getFunction(
+      Map<String, Integer> outputMap) {
+    return new SparkExecutableStageFunction<>(
+        stagePayload, outputMap, jobBundleFactoryCreator, 
Collections.emptyMap());
   }
 }

Reply via email to