This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 72027e11d1f Fix DebeziumIO resuming from worker restart (#37689)
72027e11d1f is described below

commit 72027e11d1f2f613e53302394534b204182cc32c
Author: Yi Hu <[email protected]>
AuthorDate: Wed Feb 25 17:49:01 2026 -0500

    Fix DebeziumIO resuming from worker restart (#37689)
    
    * Fix DebeziumIO resuming from worker restart
    
    * Move startTime recording into setup to fix NPE in restarted worker
    
    * Fix DebeziumIO poll loop not exiting when record list isn't empty
    
    * Make pollTimeout configurable
    
    * Include first poll in stopWatch
    
    * Fix pipeline terminate when poll returns null, which is valid per spec
    
    * Adjust MaxNumRecord logic after previous fix. Previously if there is
      in total N=MaxNumRecord records the pipeline won't finish until N+1 record
      appears. The test relied on null poll to return actually
---
 sdks/java/io/debezium/src/README.md                |   7 +-
 .../org/apache/beam/io/debezium/DebeziumIO.java    |  28 ++--
 .../beam/io/debezium/KafkaSourceConsumerFn.java    | 151 +++++++++++----------
 .../io/debezium/KafkaSourceConsumerFnTest.java     |  89 ++++++++++--
 .../apache/beam/io/debezium/OffsetTrackerTest.java |  25 +---
 5 files changed, 174 insertions(+), 126 deletions(-)

diff --git a/sdks/java/io/debezium/src/README.md 
b/sdks/java/io/debezium/src/README.md
index e56ac370b70..53521321885 100644
--- a/sdks/java/io/debezium/src/README.md
+++ b/sdks/java/io/debezium/src/README.md
@@ -155,12 +155,7 @@ There are two ways of initializing KSC:
 *  Restricted by number of records
 *  Restricted by amount of time (minutes)
 
-By default, DebeziumIO initializes it with the former, though user may choose 
the latter by setting the amount of minutes as a parameter:
-
-|Function|Param|Description|
-|-|-|-|
-|`KafkaSourceConsumerFn(connectorClass, recordMapper, maxRecords)`|_Class, 
SourceRecordMapper, Int_|Restrict run by number of records (Default).|
-|`KafkaSourceConsumerFn(connectorClass, recordMapper, timeToRun)`|_Class, 
SourceRecordMapper, Long_|Restrict run by amount of time (in minutes).|
+By default, DebeziumIO initializes it with the former, though user may choose 
the latter by setting the amount of minutes as a parameter for DebeziumIO.Read 
transform.
 
 ### Requirements and Supported versions
 
diff --git 
a/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/DebeziumIO.java
 
b/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/DebeziumIO.java
index b38c035adf2..ebf91a4a095 100644
--- 
a/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/DebeziumIO.java
+++ 
b/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/DebeziumIO.java
@@ -63,11 +63,6 @@ import org.slf4j.LoggerFactory;
  *
  * <h3>Usage example</h3>
  *
- * <p>Support is currently experimental. One of the known issues is that the 
connector does not
- * preserve the offset on a worker crash or restart, causing it to retrieve 
all the data from the
- * beginning again. See <a 
href="https://github.com/apache/beam/issues/28248";>Issue #28248</a> for
- * details.
- *
  * <p>Connect to a Debezium - MySQL database and run a Pipeline
  *
  * <pre>
@@ -147,6 +142,8 @@ public class DebeziumIO {
 
     abstract @Nullable Long getMaxTimeToRun();
 
+    abstract @Nullable Long getPollingTimeout();
+
     abstract @Nullable Coder<T> getCoder();
 
     abstract Builder<T> toBuilder();
@@ -163,6 +160,8 @@ public class DebeziumIO {
 
       abstract Builder<T> setMaxTimeToRun(Long miliseconds);
 
+      abstract Builder<T> setPollingTimeout(Long miliseconds);
+
       abstract Read<T> build();
     }
 
@@ -222,12 +221,18 @@ public class DebeziumIO {
       return toBuilder().setMaxTimeToRun(miliseconds).build();
     }
 
+    /**
+     * Sets the timeout in milliseconds for consumer polling request in the 
{@link
+     * KafkaSourceConsumerFn}. A lower timeout optimizes for latency. Increase 
the timeout if the
+     * consumer is not fetching any records. The default is 1000 milliseconds.
+     */
+    public Read<T> withPollingTimeout(Long miliseconds) {
+      return toBuilder().setPollingTimeout(miliseconds).build();
+    }
+
     protected Schema getRecordSchema() {
       KafkaSourceConsumerFn<T> fn =
-          new KafkaSourceConsumerFn<>(
-              getConnectorConfiguration().getConnectorClass().get(),
-              getFormatFunction(),
-              getMaxNumberOfRecords());
+          new 
KafkaSourceConsumerFn<>(getConnectorConfiguration().getConnectorClass().get(), 
this);
       fn.register(
           new KafkaSourceConsumerFn.OffsetTracker(
               new KafkaSourceConsumerFn.OffsetHolder(null, null, 0)));
@@ -267,10 +272,7 @@ public class DebeziumIO {
           .apply(
               ParDo.of(
                   new KafkaSourceConsumerFn<>(
-                      getConnectorConfiguration().getConnectorClass().get(),
-                      getFormatFunction(),
-                      getMaxNumberOfRecords(),
-                      getMaxTimeToRun())))
+                      getConnectorConfiguration().getConnectorClass().get(), 
this)))
           .setCoder(getCoder());
     }
   }
diff --git 
a/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/KafkaSourceConsumerFn.java
 
b/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/KafkaSourceConsumerFn.java
index 00d7e6ac741..fb4c2f21458 100644
--- 
a/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/KafkaSourceConsumerFn.java
+++ 
b/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/KafkaSourceConsumerFn.java
@@ -29,6 +29,7 @@ import io.debezium.relational.history.SchemaHistoryException;
 import java.io.IOException;
 import java.io.Serializable;
 import java.lang.reflect.InvocationTargetException;
+import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -48,6 +49,8 @@ import 
org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
 import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
 import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams;
@@ -60,7 +63,6 @@ import org.apache.kafka.connect.source.SourceTaskContext;
 import org.apache.kafka.connect.storage.OffsetStorageReader;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.DateTime;
-import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -90,54 +92,37 @@ import org.slf4j.LoggerFactory;
 public class KafkaSourceConsumerFn<T> extends DoFn<Map<String, String>, T> {
   private static final Logger LOG = 
LoggerFactory.getLogger(KafkaSourceConsumerFn.class);
   public static final String BEAM_INSTANCE_PROPERTY = "beam.parent.instance";
+  private static final Long DEFAULT_POLLING_TIMEOUT = 1000L;
 
   private final Class<? extends SourceConnector> connectorClass;
+  private final DebeziumIO.Read<T> spec;
   private final SourceRecordMapper<T> fn;
+  private final Long pollingTimeOut;
 
-  private final Long millisecondsToRun;
-  private final Integer maxRecords;
-
-  private static DateTime startTime;
+  private transient DateTime startTime;
   private static final Map<String, RestrictionTracker<OffsetHolder, 
Map<String, Object>>>
       restrictionTrackers = new ConcurrentHashMap<>();
 
-  /**
-   * Initializes the SDF with a time limit.
-   *
-   * @param connectorClass Supported Debezium connector class
-   * @param fn a SourceRecordMapper
-   * @param maxRecords Maximum number of records to fetch before finishing.
-   * @param millisecondsToRun Maximum time to run (in milliseconds)
-   */
-  @SuppressWarnings("unchecked")
-  KafkaSourceConsumerFn(
-      Class<?> connectorClass,
-      SourceRecordMapper<T> fn,
-      Integer maxRecords,
-      Long millisecondsToRun) {
-    this.connectorClass = (Class<? extends SourceConnector>) connectorClass;
-    this.fn = fn;
-    this.maxRecords = maxRecords;
-    this.millisecondsToRun = millisecondsToRun;
-  }
-
   /**
    * Initializes the SDF to be run indefinitely.
    *
    * @param connectorClass Supported Debezium connector class
-   * @param fn a SourceRecordMapper
-   * @param maxRecords Maximum number of records to fetch before finishing.
+   * @param spec a DebeziumIO.Read treansform
    */
-  KafkaSourceConsumerFn(Class<?> connectorClass, SourceRecordMapper<T> fn, 
Integer maxRecords) {
-    this(connectorClass, fn, maxRecords, null);
+  KafkaSourceConsumerFn(Class<?> connectorClass, DebeziumIO.Read<T> spec) {
+    // this(connectorClass, fn, maxRecords, null);
+    this.connectorClass = (Class<? extends SourceConnector>) connectorClass;
+    this.spec = spec;
+    this.fn = spec.getFormatFunction();
+    this.pollingTimeOut =
+        MoreObjects.firstNonNull(spec.getPollingTimeout(), 
DEFAULT_POLLING_TIMEOUT);
   }
 
   @SuppressFBWarnings("ST_WRITE_TO_STATIC_FROM_INSTANCE_METHOD")
   @GetInitialRestriction
   public OffsetHolder getInitialRestriction(@Element Map<String, String> 
unused)
       throws IOException {
-    KafkaSourceConsumerFn.startTime = new DateTime();
-    return new OffsetHolder(null, null, null, this.maxRecords, 
this.millisecondsToRun);
+    return new OffsetHolder(null, null, null, spec.getMaxNumberOfRecords(), 
spec.getMaxTimeToRun());
   }
 
   @NewTracker
@@ -211,6 +196,11 @@ public class KafkaSourceConsumerFn<T> extends 
DoFn<Map<String, String>, T> {
     return timestamp;
   }
 
+  @Setup
+  public void setup() {
+    startTime = DateTime.now();
+  }
+
   /**
    * Process the retrieved element and format it for output. Update all pending
    *
@@ -222,39 +212,61 @@ public class KafkaSourceConsumerFn<T> extends 
DoFn<Map<String, String>, T> {
    *     continue processing after 1 second. Otherwise, if we've reached a 
limit of elements, to
    *     stop processing.
    */
-  @DoFn.ProcessElement
+  @ProcessElement
   public ProcessContinuation process(
       @Element Map<String, String> element,
       RestrictionTracker<OffsetHolder, Map<String, Object>> tracker,
-      OutputReceiver<T> receiver)
-      throws Exception {
+      OutputReceiver<T> receiver) {
+
+    if (spec.getMaxNumberOfRecords() != null
+        && tracker.currentRestriction().fetchedRecords != null
+        && tracker.currentRestriction().fetchedRecords >= 
spec.getMaxNumberOfRecords()) {
+      return ProcessContinuation.stop();
+    }
+
     Map<String, String> configuration = new HashMap<>(element);
 
     // Adding the current restriction to the class object to be found by the 
database history
     register(tracker);
     configuration.put(BEAM_INSTANCE_PROPERTY, this.getHashCode());
 
-    SourceConnector connector = 
connectorClass.getDeclaredConstructor().newInstance();
-    connector.start(configuration);
-
-    SourceTask task = (SourceTask) 
connector.taskClass().getDeclaredConstructor().newInstance();
+    SourceConnector connector;
+    SourceTask task;
+    try {
+      connector = connectorClass.getDeclaredConstructor().newInstance();
+      connector.start(configuration);
+      task = (SourceTask) 
connector.taskClass().getDeclaredConstructor().newInstance();
+    } catch (InvocationTargetException
+        | InstantiationException
+        | IllegalAccessException
+        | NoSuchMethodException e) {
+      throw new RuntimeException(e);
+    }
 
+    Duration remainingTimeout = Duration.ofMillis(pollingTimeOut);
     try {
       Map<String, ?> consumerOffset = tracker.currentRestriction().offset;
       LOG.debug("--------- Consumer offset from Debezium Tracker: {}", 
consumerOffset);
 
-      task.initialize(new 
BeamSourceTaskContext(tracker.currentRestriction().offset));
+      task.initialize(new BeamSourceTaskContext(consumerOffset));
       task.start(connector.taskConfigs(1).get(0));
+      final Stopwatch pollTimer = Stopwatch.createUnstarted();
 
-      List<SourceRecord> records = task.poll();
+      while (Duration.ZERO.compareTo(remainingTimeout) < 0) {
+        pollTimer.reset().start();
+        List<SourceRecord> records = task.poll();
 
-      if (records == null) {
-        LOG.debug("-------- Pulled records null");
-        return ProcessContinuation.stop();
-      }
+        try {
+          remainingTimeout = remainingTimeout.minus(pollTimer.elapsed());
+        } catch (ArithmeticException e) {
+          remainingTimeout = Duration.ZERO;
+        }
+
+        if (records == null || records.isEmpty()) {
+          LOG.debug("-------- Pulled records null or empty");
+          break;
+        }
 
-      LOG.debug("-------- {} records found", records.size());
-      while (records != null && !records.isEmpty()) {
         for (SourceRecord record : records) {
           LOG.debug("-------- Record found: {}", record);
 
@@ -272,7 +284,6 @@ public class KafkaSourceConsumerFn<T> extends 
DoFn<Map<String, String>, T> {
           receiver.outputWithTimestamp(json, recordInstant);
         }
         task.commit();
-        records = task.poll();
       }
     } catch (Exception ex) {
       throw new RuntimeException("Error occurred when consuming changes from 
Database. ", ex);
@@ -283,12 +294,14 @@ public class KafkaSourceConsumerFn<T> extends 
DoFn<Map<String, String>, T> {
       task.stop();
     }
 
-    long elapsedTime = System.currentTimeMillis() - 
KafkaSourceConsumerFn.startTime.getMillis();
-    if (millisecondsToRun != null && millisecondsToRun > 0 && elapsedTime >= 
millisecondsToRun) {
-      return ProcessContinuation.stop();
-    } else {
-      return 
ProcessContinuation.resume().withResumeDelay(Duration.standardSeconds(1));
+    if (spec.getMaxTimeToRun() != null && spec.getMaxTimeToRun() > 0) {
+      long elapsedTime = System.currentTimeMillis() - startTime.getMillis();
+      if (elapsedTime >= spec.getMaxTimeToRun()) {
+        return ProcessContinuation.stop();
+      }
     }
+    return ProcessContinuation.resume()
+        
.withResumeDelay(org.joda.time.Duration.millis(remainingTimeout.toMillis()));
   }
 
   public String getHashCode() {
@@ -418,17 +431,8 @@ public class KafkaSourceConsumerFn<T> extends 
DoFn<Map<String, String>, T> {
     /**
      * Overriding {@link #tryClaim} in order to stop fetching records from the 
database.
      *
-     * <p>This works on two different ways:
-     *
-     * <h3>Number of records</h3>
-     *
-     * <p>This is the default behavior. Once the specified number of records 
has been reached, it
-     * will stop fetching them.
-     *
-     * <h3>Time based</h3>
-     *
-     * User may specify the amount of time the connector to be kept alive. 
Please see {@link
-     * KafkaSourceConsumerFn} for more details on this.
+     * <p>If number of record has been set, once the specified number of 
records has been reached,
+     * it will stop fetching them.
      *
      * @param position Currently not used
      * @return boolean
@@ -436,23 +440,20 @@ public class KafkaSourceConsumerFn<T> extends 
DoFn<Map<String, String>, T> {
     @Override
     public boolean tryClaim(Map<String, Object> position) {
       LOG.debug("-------------- Claiming {} used to have: {}", position, 
restriction.offset);
-      long elapsedTime = System.currentTimeMillis() - startTime.getMillis();
       int fetchedRecords =
-          this.restriction.fetchedRecords == null ? 0 : 
this.restriction.fetchedRecords + 1;
+          this.restriction.fetchedRecords == null ? 0 : 
this.restriction.fetchedRecords;
       LOG.debug("------------Fetched records {} / {}", fetchedRecords, 
this.restriction.maxRecords);
-      LOG.debug(
-          "-------------- Time running: {} / {}", elapsedTime, 
(this.restriction.millisToRun));
       this.restriction.offset = position;
-      this.restriction.fetchedRecords = fetchedRecords;
       LOG.debug("-------------- History: {}", this.restriction.history);
 
-      // If we've reached the maximum number of records OR the maximum time, 
we reject
-      // the attempt to claim.
-      // If we've reached neither, then we continue approve the claim.
-      return (this.restriction.maxRecords == null || fetchedRecords < 
this.restriction.maxRecords)
-          && (this.restriction.millisToRun == null
-              || this.restriction.millisToRun == -1
-              || elapsedTime < this.restriction.millisToRun);
+      // If we've reached the maximum number of records, we reject the attempt 
to claim.
+      // Otherwise, we approve the claim.
+      boolean claimed =
+          (this.restriction.maxRecords == null || fetchedRecords < 
this.restriction.maxRecords);
+      if (claimed) {
+        this.restriction.fetchedRecords = fetchedRecords + 1;
+      }
+      return claimed;
     }
 
     @Override
diff --git 
a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/KafkaSourceConsumerFnTest.java
 
b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/KafkaSourceConsumerFnTest.java
index 1df50b5e9ac..354e2589753 100644
--- 
a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/KafkaSourceConsumerFnTest.java
+++ 
b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/KafkaSourceConsumerFnTest.java
@@ -18,21 +18,27 @@
 package org.apache.beam.io.debezium;
 
 import com.google.common.testing.EqualsTester;
+import java.io.IOException;
 import java.io.Serializable;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
+import java.util.ListIterator;
 import java.util.Map;
+import javax.annotation.concurrent.NotThreadSafe;
 import org.apache.beam.io.debezium.KafkaSourceConsumerFn.OffsetHolder;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.MapCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestOutputReceiver;
 import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.values.PCollection;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
@@ -48,6 +54,7 @@ import org.apache.kafka.connect.source.SourceRecord;
 import org.apache.kafka.connect.source.SourceTask;
 import org.apache.kafka.connect.source.SourceTaskContext;
 import org.checkerframework.checker.nullness.qual.Nullable;
+import org.junit.After;
 import org.junit.Assert;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -56,6 +63,19 @@ import org.junit.runners.JUnit4;
 @RunWith(JUnit4.class)
 public class KafkaSourceConsumerFnTest implements Serializable {
 
+  static <T> DebeziumIO.Read<T> getSpec(SourceRecordMapper<T> fn, Integer 
maxRecords) {
+    DebeziumIO.Read<T> transform = DebeziumIO.<T>read().withFormatFunction(fn);
+    if (maxRecords > 0) {
+      transform = transform.withMaxNumberOfRecords(maxRecords);
+    }
+    return transform;
+  }
+
+  @After
+  public void cleanUp() {
+    CounterTask.resetCountTask();
+  }
+
   @Test
   public void testKafkaSourceConsumerFn() {
     Map<String, String> config =
@@ -76,9 +96,10 @@ public class KafkaSourceConsumerFnTest implements 
Serializable {
                 ParDo.of(
                     new KafkaSourceConsumerFn<>(
                         CounterSourceConnector.class,
-                        sourceRecord ->
-                            ((Struct) 
sourceRecord.value()).getInt64("value").intValue(),
-                        10)))
+                        getSpec(
+                            sourceRecord ->
+                                ((Struct) 
sourceRecord.value()).getInt64("value").intValue(),
+                            10))))
             .setCoder(VarIntCoder.of());
 
     PAssert.that(counts).containsInAnyOrder(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
@@ -104,8 +125,10 @@ public class KafkaSourceConsumerFnTest implements 
Serializable {
             ParDo.of(
                 new KafkaSourceConsumerFn<>(
                     CounterSourceConnector.class,
-                    sourceRecord -> ((Struct) 
sourceRecord.value()).getInt64("value").intValue(),
-                    1)))
+                    getSpec(
+                        sourceRecord ->
+                            ((Struct) 
sourceRecord.value()).getInt64("value").intValue(),
+                        1))))
         .setCoder(VarIntCoder.of());
 
     pipeline.run().waitUntilFinish();
@@ -159,6 +182,36 @@ public class KafkaSourceConsumerFnTest implements 
Serializable {
             null));
     tester.testEquals();
   }
+
+  @Test(timeout = 2000)
+  public void testMaxTimeToRun() throws IOException {
+    KafkaSourceConsumerFn<Integer> kafkaSourceConsumerFn =
+        new KafkaSourceConsumerFn<>(
+            CounterSourceConnector.class,
+            KafkaSourceConsumerFnTest.getSpec(
+                    sourceRecord -> ((Struct) 
sourceRecord.value()).getInt64("value").intValue(), 0)
+                .withPollingTimeout(100L)
+                .withMaxTimeToRun(500L)); // Run for 0.5 s
+    kafkaSourceConsumerFn.setup();
+    OffsetHolder initialRestriction = 
kafkaSourceConsumerFn.getInitialRestriction(null);
+    RestrictionTracker<OffsetHolder, Map<String, Object>> tracker =
+        kafkaSourceConsumerFn.newTracker(initialRestriction);
+    Map<String, String> config =
+        ImmutableMap.of("from", "1", "delay", "0.4", "sleep", "1", "topic", 
"any");
+    TestOutputReceiver<Integer> receiver = new TestOutputReceiver<>();
+    while (true) {
+      DoFn.ProcessContinuation continuation =
+          kafkaSourceConsumerFn.process(config, tracker, receiver);
+      if (continuation == DoFn.ProcessContinuation.stop()) {
+        break;
+      }
+    }
+    // Check results are in order
+    ListIterator<Integer> it = receiver.getOutputs().listIterator();
+    while (it.hasNext()) {
+      Assert.assertEquals(it.nextIndex(), it.next() - 1);
+    }
+  }
 }
 
 class CounterSourceConnector extends SourceConnector {
@@ -173,9 +226,15 @@ class CounterSourceConnector extends SourceConnector {
     protected static ConfigDef configDef() {
       return new ConfigDef()
           .define("from", ConfigDef.Type.INT, ConfigDef.Importance.HIGH, 
"Number to start from")
-          .define("to", ConfigDef.Type.INT, ConfigDef.Importance.HIGH, "Number 
to go to")
+          .define("to", ConfigDef.Type.INT, -1, ConfigDef.Importance.HIGH, 
"Number to go to")
           .define(
               "delay", ConfigDef.Type.DOUBLE, ConfigDef.Importance.HIGH, "Time 
between each event")
+          .define(
+              "sleep",
+              ConfigDef.Type.INT,
+              0,
+              ConfigDef.Importance.MEDIUM,
+              "Millis to sleep in each poll")
           .define(
               "topic",
               ConfigDef.Type.STRING,
@@ -205,8 +264,9 @@ class CounterSourceConnector extends SourceConnector {
     return Collections.singletonList(
         ImmutableMap.of(
             "from", this.connectorConfig.props.get("from"),
-            "to", this.connectorConfig.props.get("to"),
+            "to", this.connectorConfig.props.getOrDefault("to", "-1"),
             "delay", this.connectorConfig.props.get("delay"),
+            "sleep", this.connectorConfig.props.getOrDefault("sleep", "0"),
             "topic", this.connectorConfig.props.get("topic")));
   }
 
@@ -224,11 +284,13 @@ class CounterSourceConnector extends SourceConnector {
   }
 }
 
+@NotThreadSafe
 class CounterTask extends SourceTask {
   private static int countStopTasks = 0;
   private String topic = "";
   private Integer from = 0;
   private Integer to = 0;
+  private Integer sleep = 0;
   private Double delay = 0.0;
 
   private Long start = System.currentTimeMillis();
@@ -266,8 +328,9 @@ class CounterTask extends SourceTask {
   public void start(Map<String, String> props) {
     this.topic = props.getOrDefault("topic", "");
     this.from = Integer.parseInt(props.getOrDefault("from", "0"));
-    this.to = Integer.parseInt(props.getOrDefault("to", "0"));
+    this.to = Integer.parseInt(props.getOrDefault("to", "-1"));
     this.delay = Double.parseDouble(props.getOrDefault("delay", "0"));
+    this.sleep = Integer.parseInt(props.getOrDefault("sleep", "0"));
 
     if (this.lastOffset != null) {
       return;
@@ -296,7 +359,7 @@ class CounterTask extends SourceTask {
     Long secondsSinceStart = (callTime - this.start) / 1000;
     Long recordsToOutput = Math.round(Math.floor(secondsSinceStart / 
this.delay));
 
-    while (this.last < this.to) {
+    while (this.to == -1 || this.last < this.to) {
       this.last = this.last + 1;
       Map<String, Integer> sourcePartition = 
Collections.singletonMap(PARTITION_FIELD, 1);
       Map<String, Long> sourceOffset =
@@ -316,7 +379,9 @@ class CounterTask extends SourceTask {
         break;
       }
     }
-
+    if (this.sleep > 0) {
+      Thread.sleep(this.sleep);
+    }
     return records;
   }
 
@@ -328,4 +393,8 @@ class CounterTask extends SourceTask {
   public static int getCountTasks() {
     return CounterTask.countStopTasks;
   }
+
+  public static void resetCountTask() {
+    CounterTask.countStopTasks = 0;
+  }
 }
diff --git 
a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/OffsetTrackerTest.java
 
b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/OffsetTrackerTest.java
index dc4338ac048..c4ddfbef4f2 100644
--- 
a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/OffsetTrackerTest.java
+++ 
b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/OffsetTrackerTest.java
@@ -37,7 +37,9 @@ public class OffsetTrackerTest implements Serializable {
     Map<String, Object> position = new HashMap<>();
     KafkaSourceConsumerFn<String> kafkaSourceConsumerFn =
         new KafkaSourceConsumerFn<String>(
-            MySqlConnector.class, new 
SourceRecordJson.SourceRecordJsonMapper(), maxNumRecords);
+            MySqlConnector.class,
+            KafkaSourceConsumerFnTest.getSpec(
+                new SourceRecordJson.SourceRecordJsonMapper(), maxNumRecords));
     KafkaSourceConsumerFn.OffsetHolder restriction =
         kafkaSourceConsumerFn.getInitialRestriction(new HashMap<>());
     KafkaSourceConsumerFn.OffsetTracker tracker =
@@ -48,25 +50,4 @@ public class OffsetTrackerTest implements Serializable {
     }
     assertFalse("OffsetTracker should stop", tracker.tryClaim(position));
   }
-
-  @Test
-  public void testRestrictByAmountOfTime() throws IOException, 
InterruptedException {
-    Map<String, Object> position = new HashMap<>();
-    KafkaSourceConsumerFn<String> kafkaSourceConsumerFn =
-        new KafkaSourceConsumerFn<String>(
-            MySqlConnector.class,
-            new SourceRecordJson.SourceRecordJsonMapper(),
-            100000,
-            500L); // Run for 500 ms
-    KafkaSourceConsumerFn.OffsetHolder restriction =
-        kafkaSourceConsumerFn.getInitialRestriction(new HashMap<>());
-    KafkaSourceConsumerFn.OffsetTracker tracker =
-        new KafkaSourceConsumerFn.OffsetTracker(restriction);
-
-    assertTrue(tracker.tryClaim(position));
-
-    Thread.sleep(1000); // Sleep for a whole 2 seconds
-
-    assertFalse(tracker.tryClaim(position));
-  }
 }

Reply via email to