This is an automated email from the ASF dual-hosted git repository. boyuanz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new bc955de Lengthprefix any input coder for an ProcessBundleDescriptor. new 5419c3b Merge pull request #13120 from [BEAM-10940] Lengthprefix any input coder for an ProcessBundleDescriptor. bc955de is described below commit bc955ded10e0a054d437adf5c7117004de978d46 Author: Boyuan Zhang <boyu...@google.com> AuthorDate: Wed Oct 14 13:47:55 2020 -0700 Lengthprefix any input coder for an ProcessBundleDescriptor. --- .../control/ProcessBundleDescriptors.java | 36 +++----- .../control/ProcessBundleDescriptorsTest.java | 101 +++++++++++++++++++++ 2 files changed, 113 insertions(+), 24 deletions(-) diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java index e76c130..ac3b882 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java @@ -37,7 +37,6 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.Components; import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.WireCoderSetting; import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; -import org.apache.beam.runners.core.construction.ModelCoders; import org.apache.beam.runners.core.construction.RehydratedComponents; import org.apache.beam.runners.core.construction.Timer; import org.apache.beam.runners.core.construction.graph.ExecutableStage; @@ -59,7 +58,6 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; import org.apache.beam.sdk.values.KV; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableTable; @@ -141,9 +139,7 @@ public class ProcessBundleDescriptors { Map<String, Map<String, TimerSpec>> timerSpecs = forTimerSpecs(stage, components); - if (bagUserStateSpecs.size() > 0 || timerSpecs.size() > 0) { - lengthPrefixKeyCoder(stage.getInputPCollection().getId(), components); - } + lengthPrefixAnyInputCoder(stage.getInputPCollection().getId(), components); // Copy data from components to ProcessBundleDescriptor. ProcessBundleDescriptor.Builder bundleDescriptorBuilder = @@ -174,26 +170,18 @@ public class ProcessBundleDescriptors { } /** - * Patches the input coder of a stateful transform to ensure that the byte representation of a key - * used to partition the input element at the Runner, matches the key byte representation received - * for state requests and timers from the SDK Harness. Stateful transforms always have a KvCoder - * as input. + * Patches the input coder of the transform to ensure that the byte representation of input used + * at the Runner, matches the byte representation received from the SDK Harness. */ - private static void lengthPrefixKeyCoder( - String inputColId, Components.Builder componentsBuilder) { - RunnerApi.PCollection pcollection = componentsBuilder.getPcollectionsOrThrow(inputColId); - RunnerApi.Coder kvCoder = componentsBuilder.getCodersOrThrow(pcollection.getCoderId()); - Preconditions.checkState( - ModelCoders.KV_CODER_URN.equals(kvCoder.getSpec().getUrn()), - "Stateful executable stages must use a KV coder, but is: %s", - kvCoder.getSpec().getUrn()); - String keyCoderId = ModelCoders.getKvCoderComponents(kvCoder).keyCoderId(); - // Retain the original coder, but wrap in LengthPrefixCoder - String newKeyCoderId = - LengthPrefixUnknownCoders.addLengthPrefixedCoder(keyCoderId, componentsBuilder, false); - // Replace old key coder with LengthPrefixCoder<old_key_coder> - kvCoder = kvCoder.toBuilder().setComponentCoderIds(0, newKeyCoderId).build(); - componentsBuilder.putCoders(pcollection.getCoderId(), kvCoder); + private static void lengthPrefixAnyInputCoder( + String inputPCollectionId, Components.Builder componentsBuilder) { + RunnerApi.PCollection pcollection = + componentsBuilder.getPcollectionsOrThrow(inputPCollectionId); + String newInputCoderId = + LengthPrefixUnknownCoders.addLengthPrefixedCoder( + pcollection.getCoderId(), componentsBuilder, false); + componentsBuilder.putPcollections( + inputPCollectionId, pcollection.toBuilder().setCoderId(newInputCoderId).build()); } private static Map<String, Coder<WindowedValue<?>>> addStageOutputs( diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java index 98fe899..9337c63 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java @@ -29,11 +29,14 @@ import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.runners.core.construction.CoderTranslation; import org.apache.beam.runners.core.construction.ModelCoderRegistrar; import org.apache.beam.runners.core.construction.ModelCoders; +import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.PipelineTranslation; import org.apache.beam.runners.core.construction.graph.ExecutableStage; import org.apache.beam.runners.core.construction.graph.FusedPipeline; import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser; import org.apache.beam.runners.core.construction.graph.PipelineNode; +import org.apache.beam.runners.core.construction.graph.ProtoOverrides; +import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander; import org.apache.beam.runners.core.construction.graph.TimerReference; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; @@ -48,9 +51,12 @@ import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.TimerSpec; import org.apache.beam.sdk.state.TimerSpecs; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.ProcessContext; +import org.apache.beam.sdk.transforms.DoFn.ProcessElement; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.Impulse; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.values.KV; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; @@ -151,6 +157,99 @@ public class ProcessBundleDescriptorsTest implements Serializable { ensureLengthPrefixed(timerKeyCoder, originalKeyCoder, pbsCoderMap); } + @Test + public void testLengthPrefixingOfInputCoderExecutableStage() throws Exception { + Pipeline p = Pipeline.create(); + Coder<Void> voidCoder = VoidCoder.of(); + assertThat(ModelCoderRegistrar.isKnownCoder(voidCoder), is(false)); + p.apply("impulse", Impulse.create()) + .apply( + ParDo.of( + new DoFn<byte[], Void>() { + @ProcessElement + public void process(ProcessContext ctxt) {} + })) + .setCoder(voidCoder) + .apply( + ParDo.of( + new DoFn<Void, Void>() { + @ProcessElement + public void processElement( + ProcessContext context, RestrictionTracker<Void, Void> tracker) {} + + @GetInitialRestriction + public Void getInitialRestriction() { + return null; + } + + @NewTracker + public SomeTracker newTracker(@Restriction Void restriction) { + return null; + } + })) + .setCoder(voidCoder); + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + RunnerApi.Pipeline pipelineWithSdfExpanded = + ProtoOverrides.updateTransform( + PTransformTranslation.PAR_DO_TRANSFORM_URN, + pipelineProto, + SplittableParDoExpander.createSizedReplacement()); + FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineWithSdfExpanded); + Optional<ExecutableStage> optionalStage = + Iterables.tryFind( + fused.getFusedStages(), + (ExecutableStage stage) -> + stage.getTransforms().stream() + .anyMatch( + transform -> + transform + .getTransform() + .getSpec() + .getUrn() + .equals( + PTransformTranslation + .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN))); + checkState( + optionalStage.isPresent(), + "Expected a stage with SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN."); + + ExecutableStage stage = optionalStage.get(); + PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection(); + Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap(); + RunnerApi.Coder originalMainInputCoder = + stageCoderMap.get(inputPCollection.getPCollection().getCoderId()); + + BeamFnApi.ProcessBundleDescriptor pbd = + ProcessBundleDescriptors.fromExecutableStage( + "test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance()) + .getProcessBundleDescriptor(); + Map<String, RunnerApi.Coder> pbsCoderMap = pbd.getCodersMap(); + + RunnerApi.Coder pbsMainInputCoder = + pbsCoderMap.get(pbd.getPcollectionsOrThrow(inputPCollection.getId()).getCoderId()); + + RunnerApi.Coder kvCoder = + pbsCoderMap.get(ModelCoders.getKvCoderComponents(pbsMainInputCoder).keyCoderId()); + RunnerApi.Coder keyCoder = + pbsCoderMap.get(ModelCoders.getKvCoderComponents(kvCoder).keyCoderId()); + RunnerApi.Coder valueKvCoder = + pbsCoderMap.get(ModelCoders.getKvCoderComponents(kvCoder).valueCoderId()); + RunnerApi.Coder valueCoder = + pbsCoderMap.get(ModelCoders.getKvCoderComponents(valueKvCoder).keyCoderId()); + + RunnerApi.Coder originalKvCoder = + stageCoderMap.get(ModelCoders.getKvCoderComponents(originalMainInputCoder).keyCoderId()); + RunnerApi.Coder originalKeyCoder = + stageCoderMap.get(ModelCoders.getKvCoderComponents(originalKvCoder).keyCoderId()); + RunnerApi.Coder originalvalueKvCoder = + stageCoderMap.get(ModelCoders.getKvCoderComponents(originalKvCoder).valueCoderId()); + RunnerApi.Coder originalvalueCoder = + stageCoderMap.get(ModelCoders.getKvCoderComponents(originalvalueKvCoder).keyCoderId()); + + ensureLengthPrefixed(keyCoder, originalKeyCoder, pbsCoderMap); + ensureLengthPrefixed(valueCoder, originalvalueCoder, pbsCoderMap); + } + private static void ensureLengthPrefixed( RunnerApi.Coder coder, RunnerApi.Coder originalCoder, @@ -160,4 +259,6 @@ public class ProcessBundleDescriptorsTest implements Serializable { String lengthPrefixedWrappedCoderId = coder.getComponentCoderIds(0); assertThat(pbsCoderMap.get(lengthPrefixedWrappedCoderId), is(originalCoder)); } + + private abstract static class SomeTracker extends RestrictionTracker<Void, Void> {} }