This is an automated email from the ASF dual-hosted git repository. iemejia pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new c45a50f [BEAM-6935] Spark portable runner: implement side inputs new 16b58ad Merge pull request #8220: [BEAM-6935] Spark portable runner: implement side inputs c45a50f is described below commit c45a50fb092171ce4fa5f8b0758a584911d4f50d Author: Kyle Weaver <kcwea...@google.com> AuthorDate: Thu Mar 28 19:16:51 2019 -0700 [BEAM-6935] Spark portable runner: implement side inputs --- .../functions/FlinkExecutableStageFunction.java | 4 +- .../translation/BatchSideInputHandlerFactory.java} | 35 ++++++------- .../BatchSideInputHandlerFactoryTest.java} | 40 +++++++-------- .../runners/spark/translation/BoundedDataset.java | 9 ++++ .../SparkBatchPortablePipelineTranslator.java | 47 +++++++++++++++-- .../translation/SparkExecutableStageFunction.java | 59 +++++++++++++++++++--- .../runners/spark/SparkPortableExecutionTest.java | 36 +++++++++---- .../SparkExecutableStageFunctionTest.java | 15 +++--- 8 files changed, 181 insertions(+), 64 deletions(-) diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java index e7dafa8..c02aa65 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java @@ -54,6 +54,7 @@ import org.apache.beam.runners.fnexecution.control.StageBundleFactory; import org.apache.beam.runners.fnexecution.provisioning.JobInfo; import org.apache.beam.runners.fnexecution.state.StateRequestHandler; import org.apache.beam.runners.fnexecution.state.StateRequestHandlers; +import org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.fn.data.FnDataReceiver; import org.apache.beam.sdk.io.FileSystems; @@ -167,7 +168,8 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction RuntimeContext runtimeContext) { final StateRequestHandler sideInputHandler; StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory = - FlinkBatchSideInputHandlerFactory.forStage(executableStage, runtimeContext); + BatchSideInputHandlerFactory.forStage( + executableStage, runtimeContext::getBroadcastVariable); try { sideInputHandler = StateRequestHandlers.forSideInputHandlerFactory( diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactory.java similarity index 87% rename from runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java rename to runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactory.java index 798c32b..5460898 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactory.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.runners.flink.translation.functions; +package org.apache.beam.runners.fnexecution.translation; import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument; @@ -43,24 +43,25 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMultimap; import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Multimap; -import org.apache.flink.api.common.functions.RuntimeContext; -/** - * {@link StateRequestHandler} that uses a Flink {@link RuntimeContext} to access Flink broadcast - * variable that represent side inputs. - */ -class FlinkBatchSideInputHandlerFactory implements SideInputHandlerFactory { +/** {@link StateRequestHandler} that uses a {@link SideInputGetter} to access side inputs. */ +public class BatchSideInputHandlerFactory implements SideInputHandlerFactory { // Map from side input id to global PCollection id. private final Map<SideInputId, PCollectionNode> sideInputToCollection; - private final RuntimeContext runtimeContext; + private final SideInputGetter sideInputGetter; + + /** Returns the value for the side input with the given PCollection id from the runner. */ + public interface SideInputGetter { + <T> List<T> getSideInput(String pCollectionId); + } /** * Creates a new state handler for the given stage. Note that this requires a traversal of the * stage itself, so this should only be called once per stage rather than once per bundle. */ - static FlinkBatchSideInputHandlerFactory forStage( - ExecutableStage stage, RuntimeContext runtimeContext) { + public static BatchSideInputHandlerFactory forStage( + ExecutableStage stage, SideInputGetter sideInputGetter) { ImmutableMap.Builder<SideInputId, PCollectionNode> sideInputBuilder = ImmutableMap.builder(); for (SideInputReference sideInput : stage.getSideInputs()) { sideInputBuilder.put( @@ -70,13 +71,13 @@ class FlinkBatchSideInputHandlerFactory implements SideInputHandlerFactory { .build(), sideInput.collection()); } - return new FlinkBatchSideInputHandlerFactory(sideInputBuilder.build(), runtimeContext); + return new BatchSideInputHandlerFactory(sideInputBuilder.build(), sideInputGetter); } - private FlinkBatchSideInputHandlerFactory( - Map<SideInputId, PCollectionNode> sideInputToCollection, RuntimeContext runtimeContext) { + private BatchSideInputHandlerFactory( + Map<SideInputId, PCollectionNode> sideInputToCollection, SideInputGetter sideInputGetter) { this.sideInputToCollection = sideInputToCollection; - this.runtimeContext = runtimeContext; + this.sideInputGetter = sideInputGetter; } @Override @@ -96,7 +97,7 @@ class FlinkBatchSideInputHandlerFactory implements SideInputHandlerFactory { @SuppressWarnings("unchecked") // T == V Coder<V> outputCoder = (Coder<V>) elementCoder; return forIterableSideInput( - runtimeContext.getBroadcastVariable(collectionNode.getId()), outputCoder, windowCoder); + sideInputGetter.getSideInput(collectionNode.getId()), outputCoder, windowCoder); } else if (PTransformTranslation.MULTIMAP_SIDE_INPUT.equals(accessPattern.getUrn()) || Materializations.MULTIMAP_MATERIALIZATION_URN.equals(accessPattern.getUrn())) { // TODO: Remove non standard URN. @@ -104,7 +105,7 @@ class FlinkBatchSideInputHandlerFactory implements SideInputHandlerFactory { @SuppressWarnings("unchecked") // T == KV<?, V> KvCoder<?, V> kvCoder = (KvCoder<?, V>) elementCoder; return forMultimapSideInput( - runtimeContext.getBroadcastVariable(collectionNode.getId()), + sideInputGetter.getSideInput(collectionNode.getId()), kvCoder.getKeyCoder(), kvCoder.getValueCoder(), windowCoder); @@ -202,7 +203,7 @@ class FlinkBatchSideInputHandlerFactory implements SideInputHandlerFactory { @AutoValue abstract static class SideInputKey { static SideInputKey of(Object key, Object window) { - return new AutoValue_FlinkBatchSideInputHandlerFactory_SideInputKey(key, window); + return new AutoValue_BatchSideInputHandlerFactory_SideInputKey(key, window); } @Nullable diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactoryTest.java similarity index 89% rename from runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java rename to runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactoryTest.java index 897289f..f664aa9 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/translation/BatchSideInputHandlerFactoryTest.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.runners.flink.translation.functions; +package org.apache.beam.runners.fnexecution.translation; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; @@ -50,7 +50,6 @@ import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCod import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; -import org.apache.flink.api.common.functions.RuntimeContext; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.joda.time.Instant; @@ -63,9 +62,9 @@ import org.junit.runners.JUnit4; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -/** Tests for {@link FlinkBatchSideInputHandlerFactory}. */ +/** Tests for {@link BatchSideInputHandlerFactory}. */ @RunWith(JUnit4.class) -public class FlinkBatchSideInputHandlerFactoryTest { +public class BatchSideInputHandlerFactoryTest { private static final String TRANSFORM_ID = "transform-id"; private static final String SIDE_INPUT_NAME = "side-input"; @@ -87,7 +86,7 @@ public class FlinkBatchSideInputHandlerFactoryTest { @Rule public ExpectedException thrown = ExpectedException.none(); - @Mock private RuntimeContext context; + @Mock private BatchSideInputHandlerFactory.SideInputGetter context; @Before public void setUpMocks() { @@ -97,8 +96,7 @@ public class FlinkBatchSideInputHandlerFactoryTest { @Test public void invalidSideInputThrowsException() { ExecutableStage stage = createExecutableStage(Collections.emptyList()); - FlinkBatchSideInputHandlerFactory factory = - FlinkBatchSideInputHandlerFactory.forStage(stage, context); + BatchSideInputHandlerFactory factory = BatchSideInputHandlerFactory.forStage(stage, context); thrown.expect(instanceOf(IllegalArgumentException.class)); factory.forSideInput( "transform-id", @@ -110,8 +108,8 @@ public class FlinkBatchSideInputHandlerFactoryTest { @Test public void emptyResultForEmptyCollection() { - FlinkBatchSideInputHandlerFactory factory = - FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context); + BatchSideInputHandlerFactory factory = + BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context); SideInputHandler<Integer, GlobalWindow> handler = factory.forSideInput( TRANSFORM_ID, @@ -127,12 +125,12 @@ public class FlinkBatchSideInputHandlerFactoryTest { @Test public void singleElementForCollection() { - when(context.getBroadcastVariable(COLLECTION_ID)) + when(context.getSideInput(COLLECTION_ID)) .thenReturn( Arrays.asList(WindowedValue.valueInGlobalWindow(KV.<Void, Integer>of(null, 3)))); - FlinkBatchSideInputHandlerFactory factory = - FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context); + BatchSideInputHandlerFactory factory = + BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context); SideInputHandler<Integer, GlobalWindow> handler = factory.forSideInput( TRANSFORM_ID, @@ -146,15 +144,15 @@ public class FlinkBatchSideInputHandlerFactoryTest { @Test public void groupsValuesByKey() { - when(context.getBroadcastVariable(COLLECTION_ID)) + when(context.getSideInput(COLLECTION_ID)) .thenReturn( Arrays.asList( WindowedValue.valueInGlobalWindow(KV.of("foo", 2)), WindowedValue.valueInGlobalWindow(KV.of("bar", 3)), WindowedValue.valueInGlobalWindow(KV.of("foo", 5)))); - FlinkBatchSideInputHandlerFactory factory = - FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context); + BatchSideInputHandlerFactory factory = + BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context); SideInputHandler<Integer, GlobalWindow> handler = factory.forSideInput( TRANSFORM_ID, @@ -173,7 +171,7 @@ public class FlinkBatchSideInputHandlerFactoryTest { Instant instantC = new DateTime(2018, 1, 1, 1, 3, DateTimeZone.UTC).toInstant(); IntervalWindow windowA = new IntervalWindow(instantA, instantB); IntervalWindow windowB = new IntervalWindow(instantB, instantC); - when(context.getBroadcastVariable(COLLECTION_ID)) + when(context.getSideInput(COLLECTION_ID)) .thenReturn( Arrays.asList( WindowedValue.of(KV.of("foo", 1), instantA, windowA, PaneInfo.NO_FIRING), @@ -183,8 +181,8 @@ public class FlinkBatchSideInputHandlerFactoryTest { WindowedValue.of(KV.of("bar", 5), instantB, windowB, PaneInfo.NO_FIRING), WindowedValue.of(KV.of("foo", 6), instantB, windowB, PaneInfo.NO_FIRING))); - FlinkBatchSideInputHandlerFactory factory = - FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context); + BatchSideInputHandlerFactory factory = + BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context); SideInputHandler<Integer, IntervalWindow> handler = factory.forSideInput( TRANSFORM_ID, @@ -205,7 +203,7 @@ public class FlinkBatchSideInputHandlerFactoryTest { Instant instantC = new DateTime(2018, 1, 1, 1, 3, DateTimeZone.UTC).toInstant(); IntervalWindow windowA = new IntervalWindow(instantA, instantB); IntervalWindow windowB = new IntervalWindow(instantB, instantC); - when(context.getBroadcastVariable(COLLECTION_ID)) + when(context.getSideInput(COLLECTION_ID)) .thenReturn( Arrays.asList( WindowedValue.of(1, instantA, windowA, PaneInfo.NO_FIRING), @@ -213,8 +211,8 @@ public class FlinkBatchSideInputHandlerFactoryTest { WindowedValue.of(3, instantB, windowB, PaneInfo.NO_FIRING), WindowedValue.of(4, instantB, windowB, PaneInfo.NO_FIRING))); - FlinkBatchSideInputHandlerFactory factory = - FlinkBatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context); + BatchSideInputHandlerFactory factory = + BatchSideInputHandlerFactory.forStage(EXECUTABLE_STAGE, context); SideInputHandler<Integer, IntervalWindow> handler = factory.forSideInput( TRANSFORM_ID, diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java index 1e620e7..c81c5f4 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java @@ -46,6 +46,7 @@ public class BoundedDataset<T> implements Dataset { private Iterable<WindowedValue<T>> windowedValues; private Coder<T> coder; private JavaRDD<WindowedValue<T>> rdd; + private List<byte[]> clientBytes; BoundedDataset(JavaRDD<WindowedValue<T>> rdd) { this.rdd = rdd; @@ -69,6 +70,14 @@ public class BoundedDataset<T> implements Dataset { return rdd; } + List<byte[]> getBytes(WindowedValue.WindowedValueCoder<T> wvCoder) { + if (clientBytes == null) { + JavaRDDLike<byte[], ?> bytesRDD = rdd.map(CoderHelpers.toByteFunction(wvCoder)); + clientBytes = bytesRDD.collect(); + } + return clientBytes; + } + Iterable<WindowedValue<T>> getValues(PCollection<T> pcollection) { if (windowedValues == null) { WindowFn<?, ?> windowFn = pcollection.getWindowingStrategy().getWindowFn(); diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java index c65caa4..82557ae 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java @@ -22,10 +22,12 @@ import static org.apache.beam.runners.fnexecution.translation.PipelineTranslator import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Set; import javax.annotation.Nullable; import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId; import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.core.construction.PTransformTranslation; @@ -54,6 +56,8 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables; import org.apache.spark.HashPartitioner; import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.broadcast.Broadcast; +import scala.Tuple2; /** Translates a bounded portable pipeline into a Spark job. */ public class SparkBatchPortablePipelineTranslator { @@ -163,7 +167,7 @@ public class SparkBatchPortablePipelineTranslator { context.pushDataset(getOutputId(transformNode), new BoundedDataset<>(groupedByKeyAndWindow)); } - private static <InputT, OutputT> void translateExecutableStage( + private static <InputT, OutputT, SideInputT> void translateExecutableStage( PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) { RunnerApi.ExecutableStagePayload stagePayload; @@ -180,8 +184,22 @@ public class SparkBatchPortablePipelineTranslator { Map<String, String> outputs = transformNode.getTransform().getOutputsMap(); BiMap<String, Integer> outputMap = createOutputMap(outputs.values()); - SparkExecutableStageFunction<InputT> function = - new SparkExecutableStageFunction<>(stagePayload, context.jobInfo, outputMap); + ImmutableMap.Builder<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> + broadcastVariablesBuilder = ImmutableMap.builder(); + for (SideInputId sideInputId : stagePayload.getSideInputsList()) { + RunnerApi.Components components = stagePayload.getComponents(); + String collectionId = + components + .getTransformsOrThrow(sideInputId.getTransformId()) + .getInputsOrThrow(sideInputId.getLocalName()); + Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 = + broadcastSideInput(collectionId, components, context); + broadcastVariablesBuilder.put(collectionId, tuple2); + } + + SparkExecutableStageFunction<InputT, SideInputT> function = + new SparkExecutableStageFunction<>( + stagePayload, context.jobInfo, outputMap, broadcastVariablesBuilder.build()); JavaRDD<RawUnionValue> staged = inputRdd.mapPartitions(function); for (String outputId : outputs.values()) { @@ -191,6 +209,29 @@ public class SparkBatchPortablePipelineTranslator { } } + /** + * Collect and serialize the data and then broadcast the result. *This can be expensive.* + * + * @return Spark broadcast variable and coder to decode its contents + */ + private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<T>> broadcastSideInput( + String collectionId, RunnerApi.Components components, SparkTranslationContext context) { + PCollection collection = components.getPcollectionsOrThrow(collectionId); + @SuppressWarnings("unchecked") + BoundedDataset<T> dataset = (BoundedDataset<T>) context.popDataset(collectionId); + PCollectionNode collectionNode = PipelineNode.pCollection(collectionId, collection); + WindowedValueCoder<T> coder; + try { + coder = + (WindowedValueCoder<T>) WireCoders.instantiateRunnerWireCoder(collectionNode, components); + } catch (IOException e) { + throw new RuntimeException(e); + } + List<byte[]> bytes = dataset.getBytes(coder); + Broadcast<List<byte[]>> broadcast = context.getSparkContext().broadcast(bytes); + return new Tuple2<>(broadcast, coder); + } + @Nullable private static Partitioner getPartitioner(SparkTranslationContext context) { Long bundleSize = diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java index 93250bc..e9ff511 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunction.java @@ -17,12 +17,15 @@ */ package org.apache.beam.runners.spark.translation; +import java.io.IOException; import java.io.Serializable; import java.util.EnumMap; import java.util.Iterator; +import java.util.List; import java.util.Locale; import java.util.Map; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.stream.Collectors; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; @@ -33,17 +36,23 @@ import org.apache.beam.runners.fnexecution.control.BundleProgressHandler; import org.apache.beam.runners.fnexecution.control.DefaultJobBundleFactory; import org.apache.beam.runners.fnexecution.control.JobBundleFactory; import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory; +import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors; import org.apache.beam.runners.fnexecution.control.RemoteBundle; import org.apache.beam.runners.fnexecution.control.StageBundleFactory; import org.apache.beam.runners.fnexecution.provisioning.JobInfo; import org.apache.beam.runners.fnexecution.state.StateRequestHandler; import org.apache.beam.runners.fnexecution.state.StateRequestHandlers; +import org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory; +import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.sdk.fn.data.FnDataReceiver; import org.apache.beam.sdk.transforms.join.RawUnionValue; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.broadcast.Broadcast; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import scala.Tuple2; /** * Spark function that passes its input through an SDK-executed {@link @@ -54,7 +63,7 @@ import org.slf4j.LoggerFactory; * The resulting data set should be further processed by a {@link * SparkExecutableStageExtractionFunction}. */ -public class SparkExecutableStageFunction<InputT> +public class SparkExecutableStageFunction<InputT, SideInputT> implements FlatMapFunction<Iterator<WindowedValue<InputT>>, RawUnionValue> { private static final Logger LOG = LoggerFactory.getLogger(SparkExecutableStageFunction.class); @@ -62,21 +71,27 @@ public class SparkExecutableStageFunction<InputT> private final RunnerApi.ExecutableStagePayload stagePayload; private final Map<String, Integer> outputMap; private final JobBundleFactoryCreator jobBundleFactoryCreator; + // map from pCollection id to tuple of serialized bytes and coder to decode the bytes + private final Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> + sideInputs; SparkExecutableStageFunction( RunnerApi.ExecutableStagePayload stagePayload, JobInfo jobInfo, - Map<String, Integer> outputMap) { - this(stagePayload, outputMap, () -> DefaultJobBundleFactory.create(jobInfo)); + Map<String, Integer> outputMap, + Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> sideInputs) { + this(stagePayload, outputMap, () -> DefaultJobBundleFactory.create(jobInfo), sideInputs); } SparkExecutableStageFunction( RunnerApi.ExecutableStagePayload stagePayload, Map<String, Integer> outputMap, - JobBundleFactoryCreator jobBundleFactoryCreator) { + JobBundleFactoryCreator jobBundleFactoryCreator, + Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> sideInputs) { this.stagePayload = stagePayload; this.outputMap = outputMap; this.jobBundleFactoryCreator = jobBundleFactoryCreator; + this.sideInputs = sideInputs; } @Override @@ -86,10 +101,8 @@ public class SparkExecutableStageFunction<InputT> try (StageBundleFactory stageBundleFactory = jobBundleFactory.forStage(executableStage)) { ConcurrentLinkedQueue<RawUnionValue> collector = new ConcurrentLinkedQueue<>(); ReceiverFactory receiverFactory = new ReceiverFactory(collector, outputMap); - EnumMap<TypeCase, StateRequestHandler> handlers = new EnumMap<>(StateKey.TypeCase.class); - // TODO add state request handlers StateRequestHandler stateRequestHandler = - StateRequestHandlers.delegateBasedUponType(handlers); + getStateRequestHandler(executableStage, stageBundleFactory.getProcessBundleDescriptor()); SparkBundleProgressHandler bundleProgressHandler = new SparkBundleProgressHandler(); try (RemoteBundle bundle = stageBundleFactory.getBundle( @@ -109,6 +122,38 @@ public class SparkExecutableStageFunction<InputT> } } + private StateRequestHandler getStateRequestHandler( + ExecutableStage executableStage, + ProcessBundleDescriptors.ExecutableProcessBundleDescriptor processBundleDescriptor) { + EnumMap<TypeCase, StateRequestHandler> handlerMap = new EnumMap<>(StateKey.TypeCase.class); + final StateRequestHandler sideInputHandler; + StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory = + BatchSideInputHandlerFactory.forStage( + executableStage, + new BatchSideInputHandlerFactory.SideInputGetter() { + @Override + public <T> List<T> getSideInput(String pCollectionId) { + Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 = + sideInputs.get(pCollectionId); + Broadcast<List<byte[]>> broadcast = tuple2._1; + WindowedValueCoder<SideInputT> coder = tuple2._2; + return (List<T>) + broadcast.value().stream() + .map(bytes -> CoderHelpers.fromByteArray(bytes, coder)) + .collect(Collectors.toList()); + } + }); + try { + sideInputHandler = + StateRequestHandlers.forSideInputHandlerFactory( + ProcessBundleDescriptors.getSideInputs(executableStage), sideInputHandlerFactory); + } catch (IOException e) { + throw new RuntimeException("Failed to setup state handler", e); + } + handlerMap.put(StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputHandler); + return StateRequestHandlers.delegateBasedUponType(handlerMap); + } + interface JobBundleFactoryCreator extends Serializable { JobBundleFactory create(); } diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java index ad97ec0..38bdd1f 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java @@ -34,9 +34,11 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.Impulse; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.ListeningExecutorService; import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.MoreExecutors; import org.junit.AfterClass; @@ -80,6 +82,20 @@ public class SparkPortableExecutionTest implements Serializable { .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED); Pipeline p = Pipeline.create(options); + + final PCollectionView<Integer> view = + p.apply("impulse23", Impulse.create()) + .apply( + "create23", + ParDo.of( + new DoFn<byte[], Integer>() { + @ProcessElement + public void process(ProcessContext context) { + context.output(23); + } + })) + .apply(View.asSingleton()); + PCollection<KV<String, Iterable<Long>>> result = p.apply("impulse", Impulse.create()) .apply( @@ -108,15 +124,17 @@ public class SparkPortableExecutionTest implements Serializable { .apply( "print", ParDo.of( - new DoFn<KV<String, Iterable<Long>>, KV<String, Long>>() { - @ProcessElement - public void process(ProcessContext context) { - LOG.info("Output element: {}", context.element()); - for (Long i : context.element().getValue()) { - context.output(KV.of(context.element().getKey(), i)); - } - } - })) + new DoFn<KV<String, Iterable<Long>>, KV<String, Long>>() { + @ProcessElement + public void process(ProcessContext context) { + LOG.info("Side input: {}", context.sideInput(view)); + LOG.info("Output element: {}", context.element()); + for (Long i : context.element().getValue()) { + context.output(KV.of(context.element().getKey(), i)); + } + } + }) + .withSideInputs(view)) // Second GBK forces the output to be materialized .apply("gbk", GroupByKey.create()); diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java index 8f1bdca..bba1ea4 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java @@ -89,14 +89,14 @@ public class SparkExecutableStageFunctionTest { @Test(expected = Exception.class) public void sdkErrorsSurfaceOnClose() throws Exception { - SparkExecutableStageFunction<Integer> function = getFunction(Collections.emptyMap()); + SparkExecutableStageFunction<Integer, ?> function = getFunction(Collections.emptyMap()); doThrow(new Exception()).when(remoteBundle).close(); function.call(Collections.emptyIterator()); } @Test public void expectedInputsAreSent() throws Exception { - SparkExecutableStageFunction<Integer> function = getFunction(Collections.emptyMap()); + SparkExecutableStageFunction<Integer, ?> function = getFunction(Collections.emptyMap()); RemoteBundle bundle = Mockito.mock(RemoteBundle.class); when(stageBundleFactory.getBundle(any(), any(), any())).thenReturn(bundle); @@ -178,7 +178,7 @@ public class SparkExecutableStageFunctionTest { }; when(jobBundleFactory.forStage(any())).thenReturn(stageBundleFactory); - SparkExecutableStageFunction<Integer> function = getFunction(outputTagMap); + SparkExecutableStageFunction<Integer, ?> function = getFunction(outputTagMap); Iterator<RawUnionValue> iterator = function.call(Collections.emptyIterator()); Iterable<RawUnionValue> iterable = () -> iterator; @@ -190,14 +190,17 @@ public class SparkExecutableStageFunctionTest { @Test public void testStageBundleClosed() throws Exception { - SparkExecutableStageFunction<Integer> function = getFunction(Collections.emptyMap()); + SparkExecutableStageFunction<Integer, ?> function = getFunction(Collections.emptyMap()); function.call(Collections.emptyIterator()); verify(stageBundleFactory).getBundle(any(), any(), any()); + verify(stageBundleFactory).getProcessBundleDescriptor(); verify(stageBundleFactory).close(); verifyNoMoreInteractions(stageBundleFactory); } - private <T> SparkExecutableStageFunction<T> getFunction(Map<String, Integer> outputMap) { - return new SparkExecutableStageFunction<>(stagePayload, outputMap, jobBundleFactoryCreator); + private <InputT, SideInputT> SparkExecutableStageFunction<InputT, SideInputT> getFunction( + Map<String, Integer> outputMap) { + return new SparkExecutableStageFunction<>( + stagePayload, outputMap, jobBundleFactoryCreator, Collections.emptyMap()); } }