Port direct runner StatefulParDo to KeyedWorkItem
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1f018ab6 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1f018ab6 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1f018ab6 Branch: refs/heads/python-sdk Commit: 1f018ab69fdcc720a10e2aeb8ec1eea1c06e1cbc Parents: d040b7f Author: Kenneth Knowles <k...@google.com> Authored: Mon Dec 12 19:49:58 2016 -0800 Committer: Kenneth Knowles <k...@google.com> Committed: Tue Dec 20 11:19:07 2016 -0800 ---------------------------------------------------------------------- .../direct/KeyedPValueTrackingVisitor.java | 13 ++- .../direct/ParDoMultiOverrideFactory.java | 94 +++++++++++++++++--- .../direct/StatefulParDoEvaluatorFactory.java | 36 ++++---- .../direct/KeyedPValueTrackingVisitorTest.java | 69 ++++++++++++-- .../StatefulParDoEvaluatorFactoryTest.java | 51 +++++++---- 5 files changed, 205 insertions(+), 58 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java index e91a768..65c41e0 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java @@ -31,6 +31,7 @@ import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PValue; /** @@ -105,7 +106,15 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor { } private static boolean isKeyPreserving(PTransform<?, ?> transform) { - // There are currently no key-preserving transforms; this lays the infrastructure for them - return false; + // This is a hacky check for what is considered key-preserving to the direct runner. + // The most obvious alternative would be a package-private marker interface, but + // better to make this obviously hacky so it is less likely to proliferate. Meanwhile + // we intend to allow explicit expression of key-preserving DoFn in the model. + if (transform instanceof ParDo.BoundMulti) { + ParDo.BoundMulti<?, ?> parDo = (ParDo.BoundMulti<?, ?>) transform; + return parDo.getFn() instanceof ParDoMultiOverrideFactory.ToKeyedWorkItem; + } else { + return false; + } } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java index c5bc069..2cea999 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java @@ -17,9 +17,15 @@ */ package org.apache.beam.runners.direct; +import static com.google.common.base.Preconditions.checkState; + +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.core.KeyedWorkItemCoder; +import org.apache.beam.runners.core.KeyedWorkItems; import org.apache.beam.runners.core.SplittableParDo; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; @@ -28,6 +34,8 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ParDo.BoundMulti; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; @@ -84,16 +92,41 @@ class ParDoMultiOverrideFactory<InputT, OutputT> @Override public PCollectionTuple expand(PCollection<KV<K, InputT>> input) { - PCollectionTuple outputs = input - .apply("Group by key", GroupByKey.<K, InputT>create()) - .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input)); + // A KvCoder is required since this goes through GBK. Further, WindowedValueCoder + // is not registered by default, so we explicitly set the relevant coders. + checkState(input.getCoder() instanceof KvCoder, + "Input to a %s using state requires a %s, but the coder was %s", + ParDo.class.getSimpleName(), + KvCoder.class.getSimpleName(), + input.getCoder()); + KvCoder<K, InputT> kvCoder = (KvCoder<K, InputT>) input.getCoder(); + Coder<K> keyCoder = kvCoder.getKeyCoder(); + Coder<? extends BoundedWindow> windowCoder = + input.getWindowingStrategy().getWindowFn().windowCoder(); + + PCollectionTuple outputs = + input + // Stash the original timestamps, etc, for when it is fed to the user's DoFn + .apply("Reify timestamps", ParDo.of(new ReifyWindowedValueFn<K, InputT>())) + .setCoder(KvCoder.of(keyCoder, WindowedValue.getFullCoder(kvCoder, windowCoder))) + + // A full GBK to group by key _and_ window + .apply("Group by key", GroupByKey.<K, WindowedValue<KV<K, InputT>>>create()) + + // Adapt to KeyedWorkItem; that is how this runner delivers timers + .apply("To KeyedWorkItem", ParDo.of(new ToKeyedWorkItem<K, InputT>())) + .setCoder(KeyedWorkItemCoder.of(keyCoder, kvCoder, windowCoder)) + + // Explode the resulting iterable into elements that are exactly the ones from + // the input + .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input)); return outputs; } } static class StatefulParDo<K, InputT, OutputT> - extends PTransform<PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple> { + extends PTransform<PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple> { private final transient ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo; private final transient PCollection<KV<K, InputT>> originalInput; @@ -110,21 +143,58 @@ class ParDoMultiOverrideFactory<InputT, OutputT> @Override public <T> Coder<T> getDefaultOutputCoder( - PCollection<? extends KV<K, Iterable<InputT>>> input, TypedPValue<T> output) + PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>> input, TypedPValue<T> output) throws CannotProvideCoderException { return underlyingParDo.getDefaultOutputCoder(originalInput, output); } - public PCollectionTuple expand(PCollection<? extends KV<K, Iterable<InputT>>> input) { + @Override + public PCollectionTuple expand(PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>> input) { - PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal( - input.getPipeline(), - TupleTagList.of(underlyingParDo.getMainOutputTag()) - .and(underlyingParDo.getSideOutputTags().getAll()), - input.getWindowingStrategy(), - input.isBounded()); + PCollectionTuple outputs = + PCollectionTuple.ofPrimitiveOutputsInternal( + input.getPipeline(), + TupleTagList.of(underlyingParDo.getMainOutputTag()) + .and(underlyingParDo.getSideOutputTags().getAll()), + input.getWindowingStrategy(), + input.isBounded()); return outputs; } } + + /** + * A distinguished key-preserving {@link DoFn}. + * + * <p>This wraps the {@link GroupByKey} output in a {@link KeyedWorkItem} to be able to deliver + * timers. It also explodes them into single {@link KV KVs} since this is what the user's {@link + * DoFn} needs to process anyhow. + */ + static class ReifyWindowedValueFn<K, V> extends DoFn<KV<K, V>, KV<K, WindowedValue<KV<K, V>>>> { + @ProcessElement + public void processElement(final ProcessContext c, final BoundedWindow window) { + c.output( + KV.of( + c.element().getKey(), + WindowedValue.of(c.element(), c.timestamp(), window, c.pane()))); + } + } + + /** + * A runner-specific primitive that is just a key-preserving {@link ParDo}, but we do not have the + * machinery to detect or enforce that yet. + * + * <p>This wraps the {@link GroupByKey} output in a {@link KeyedWorkItem} to be able to deliver + * timers. It also explodes them into single {@link KV KVs} since this is what the user's {@link + * DoFn} needs to process anyhow. + */ + static class ToKeyedWorkItem<K, V> + extends DoFn<KV<K, Iterable<WindowedValue<KV<K, V>>>>, KeyedWorkItem<K, KV<K, V>>> { + + @ProcessElement + public void processElement(final ProcessContext c, final BoundedWindow window) { + final K key = c.element().getKey(); + c.output(KeyedWorkItems.elementsWorkItem(key, c.element().getValue())); + } + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java index 1f64d9a..5f9d8f4 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java @@ -23,6 +23,8 @@ import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.Lists; import java.util.Collections; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.core.KeyedWorkItems; import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo; @@ -77,12 +79,12 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo } @SuppressWarnings({"unchecked", "rawtypes"}) - private TransformEvaluator<KV<K, Iterable<InputT>>> createEvaluator( + private TransformEvaluator<KeyedWorkItem<K, KV<K, InputT>>> createEvaluator( AppliedPTransform< - PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple, + PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple, StatefulParDo<K, InputT, OutputT>> application, - CommittedBundle<KV<K, Iterable<InputT>>> inputBundle) + CommittedBundle<KeyedWorkItem<K, KV<K, InputT>>> inputBundle) throws Exception { final DoFn<KV<K, InputT>, OutputT> doFn = @@ -185,7 +187,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo @AutoValue abstract static class AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> { abstract AppliedPTransform< - PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple, + PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple, StatefulParDo<K, InputT, OutputT>> getTransform(); @@ -195,7 +197,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo static <K, InputT, OutputT> AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> create( AppliedPTransform< - PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple, + PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple, StatefulParDo<K, InputT, OutputT>> transform, StructuralKey<K> key, @@ -206,7 +208,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo } private static class StatefulParDoEvaluator<K, InputT> - implements TransformEvaluator<KV<K, Iterable<InputT>>> { + implements TransformEvaluator<KeyedWorkItem<K, KV<K, InputT>>> { private final TransformEvaluator<KV<K, InputT>> delegateEvaluator; @@ -215,20 +217,20 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo } @Override - public void processElement(WindowedValue<KV<K, Iterable<InputT>>> gbkResult) throws Exception { + public void processElement(WindowedValue<KeyedWorkItem<K, KV<K, InputT>>> gbkResult) + throws Exception { - for (InputT value : gbkResult.getValue().getValue()) { - delegateEvaluator.processElement( - gbkResult.withValue(KV.of(gbkResult.getValue().getKey(), value))); + for (WindowedValue<KV<K, InputT>> windowedValue : gbkResult.getValue().elementsIterable()) { + delegateEvaluator.processElement(windowedValue); } } @Override - public TransformResult<KV<K, Iterable<InputT>>> finishBundle() throws Exception { + public TransformResult<KeyedWorkItem<K, KV<K, InputT>>> finishBundle() throws Exception { TransformResult<KV<K, InputT>> delegateResult = delegateEvaluator.finishBundle(); - StepTransformResult.Builder<KV<K, Iterable<InputT>>> regroupedResult = - StepTransformResult.<KV<K, Iterable<InputT>>>withHold( + StepTransformResult.Builder<KeyedWorkItem<K, KV<K, InputT>>> regroupedResult = + StepTransformResult.<KeyedWorkItem<K, KV<K, InputT>>>withHold( delegateResult.getTransform(), delegateResult.getWatermarkHold()) .withTimerUpdate(delegateResult.getTimerUpdate()) .withAggregatorChanges(delegateResult.getAggregatorChanges()) @@ -240,12 +242,10 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo // outputs, but just make a bunch of singletons for (WindowedValue<?> untypedUnprocessed : delegateResult.getUnprocessedElements()) { WindowedValue<KV<K, InputT>> windowedKv = (WindowedValue<KV<K, InputT>>) untypedUnprocessed; - WindowedValue<KV<K, Iterable<InputT>>> pushedBack = + WindowedValue<KeyedWorkItem<K, KV<K, InputT>>> pushedBack = windowedKv.withValue( - KV.of( - windowedKv.getValue().getKey(), - (Iterable<InputT>) - Collections.singletonList(windowedKv.getValue().getValue()))); + KeyedWorkItems.elementsWorkItem( + windowedKv.getValue().getKey(), Collections.singleton(windowedKv))); regroupedResult.addUnprocessedElements(pushedBack); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java index a357005..a1fb81b 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java @@ -22,8 +22,10 @@ import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertThat; import java.util.Collections; +import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.testing.TestPipeline; @@ -32,8 +34,12 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.Keys; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.joda.time.Instant; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -41,9 +47,7 @@ import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** - * Tests for {@link KeyedPValueTrackingVisitor}. - */ +/** Tests for {@link KeyedPValueTrackingVisitor}. */ @RunWith(JUnit4.class) public class KeyedPValueTrackingVisitorTest { @Rule public ExpectedException thrown = ExpectedException.none(); @@ -61,8 +65,7 @@ public class KeyedPValueTrackingVisitorTest { @Test public void groupByKeyProducesKeyedOutput() { PCollection<KV<String, Iterable<Integer>>> keyed = - p.apply(Create.of(KV.of("foo", 3))) - .apply(GroupByKey.<String, Integer>create()); + p.apply(Create.of(KV.of("foo", 3))).apply(GroupByKey.<String, Integer>create()); p.traverseTopologically(visitor); assertThat(visitor.getKeyedPValues(), hasItem(keyed)); @@ -91,16 +94,66 @@ public class KeyedPValueTrackingVisitorTest { } @Test + public void unkeyedInputWithKeyPreserving() { + + PCollection<KV<String, Iterable<WindowedValue<KV<String, Integer>>>>> input = + p.apply( + Create.of( + KV.of( + "hello", + (Iterable<WindowedValue<KV<String, Integer>>>) + Collections.<WindowedValue<KV<String, Integer>>>emptyList())) + .withCoder( + KvCoder.of( + StringUtf8Coder.of(), + IterableCoder.of( + WindowedValue.getValueOnlyCoder( + KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())))))); + + PCollection<KeyedWorkItem<String, KV<String, Integer>>> unkeyed = + input.apply(ParDo.of(new ParDoMultiOverrideFactory.ToKeyedWorkItem<String, Integer>())); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed))); + } + + @Test + public void keyedInputWithKeyPreserving() { + + PCollection<KV<String, WindowedValue<KV<String, Integer>>>> input = + p.apply( + Create.of( + KV.of( + "hello", + WindowedValue.of( + KV.of("hello", 3), + new Instant(0), + new IntervalWindow(new Instant(0), new Instant(9)), + PaneInfo.NO_FIRING))) + .withCoder( + KvCoder.of( + StringUtf8Coder.of(), + WindowedValue.getValueOnlyCoder( + KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))))); + + PCollection<KeyedWorkItem<String, KV<String, Integer>>> keyed = + input + .apply(GroupByKey.<String, WindowedValue<KV<String, Integer>>>create()) + .apply(ParDo.of(new ParDoMultiOverrideFactory.ToKeyedWorkItem<String, Integer>())); + + p.traverseTopologically(visitor); + assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + } + + @Test public void traverseMultipleTimesThrows() { p.apply( - Create.<KV<Integer, Void>>of( - KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null)) + Create.of(KV.of(1, (Void) null), KV.of(2, (Void) null), KV.of(3, (Void) null)) .withCoder(KvCoder.of(VarIntCoder.of(), VoidCoder.of()))) .apply(GroupByKey.<Integer, Void>create()) .apply(Keys.<Integer>create()); p.traverseTopologically(visitor); - thrown.expect(IllegalStateException.class); thrown.expectMessage("already been finalized"); thrown.expectMessage(KeyedPValueTrackingVisitor.class.getSimpleName()); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1f018ab6/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java index d312aa3..b88d5e0 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java @@ -27,12 +27,14 @@ import static org.mockito.Matchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.core.KeyedWorkItems; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo; @@ -136,7 +138,7 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { new StatefulParDoEvaluatorFactory(mockEvaluationContext); AppliedPTransform< - PCollection<? extends KV<String, Iterable<Integer>>>, PCollectionTuple, + PCollection<? extends KeyedWorkItem<String, KV<String, Integer>>>, PCollectionTuple, StatefulParDo<String, Integer, Integer>> producingTransform = (AppliedPTransform) DirectGraphs.getProducer(produced); @@ -245,7 +247,7 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { // This will be the stateful ParDo from the expansion AppliedPTransform< - PCollection<KV<String, Iterable<Integer>>>, PCollectionTuple, + PCollection<KeyedWorkItem<String, KV<String, Integer>>>, PCollectionTuple, StatefulParDo<String, Integer, Integer>> producingTransform = (AppliedPTransform) DirectGraphs.getProducer(produced); @@ -270,37 +272,50 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { // A single bundle with some elements in the global window; it should register cleanup for the // global window state merely by having the evaluator created. The cleanup logic does not // depend on the window. - WindowedValue<KV<String, Iterable<Integer>>> gbkOutputElement = - WindowedValue.of( - KV.<String, Iterable<Integer>>of("hello", Lists.newArrayList(1, 13, 15)), - new Instant(3), - firstWindow, - PaneInfo.NO_FIRING); - CommittedBundle<KV<String, Iterable<Integer>>> inputBundle = + String key = "hello"; + WindowedValue<KV<String, Integer>> firstKv = WindowedValue.of( + KV.of(key, 1), + new Instant(3), + firstWindow, + PaneInfo.NO_FIRING); + + WindowedValue<KeyedWorkItem<String, KV<String, Integer>>> gbkOutputElement = + firstKv.withValue( + KeyedWorkItems.elementsWorkItem( + "hello", + ImmutableList.of( + firstKv, + firstKv.withValue(KV.of(key, 13)), + firstKv.withValue(KV.of(key, 15))))); + + CommittedBundle<KeyedWorkItem<String, KV<String, Integer>>> inputBundle = BUNDLE_FACTORY .createBundle(producingTransform.getInput()) .add(gbkOutputElement) .commit(Instant.now()); - TransformEvaluator<KV<String, Iterable<Integer>>> evaluator = + TransformEvaluator<KeyedWorkItem<String, KV<String, Integer>>> evaluator = factory.forApplication(producingTransform, inputBundle); + evaluator.processElement(gbkOutputElement); // This should push back every element as a KV<String, Iterable<Integer>> // in the appropriate window. Since the keys are equal they are single-threaded - TransformResult<KV<String, Iterable<Integer>>> result = evaluator.finishBundle(); + TransformResult<KeyedWorkItem<String, KV<String, Integer>>> result = + evaluator.finishBundle(); List<Integer> pushedBackInts = new ArrayList<>(); - for (WindowedValue<?> unprocessedElement : result.getUnprocessedElements()) { - WindowedValue<KV<String, Iterable<Integer>>> unprocessedKv = - (WindowedValue<KV<String, Iterable<Integer>>>) unprocessedElement; + for (WindowedValue<? extends KeyedWorkItem<String, KV<String, Integer>>> unprocessedElement : + result.getUnprocessedElements()) { assertThat( Iterables.getOnlyElement(unprocessedElement.getWindows()), equalTo((BoundedWindow) firstWindow)); - assertThat(unprocessedKv.getValue().getKey(), equalTo("hello")); - for (Integer i : unprocessedKv.getValue().getValue()) { - pushedBackInts.add(i); + + assertThat(unprocessedElement.getValue().key(), equalTo("hello")); + for (WindowedValue<KV<String, Integer>> windowedKv : + unprocessedElement.getValue().elementsIterable()) { + pushedBackInts.add(windowedKv.getValue().getValue()); } } assertThat(pushedBackInts, containsInAnyOrder(1, 13, 15));