kennknowles closed pull request #4024: [BEAM-2304] Allow declared state to be
accessed as a superclass.
URL: https://github.com/apache/beam/pull/4024
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
index c54c44f2d58..52607833f71 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
@@ -888,8 +888,8 @@ private static Parameter analyzeExtraParameter(
id);
paramErrors.checkArgument(
- stateDecl.stateType().equals(stateType),
- "reference to %s %s with different type %s",
+ stateDecl.stateType().isSubtypeOf(stateType),
+ "data type of reference to %s %s must be a supertype of %s",
StateId.class.getSimpleName(),
id,
formatType(stateDecl.stateType()));
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 03e310463f1..b2cc7dca98a 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -68,6 +68,7 @@
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.GroupingState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.SetState;
@@ -2532,6 +2533,40 @@ public void processElement(
pipeline.run();
}
+ @Test
+ @Category({ValidatesRunner.class, UsesStatefulParDo.class})
+ public void testCombiningStateParameterSuperclass() {
+ final String stateId = "foo";
+
+ DoFn<KV<Integer, Integer>, String> fn =
+ new DoFn<KV<Integer, Integer>, String>() {
+ private static final int EXPECTED_SUM = 8;
+
+ @StateId(stateId)
+ private final StateSpec<CombiningState<Integer, int[], Integer>>
state =
+ StateSpecs.combining(Sum.ofIntegers());
+
+ @ProcessElement
+ public void processElement(ProcessContext c,
+ @StateId(stateId) GroupingState<Integer, Integer> state) {
+ state.add(c.element().getValue());
+ Integer currentValue = state.read();
+ if (currentValue == EXPECTED_SUM) {
+ c.output("right on");
+ }
+ }
+ };
+
+ PCollection<String> output =
+ pipeline
+ .apply(Create.of(KV.of(123, 4), KV.of(123, 7), KV.of(123, -3)))
+ .apply(ParDo.of(fn));
+
+ // There should only be one moment at which the sum is exactly 8
+ PAssert.that(output).containsInAnyOrder("right on");
+ pipeline.run();
+ }
+
@Test
@Category({ValidatesRunner.class, UsesStatefulParDo.class})
public void testBagStateSideInput() {
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
index 70c8dfdb312..a961203ffed 100644
---
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
+++
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
@@ -30,6 +30,8 @@
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.GroupingState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.TimeDomain;
@@ -39,6 +41,7 @@
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.state.WatermarkHoldState;
import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ProcessContextParameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter;
@@ -649,7 +652,7 @@ public void testStateParameterWrongStateType() throws
Exception {
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("WatermarkHoldState");
thrown.expectMessage("reference to");
- thrown.expectMessage("different type");
+ thrown.expectMessage("supertype");
thrown.expectMessage("ValueState");
thrown.expectMessage("my-id");
thrown.expectMessage("myProcessElement");
@@ -673,7 +676,7 @@ public void testStateParameterWrongGenericType() throws
Exception {
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("ValueState<String>");
thrown.expectMessage("reference to");
- thrown.expectMessage("different type");
+ thrown.expectMessage("supertype");
thrown.expectMessage("ValueState<Integer>");
thrown.expectMessage("my-id");
thrown.expectMessage("myProcessElement");
@@ -692,6 +695,19 @@ public void myProcessElement(
}.getClass());
}
+ @Test
+ public void testGoodStateParameterSuperclassStateType() throws Exception {
+ DoFnSignatures.getSignature(new DoFn<KV<String, Integer>, Long>() {
+ @StateId("my-id")
+ private final StateSpec<CombiningState<Integer, int[], Integer>> state =
+ StateSpecs.combining(Sum.ofIntegers());
+
+ @ProcessElement public void myProcessElement(
+ ProcessContext context,
+ @StateId("my-id") GroupingState<Integer, Integer> groupingState) {}
+ }.getClass());
+ }
+
@Test
public void testSimpleStateIdAnonymousDoFn() throws Exception {
DoFnSignature sig =
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services