This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 4602f8367d9d6c38574010fdff8b044cedf4d286
Author: Etienne Chauchot <echauc...@apache.org>
AuthorDate: Fri Jun 14 14:50:45 2019 +0200

    Output data after combine
---
 .../translation/batch/AggregatorCombinerGlobally.java      | 14 ++++++++------
 .../translation/batch/CombineGloballyTranslatorBatch.java  |  7 +++++--
 2 files changed, 13 insertions(+), 8 deletions(-)

diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java
index 0d13218..6996165 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java
@@ -44,7 +44,7 @@ import scala.Tuple2;
  * */
 
 class AggregatorCombinerGlobally<InputT, AccumT, OutputT, W extends 
BoundedWindow>
-    extends Aggregator<WindowedValue<InputT>, Iterable<WindowedValue<AccumT>>, 
WindowedValue<OutputT>> {
+    extends Aggregator<WindowedValue<InputT>, Iterable<WindowedValue<AccumT>>, 
Iterable<WindowedValue<OutputT>>> {
 
   private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn;
   private WindowingStrategy<InputT, W> windowingStrategy;
@@ -123,10 +123,12 @@ class AggregatorCombinerGlobally<InputT, AccumT, OutputT, 
W extends BoundedWindo
     return null;
   }
 
-  @Override public WindowedValue<OutputT> 
finish(Iterable<WindowedValue<AccumT>> reduction) {
-    // TODO
-    //    return combineFn.extractOutput(reduction);
-    return null;
+  @Override public Iterable<WindowedValue<OutputT>> 
finish(Iterable<WindowedValue<AccumT>> reduction) {
+    List<WindowedValue<OutputT>> result = new ArrayList<>();
+    for (WindowedValue<AccumT> windowedValue: reduction) {
+      
result.add(windowedValue.withValue(combineFn.extractOutput(windowedValue.getValue())));
+    }
+    return result;
   }
 
   @Override public Encoder<Iterable<WindowedValue<AccumT>>> bufferEncoder() {
@@ -134,7 +136,7 @@ class AggregatorCombinerGlobally<InputT, AccumT, OutputT, W 
extends BoundedWindo
     return EncoderHelpers.genericEncoder();
   }
 
-  @Override public Encoder<WindowedValue<OutputT>> outputEncoder() {
+  @Override public Encoder<Iterable<WindowedValue<OutputT>>> outputEncoder() {
     // TODO replace with outputCoder if possible
     return EncoderHelpers.genericEncoder();
   }
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java
index f18572b..fb9e1dd 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java
@@ -21,12 +21,12 @@ import 
org.apache.beam.runners.spark.structuredstreaming.translation.TransformTr
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers;
-import 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.WindowingHelpers;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 
@@ -57,9 +57,12 @@ class CombineGloballyTranslatorBatch<InputT, AccumT, OutputT>
     Dataset<Row> combinedRowDataset =
         inputDataset.agg(new AggregatorCombinerGlobally<>(combineFn, 
windowingStrategy).toColumn());
 
-    Dataset<WindowedValue<OutputT>> outputDataset =
+    Dataset<Iterable<WindowedValue<OutputT>>> accumulatedDataset =
         combinedRowDataset.map(
             RowHelpers.extractObjectFromRowMapFunction(), 
EncoderHelpers.windowedValueEncoder());
+    Dataset<WindowedValue<OutputT>> outputDataset = accumulatedDataset.flatMap(
+        (FlatMapFunction<Iterable<WindowedValue<OutputT>>, 
WindowedValue<OutputT>>)
+            windowedValues -> windowedValues.iterator(), 
EncoderHelpers.windowedValueEncoder());
     context.putDataset(output, outputDataset);
   }
 }

Reply via email to