This is an automated email from the ASF dual-hosted git repository.
ajothomas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git
The following commit(s) were added to refs/heads/master by this push:
new 63c86b5e6 Add new configuration allowing to keep processing when there
are fatal exceptions or timeout (#1708)
63c86b5e6 is described below
commit 63c86b5e661b0a0a9a0b33a0851379b8786ee36e
Author: Haolan Ye <[email protected]>
AuthorDate: Mon Nov 25 10:56:07 2024 -0800
Add new configuration allowing to keep processing when there are fatal
exceptions or timeout (#1708)
---
.../java/org/apache/samza/config/TaskConfig.java | 27 +++
.../org/apache/samza/container/TaskInstance.scala | 57 ++++-
.../samza/container/TaskInstanceMetrics.scala | 6 +-
.../apache/samza/container/TestTaskInstance.scala | 249 +++++++++++++++++++--
4 files changed, 313 insertions(+), 26 deletions(-)
diff --git a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
index 0f168be18..276f3812f 100644
--- a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
@@ -65,6 +65,21 @@ public class TaskConfig extends MapConfig {
public static final String COMMIT_TIMEOUT_MS = "task.commit.timeout.ms";
static final long DEFAULT_COMMIT_TIMEOUT_MS =
Duration.ofMinutes(30).toMillis();
+ // Flag to indicate whether to skip commit during failures (exceptions or
timeouts)
+ // The number of allowed successive commit exceptions and timeouts are
controlled by the following two configs.
+ public static final String SKIP_COMMIT_DURING_FAILURES_ENABLED =
"task.commit.skip.commit.during.failures.enabled";
+ private static final boolean DEFAULT_SKIP_COMMIT_DURING_FAILURES_ENABLED =
false;
+
+ // Maximum number of allowed successive commit exceptions.
+ // If the number of successive commit exceptions exceeds this limit, the
task will be shut down.
+ public static final String SKIP_COMMIT_EXCEPTION_MAX_LIMIT =
"task.commit.skip.commit.exception.max.limit";
+ private static final int DEFAULT_SKIP_COMMIT_EXCEPTION_MAX_LIMIT = 5;
+
+ // Maximum number of allowed successive commit timeouts.
+ // If the number of successive commit timeout exceeds this limit, the task
will be shut down.
+ public static final String SKIP_COMMIT_TIMEOUT_MAX_LIMIT =
"task.commit.skip.commit.timeout.max.limit";
+ private static final int DEFAULT_SKIP_COMMIT_TIMEOUT_MAX_LIMIT = 2;
+
// how long to wait for a clean shutdown
public static final String TASK_SHUTDOWN_MS = "task.shutdown.ms";
static final long DEFAULT_TASK_SHUTDOWN_MS = 30000L;
@@ -418,4 +433,16 @@ public class TaskConfig extends MapConfig {
public double getWatermarkQuorumSizePercentage() {
return getDouble(WATERMARK_QUORUM_SIZE_PERCENTAGE,
DEFAULT_WATERMARK_QUORUM_SIZE_PERCENTAGE);
}
+
+ public boolean getSkipCommitDuringFailuresEnabled() {
+ return getBoolean(SKIP_COMMIT_DURING_FAILURES_ENABLED,
DEFAULT_SKIP_COMMIT_DURING_FAILURES_ENABLED);
+ }
+
+ public int getSkipCommitExceptionMaxLimit() {
+ return getInt(SKIP_COMMIT_EXCEPTION_MAX_LIMIT,
DEFAULT_SKIP_COMMIT_EXCEPTION_MAX_LIMIT);
+ }
+
+ public int getSkipCommitTimeoutMaxLimit() {
+ return getInt(SKIP_COMMIT_TIMEOUT_MAX_LIMIT,
DEFAULT_SKIP_COMMIT_TIMEOUT_MAX_LIMIT);
+ }
}
diff --git
a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index 70d9ca380..f5d13106f 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -38,7 +38,7 @@ import
org.apache.samza.util.ScalaJavaUtil.JavaOptionals.toRichOptional
import org.apache.samza.util.{Logging, ReflectionUtil, ScalaJavaUtil}
import java.util
-import java.util.concurrent.atomic.AtomicReference
+import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import java.util.function.BiConsumer
import java.util.function.Function
import scala.collection.JavaConversions._
@@ -133,8 +133,13 @@ class TaskInstance(
val checkpointWriteVersions = new
TaskConfig(config).getCheckpointWriteVersions
@volatile var lastCommitStartTimeMs = System.currentTimeMillis()
+ val commitExceptionCounter = new AtomicInteger(0)
+ val commitTimeoutCounter = new AtomicInteger(0)
val commitMaxDelayMs = taskConfig.getCommitMaxDelayMs
val commitTimeoutMs = taskConfig.getCommitTimeoutMs
+ val skipCommitDuringFailureEnabled =
taskConfig.getSkipCommitDuringFailuresEnabled
+ val skipCommitExceptionMaxLimit = taskConfig.getSkipCommitExceptionMaxLimit
+ val skipCommitTimeoutMaxLimit = taskConfig.getSkipCommitTimeoutMaxLimit
val commitInProgress = new Semaphore(1)
val commitException = new AtomicReference[Exception]()
@@ -312,10 +317,22 @@ class TaskInstance(
val commitStartNs = System.nanoTime()
// first check if there were any unrecoverable errors during the async
stage of the pending commit
- // and if so, shut down the container.
+ // If there is unrecoverable error, increment the metric and the counter.
+ // Shutdown the container in the following scenarios:
+ // 1. skipCommitDuringFailureEnabled is not enabled
+ // 2. skipCommitDuringFailureEnabled is enabled but the number of
exceptions exceeded the max count
+ // Otherwise, ignore the exception.
if (commitException.get() != null) {
- throw new SamzaException("Unrecoverable error during pending commit for
taskName: %s." format taskName,
- commitException.get())
+ metrics.commitExceptions.inc()
+ commitExceptionCounter.incrementAndGet()
+ if (!skipCommitDuringFailureEnabled || commitExceptionCounter.get() >
skipCommitExceptionMaxLimit) {
+ throw new SamzaException("Unrecoverable error during pending commit
for taskName: %s. Exception Counter: %s"
+ format (taskName, commitExceptionCounter.get()),
commitException.get())
+ } else {
+ warn("Ignored the commit failure for taskName %s. Exception Counter:
%s."
+ format (taskName, commitExceptionCounter.get()),
commitException.get())
+ commitException.set(null)
+ }
}
// if no commit is in progress for this task, continue with this commit.
@@ -328,7 +345,7 @@ class TaskInstance(
if (timeSinceLastCommit < commitMaxDelayMs) {
info("Skipping commit for taskName: %s since another commit is in
progress. " +
"%s ms have elapsed since the pending commit started." format
(taskName, timeSinceLastCommit))
- metrics.commitsSkipped.set(metrics.commitsSkipped.getValue + 1)
+ metrics.commitsSkipped.inc()
return
} else {
warn("Blocking processing for taskName: %s until in-flight commit is
complete. " +
@@ -336,13 +353,28 @@ class TaskInstance(
"which is greater than the max allowed commit delay: %s."
format (taskName, timeSinceLastCommit, commitMaxDelayMs))
+ // Wait for the previous commit to complete within the timeout.
+ // If it doesn't complete within the timeout, increment metric and the
counter.
+ // Shutdown the container in the following scenarios:
+ // 1. skipCommitDuringFailureEnabled is not enabled
+ // 2. skipCommitDuringFailureEnabled is enabled but the number of
timeouts exceeded the max count
+ // Otherwise, ignore the timeout.
if (!commitInProgress.tryAcquire(commitTimeoutMs,
TimeUnit.MILLISECONDS)) {
val timeSinceLastCommit = System.currentTimeMillis() -
lastCommitStartTimeMs
- metrics.commitsTimedOut.set(metrics.commitsTimedOut.getValue + 1)
- throw new SamzaException("Timeout waiting for pending commit for
taskName: %s to finish. " +
- "%s ms have elapsed since the pending commit started. Max allowed
commit delay is %s ms " +
- "and commit timeout beyond that is %s ms" format (taskName,
timeSinceLastCommit,
- commitMaxDelayMs, commitTimeoutMs))
+ metrics.commitsTimedOut.inc()
+ commitTimeoutCounter.incrementAndGet()
+ if (!skipCommitDuringFailureEnabled || commitTimeoutCounter.get() >
skipCommitTimeoutMaxLimit) {
+ throw new SamzaException("Timeout waiting for pending commit for
taskName: %s to finish. " +
+ "%s ms have elapsed since the pending commit started. Max
allowed commit delay is %s ms " +
+ "and commit timeout beyond that is %s ms. Timeout Counter: %s"
format (taskName, timeSinceLastCommit,
+ commitMaxDelayMs, commitTimeoutMs, commitTimeoutCounter.get()))
+ } else {
+ warn("Ignoring commit timeout for taskName: %s. %s ms have elapsed
since another commit started. " +
+ "Max allowed commit delay is %s ms and commit timeout beyond
that is %s ms. Timeout Counter: %s."
+ format (taskName, timeSinceLastCommit, commitMaxDelayMs,
commitTimeoutMs, commitTimeoutCounter.get()))
+ commitInProgress.release()
+ return
+ }
}
}
}
@@ -426,7 +458,7 @@ class TaskInstance(
}
})
- metrics.lastCommitNs.set(System.nanoTime() - commitStartNs)
+ metrics.lastCommitNs.set(System.nanoTime())
metrics.commitSyncNs.update(System.nanoTime() - commitStartNs)
debug("Finishing sync stage of commit for taskName: %s checkpointId: %s"
format (taskName, checkpointId))
}
@@ -531,8 +563,11 @@ class TaskInstance(
"Saved exception under Caused By.", commitException.get())
}
} else {
+ commitExceptionCounter.set(0)
+ commitTimeoutCounter.set(0)
metrics.commitAsyncNs.update(System.nanoTime() - asyncStageStartNs)
metrics.commitNs.update(System.nanoTime() - commitStartNs)
+ metrics.lastCommitAsyncTimestamp.set(System.nanoTime())
}
} finally {
// release the permit indicating that previous commit is complete.
diff --git
a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
index 54d366525..02674fb7e 100644
---
a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
+++
b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
@@ -38,10 +38,12 @@ class TaskInstanceMetrics(
val pendingMessages = newGauge("pending-messages", 0)
val messagesInFlight = newGauge("messages-in-flight", 0)
val asyncCallbackCompleted = newCounter("async-callback-complete-calls")
- val commitsTimedOut = newGauge("commits-timed-out", 0)
- val commitsSkipped = newGauge("commits-skipped", 0)
+ val commitsTimedOut = newCounter("commits-timed-out")
+ val commitsSkipped = newCounter("commits-skipped")
+ val commitExceptions = newCounter("commit-exceptions")
val commitNs = newTimer("commit-ns")
val lastCommitNs = newGauge("last-commit-ns", 0L)
+ val lastCommitAsyncTimestamp = newGauge("last-async-commit-timestamp", 0L)
val commitSyncNs = newTimer("commit-sync-ns")
val commitAsyncNs = newTimer("commit-async-ns")
val snapshotNs = newTimer("snapshot-ns")
diff --git
a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
index 6afec52e7..ff52b4006 100644
---
a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
+++
b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
@@ -277,7 +277,7 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -370,7 +370,7 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
val uploadTimer = mock[Timer]
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val inputOffsets = Map(SYSTEM_STREAM_PARTITION -> "4").asJava
@@ -431,7 +431,7 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
val uploadTimer = mock[Timer]
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
@@ -504,10 +504,12 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+ val commitExceptionsGauge = mock[Counter]
+ when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge)
val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -556,10 +558,12 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+ val commitExceptionsGauge = mock[Counter]
+ when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge)
val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -608,10 +612,12 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+ val commitExceptionsGauge = mock[Counter]
+ when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge)
val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -661,10 +667,12 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+ val commitExceptionsGauge = mock[Counter]
+ when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge)
val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -714,10 +722,12 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+ val commitExceptionsGauge = mock[Counter]
+ when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge)
val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -768,7 +778,7 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
@@ -828,7 +838,7 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
@@ -859,7 +869,7 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
taskInstance.commit
- verify(skippedCounter).set(1)
+ verify(skippedCounter, times(1)).inc()
verify(commitsCounter, times(1)).inc() // should only have been
incremented once on the initial commit
verify(snapshotTimer).update(anyLong())
@@ -884,7 +894,7 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
@@ -947,7 +957,7 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
val cleanUpTimer = mock[Timer]
when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
- val skippedCounter = mock[Gauge[Int]]
+ val skippedCounter = mock[Counter]
when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
val lastCommitGauge = mock[Gauge[Long]]
when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
@@ -1004,6 +1014,208 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
verify(snapshotTimer, times(2)).update(anyLong())
}
+ @Test
+ def testSkipExceptionFromFirstCommitAndContinueSecondCommit(): Unit = {
+ val commitsCounter = mock[Counter]
+ when(this.metrics.commits).thenReturn(commitsCounter)
+ val snapshotTimer = mock[Timer]
+ when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+ val uploadTimer = mock[Timer]
+ when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+ val commitTimer = mock[Timer]
+ when(this.metrics.commitNs).thenReturn(commitTimer)
+ val commitSyncTimer = mock[Timer]
+ when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+ val commitAsyncTimer = mock[Timer]
+ when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+ val cleanUpTimer = mock[Timer]
+ when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+ val skippedCounter = mock[Counter]
+ when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+ val lastCommitGauge = mock[Gauge[Long]]
+ when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+ val commitExceptionCounter = mock[Counter]
+ when(this.metrics.commitExceptions).thenReturn(commitExceptionCounter)
+
+ val taskConfigsMap = new util.HashMap[String, String]()
+ taskConfigsMap.put("task.commit.ms", "-1")
+ taskConfigsMap.put("task.commit.max.delay.ms", "-1")
+ taskConfigsMap.put("task.commit.timeout.ms", "2000000")
+ // skip commit if exception occurs during the commit
+ taskConfigsMap.put("task.commit.skip.commit.during.failures.enabled",
"true")
+ // should throw exception if second commit exception occurs
+ taskConfigsMap.put("task.commit.skip.commit.exception.max.limit", "1")
+ when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap))
+ setupTaskInstance(None, ForkJoinPool.commonPool())
+
+ val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+ inputOffsets.put(SYSTEM_STREAM_PARTITION, "4")
+ val stateCheckpointMarkers: util.Map[String, String] = new
util.HashMap[String, String]()
+
when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+ // Ensure the second commit proceeds without exceptions
+ when(this.taskCommitManager.upload(any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(
+
Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME,
stateCheckpointMarkers)))
+ // exception during the first commit
+ when(this.taskCommitManager.upload(any(), any()))
+ .thenReturn(FutureUtil.failedFuture[util.Map[String, util.Map[String,
String]]](new RuntimeException))
+
+ // First commit fails but should not throw exception
+ taskInstance.commit
+ verify(commitsCounter).inc()
+ verify(snapshotTimer).update(anyLong())
+ verifyZeroInteractions(uploadTimer)
+ verifyZeroInteractions(commitTimer)
+ verifyZeroInteractions(skippedCounter)
+ waitForCommitExceptionIsSet(100, 5)
+ // Second commit should succeed
+ taskInstance.commit
+ verify(commitsCounter, times(2)).inc() // should only have been
incremented twice - once for each commit
+ verify(commitExceptionCounter).inc()
+ }
+
+ @Test
+ def testCommitThrowsIfAllowSkipCommitButExceptionCountReachMaxLimit(): Unit
= {
+ val commitsCounter = mock[Counter]
+ when(this.metrics.commits).thenReturn(commitsCounter)
+ val snapshotTimer = mock[Timer]
+ when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+ val uploadTimer = mock[Timer]
+ when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+ val commitTimer = mock[Timer]
+ when(this.metrics.commitNs).thenReturn(commitTimer)
+ val commitSyncTimer = mock[Timer]
+ when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+ val commitAsyncTimer = mock[Timer]
+ when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+ val cleanUpTimer = mock[Timer]
+ when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+ val skippedCounter = mock[Counter]
+ when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+ val lastCommitGauge = mock[Gauge[Long]]
+ when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+ val commitExceptionCounter = mock[Counter]
+ when(this.metrics.commitExceptions).thenReturn(commitExceptionCounter)
+
+ val taskConfigsMap = new util.HashMap[String, String]()
+ taskConfigsMap.put("task.commit.ms", "-1")
+ taskConfigsMap.put("task.commit.max.delay.ms", "-1")
+ taskConfigsMap.put("task.commit.timeout.ms", "2000000")
+ // skip commit if exception occurs during the commit
+ taskConfigsMap.put("task.commit.skip.commit.during.failures.enabled",
"true")
+ // should throw exception if second commit exception occurs
+ taskConfigsMap.put("task.commit.skip.commit.exception.max.limit", "1")
+ when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap))
+ setupTaskInstance(None, ForkJoinPool.commonPool())
+
+ val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+ inputOffsets.put(SYSTEM_STREAM_PARTITION, "4")
+
when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+ // exception for commits
+ when(this.taskCommitManager.upload(any(), any()))
+ .thenReturn(FutureUtil.failedFuture[util.Map[String, util.Map[String,
String]]](new RuntimeException))
+
+ // First commit fails but should not throw exception
+ taskInstance.commit
+ waitForCommitExceptionIsSet(100, 5)
+ // Second commit fails but should not throw exception
+ taskInstance.commit
+ verify(commitExceptionCounter).inc()
+ verify(commitsCounter, times(2)).inc()
+ verify(snapshotTimer, times(2)).update(anyLong())
+ verifyZeroInteractions(uploadTimer)
+ verifyZeroInteractions(commitTimer)
+ verifyZeroInteractions(skippedCounter)
+ waitForCommitExceptionIsSet(100, 5)
+ // third commit should fail as the the commit exception counter is greater
than the max limit
+ try {
+ taskInstance.commit
+ fail("Should have thrown an exception if exception count reached the max
limit.")
+ } catch {
+ case e: Exception =>
+ // expected
+ }
+ verify(commitExceptionCounter, times(2)).inc()
+ verify(commitsCounter, times(2)).inc()
+ }
+
+ @Test
+ def testCommitThrowsIfAllowSkipTimeoutButTimeoutCountReachMaxLimit(): Unit =
{
+ val commitsCounter = mock[Counter]
+ when(this.metrics.commits).thenReturn(commitsCounter)
+ val snapshotTimer = mock[Timer]
+ when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+ val commitTimer = mock[Timer]
+ when(this.metrics.commitNs).thenReturn(commitTimer)
+ val commitSyncTimer = mock[Timer]
+ when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+ val commitAsyncTimer = mock[Timer]
+ when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+ val uploadTimer = mock[Timer]
+ when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+ val cleanUpTimer = mock[Timer]
+ when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+ val skippedCounter = mock[Counter]
+ when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+ val commitsTimedOutCounter = mock[Counter]
+ when(this.metrics.commitsTimedOut).thenReturn(commitsTimedOutCounter)
+ val lastCommitGauge = mock[Gauge[Long]]
+ when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+ val commitExceptionCounter = mock[Counter]
+ when(this.metrics.commitExceptions).thenReturn(commitExceptionCounter)
+
+ val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+ inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+ val changelogSSP = new SystemStreamPartition(new SystemStream(SYSTEM_NAME,
"test-changelog-stream"), new Partition(0))
+
+ val stateCheckpointMarkers: util.Map[String, String] = new
util.HashMap[String, String]()
+ val stateCheckpointMarker = KafkaStateCheckpointMarker.serialize(new
KafkaStateCheckpointMarker(changelogSSP, "5"))
+ stateCheckpointMarkers.put("storeName", stateCheckpointMarker)
+
when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+
+ val snapshotSCMs =
ImmutableMap.of(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME,
stateCheckpointMarkers)
+ when(this.taskCommitManager.snapshot(any())).thenReturn(snapshotSCMs)
+ val snapshotSCMFuture: CompletableFuture[util.Map[String, util.Map[String,
String]]] =
+ CompletableFuture.completedFuture(snapshotSCMs)
+
+ when(this.taskCommitManager.upload(any(),
Matchers.eq(snapshotSCMs))).thenReturn(snapshotSCMFuture) // kafka is no-op
+
+ val cleanUpFuture = new CompletableFuture[Void]()
+ when(this.taskCommitManager.cleanUp(any(),
any())).thenReturn(cleanUpFuture)
+
+ // use a separate executor to perform async operations on to test caller
thread blocking behavior
+ val taskConfigsMap = new util.HashMap[String, String]()
+ taskConfigsMap.put("task.commit.ms", "-1")
+ // "block" immediately if previous commit async stage not complete
+ taskConfigsMap.put("task.commit.max.delay.ms", "-1")
+ taskConfigsMap.put("task.commit.timeout.ms", "0") // throw exception
immediately if blocked
+ taskConfigsMap.put("task.commit.skip.commit.during.failures.enabled",
"true")
+ // should throw exception if second commit timeout occurs
+ taskConfigsMap.put("task.commit.skip.commit.timeout.max.limit", "1")
+ when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap))
// override default behavior
+
+ setupTaskInstance(None, ForkJoinPool.commonPool())
+
+ taskInstance.commit // async stage will not complete until cleanUpFuture
is completed
+ taskInstance.commit // second commit found commit timeout and release the
semaphore
+
+ verifyZeroInteractions(commitExceptionCounter)
+ verifyZeroInteractions(skippedCounter)
+ verify(commitsTimedOutCounter).inc()
+ verify(commitsCounter, times(1)).inc() // should only have been
incremented once now - second commit was skipped
+ taskInstance.commit // third commit should proceed without any issues and
acquire the semaphore
+ try {
+ taskInstance.commit // fourth commit should throw exception as the
timeout count reached the max limit
+ fail("Should have thrown an exception due to exceeding timeout limit.")
+ } catch {
+ case e: Exception =>
+ // expected
+ }
+ verify(commitsTimedOutCounter, times(2)).inc() // incremented twice
(second and fourth commit)
+ verify(commitsCounter, times(2)).inc() // incremented twice (first and
third commit)
+ cleanUpFuture.complete(null) // just to unblock shared executor
+ }
+
/**
* Given that no application task context factory is provided, then no
lifecycle calls should be made.
@@ -1091,6 +1303,17 @@ class TestTaskInstance extends AssertionsForJUnit with
MockitoSugar {
externalContextOption = Some(this.externalContext), elasticityFactor =
elasticityFactor)
}
+ private def waitForCommitExceptionIsSet(sleepTimeInMs: Int, maxRetry: Int):
Unit = {
+ var retries = 0
+ while (taskInstance.commitException.get() == null && retries < maxRetry) {
+ retries += 1
+ Thread.sleep(sleepTimeInMs)
+ }
+ if (taskInstance.commitException.get() == null) {
+ fail("Should have set the commit exception.")
+ }
+ }
+
/**
* Task type which has all task traits, which can be mocked.
*/