This is an automated email from the ASF dual-hosted git repository.

mmack pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 35596b7050c [AWS SQS] Support strict expiration of SQS batches when 
writing (#27484)
35596b7050c is described below

commit 35596b7050ca3f947792021ae735c469cb6b3198
Author: Moritz Mack <mm...@talend.com>
AuthorDate: Mon Aug 14 09:05:53 2023 +0200

    [AWS SQS] Support strict expiration of SQS batches when writing (#27484)
---
 .../org/apache/beam/sdk/io/aws2/sqs/SqsIO.java     | 295 +++++++++++++++------
 .../sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java     |  87 ++++++
 2 files changed, 298 insertions(+), 84 deletions(-)

diff --git 
a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
 
b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
index db918aa680c..f7f767ab85e 100644
--- 
a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
+++ 
b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
@@ -18,6 +18,7 @@
 package org.apache.beam.sdk.io.aws2.sqs;
 
 import static java.util.Collections.EMPTY_LIST;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
 import static 
org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory.buildClient;
 import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
@@ -27,10 +28,19 @@ import com.google.auto.value.AutoValue;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.ConcurrentModificationException;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CancellationException;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiConsumer;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
@@ -61,6 +71,7 @@ import org.apache.beam.sdk.values.TupleTag;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
 import org.checkerframework.checker.nullness.qual.NonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.checkerframework.dataflow.qual.Pure;
@@ -152,6 +163,7 @@ public class SqsIO {
         .concurrentRequests(WriteBatches.DEFAULT_CONCURRENCY)
         .batchSize(WriteBatches.MAX_BATCH_SIZE)
         .batchTimeout(WriteBatches.DEFAULT_BATCH_TIMEOUT)
+        .strictTimeouts(false)
         .build();
   }
 
@@ -289,6 +301,8 @@ public class SqsIO {
 
     abstract @Pure Duration batchTimeout();
 
+    abstract @Pure boolean strictTimeouts();
+
     abstract @Pure int batchSize();
 
     abstract @Pure ClientConfiguration clientConfiguration();
@@ -311,6 +325,8 @@ public class SqsIO {
 
       abstract Builder<T> batchTimeout(Duration duration);
 
+      abstract Builder<T> strictTimeouts(boolean strict);
+
       abstract Builder<T> batchSize(int batchSize);
 
       abstract Builder<T> clientConfiguration(ClientConfiguration config);
@@ -363,10 +379,20 @@ public class SqsIO {
     /**
      * The duration to accumulate records before timing out, default is 3 secs.
      *
-     * <p>Timeouts will be checked upon arrival of new messages.
+     * <p>By default timeouts will be checked upon arrival of records.
      */
     public WriteBatches<T> withBatchTimeout(Duration timeout) {
-      return builder().batchTimeout(timeout).build();
+      return withBatchTimeout(timeout, false);
+    }
+
+    /**
+     * The duration to accumulate records before timing out, default is 3 secs.
+     *
+     * <p>By default timeouts will be checked upon arrival of records. If 
using {@code strict}
+     * enforcement, timeouts will be check by a separate thread.
+     */
+    public WriteBatches<T> withBatchTimeout(Duration timeout, boolean strict) {
+      return builder().batchTimeout(timeout).strictTimeouts(strict).build();
     }
 
     /** Dynamic record based destination to write to. */
@@ -546,12 +572,18 @@ public class SqsIO {
     }
 
     private static class BatchHandler<T> implements AutoCloseable {
+      private static final int CHECKS_PER_TIMEOUT_PERIOD = 5;
+      public static final int EXPIRATION_CHECK_TIMEOUT_SECS = 3;
+
       private final WriteBatches<T> spec;
       private final SqsAsyncClient sqs;
       private final Batches batches;
       private final EntryMapperFn<T> entryMapper;
       private final AsyncBatchWriteHandler<SendMessageBatchRequestEntry, 
BatchResultErrorEntry>
           handler;
+      private final @Nullable ScheduledExecutorService scheduler;
+
+      private @MonotonicNonNull ScheduledFuture<?> expirationCheck = null;
 
       BatchHandler(WriteBatches<T> spec, EntryMapperFn<T> entryMapper, 
AwsOptions options) {
         this.spec = spec;
@@ -567,8 +599,10 @@ public class SqsIO {
                 error -> error.code(),
                 record -> record.id(),
                 error -> error.id());
+        this.scheduler =
+            spec.strictTimeouts() ? 
Executors.newSingleThreadScheduledExecutor() : null;
         if (spec.queueUrl() != null) {
-          this.batches = new Single(spec.queueUrl());
+          this.batches = new Single();
         } else if (spec.dynamicDestination() != null) {
           this.batches = new Dynamic(spec.dynamicDestination());
         } else {
@@ -585,6 +619,13 @@ public class SqsIO {
 
       public void startBundle() {
         handler.reset();
+        if (scheduler != null && spec.strictTimeouts()) {
+          long timeout = spec.batchTimeout().getMillis();
+          long period = timeout / CHECKS_PER_TIMEOUT_PERIOD;
+          expirationCheck =
+              scheduler.scheduleWithFixedDelay(
+                  () -> batches.submitExpired(false), timeout, period, 
MILLISECONDS);
+        }
       }
 
       public void process(T msg) {
@@ -592,18 +633,21 @@ public class SqsIO {
         Batch batch = batches.getLocked(msg);
         batch.add(entry);
         if (batch.size() >= spec.batchSize() || batch.isExpired()) {
-          writeEntries(batch, true);
+          submitEntries(batch, true);
         } else {
           checkState(batch.lock(false)); // unlock to continue writing to batch
         }
 
-        // check timeouts synchronously on arrival of new messages
-        batches.writeExpired(true);
+        if (scheduler == null) {
+          // check for expired batches synchronously
+          batches.submitExpired(true);
+        }
       }
 
-      private void writeEntries(Batch batch, boolean throwPendingFailures) {
+      /** Submit entries of a {@link Batch} to the async write handler. */
+      private void submitEntries(Batch batch, boolean throwFailures) {
         try {
-          handler.batchWrite(batch.queue, batch.getAndClear(), 
throwPendingFailures);
+          handler.batchWrite(batch.queue, batch.getAndClose(), throwFailures);
         } catch (RuntimeException e) {
           throw e;
         } catch (Throwable e) {
@@ -612,32 +656,54 @@ public class SqsIO {
       }
 
       public void finishBundle() throws Throwable {
-        batches.writeAll();
+        if (expirationCheck != null) {
+          expirationCheck.cancel(false);
+          while (true) {
+            try {
+              expirationCheck.get(EXPIRATION_CHECK_TIMEOUT_SECS, 
TimeUnit.SECONDS);
+            } catch (TimeoutException e) {
+              LOG.warn("Waiting for timeout check to complete");
+            } catch (CancellationException e) {
+              break; // scheduled checks completed after cancellation
+            }
+          }
+        }
+        // safe to write remaining batches without risking to encounter locked 
ones
+        checkState(batches.submitAll());
         handler.waitForCompletion();
       }
 
       @Override
       public void close() throws Exception {
         sqs.close();
+        if (scheduler != null) {
+          scheduler.shutdown();
+        }
       }
 
       /**
        * Batch(es) of a single fixed or several dynamic queues.
        *
-       * <p>{@link #getLocked} is meant to support atomic writes from multiple 
threads if using an
-       * appropriate thread-safe implementation. This is necessary to later 
support strict timeouts
-       * (see below).
+       * <p>A {@link Batch} can only ever be modified from the single runner 
thread.
        *
-       * <p>For simplicity, check for expired messages after appending to a 
batch. For strict
-       * enforcement of timeouts, {@link #writeExpired} would have to be 
periodically called using a
-       * scheduler and requires also a thread-safe impl of {@link 
Batch#lock(boolean)}.
+       * <p>In case of strict timeouts, a batch may be submitted to the write 
handler by periodic
+       * expiration checks using a scheduler. Otherwise, and by default, this 
is done after
+       * appending to a batch. {@link Batch#lock(boolean)} prevents concurrent 
access to a batch
+       * between threads. Once a batch was locked by an expiration check, it 
must always be
+       * submitted to the write handler.
        */
+      @NotThreadSafe
       private abstract class Batches {
         private int nextId = 0; // only ever used from one "runner" thread
 
         abstract int maxBatches();
 
-        /** Next batch entry id is guaranteed to be unique for all open 
batches. */
+        /**
+         * Next batch entry id is guaranteed to be unique for all open batches.
+         *
+         * <p>This method is not thread-safe and may only ever be called from 
the single runner
+         * thread.
+         */
         String nextId() {
           if (nextId >= (spec.batchSize() * maxBatches())) {
             nextId = 0;
@@ -645,24 +711,40 @@ public class SqsIO {
           return Integer.toString(nextId++);
         }
 
-        /** Get existing or new locked batch that can be written to. */
+        /**
+         * Get an existing or new locked batch to append new messages.
+         *
+         * <p>This method is not thread-safe and may only ever be called from 
a single runner
+         * thread. If this encounters a locked batch, it assumes the {@link 
Batch} is currently
+         * written to SQS and creates a new one.
+         */
         abstract Batch getLocked(T record);
 
-        /** Write all remaining batches (that can be locked). */
-        abstract void writeAll();
-
-        /** Write all expired batches (that can be locked). */
-        abstract void writeExpired(boolean throwPendingFailures);
-
-        /** Create a new locked batch that is ready for writing. */
-        Batch createLocked(String queue) {
-          return new Batch(queue, spec.batchSize(), spec.batchTimeout());
-        }
-
-        /** Write a batch if it can be locked. */
-        protected boolean writeLocked(Batch batch, boolean 
throwPendingFailures) {
-          if (batch.lock(true)) {
-            writeEntries(batch, throwPendingFailures);
+        /**
+         * Submit all remaining batches (that can be locked) to the write 
handler.
+         *
+         * @return {@code true} if successful for all batches.
+         */
+        abstract boolean submitAll();
+
+        /**
+         * Submit all expired batches (that can be locked) to the write 
handler.
+         *
+         * <p>This is the only method that may be invoked from a thread other 
than the runner
+         * thread.
+         */
+        abstract void submitExpired(boolean throwFailures);
+
+        /**
+         * Submit a batch to the write handler if it can be locked.
+         *
+         * @return {@code true} if successful (or closed).
+         */
+        protected boolean lockAndSubmit(Batch batch, boolean throwFailures) {
+          if (batch.isClosed()) {
+            return true; // nothing to submit
+          } else if (batch.lock(true)) {
+            submitEntries(batch, throwFailures);
             return true;
           }
           return false;
@@ -672,11 +754,7 @@ public class SqsIO {
       /** Batch of a single, fixed queue. */
       @NotThreadSafe
       private class Single extends Batches {
-        private Batch batch;
-
-        Single(String queue) {
-          this.batch = new Batch(queue, EMPTY_LIST, Batch.NEVER); // locked
-        }
+        private @Nullable Batch batch;
 
         @Override
         int maxBatches() {
@@ -685,18 +763,21 @@ public class SqsIO {
 
         @Override
         Batch getLocked(T record) {
-          return batch.lock(true) ? batch : (batch = 
createLocked(batch.queue));
+          if (batch == null || !batch.lock(true)) {
+            batch = Batch.createLocked(checkStateNotNull(spec.queueUrl()), 
spec);
+          }
+          return batch;
         }
 
         @Override
-        void writeAll() {
-          writeLocked(batch, true);
+        boolean submitAll() {
+          return batch == null || lockAndSubmit(batch, true);
         }
 
         @Override
-        void writeExpired(boolean throwPendingFailures) {
-          if (batch.isExpired()) {
-            writeLocked(batch, throwPendingFailures);
+        void submitExpired(boolean throwFailures) {
+          if (batch != null && batch.isExpired()) {
+            lockAndSubmit(batch, throwFailures);
           }
         }
       }
@@ -709,8 +790,9 @@ public class SqsIO {
             (queue, batch) -> batch != null && batch.lock(true) ? batch : 
createLocked(queue);
 
         private final Map<@NonNull String, Batch> batches = new HashMap<>();
+        private final AtomicBoolean submitExpiredRunning = new 
AtomicBoolean(false);
+        private final AtomicReference<Instant> nextTimeout = new 
AtomicReference<>(Batch.NEVER);
         private final DynamicDestination<T> destination;
-        private Instant nextTimeout = Batch.NEVER;
 
         Dynamic(DynamicDestination<T> destination) {
           this.destination = destination;
@@ -727,77 +809,118 @@ public class SqsIO {
         }
 
         @Override
-        void writeAll() {
-          batches.values().forEach(batch -> writeLocked(batch, true));
+        boolean submitAll() {
+          AtomicBoolean res = new AtomicBoolean(true);
+          batches.values().forEach(batch -> res.compareAndSet(true, 
lockAndSubmit(batch, true)));
           batches.clear();
-          nextTimeout = Batch.NEVER;
+          nextTimeout.set(Batch.NEVER);
+          return res.get();
         }
 
-        private void writeExpired(Batch batch) {
-          if (!batch.isExpired() || !writeLocked(batch, true)) {
-            // find next timeout for remaining, unwritten batches
-            if (batch.timeout.isBefore(nextTimeout)) {
-              nextTimeout = batch.timeout;
-            }
+        private void updateNextTimeout(Batch batch) {
+          Instant prev;
+          do {
+            prev = nextTimeout.get();
+          } while (batch.expirationTime.isBefore(prev)
+              && !nextTimeout.compareAndSet(prev, batch.expirationTime));
+        }
+
+        private void submitExpired(Batch batch, boolean throwFailures) {
+          if (!batch.isClosed() && (!batch.isExpired() || 
!lockAndSubmit(batch, throwFailures))) {
+            updateNextTimeout(batch);
           }
         }
 
         @Override
-        void writeExpired(boolean throwPendingFailures) {
-          if (nextTimeout.isBeforeNow()) {
-            nextTimeout = Batch.NEVER;
-            batches.values().forEach(this::writeExpired);
+        void submitExpired(boolean throwFailures) {
+          Instant timeout = nextTimeout.get();
+          if (timeout.isBeforeNow()) {
+            // prevent concurrent checks for expired batches
+            if (submitExpiredRunning.compareAndSet(false, true)) {
+              try {
+                nextTimeout.set(Batch.NEVER);
+                batches.values().forEach(b -> submitExpired(b, throwFailures));
+              } catch (ConcurrentModificationException e) {
+                // Can happen rarely when adding a new dynamic destination and 
is expected.
+                // Reset old timeout to repeat check asap.
+                nextTimeout.set(timeout);
+              } finally {
+                submitExpiredRunning.set(false);
+              }
+            }
           }
         }
 
-        @Override
         Batch createLocked(String queue) {
-          Batch batch = super.createLocked(queue);
-          if (batch.timeout.isBefore(nextTimeout)) {
-            nextTimeout = batch.timeout;
-          }
+          Batch batch = Batch.createLocked(queue, spec);
+          updateNextTimeout(batch);
           return batch;
         }
       }
     }
 
-    /**
-     * Batch of entries of a queue.
-     *
-     * <p>Overwrite {@link #lock} with a thread-safe implementation to support 
concurrent usage.
-     */
+    /** Batch of entries of a queue. */
     @NotThreadSafe
-    private static final class Batch {
+    private abstract static class Batch {
       private static final Instant NEVER = 
Instant.ofEpochMilli(Long.MAX_VALUE);
+
       private final String queue;
-      private Instant timeout;
+      private final Instant expirationTime;
       private List<SendMessageBatchRequestEntry> entries;
 
-      Batch(String queue, int size, Duration bufferedTime) {
-        this(queue, new ArrayList<>(size), Instant.now().plus(bufferedTime));
+      static Batch createLocked(String queue, SqsIO.WriteBatches<?> spec) {
+        return spec.strictTimeouts()
+            ? new BatchWithAtomicLock(queue, spec.batchSize(), 
spec.batchTimeout())
+            : new BatchWithNoopLock(queue, spec.batchSize(), 
spec.batchTimeout());
       }
 
-      Batch(String queue, List<SendMessageBatchRequestEntry> entries, Instant 
timeout) {
-        this.queue = queue;
-        this.entries = entries;
-        this.timeout = timeout;
+      /** A {@link Batch} with a noop lock that just rejects un/locking if 
closed. */
+      private static class BatchWithNoopLock extends Batch {
+        BatchWithNoopLock(String queue, int size, Duration timeout) {
+          super(queue, size, timeout);
+        }
+
+        @Override
+        boolean lock(boolean lock) {
+          return !isClosed(); // always un/lock unless closed
+        }
       }
 
-      /** Attempt to un/lock this batch and return if successful. */
-      boolean lock(boolean lock) {
-        // thread unsafe dummy impl that rejects locking batches after 
getAndClear
-        return !NEVER.equals(timeout) || !lock;
+      /** A {@link Batch} supporting atomic locking for concurrent usage. */
+      private static class BatchWithAtomicLock extends Batch {
+        private final AtomicBoolean locked = new AtomicBoolean(true); // 
always lock on creation
+
+        BatchWithAtomicLock(String queue, int size, Duration timeout) {
+          super(queue, size, timeout);
+        }
+
+        @Override
+        boolean lock(boolean lock) {
+          return !isClosed() && locked.compareAndSet(!lock, lock);
+        }
+      }
+
+      private Batch(String queue, int size, Duration timeout) {
+        this.queue = queue;
+        this.entries = new ArrayList<>(size);
+        this.expirationTime = Instant.now().plus(timeout);
       }
 
-      /** Get and clear entries for writing. */
-      List<SendMessageBatchRequestEntry> getAndClear() {
+      /** Attempt to un/lock this batch, if closed this always fails. */
+      abstract boolean lock(boolean lock);
+
+      /**
+       * Get and clear entries for submission to the write handler.
+       *
+       * <p>The batch must be locked and kept locked, it can't be modified 
anymore.
+       */
+      List<SendMessageBatchRequestEntry> getAndClose() {
         List<SendMessageBatchRequestEntry> res = entries;
         entries = EMPTY_LIST;
-        timeout = NEVER;
         return res;
       }
 
-      /** Add entry to this batch. */
+      /** Append entry (only use if locked!). */
       void add(SendMessageBatchRequestEntry entry) {
         entries.add(entry);
       }
@@ -807,7 +930,11 @@ public class SqsIO {
       }
 
       boolean isExpired() {
-        return timeout.isBeforeNow();
+        return expirationTime.isBeforeNow();
+      }
+
+      boolean isClosed() {
+        return entries == EMPTY_LIST;
       }
     }
   }
diff --git 
a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
 
b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
index 0dc0719cc47..e92720bfb5a 100644
--- 
a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
+++ 
b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.io.aws2.sqs;
 
+import static java.lang.Math.sqrt;
 import static java.nio.charset.StandardCharsets.UTF_8;
 import static java.util.concurrent.CompletableFuture.completedFuture;
 import static java.util.concurrent.CompletableFuture.supplyAsync;
@@ -26,14 +27,17 @@ import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Pr
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.joda.time.Duration.millis;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
@@ -45,6 +49,7 @@ import 
org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler;
 import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
 import org.apache.beam.sdk.io.aws2.common.RetryConfiguration;
 import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches;
+import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches.DynamicDestination;
 import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches.EntryMapperFn;
 import org.apache.beam.sdk.schemas.SchemaRegistry;
 import org.apache.beam.sdk.testing.ExpectedLogs;
@@ -54,6 +59,7 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams;
+import org.apache.commons.lang3.RandomUtils;
 import org.joda.time.Duration;
 import org.junit.Before;
 import org.junit.Rule;
@@ -263,6 +269,29 @@ public class SqsIOWriteBatchesTest {
     verify(sqs).sendMessageBatch(request("queue", entries[3], entries[4]));
   }
 
+  @Test
+  public void testWriteBatchesWithStrictTimeout() {
+    when(sqs.sendMessageBatch(any(SendMessageBatchRequest.class)))
+        
.thenReturn(completedFuture(SendMessageBatchResponse.builder().build()));
+
+    p.apply(Create.of(5))
+        .apply(ParDo.of(new CreateMessages()))
+        .apply(
+            // simulate delay between messages > batch timeout
+            SqsIO.<String>writeBatches()
+                .withEntryMapper(withDelay(millis(100), SET_MESSAGE_BODY))
+                .withBatchTimeout(millis(150), true)
+                .to("queue"));
+
+    p.run().waitUntilFinish();
+
+    SendMessageBatchRequestEntry[] entries = entries(range(0, 5));
+    // using strict timeouts batches, batches are timed out by a separate 
thread
+    verify(sqs).sendMessageBatch(request("queue", entries[0], entries[1]));
+    verify(sqs).sendMessageBatch(request("queue", entries[2], entries[3]));
+    verify(sqs).sendMessageBatch(request("queue", entries[4]));
+  }
+
   @Test
   public void testWriteBatchesToDynamic() {
     
when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
@@ -315,6 +344,64 @@ public class SqsIOWriteBatchesTest {
     verify(sqs).sendMessageBatch(request("even", entries[4]));
   }
 
+  @Test
+  public void testWriteBatchesToDynamicWithStrictTimeout() {
+    when(sqs.sendMessageBatch(any(SendMessageBatchRequest.class)))
+        
.thenReturn(completedFuture(SendMessageBatchResponse.builder().build()));
+
+    p.apply(Create.of(5))
+        .apply(ParDo.of(new CreateMessages()))
+        .apply(
+            // simulate delay between messages > batch timeout
+            SqsIO.<String>writeBatches()
+                .withEntryMapper(withDelay(millis(100), SET_MESSAGE_BODY))
+                .withBatchTimeout(millis(150), true)
+                .to(msg -> Integer.valueOf(msg) % 2 == 0 ? "even" : "uneven"));
+
+    p.run().waitUntilFinish();
+
+    SendMessageBatchRequestEntry[] entries = entries(range(0, 5));
+    // using strict timeouts batches, batches are timed out by a separate 
thread before any 2nd
+    // entry
+    verify(sqs).sendMessageBatch(request("even", entries[0]));
+    verify(sqs).sendMessageBatch(request("uneven", entries[1]));
+    verify(sqs).sendMessageBatch(request("even", entries[2]));
+    verify(sqs).sendMessageBatch(request("uneven", entries[3]));
+    verify(sqs).sendMessageBatch(request("even", entries[4]));
+  }
+
+  @Test
+  public void testWriteBatchesToDynamicWithStrictTimeoutAtHighVolume() {
+    when(sqs.sendMessageBatch(any(SendMessageBatchRequest.class)))
+        
.thenReturn(completedFuture(SendMessageBatchResponse.builder().build()));
+
+    // Use sqrt to change the rate of newly created dynamic destinations over 
time
+    DynamicDestination<String> dynamicDestination =
+        msg -> String.valueOf(RandomUtils.nextInt(0, (int) (1 + 
sqrt(Integer.valueOf(msg)))));
+
+    p.apply(Create.of(100000))
+        .apply(ParDo.of(new CreateMessages()))
+        .apply(
+            SqsIO.<String>writeBatches()
+                .withEntryMapper(SET_MESSAGE_BODY)
+                .withBatchTimeout(millis(10), true)
+                .to(dynamicDestination));
+
+    p.run().waitUntilFinish();
+
+    ArgumentCaptor<SendMessageBatchRequest> reqCaptor =
+        ArgumentCaptor.forClass(SendMessageBatchRequest.class);
+    verify(sqs, atLeastOnce()).sendMessageBatch(reqCaptor.capture());
+
+    Set<String> capturedMessages = new HashSet<>();
+    for (SendMessageBatchRequest req : reqCaptor.getAllValues()) {
+      for (SendMessageBatchRequestEntry entry : req.entries()) {
+        assertTrue("duplicate message", 
capturedMessages.add(entry.messageBody()));
+      }
+    }
+    assertEquals("Invalid message count", 100000, capturedMessages.size());
+  }
+
   private SendMessageBatchRequest anyRequest() {
     return any();
   }

Reply via email to