This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new dc382bb  [BEAM-14129] Restructure SubscriptionPartitionLoader to use a 
manual SDF so its watermarks are reasonable given the polling semantics (#17103)
dc382bb is described below

commit dc382bb7f1bfd370536129f26c34b8a6fbf98d8c
Author: dpcollins-google <40498610+dpcollins-goo...@users.noreply.github.com>
AuthorDate: Tue Mar 22 20:27:18 2022 -0400

    [BEAM-14129] Restructure SubscriptionPartitionLoader to use a manual SDF so 
its watermarks are reasonable given the polling semantics (#17103)
    
    * Restructure SubscriptionPartitionLoader to use a manual SDF so its 
watermarks are reasonable given the polling semantics
    
    * Generate initial watermark state from input element timestamp
    
    * fixes
    
    * fixes
    
    * fixes
    
    * Update 
sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoader.java
    
    Fix race
    
    Co-authored-by: Lukasz Cwik <lc...@google.com>
---
 .../internal/PerSubscriptionPartitionSdf.java      |   4 +-
 .../internal/SubscriptionPartitionLoader.java      | 151 ++++++++++++++++-----
 .../internal/SubscriptionPartitionLoaderTest.java  |   9 +-
 3 files changed, 124 insertions(+), 40 deletions(-)

diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/PerSubscriptionPartitionSdf.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/PerSubscriptionPartitionSdf.java
index 22b1389..d387cf1 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/PerSubscriptionPartitionSdf.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/PerSubscriptionPartitionSdf.java
@@ -61,8 +61,8 @@ class PerSubscriptionPartitionSdf extends 
DoFn<SubscriptionPartition, SequencedM
   }
 
   @GetInitialWatermarkEstimatorState
-  public Instant getInitialWatermarkState() {
-    return Instant.EPOCH;
+  public Instant getInitialWatermarkState(@Timestamp Instant elementTimestamp) 
{
+    return elementTimestamp;
   }
 
   @NewWatermarkEstimator
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoader.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoader.java
index 3e38385..3a21f85 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoader.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoader.java
@@ -18,27 +18,27 @@
 package org.apache.beam.sdk.io.gcp.pubsublite.internal;
 
 import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
 
 import com.google.cloud.pubsublite.Partition;
 import com.google.cloud.pubsublite.PartitionLookupUtils;
 import com.google.cloud.pubsublite.SubscriptionPath;
 import com.google.cloud.pubsublite.TopicPath;
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.testing.SerializableMatchers.SerializableSupplier;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Impulse;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
-import org.apache.beam.sdk.transforms.Watch;
-import org.apache.beam.sdk.transforms.Watch.Growth.PollFn;
-import org.apache.beam.sdk.transforms.Watch.Growth.PollResult;
-import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.TypeDescriptor;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 
@@ -47,7 +47,104 @@ class SubscriptionPartitionLoader extends 
PTransform<PBegin, PCollection<Subscri
   private final SubscriptionPath subscription;
   private final SerializableFunction<TopicPath, Integer> getPartitionCount;
   private final Duration pollDuration;
-  private final boolean terminate;
+  private final SerializableSupplier<Boolean> terminate;
+
+  private class GeneratorFn extends DoFn<byte[], SubscriptionPartition> {
+    @ProcessElement
+    public ProcessContinuation processElement(
+        RestrictionTracker<Integer, Integer> restrictionTracker,
+        OutputReceiver<SubscriptionPartition> output,
+        ManualWatermarkEstimator<Instant> estimator) {
+      int previousCount = restrictionTracker.currentRestriction();
+      int newCount = getPartitionCount.apply(topic);
+      if (!restrictionTracker.tryClaim(newCount)) {
+        return ProcessContinuation.stop();
+      }
+      if (newCount > previousCount) {
+        for (int i = previousCount; i < newCount; ++i) {
+          output.outputWithTimestamp(
+              SubscriptionPartition.of(subscription, Partition.of(i)),
+              estimator.currentWatermark());
+        }
+      }
+      estimator.setWatermark(getWatermark());
+      return ProcessContinuation.resume().withResumeDelay(pollDuration);
+    }
+
+    @GetInitialWatermarkEstimatorState
+    public Instant getInitialWatermarkEstimatorState(@Timestamp Instant 
initial) {
+      checkArgument(initial.equals(BoundedWindow.TIMESTAMP_MIN_VALUE));
+      return initial;
+    }
+
+    @GetInitialRestriction
+    public Integer getInitialRestriction() {
+      return 0;
+    }
+
+    @NewTracker
+    public RestrictionTracker<Integer, Integer> newTracker(@Restriction 
Integer input) {
+      return new RestrictionTracker<Integer, Integer>() {
+        private boolean terminated = false;
+        private int position = input;
+
+        @Override
+        public boolean tryClaim(Integer newPosition) {
+          checkArgument(newPosition >= position);
+          if (terminated) {
+            return false;
+          }
+          if (terminate.get()) {
+            terminated = true;
+            return false;
+          }
+          position = newPosition;
+          return true;
+        }
+
+        @Override
+        public Integer currentRestriction() {
+          return position;
+        }
+
+        @Override
+        public @Nullable SplitResult<Integer> trySplit(double 
fractionOfRemainder) {
+          if (fractionOfRemainder != 0) {
+            return null;
+          }
+          if (terminated) {
+            return null;
+          }
+          terminated = true;
+          return SplitResult.of(position, position);
+        }
+
+        @Override
+        public void checkDone() throws IllegalStateException {
+          checkState(terminated);
+        }
+
+        @Override
+        public IsBounded isBounded() {
+          return IsBounded.UNBOUNDED;
+        }
+      };
+    }
+
+    @NewWatermarkEstimator
+    public ManualWatermarkEstimator<Instant> newWatermarkEstimator(
+        @WatermarkEstimatorState Instant state) {
+      return new WatermarkEstimators.Manual(state);
+    }
+
+    private Instant getWatermark() {
+      return Instant.now().minus(watermarkDelay());
+    }
+
+    private Duration watermarkDelay() {
+      return pollDuration.multipliedBy(3).dividedBy(2);
+    }
+  }
 
   SubscriptionPartitionLoader(TopicPath topic, SubscriptionPath subscription) {
     this(
@@ -55,7 +152,7 @@ class SubscriptionPartitionLoader extends PTransform<PBegin, 
PCollection<Subscri
         subscription,
         PartitionLookupUtils::numPartitions,
         Duration.standardMinutes(1),
-        false);
+        () -> false);
   }
 
   @VisibleForTesting
@@ -64,7 +161,7 @@ class SubscriptionPartitionLoader extends PTransform<PBegin, 
PCollection<Subscri
       SubscriptionPath subscription,
       SerializableFunction<TopicPath, Integer> getPartitionCount,
       Duration pollDuration,
-      boolean terminate) {
+      SerializableSupplier<Boolean> terminate) {
     this.topic = topic;
     this.subscription = subscription;
     this.getPartitionCount = getPartitionCount;
@@ -74,28 +171,8 @@ class SubscriptionPartitionLoader extends 
PTransform<PBegin, PCollection<Subscri
 
   @Override
   public PCollection<SubscriptionPartition> expand(PBegin input) {
-    PCollection<TopicPath> start = 
input.apply(Create.of(ImmutableList.of(topic)));
-    PCollection<KV<TopicPath, Partition>> partitions =
-        start.apply(
-            Watch.growthOf(
-                    new PollFn<TopicPath, Partition>() {
-                      @Override
-                      public PollResult<Partition> apply(TopicPath element, 
Context c) {
-                        checkArgument(element.equals(topic));
-                        int partitionCount = getPartitionCount.apply(element);
-                        List<Partition> partitions =
-                            IntStream.range(0, partitionCount)
-                                .mapToObj(Partition::of)
-                                .collect(Collectors.toList());
-                        return PollResult.incomplete(Instant.now(), partitions)
-                            .withWatermark(Instant.now());
-                      }
-                    })
-                .withPollInterval(pollDuration)
-                .withTerminationPerInput(
-                    terminate ? Watch.Growth.afterIterations(10) : 
Watch.Growth.never()));
-    return partitions.apply(
-        MapElements.into(TypeDescriptor.of(SubscriptionPartition.class))
-            .via(kv -> SubscriptionPartition.of(subscription, kv.getValue())));
+    return input
+        .apply("Impulse", Impulse.create())
+        .apply("Watch Partition Count", ParDo.of(new GeneratorFn()));
   }
 }
diff --git 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoaderTest.java
 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoaderTest.java
index 278a1f5..31b1ad3 100644
--- 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoaderTest.java
+++ 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsublite/internal/SubscriptionPartitionLoaderTest.java
@@ -25,6 +25,7 @@ import com.google.cloud.pubsublite.Partition;
 import com.google.cloud.pubsublite.SubscriptionPath;
 import com.google.cloud.pubsublite.TopicPath;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.SerializableMatchers.SerializableSupplier;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.PCollection;
@@ -41,6 +42,8 @@ import org.mockito.Mock;
 public class SubscriptionPartitionLoaderTest {
   @Rule public final transient TestPipeline pipeline = TestPipeline.create();
   @Mock SerializableFunction<TopicPath, Integer> getPartitionCount;
+
+  @Mock SerializableSupplier<Boolean> terminate;
   private SubscriptionPartitionLoader loader;
 
   @Before
@@ -48,18 +51,21 @@ public class SubscriptionPartitionLoaderTest {
     initMocks(this);
     FakeSerializable.Handle<SerializableFunction<TopicPath, Integer>> handle =
         FakeSerializable.put(getPartitionCount);
+    FakeSerializable.Handle<SerializableSupplier<Boolean>> terminateHandle =
+        FakeSerializable.put(terminate);
     loader =
         new SubscriptionPartitionLoader(
             example(TopicPath.class),
             example(SubscriptionPath.class),
             topic -> handle.get().apply(topic),
             Duration.millis(50),
-            true);
+            () -> terminateHandle.get().get());
   }
 
   @Test
   public void singleResult() {
     when(getPartitionCount.apply(example(TopicPath.class))).thenReturn(3);
+    when(terminate.get()).thenReturn(false).thenReturn(false).thenReturn(true);
     PCollection<SubscriptionPartition> output = pipeline.apply(loader);
     PAssert.that(output)
         .containsInAnyOrder(
@@ -72,6 +78,7 @@ public class SubscriptionPartitionLoaderTest {
   @Test
   public void addedResults() {
     
when(getPartitionCount.apply(example(TopicPath.class))).thenReturn(3).thenReturn(4);
+    when(terminate.get()).thenReturn(false).thenReturn(false).thenReturn(true);
     PCollection<SubscriptionPartition> output = pipeline.apply(loader);
     PAssert.that(output)
         .containsInAnyOrder(

Reply via email to