Revert "Removes code for wrapping DoFn as an OldDoFn" This reverts commit a22de15012c51e8b7e31143021f0a298e093bf51.
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/a12fd8c5 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/a12fd8c5 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/a12fd8c5 Branch: refs/heads/gearpump-runner Commit: a12fd8c580d3b1ea46c5be951f39046bfa0dacf3 Parents: abdbee6 Author: Eugene Kirpichov <kirpic...@google.com> Authored: Fri Dec 16 15:26:28 2016 -0800 Committer: Eugene Kirpichov <kirpic...@google.com> Committed: Fri Dec 16 16:39:20 2016 -0800 ---------------------------------------------------------------------- .../apache/beam/runners/core/DoFnAdapters.java | 150 ++++++++++ .../org/apache/beam/sdk/transforms/OldDoFn.java | 295 ++++++++++++++++++- .../sdk/transforms/reflect/DoFnInvokers.java | 141 ++++++++- .../transforms/reflect/DoFnInvokersTest.java | 36 +++ 4 files changed, 611 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a12fd8c5/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnAdapters.java ---------------------------------------------------------------------- diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnAdapters.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnAdapters.java index 0f5624f..a4002da 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnAdapters.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnAdapters.java @@ -18,6 +18,8 @@ package org.apache.beam.runners.core; import java.io.IOException; +import java.util.Collection; +import javax.annotation.Nullable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.AggregatorRetriever; @@ -39,6 +41,7 @@ import org.apache.beam.sdk.util.Timer; import org.apache.beam.sdk.util.state.State; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; import org.joda.time.Duration; import org.joda.time.Instant; @@ -53,6 +56,18 @@ public class DoFnAdapters { /** Should not be instantiated. */ private DoFnAdapters() {} + /** + * If this is an {@link OldDoFn} produced via {@link #toOldDoFn}, returns the class of the + * original {@link DoFn}, otherwise returns {@code fn.getClass()}. + */ + public static Class<?> getDoFnClass(OldDoFn<?, ?> fn) { + if (fn instanceof SimpleDoFnAdapter) { + return ((SimpleDoFnAdapter<?, ?>) fn).fn.getClass(); + } else { + return fn.getClass(); + } + } + /** Creates an {@link OldDoFn} that delegates to the {@link DoFn}. */ @SuppressWarnings({"unchecked", "rawtypes"}) public static <InputT, OutputT> OldDoFn<InputT, OutputT> toOldDoFn(DoFn<InputT, OutputT> fn) { @@ -64,6 +79,126 @@ public class DoFnAdapters { } } + /** Creates a {@link OldDoFn.ProcessContext} from a {@link DoFn.ProcessContext}. */ + public static <InputT, OutputT> OldDoFn<InputT, OutputT>.ProcessContext adaptProcessContext( + OldDoFn<InputT, OutputT> fn, + final DoFn<InputT, OutputT>.ProcessContext c, + final DoFnInvoker.ArgumentProvider<InputT, OutputT> extra) { + return fn.new ProcessContext() { + @Override + public InputT element() { + return c.element(); + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + return c.sideInput(view); + } + + @Override + public Instant timestamp() { + return c.timestamp(); + } + + @Override + public BoundedWindow window() { + return extra.window(); + } + + @Override + public PaneInfo pane() { + return c.pane(); + } + + @Override + public WindowingInternals<InputT, OutputT> windowingInternals() { + return extra.windowingInternals(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return c.getPipelineOptions(); + } + + @Override + public void output(OutputT output) { + c.output(output); + } + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + c.outputWithTimestamp(output, timestamp); + } + + @Override + public <T> void sideOutput(TupleTag<T> tag, T output) { + c.sideOutput(tag, output); + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) { + c.sideOutputWithTimestamp(tag, output, timestamp); + } + + @Override + protected <AggInputT, AggOutputT> Aggregator<AggInputT, AggOutputT> createAggregatorInternal( + String name, CombineFn<AggInputT, ?, AggOutputT> combiner) { + return c.createAggregator(name, combiner); + } + }; + } + + /** Creates a {@link OldDoFn.ProcessContext} from a {@link DoFn.ProcessContext}. */ + public static <InputT, OutputT> OldDoFn<InputT, OutputT>.Context adaptContext( + OldDoFn<InputT, OutputT> fn, + final DoFn<InputT, OutputT>.Context c) { + return fn.new Context() { + @Override + public PipelineOptions getPipelineOptions() { + return c.getPipelineOptions(); + } + + @Override + public void output(OutputT output) { + c.output(output); + } + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + c.outputWithTimestamp(output, timestamp); + } + + @Override + public <T> void sideOutput(TupleTag<T> tag, T output) { + c.sideOutput(tag, output); + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) { + c.sideOutputWithTimestamp(tag, output, timestamp); + } + + @Override + protected <AggInputT, AggOutputT> Aggregator<AggInputT, AggOutputT> createAggregatorInternal( + String name, CombineFn<AggInputT, ?, AggOutputT> combiner) { + return c.createAggregator(name, combiner); + } + }; + } + + /** + * If the fn was created using {@link #toOldDoFn}, returns the original {@link DoFn}. Otherwise, + * returns {@code null}. + */ + @Nullable + public static <InputT, OutputT> DoFn<InputT, OutputT> getDoFn(OldDoFn<InputT, OutputT> fn) { + if (fn instanceof SimpleDoFnAdapter) { + return ((SimpleDoFnAdapter<InputT, OutputT>) fn).fn; + } else { + return null; + } + } + /** * Wraps a {@link DoFn} that doesn't require access to {@link BoundedWindow} as an {@link * OldDoFn}. @@ -106,6 +241,21 @@ public class DoFnAdapters { } @Override + protected TypeDescriptor<InputT> getInputTypeDescriptor() { + return fn.getInputTypeDescriptor(); + } + + @Override + protected TypeDescriptor<OutputT> getOutputTypeDescriptor() { + return fn.getOutputTypeDescriptor(); + } + + @Override + Collection<Aggregator<?, ?>> getAggregators() { + return fn.getAggregators(); + } + + @Override public Duration getAllowedTimestampSkew() { return fn.getAllowedTimestampSkew(); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a12fd8c5/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java index 7b04533..d1bb42b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java @@ -38,6 +38,7 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowingInternals; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; import org.joda.time.Duration; import org.joda.time.Instant; @@ -70,6 +71,21 @@ import org.joda.time.Instant; */ @Deprecated public abstract class OldDoFn<InputT, OutputT> implements Serializable, HasDisplayData { + + public DoFn<InputT, OutputT> toDoFn() { + DoFn<InputT, OutputT> doFn = DoFnAdapters.getDoFn(this); + if (doFn != null) { + return doFn; + } + if (this instanceof RequiresWindowAccess) { + // No parameters as it just accesses `this` + return new AdaptedRequiresWindowAccessDoFn(); + } else { + // No parameters as it just accesses `this` + return new AdaptedDoFn(); + } + } + /** * Information accessible to all methods in this {@code OldDoFn}. * Used primarily to output elements. @@ -318,7 +334,7 @@ public abstract class OldDoFn<InputT, OutputT> implements Serializable, HasDispl this(new HashMap<String, DelegatingAggregator<?, ?>>()); } - public OldDoFn(Map<String, DelegatingAggregator<?, ?>> aggregators) { + OldDoFn(Map<String, DelegatingAggregator<?, ?>> aggregators) { this.aggregators = aggregators; } @@ -403,6 +419,32 @@ public abstract class OldDoFn<InputT, OutputT> implements Serializable, HasDispl ///////////////////////////////////////////////////////////////////////////// /** + * Returns a {@link TypeDescriptor} capturing what is known statically + * about the input type of this {@code OldDoFn} instance's most-derived + * class. + * + * <p>See {@link #getOutputTypeDescriptor} for more discussion. + */ + protected TypeDescriptor<InputT> getInputTypeDescriptor() { + return new TypeDescriptor<InputT>(getClass()) {}; + } + + /** + * Returns a {@link TypeDescriptor} capturing what is known statically + * about the output type of this {@code OldDoFn} instance's + * most-derived class. + * + * <p>In the normal case of a concrete {@code OldDoFn} subclass with + * no generic type parameters of its own (including anonymous inner + * classes), this will be a complete non-generic type, which is good + * for choosing a default output {@code Coder<OutputT>} for the output + * {@code PCollection<OutputT>}. + */ + protected TypeDescriptor<OutputT> getOutputTypeDescriptor() { + return new TypeDescriptor<OutputT>(getClass()) {}; + } + + /** * Returns an {@link Aggregator} with aggregation logic specified by the * {@link CombineFn} argument. The name provided must be unique across * {@link Aggregator}s created within the OldDoFn. Aggregators can only be created @@ -462,4 +504,255 @@ public abstract class OldDoFn<InputT, OutputT> implements Serializable, HasDispl Collection<Aggregator<?, ?>> getAggregators() { return Collections.<Aggregator<?, ?>>unmodifiableCollection(aggregators.values()); } + + /** + * A {@link Context} for an {@link OldDoFn} via a context for a proper {@link DoFn}. + */ + private class AdaptedContext extends Context { + + private final DoFn<InputT, OutputT>.Context newContext; + + public AdaptedContext( + DoFn<InputT, OutputT>.Context newContext) { + this.newContext = newContext; + super.setupDelegateAggregators(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return newContext.getPipelineOptions(); + } + + @Override + public void output(OutputT output) { + newContext.output(output); + } + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + newContext.outputWithTimestamp(output, timestamp); + } + + @Override + public <T> void sideOutput(TupleTag<T> tag, T output) { + newContext.sideOutput(tag, output); + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) { + newContext.sideOutputWithTimestamp(tag, output, timestamp); + } + + @Override + protected <AggInputT, AggOutputT> Aggregator<AggInputT, AggOutputT> createAggregatorInternal( + String name, CombineFn<AggInputT, ?, AggOutputT> combiner) { + return newContext.createAggregator(name, combiner); + } + } + + /** + * A {@link ProcessContext} for an {@link OldDoFn} via a context for a proper {@link DoFn}. + */ + private class AdaptedProcessContext extends ProcessContext { + + private final DoFn<InputT, OutputT>.ProcessContext newContext; + + public AdaptedProcessContext(DoFn<InputT, OutputT>.ProcessContext newContext) { + this.newContext = newContext; + } + + @Override + public InputT element() { + return newContext.element(); + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + return newContext.sideInput(view); + } + + @Override + public Instant timestamp() { + return newContext.timestamp(); + } + + @Override + public BoundedWindow window() { + throw new UnsupportedOperationException(String.format( + "%s.%s.windowingInternals() is no longer supported. Please convert your %s to a %s", + OldDoFn.class.getSimpleName(), + OldDoFn.ProcessContext.class.getSimpleName(), + OldDoFn.class.getSimpleName(), + DoFn.class.getSimpleName())); + } + + @Override + public PaneInfo pane() { + return newContext.pane(); + } + + @Override + public WindowingInternals<InputT, OutputT> windowingInternals() { + throw new UnsupportedOperationException(String.format( + "%s.%s.windowingInternals() is no longer supported. Please convert your %s to a %s", + OldDoFn.class.getSimpleName(), + OldDoFn.ProcessContext.class.getSimpleName(), + OldDoFn.class.getSimpleName(), + DoFn.class.getSimpleName())); + } + + @Override + public PipelineOptions getPipelineOptions() { + return newContext.getPipelineOptions(); + } + + @Override + public void output(OutputT output) { + newContext.output(output); + } + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + newContext.outputWithTimestamp(output, timestamp); + } + + @Override + public <T> void sideOutput(TupleTag<T> tag, T output) { + newContext.sideOutput(tag, output); + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) { + newContext.sideOutputWithTimestamp(tag, output, timestamp); + } + + @Override + protected <AggInputT, AggOutputT> Aggregator<AggInputT, AggOutputT> createAggregatorInternal( + String name, CombineFn<AggInputT, ?, AggOutputT> combiner) { + return newContext.createAggregator(name, combiner); + } + } + + private class AdaptedDoFn extends DoFn<InputT, OutputT> { + + @Setup + public void setup() throws Exception { + OldDoFn.this.setup(); + } + + @StartBundle + public void startBundle(Context c) throws Exception { + OldDoFn.this.startBundle(OldDoFn.this.new AdaptedContext(c)); + } + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + OldDoFn.this.processElement(OldDoFn.this.new AdaptedProcessContext(c)); + } + + @FinishBundle + public void finishBundle(Context c) throws Exception { + OldDoFn.this.finishBundle(OldDoFn.this.new AdaptedContext(c)); + } + + @Teardown + public void teardown() throws Exception { + OldDoFn.this.teardown(); + } + + @Override + public Duration getAllowedTimestampSkew() { + return OldDoFn.this.getAllowedTimestampSkew(); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + OldDoFn.this.populateDisplayData(builder); + } + + @Override + public TypeDescriptor<InputT> getInputTypeDescriptor() { + return OldDoFn.this.getInputTypeDescriptor(); + } + + @Override + Collection<Aggregator<?, ?>> getAggregators() { + return OldDoFn.this.getAggregators(); + } + + @Override + public TypeDescriptor<OutputT> getOutputTypeDescriptor() { + return OldDoFn.this.getOutputTypeDescriptor(); + } + } + + /** + * A {@link ProcessContext} for an {@link OldDoFn} that implements + * {@link RequiresWindowAccess}, via a context for a proper {@link DoFn}. + */ + private class AdaptedRequiresWindowAccessProcessContext extends AdaptedProcessContext { + + private final BoundedWindow window; + + public AdaptedRequiresWindowAccessProcessContext( + DoFn<InputT, OutputT>.ProcessContext newContext, + BoundedWindow window) { + super(newContext); + this.window = window; + } + + @Override + public BoundedWindow window() { + return window; + } + } + + private class AdaptedRequiresWindowAccessDoFn extends DoFn<InputT, OutputT> { + + @Setup + public void setup() throws Exception { + OldDoFn.this.setup(); + } + + @StartBundle + public void startBundle(Context c) throws Exception { + OldDoFn.this.startBundle(OldDoFn.this.new AdaptedContext(c)); + } + + @ProcessElement + public void processElement(ProcessContext c, BoundedWindow window) throws Exception { + OldDoFn.this.processElement( + OldDoFn.this.new AdaptedRequiresWindowAccessProcessContext(c, window)); + } + + @FinishBundle + public void finishBundle(Context c) throws Exception { + OldDoFn.this.finishBundle(OldDoFn.this.new AdaptedContext(c)); + } + + @Teardown + public void teardown() throws Exception { + OldDoFn.this.teardown(); + } + + @Override + public Duration getAllowedTimestampSkew() { + return OldDoFn.this.getAllowedTimestampSkew(); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + OldDoFn.this.populateDisplayData(builder); + } + + @Override + public TypeDescriptor<InputT> getInputTypeDescriptor() { + return OldDoFn.this.getInputTypeDescriptor(); + } + + @Override + public TypeDescriptor<OutputT> getOutputTypeDescriptor() { + return OldDoFn.this.getOutputTypeDescriptor(); + } + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a12fd8c5/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java index b141d51..50a7082 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java @@ -18,7 +18,13 @@ package org.apache.beam.sdk.transforms.reflect; import java.io.Serializable; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnAdapters; +import org.apache.beam.sdk.transforms.OldDoFn; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; +import org.apache.beam.sdk.util.UserCodeException; /** Static utilities for working with {@link DoFnInvoker}. */ public class DoFnInvokers { @@ -36,22 +42,137 @@ public class DoFnInvokers { return ByteBuddyDoFnInvokerFactory.only().newByteBuddyInvoker(fn); } + private DoFnInvokers() {} + /** - * Temporarily retained for compatibility with Dataflow worker. - * TODO: delete this when Dataflow worker is fixed to call {@link #invokerFor(DoFn)}. + * Returns a {@link DoFnInvoker} for the given {@link Object}, which should be either a {@link + * DoFn} or an {@link OldDoFn}. The expected use would be to deserialize a user's function as an + * {@link Object} and then pass it to this method, so there is no need to statically specify what + * sort of object it is. * - * @deprecated Use {@link #invokerFor(DoFn)}. + * @deprecated this is to be used only as a migration path for decoupling upgrades */ - @SuppressWarnings("unchecked") @Deprecated - public static <InputT, OutputT> DoFnInvoker<InputT, OutputT> invokerFor( - Serializable deserializedFn) { + public static DoFnInvoker<?, ?> invokerFor(Serializable deserializedFn) { if (deserializedFn instanceof DoFn) { - return invokerFor((DoFn<InputT, OutputT>) deserializedFn); + return invokerFor((DoFn<?, ?>) deserializedFn); + } else if (deserializedFn instanceof OldDoFn) { + return new OldDoFnInvoker<>((OldDoFn<?, ?>) deserializedFn); + } else { + throw new IllegalArgumentException( + String.format( + "Cannot create a %s for %s; it should be either a %s or an %s.", + DoFnInvoker.class.getSimpleName(), + deserializedFn.toString(), + DoFn.class.getSimpleName(), + OldDoFn.class.getSimpleName())); } - throw new UnsupportedOperationException( - "Only DoFn supported, was: " + deserializedFn.getClass()); } - private DoFnInvokers() {} + /** @deprecated use {@link DoFnInvokers#invokerFor(DoFn)}. */ + @Deprecated public static final DoFnInvokers INSTANCE = new DoFnInvokers(); + + /** @deprecated use {@link DoFnInvokers#invokerFor(DoFn)}. */ + @Deprecated + public <InputT, OutputT> DoFnInvoker<InputT, OutputT> invokerFor(Object deserializedFn) { + return (DoFnInvoker<InputT, OutputT>) DoFnInvokers.invokerFor((Serializable) deserializedFn); + } + + + static class OldDoFnInvoker<InputT, OutputT> implements DoFnInvoker<InputT, OutputT> { + + private final OldDoFn<InputT, OutputT> fn; + + public OldDoFnInvoker(OldDoFn<InputT, OutputT> fn) { + this.fn = fn; + } + + @Override + public DoFn.ProcessContinuation invokeProcessElement( + ArgumentProvider<InputT, OutputT> extra) { + // The outer DoFn is immaterial - it exists only to avoid typing InputT and OutputT repeatedly + DoFn<InputT, OutputT>.ProcessContext newCtx = + extra.processContext(new DoFn<InputT, OutputT>() {}); + OldDoFn<InputT, OutputT>.ProcessContext oldCtx = + DoFnAdapters.adaptProcessContext(fn, newCtx, extra); + try { + fn.processElement(oldCtx); + return DoFn.ProcessContinuation.stop(); + } catch (Throwable exc) { + throw UserCodeException.wrap(exc); + } + } + + @Override + public void invokeOnTimer(String timerId, ArgumentProvider<InputT, OutputT> arguments) { + throw new UnsupportedOperationException( + String.format("Timers are not supported for %s", OldDoFn.class.getSimpleName())); + } + + @Override + public void invokeStartBundle(DoFn.Context c) { + OldDoFn<InputT, OutputT>.Context oldCtx = DoFnAdapters.adaptContext(fn, c); + try { + fn.startBundle(oldCtx); + } catch (Throwable exc) { + throw UserCodeException.wrap(exc); + } + } + + @Override + public void invokeFinishBundle(DoFn.Context c) { + OldDoFn<InputT, OutputT>.Context oldCtx = DoFnAdapters.adaptContext(fn, c); + try { + fn.finishBundle(oldCtx); + } catch (Throwable exc) { + throw UserCodeException.wrap(exc); + } + } + + @Override + public void invokeSetup() { + try { + fn.setup(); + } catch (Throwable exc) { + throw UserCodeException.wrap(exc); + } + } + + @Override + public void invokeTeardown() { + try { + fn.teardown(); + } catch (Throwable exc) { + throw UserCodeException.wrap(exc); + } + } + + @Override + public <RestrictionT> RestrictionT invokeGetInitialRestriction(InputT element) { + throw new UnsupportedOperationException("OldDoFn is not splittable"); + } + + @Override + public <RestrictionT> Coder<RestrictionT> invokeGetRestrictionCoder( + CoderRegistry coderRegistry) { + throw new UnsupportedOperationException("OldDoFn is not splittable"); + } + + @Override + public <RestrictionT> void invokeSplitRestriction( + InputT element, RestrictionT restriction, DoFn.OutputReceiver<RestrictionT> receiver) { + throw new UnsupportedOperationException("OldDoFn is not splittable"); + } + + @Override + public <RestrictionT, TrackerT extends RestrictionTracker<RestrictionT>> + TrackerT invokeNewTracker(RestrictionT restriction) { + throw new UnsupportedOperationException("OldDoFn is not splittable"); + } + + @Override + public DoFn<InputT, OutputT> getFn() { + throw new UnsupportedOperationException("getFn is not supported for OldDoFn"); + } + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a12fd8c5/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java index 4c6bee1..4233b39 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java @@ -25,6 +25,7 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; +import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -731,4 +732,39 @@ public class DoFnInvokersTest { invoker.invokeOnTimer(timerId, mockArgumentProvider); assertThat(fn.window, equalTo(testWindow)); } + + private class OldDoFnIdentity extends OldDoFn<String, String> { + public void processElement(ProcessContext c) {} + } + + @Test + public void testOldDoFnProcessElement() throws Exception { + new DoFnInvokers.OldDoFnInvoker<>(mockOldDoFn) + .invokeProcessElement(mockArgumentProvider); + verify(mockOldDoFn).processElement(any(OldDoFn.ProcessContext.class)); + } + + @Test + public void testOldDoFnStartBundle() throws Exception { + new DoFnInvokers.OldDoFnInvoker<>(mockOldDoFn).invokeStartBundle(mockProcessContext); + verify(mockOldDoFn).startBundle(any(OldDoFn.Context.class)); + } + + @Test + public void testOldDoFnFinishBundle() throws Exception { + new DoFnInvokers.OldDoFnInvoker<>(mockOldDoFn).invokeFinishBundle(mockProcessContext); + verify(mockOldDoFn).finishBundle(any(OldDoFn.Context.class)); + } + + @Test + public void testOldDoFnSetup() throws Exception { + new DoFnInvokers.OldDoFnInvoker<>(mockOldDoFn).invokeSetup(); + verify(mockOldDoFn).setup(); + } + + @Test + public void testOldDoFnTeardown() throws Exception { + new DoFnInvokers.OldDoFnInvoker<>(mockOldDoFn).invokeTeardown(); + verify(mockOldDoFn).teardown(); + } }