This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git

commit fb3aa34b55bdb21afc030ce618a27dcf2623f40e
Author: Etienne Chauchot <echauc...@apache.org>
AuthorDate: Fri Oct 11 16:18:11 2019 +0200

    Add missing windowedValue Encoder call in Pardo
---
 .../translation/batch/ParDoTranslatorBatch.java    | 37 ++++++++++++----------
 1 file changed, 21 insertions(+), 16 deletions(-)

diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
index e73d38e..f5a109e 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
@@ -93,6 +93,8 @@ class ParDoTranslatorBatch<InputT, OutputT>
     List<TupleTag<?>> outputTags = new ArrayList<>(outputs.keySet());
     WindowingStrategy<?, ?> windowingStrategy =
         ((PCollection<InputT>) input).getWindowingStrategy();
+    Coder<InputT> inputCoder = ((PCollection<InputT>) input).getCoder();
+    Coder<? extends BoundedWindow> windowCoder = 
windowingStrategy.getWindowFn().windowCoder();
 
     // construct a map from side input to WindowingStrategy so that
     // the DoFn runner can map main-input windows to side input windows
@@ -105,7 +107,6 @@ class ParDoTranslatorBatch<InputT, OutputT>
     SideInputBroadcast broadcastStateData = 
createBroadcastSideInputs(sideInputs, context);
 
     Map<TupleTag<?>, Coder<?>> outputCoderMap = context.getOutputCoders();
-    Coder<InputT> inputCoder = ((PCollection<InputT>) input).getCoder();
     MetricsContainerStepMapAccumulator metricsAccum = 
MetricsAccumulator.getInstance();
 
     List<TupleTag<?>> additionalOutputTags = new ArrayList<>();
@@ -131,24 +132,25 @@ class ParDoTranslatorBatch<InputT, OutputT>
             broadcastStateData,
             doFnSchemaInformation);
 
-    MultiOuputCoder multipleOutputCoder =
-        MultiOuputCoder.of(
-            SerializableCoder.of(TupleTag.class),
-            outputCoderMap,
-            windowingStrategy.getWindowFn().windowCoder());
+    MultiOuputCoder multipleOutputCoder = MultiOuputCoder
+        .of(SerializableCoder.of(TupleTag.class), outputCoderMap, windowCoder);
     Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs =
         inputDataSet.mapPartitions(doFnWrapper, 
EncoderHelpers.fromBeamCoder(multipleOutputCoder));
     if (outputs.entrySet().size() > 1) {
       allOutputs.persist();
       for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
-        pruneOutputFilteredByTag(context, allOutputs, output);
+        pruneOutputFilteredByTag(context, allOutputs, output, windowCoder);
       }
     } else {
+      Coder<OutputT> outputCoder = ((PCollection<OutputT>) 
outputs.get(mainOutputTag)).getCoder();
+      Coder<WindowedValue<?>> windowedValueCoder =
+          (Coder<WindowedValue<?>>)
+              (Coder<?>) WindowedValue.getFullCoder(outputCoder, windowCoder);
       Dataset<WindowedValue<?>> outputDataset =
           allOutputs.map(
               (MapFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, 
WindowedValue<?>>)
                   value -> value._2,
-              EncoderHelpers.windowedValueEncoder());
+              EncoderHelpers.fromBeamCoder(windowedValueCoder));
       
context.putDatasetWildcard(outputs.entrySet().iterator().next().getValue(), 
outputDataset);
     }
   }
@@ -159,14 +161,14 @@ class ParDoTranslatorBatch<InputT, OutputT>
         
JavaSparkContext.fromSparkContext(context.getSparkSession().sparkContext());
 
     SideInputBroadcast sideInputBroadcast = new SideInputBroadcast();
-    for (PCollectionView<?> input : sideInputs) {
+    for (PCollectionView<?> sideInput : sideInputs) {
       Coder<? extends BoundedWindow> windowCoder =
-          
input.getPCollection().getWindowingStrategy().getWindowFn().windowCoder();
+          
sideInput.getPCollection().getWindowingStrategy().getWindowFn().windowCoder();
+
       Coder<WindowedValue<?>> windowedValueCoder =
           (Coder<WindowedValue<?>>)
-              (Coder<?>) 
WindowedValue.getFullCoder(input.getPCollection().getCoder(), windowCoder);
-
-      Dataset<WindowedValue<?>> broadcastSet = 
context.getSideInputDataSet(input);
+              (Coder<?>) 
WindowedValue.getFullCoder(sideInput.getPCollection().getCoder(), windowCoder);
+      Dataset<WindowedValue<?>> broadcastSet = 
context.getSideInputDataSet(sideInput);
       List<WindowedValue<?>> valuesList = broadcastSet.collectAsList();
       List<byte[]> codedValues = new ArrayList<>();
       for (WindowedValue<?> v : valuesList) {
@@ -174,7 +176,7 @@ class ParDoTranslatorBatch<InputT, OutputT>
       }
 
       sideInputBroadcast.add(
-          input.getTagInternal().getId(), jsc.broadcast(codedValues), 
windowedValueCoder);
+          sideInput.getTagInternal().getId(), jsc.broadcast(codedValues), 
windowedValueCoder);
     }
     return sideInputBroadcast;
   }
@@ -213,14 +215,17 @@ class ParDoTranslatorBatch<InputT, OutputT>
   private void pruneOutputFilteredByTag(
       TranslationContext context,
       Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs,
-      Map.Entry<TupleTag<?>, PValue> output) {
+      Map.Entry<TupleTag<?>, PValue> output, Coder<? extends BoundedWindow> 
windowCoder) {
     Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> filteredDataset =
         allOutputs.filter(new DoFnFilterFunction(output.getKey()));
+    Coder<WindowedValue<?>> windowedValueCoder =
+        (Coder<WindowedValue<?>>)
+            (Coder<?>) 
WindowedValue.getFullCoder(((PCollection<OutputT>)output.getValue()).getCoder(),
 windowCoder);
     Dataset<WindowedValue<?>> outputDataset =
         filteredDataset.map(
             (MapFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, 
WindowedValue<?>>)
                 value -> value._2,
-            EncoderHelpers.windowedValueEncoder());
+            EncoderHelpers.fromBeamCoder(windowedValueCoder));
     context.putDatasetWildcard(output.getValue(), outputDataset);
   }
 

Reply via email to