kennknowles closed pull request #4024: [BEAM-2304] Allow declared state to be 
accessed as a superclass.
URL: https://github.com/apache/beam/pull/4024
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
index c54c44f2d58..52607833f71 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
@@ -888,8 +888,8 @@ private static Parameter analyzeExtraParameter(
           id);
 
       paramErrors.checkArgument(
-          stateDecl.stateType().equals(stateType),
-          "reference to %s %s with different type %s",
+          stateDecl.stateType().isSubtypeOf(stateType),
+          "data type of reference to %s %s must be a supertype of %s",
           StateId.class.getSimpleName(),
           id,
           formatType(stateDecl.stateType()));
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 03e310463f1..b2cc7dca98a 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -68,6 +68,7 @@
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.state.BagState;
 import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.GroupingState;
 import org.apache.beam.sdk.state.MapState;
 import org.apache.beam.sdk.state.ReadableState;
 import org.apache.beam.sdk.state.SetState;
@@ -2532,6 +2533,40 @@ public void processElement(
     pipeline.run();
   }
 
+  @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class})
+  public void testCombiningStateParameterSuperclass() {
+    final String stateId = "foo";
+
+    DoFn<KV<Integer, Integer>, String> fn =
+        new DoFn<KV<Integer, Integer>, String>() {
+          private static final int EXPECTED_SUM = 8;
+
+          @StateId(stateId)
+          private final StateSpec<CombiningState<Integer, int[], Integer>> 
state =
+              StateSpecs.combining(Sum.ofIntegers());
+
+          @ProcessElement
+          public void processElement(ProcessContext c,
+              @StateId(stateId) GroupingState<Integer, Integer> state) {
+            state.add(c.element().getValue());
+            Integer currentValue = state.read();
+            if (currentValue == EXPECTED_SUM) {
+              c.output("right on");
+            }
+          }
+        };
+
+    PCollection<String> output =
+        pipeline
+            .apply(Create.of(KV.of(123, 4), KV.of(123, 7), KV.of(123, -3)))
+            .apply(ParDo.of(fn));
+
+    // There should only be one moment at which the sum is exactly 8
+    PAssert.that(output).containsInAnyOrder("right on");
+    pipeline.run();
+  }
+
   @Test
   @Category({ValidatesRunner.class, UsesStatefulParDo.class})
   public void testBagStateSideInput() {
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
index 70c8dfdb312..a961203ffed 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
@@ -30,6 +30,8 @@
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.state.CombiningState;
+import org.apache.beam.sdk.state.GroupingState;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.StateSpecs;
 import org.apache.beam.sdk.state.TimeDomain;
@@ -39,6 +41,7 @@
 import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.state.WatermarkHoldState;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ProcessContextParameter;
 import 
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter;
@@ -649,7 +652,7 @@ public void testStateParameterWrongStateType() throws 
Exception {
     thrown.expect(IllegalArgumentException.class);
     thrown.expectMessage("WatermarkHoldState");
     thrown.expectMessage("reference to");
-    thrown.expectMessage("different type");
+    thrown.expectMessage("supertype");
     thrown.expectMessage("ValueState");
     thrown.expectMessage("my-id");
     thrown.expectMessage("myProcessElement");
@@ -673,7 +676,7 @@ public void testStateParameterWrongGenericType() throws 
Exception {
     thrown.expect(IllegalArgumentException.class);
     thrown.expectMessage("ValueState<String>");
     thrown.expectMessage("reference to");
-    thrown.expectMessage("different type");
+    thrown.expectMessage("supertype");
     thrown.expectMessage("ValueState<Integer>");
     thrown.expectMessage("my-id");
     thrown.expectMessage("myProcessElement");
@@ -692,6 +695,19 @@ public void myProcessElement(
             }.getClass());
   }
 
+  @Test
+  public void testGoodStateParameterSuperclassStateType() throws Exception {
+    DoFnSignatures.getSignature(new DoFn<KV<String, Integer>, Long>() {
+      @StateId("my-id")
+      private final StateSpec<CombiningState<Integer, int[], Integer>> state =
+          StateSpecs.combining(Sum.ofIntegers());
+
+      @ProcessElement public void myProcessElement(
+          ProcessContext context,
+          @StateId("my-id") GroupingState<Integer, Integer> groupingState) {}
+    }.getClass());
+  }
+
   @Test
   public void testSimpleStateIdAnonymousDoFn() throws Exception {
     DoFnSignature sig =


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to