This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 98ee1f1  Cache UnboundedReader per CheckpointMark in SDF Wrapper DoFn.
     new b6243e7  Merge pull request #13592 from [BEAM-11403] Cache 
UnboundedReader per UnboundedSourceRestriction in SDF Wrapper DoFn.
98ee1f1 is described below

commit 98ee1f178a9e80f4694f86775c06a54ecf82abb8
Author: Boyuan Zhang <boyu...@google.com>
AuthorDate: Mon Dec 21 15:13:32 2020 -0800

    Cache UnboundedReader per CheckpointMark in SDF Wrapper DoFn.
---
 .../src/main/java/org/apache/beam/sdk/io/Read.java |  96 +++++++++++----
 .../org/apache/beam/sdk/testing/TestPipeline.java  |  45 +++++++
 .../test/java/org/apache/beam/sdk/io/ReadTest.java | 130 +++++++++++++++++++++
 3 files changed, 247 insertions(+), 24 deletions(-)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
index e2f7a8f..4982066 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.io;
 
+import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
 import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
 
 import com.google.auto.value.AutoValue;
@@ -27,6 +28,7 @@ import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 import java.util.NoSuchElementException;
+import java.util.concurrent.TimeUnit;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderException;
 import org.apache.beam.sdk.coders.InstantCoder;
@@ -60,6 +62,9 @@ import org.apache.beam.sdk.values.TypeDescriptor;
 import org.apache.beam.sdk.values.ValueWithRecordId;
 import org.apache.beam.sdk.values.ValueWithRecordId.StripIdsDoFn;
 import org.apache.beam.sdk.values.ValueWithRecordId.ValueWithRecordIdCoder;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.Cache;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.RemovalListener;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
@@ -439,6 +444,8 @@ public class Read {
     private static final Logger LOG = 
LoggerFactory.getLogger(UnboundedSourceAsSDFWrapperFn.class);
     private static final int DEFAULT_BUNDLE_FINALIZATION_LIMIT_MINS = 10;
     private final Coder<CheckpointT> checkpointCoder;
+    private Cache<Object, UnboundedReader<OutputT>> cachedReaders;
+    private Coder<UnboundedSourceRestriction<OutputT, CheckpointT>> 
restrictionCoder;
 
     private UnboundedSourceAsSDFWrapperFn(Coder<CheckpointT> checkpointCoder) {
       this.checkpointCoder = checkpointCoder;
@@ -450,6 +457,27 @@ public class Read {
       return UnboundedSourceRestriction.create(element, null, 
BoundedWindow.TIMESTAMP_MIN_VALUE);
     }
 
+    @Setup
+    public void setUp() throws Exception {
+      restrictionCoder = restrictionCoder();
+      cachedReaders =
+          CacheBuilder.newBuilder()
+              .expireAfterWrite(1, TimeUnit.MINUTES)
+              .maximumSize(100)
+              .removalListener(
+                  (RemovalListener<Object, UnboundedReader>)
+                      removalNotification -> {
+                        if (removalNotification.wasEvicted()) {
+                          try {
+                            removalNotification.getValue().close();
+                          } catch (IOException e) {
+                            LOG.warn("Failed to close UnboundedReader.", e);
+                          }
+                        }
+                      })
+              .build();
+    }
+
     @SplitRestriction
     public void splitRestriction(
         @Restriction UnboundedSourceRestriction<OutputT, CheckpointT> 
restriction,
@@ -488,7 +516,10 @@ public class Read {
         restrictionTracker(
             @Restriction UnboundedSourceRestriction<OutputT, CheckpointT> 
restriction,
             PipelineOptions pipelineOptions) {
-      return new UnboundedSourceAsSDFRestrictionTracker(restriction, 
pipelineOptions);
+      checkNotNull(restrictionCoder);
+      checkNotNull(cachedReaders);
+      return new UnboundedSourceAsSDFRestrictionTracker(
+          restriction, pipelineOptions, cachedReaders, restrictionCoder);
     }
 
     @ProcessElement
@@ -756,22 +787,47 @@ public class Read {
       private final PipelineOptions pipelineOptions;
       private UnboundedSource.UnboundedReader<OutputT> currentReader;
       private boolean readerHasBeenStarted;
+      private Cache<Object, UnboundedReader<OutputT>> cachedReaders;
+      private Coder<UnboundedSourceRestriction<OutputT, CheckpointT>> 
restrictionCoder;
 
       UnboundedSourceAsSDFRestrictionTracker(
           UnboundedSourceRestriction<OutputT, CheckpointT> initialRestriction,
-          PipelineOptions pipelineOptions) {
+          PipelineOptions pipelineOptions,
+          Cache<Object, UnboundedReader<OutputT>> cachedReaders,
+          Coder<UnboundedSourceRestriction<OutputT, CheckpointT>> 
restrictionCoder) {
         this.initialRestriction = initialRestriction;
         this.pipelineOptions = pipelineOptions;
+        this.cachedReaders = cachedReaders;
+        this.restrictionCoder = restrictionCoder;
+      }
+
+      private Object createCacheKey(
+          UnboundedSource<OutputT, CheckpointT> source, CheckpointT 
checkpoint) {
+        checkNotNull(restrictionCoder);
+        // For caching reader, we don't care about the watermark.
+        return restrictionCoder.structuralValue(
+            UnboundedSourceRestriction.create(
+                source, checkpoint, BoundedWindow.TIMESTAMP_MIN_VALUE));
       }
 
       @Override
       public boolean tryClaim(UnboundedSourceValue<OutputT>[] position) {
         try {
           if (currentReader == null) {
-            currentReader =
-                initialRestriction
-                    .getSource()
-                    .createReader(pipelineOptions, 
initialRestriction.getCheckpoint());
+            Object cacheKey =
+                createCacheKey(initialRestriction.getSource(), 
initialRestriction.getCheckpoint());
+            currentReader = cachedReaders.getIfPresent(cacheKey);
+            if (currentReader == null) {
+              currentReader =
+                  initialRestriction
+                      .getSource()
+                      .createReader(pipelineOptions, 
initialRestriction.getCheckpoint());
+            } else {
+              // If the reader is from cache, then we know that the reader has 
been started.
+              // We also remove this cache entry to avoid eviction.
+              readerHasBeenStarted = true;
+              cachedReaders.invalidate(cacheKey);
+            }
           }
           if (currentReader instanceof 
EmptyUnboundedSource.EmptyUnboundedReader) {
             return false;
@@ -804,17 +860,6 @@ public class Read {
         }
       }
 
-      @Override
-      protected void finalize() throws Throwable {
-        if (currentReader != null) {
-          try {
-            currentReader.close();
-          } catch (IOException e) {
-            LOG.error("Failed to close UnboundedReader due to failure 
processing bundle.", e);
-          }
-        }
-      }
-
       /** The value is invalid if {@link #tryClaim} has ever thrown an 
exception. */
       @Override
       public UnboundedSourceRestriction<OutputT, CheckpointT> 
currentRestriction() {
@@ -858,14 +903,17 @@ public class Read {
                 UnboundedSourceRestriction.create(
                     EmptyUnboundedSource.INSTANCE, null, 
BoundedWindow.TIMESTAMP_MAX_VALUE),
                 currentRestriction);
-        try {
-          currentReader.close();
-        } catch (IOException e) {
-          LOG.warn("Failed to close UnboundedReader.", e);
-        } finally {
-          currentReader =
-              EmptyUnboundedSource.INSTANCE.createReader(null, 
currentRestriction.getCheckpoint());
+
+        if (!(currentReader instanceof 
EmptyUnboundedSource.EmptyUnboundedReader)) {
+          // We only put the reader into the cache when we know it possibly 
will be reused by
+          // residuals.
+          cachedReaders.put(
+              createCacheKey(currentRestriction.getSource(), 
currentRestriction.getCheckpoint()),
+              currentReader);
         }
+
+        currentReader =
+            EmptyUnboundedSource.INSTANCE.createReader(null, 
currentRestriction.getCheckpoint());
         return result;
       }
 
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java
index f613d6b..581c5f1 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java
@@ -334,6 +334,51 @@ public class TestPipeline extends Pipeline implements 
TestRule {
     return run(getOptions());
   }
 
+  /**
+   * Runs this {@link TestPipeline} with additional cmd pipeline option args.
+   *
+   * <p>This is useful when using {@link PipelineOptions#as(Class)} directly 
introduces circular
+   * dependency.
+   *
+   * <p>Most of logic is similar to {@link #testingPipelineOptions}.
+   */
+  public PipelineResult runWithAdditionalOptionArgs(List<String> 
additionalArgs) {
+    try {
+      @Nullable
+      String beamTestPipelineOptions = 
System.getProperty(PROPERTY_BEAM_TEST_PIPELINE_OPTIONS);
+      PipelineOptions options;
+      if (Strings.isNullOrEmpty(beamTestPipelineOptions)) {
+        options = PipelineOptionsFactory.create();
+      } else {
+        List<String> args = MAPPER.readValue(beamTestPipelineOptions, 
List.class);
+        args.addAll(additionalArgs);
+        String[] newArgs = new String[args.size()];
+        newArgs = args.toArray(newArgs);
+        options = 
PipelineOptionsFactory.fromArgs(newArgs).as(TestPipelineOptions.class);
+      }
+
+      // If no options were specified, set some reasonable defaults
+      if (Strings.isNullOrEmpty(beamTestPipelineOptions)) {
+        // If there are no provided options, check to see if a dummy runner 
should be used.
+        String useDefaultDummy = 
System.getProperty(PROPERTY_USE_DEFAULT_DUMMY_RUNNER);
+        if (!Strings.isNullOrEmpty(useDefaultDummy) && 
Boolean.valueOf(useDefaultDummy)) {
+          options.setRunner(CrashingRunner.class);
+        }
+      }
+      options.setStableUniqueNames(CheckEnabled.ERROR);
+
+      FileSystems.setDefaultPipelineOptions(options);
+      return run(options);
+    } catch (IOException e) {
+      throw new RuntimeException(
+          "Unable to instantiate test options from system property "
+              + PROPERTY_BEAM_TEST_PIPELINE_OPTIONS
+              + ":"
+              + System.getProperty(PROPERTY_BEAM_TEST_PIPELINE_OPTIONS),
+          e);
+    }
+  }
+
   /** Like {@link #run} but with the given potentially modified options. */
   @Override
   public PipelineResult run(PipelineOptions options) {
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java
index 8a1c51e..5f77b07 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java
@@ -24,17 +24,32 @@ import static org.hamcrest.MatcherAssert.assertThat;
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.List;
+import java.util.NoSuchElementException;
+import java.util.stream.Collectors;
+import java.util.stream.LongStream;
+import org.apache.beam.sdk.coders.AvroCoder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.io.CountingSource.CounterMark;
 import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark;
+import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader;
 import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.values.PCollection;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
+import org.joda.time.Instant;
 import org.junit.Rule;
 import org.junit.Test;
+import org.junit.experimental.categories.Category;
 import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -47,6 +62,7 @@ import org.junit.runners.JUnit4;
 })
 public class ReadTest implements Serializable {
   @Rule public transient ExpectedException thrown = ExpectedException.none();
+  @Rule public final transient TestPipeline pipeline = TestPipeline.create();
 
   @Test
   public void testInstantiationOfBoundedSourceAsSDFWrapper() {
@@ -109,6 +125,23 @@ public class ReadTest implements Serializable {
     assertThat(unboundedDisplayData, hasDisplayItem("maxReadTime", 
maxReadTime));
   }
 
+  @Test
+  @Category(NeedsRunner.class)
+  public void testUnboundedSdfWrapperCacheStartedReaders() throws Exception {
+    long numElements = 1000L;
+    PCollection<Long> input =
+        pipeline.apply(Read.from(new ExpectCacheUnboundedSource(numElements)));
+    PAssert.that(input)
+        .containsInAnyOrder(
+            LongStream.rangeClosed(1L, 
numElements).boxed().collect(Collectors.toList()));
+    // Force the pipeline to run with one thread to ensure the reader will be 
reused on one DoFn
+    // instance.
+    // We are not able to use DirectOptions because of circular dependency.
+    pipeline
+        .runWithAdditionalOptionArgs(ImmutableList.of("--targetParallelism=1"))
+        .waitUntilFinish();
+  }
+
   private abstract static class CustomBoundedSource extends 
BoundedSource<String> {
     @Override
     public List<? extends BoundedSource<String>> split(
@@ -139,6 +172,103 @@ public class ReadTest implements Serializable {
 
   private static class SerializableBoundedSource extends CustomBoundedSource {}
 
+  private static class ExpectCacheUnboundedSource
+      extends UnboundedSource<Long, CountingSource.CounterMark> {
+
+    private final long numElements;
+
+    ExpectCacheUnboundedSource(long numElements) {
+      this.numElements = numElements;
+    }
+
+    @Override
+    public List<? extends UnboundedSource<Long, CounterMark>> split(
+        int desiredNumSplits, PipelineOptions options) throws Exception {
+      return ImmutableList.of(this);
+    }
+
+    @Override
+    public UnboundedReader<Long> createReader(
+        PipelineOptions options, @Nullable CounterMark checkpointMark) throws 
IOException {
+      if (checkpointMark != null) {
+        throw new IOException("The reader should be retrieved from cache 
instead of a new one");
+      }
+      return new ExpectCacheReader(this, checkpointMark);
+    }
+
+    @Override
+    public Coder<Long> getOutputCoder() {
+      return VarLongCoder.of();
+    }
+
+    @Override
+    public Coder<CounterMark> getCheckpointMarkCoder() {
+      return AvroCoder.of(CountingSource.CounterMark.class);
+    }
+  }
+
+  private static class ExpectCacheReader extends UnboundedReader<Long> {
+    private long current;
+    private ExpectCacheUnboundedSource source;
+
+    ExpectCacheReader(ExpectCacheUnboundedSource source, CounterMark 
checkpointMark) {
+      this.source = source;
+      if (checkpointMark == null) {
+        current = 0L;
+      } else {
+        current = checkpointMark.getLastEmitted();
+      }
+    }
+
+    @Override
+    public boolean start() throws IOException {
+      return advance();
+    }
+
+    @Override
+    public boolean advance() throws IOException {
+      current += 1;
+      if (current > source.numElements) {
+        return false;
+      }
+      return true;
+    }
+
+    @Override
+    public Long getCurrent() throws NoSuchElementException {
+      return current;
+    }
+
+    @Override
+    public Instant getCurrentTimestamp() throws NoSuchElementException {
+      return getWatermark();
+    }
+
+    @Override
+    public void close() throws IOException {}
+
+    @Override
+    public Instant getWatermark() {
+      if (current > source.numElements) {
+        return BoundedWindow.TIMESTAMP_MAX_VALUE;
+      }
+      return BoundedWindow.TIMESTAMP_MIN_VALUE;
+    }
+
+    @Override
+    public CheckpointMark getCheckpointMark() {
+      if (current <= 0) {
+        return null;
+      }
+      return new CounterMark(current, BoundedWindow.TIMESTAMP_MIN_VALUE);
+    }
+
+    @Override
+    public UnboundedSource<Long, ?> getCurrentSource() {
+      return source;
+    }
+  }
+
   private abstract static class CustomUnboundedSource
       extends UnboundedSource<String, NoOpCheckpointMark> {
     @Override

Reply via email to