This is an automated email from the ASF dual-hosted git repository.
xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new a4c1f07b [api][runtime][java] Add durable reconcile support with
reconciler callables (#600)
a4c1f07b is described below
commit a4c1f07b1ab839d0f68fe5a16253bf131d5e9b57
Author: Joey Tong <[email protected]>
AuthorDate: Tue Apr 7 15:19:29 2026 +0800
[api][runtime][java] Add durable reconcile support with reconciler
callables (#600)
---
.../flink/agents/api/context/DurableCallable.java | 29 +
.../flink/agents/api/context/RunnerContext.java | 6 +
.../agents/runtime/actionstate/ActionState.java | 10 +
.../agents/runtime/actionstate/CallResult.java | 108 +++-
.../runtime/context/JavaRunnerContextImpl.java | 39 +-
.../agents/runtime/context/RunnerContextImpl.java | 377 ++++++++++---
.../runtime/actionstate/ActionStateSerdeTest.java | 77 ++-
.../runtime/actionstate/ActionStateTest.java | 14 +
.../agents/runtime/actionstate/CallResultTest.java | 27 +-
.../context/DurableExecutionContextTest.java | 2 +-
...vaRunnerContextImplDurableExecuteAsyncTest.java | 251 +++++++++
.../RunnerContextImplDurableExecuteTest.java | 295 ++++++++++
.../runtime/context/TestDurableCallables.java | 110 ++++
.../operator/ActionExecutionOperatorTest.java | 627 ++++++++++++++++-----
14 files changed, 1713 insertions(+), 259 deletions(-)
diff --git
a/api/src/main/java/org/apache/flink/agents/api/context/DurableCallable.java
b/api/src/main/java/org/apache/flink/agents/api/context/DurableCallable.java
index fb77a747..73f062b1 100644
--- a/api/src/main/java/org/apache/flink/agents/api/context/DurableCallable.java
+++ b/api/src/main/java/org/apache/flink/agents/api/context/DurableCallable.java
@@ -17,6 +17,10 @@
*/
package org.apache.flink.agents.api.context;
+import javax.annotation.Nullable;
+
+import java.util.concurrent.Callable;
+
/**
* A callable interface for durable execution that requires a stable
identifier.
*
@@ -46,4 +50,29 @@ public interface DurableCallable<T> {
* must be JSON-serializable.
*/
T call() throws Exception;
+
+ /**
+ * Returns an optional callable used to reconcile an in-flight durable
call during recovery.
+ *
+ * <p>Return {@code null} to disable reconcile for this durable call and
fall back to the
+ * existing durable execution behavior. During recovery, the runtime
replays a previously
+ * completed durable result when one is available; otherwise it executes
the original {@link
+ * #call()}.
+ *
+ * <p>If a reconcile callable is provided, the runtime invokes it only
when recovery revisits
+ * this durable call and finds that the original execution result has not
yet been persisted.
+ *
+ * <p>The reconcile callable should follow these rules:
+ *
+ * <ul>
+ * <li>Return the result to provide the recovered successful outcome for
this durable call.
+ * The runtime persists and replays that recovered result.
+ * <li>Throw an exception if reconcile cannot provide a successful
outcome. The exception is
+ * propagated to the caller, and the runtime does not persist a
recovered terminal outcome
+ * for this durable call.
+ * </ul>
+ */
+ default @Nullable Callable<T> reconciler() {
+ return null;
+ }
}
diff --git
a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
index 0810752a..06cd2b38 100644
--- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
+++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
@@ -115,6 +115,9 @@ public interface RunnerContext {
* <p>The result will be stored and returned from cache during job
recovery. The callable is
* executed synchronously, blocking the operator until completion.
*
+ * <p>If the callable provides a reconcile callable via {@link
DurableCallable#reconciler()},
+ * recovery may invoke it for an in-flight durable call.
+ *
* <p>Access to memory and sendEvent are prohibited within the callable.
*/
<T> T durableExecute(DurableCallable<T> callable) throws Exception;
@@ -128,6 +131,9 @@ public interface RunnerContext {
*
* <p>The result will be stored and returned from cache during job
recovery.
*
+ * <p>If the callable provides a reconcile callable via {@link
DurableCallable#reconciler()},
+ * recovery may invoke it for an in-flight durable call.
+ *
* <p>Access to memory and sendEvent are prohibited within the callable.
*/
<T> T durableExecuteAsync(DurableCallable<T> callable) throws Exception;
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
index 031928ad..9fe551d9 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
@@ -122,6 +122,16 @@ public class ActionState {
callResults.add(callResult);
}
+ /**
+ * Replaces the call result at the specified index.
+ *
+ * @param index the index to replace
+ * @param callResult the call result to store at the index
+ */
+ public void replaceCallResult(int index, CallResult callResult) {
+ callResults.set(index, callResult);
+ }
+
/**
* Gets the call result at the specified index.
*
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallResult.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallResult.java
index cb9c5338..62906d81 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallResult.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallResult.java
@@ -18,6 +18,7 @@
package org.apache.flink.agents.runtime.actionstate;
import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays;
import java.util.Objects;
@@ -34,6 +35,13 @@ import java.util.Objects;
*/
public class CallResult {
+ /** Persisted status of the durable call. */
+ private enum Status {
+ PENDING,
+ SUCCEEDED,
+ FAILED
+ }
+
/** Function identifier: module+qualname for Python, or method signature
for Java. */
private final String functionId;
@@ -46,12 +54,17 @@ public class CallResult {
/** Serialized exception info if the call failed (null if the call
succeeded). */
private final byte[] exceptionPayload;
+ /** Persisted status of the durable call. Null indicates legacy state
written before status. */
+ @JsonProperty("status")
+ private final Status status;
+
/** Default constructor for deserialization. */
public CallResult() {
this.functionId = null;
this.argsDigest = null;
this.resultPayload = null;
this.exceptionPayload = null;
+ this.status = null;
}
/**
@@ -66,6 +79,7 @@ public class CallResult {
this.argsDigest = argsDigest;
this.resultPayload = resultPayload;
this.exceptionPayload = null;
+ this.status = Status.SUCCEEDED;
}
/**
@@ -78,23 +92,46 @@ public class CallResult {
*/
public CallResult(
String functionId, String argsDigest, byte[] resultPayload, byte[]
exceptionPayload) {
+ this(
+ functionId,
+ argsDigest,
+ resultPayload,
+ exceptionPayload,
+ exceptionPayload == null ? Status.SUCCEEDED : Status.FAILED);
+ }
+
+ /**
+ * Constructs a CallResult with explicit result, exception payloads, and
status.
+ *
+ * @param functionId the function identifier
+ * @param argsDigest the digest of serialized arguments
+ * @param resultPayload the serialized return value (null if exception
occurred or pending)
+ * @param exceptionPayload the serialized exception (null if call
succeeded or pending)
+ * @param status the persisted call status
+ */
+ private CallResult(
+ String functionId,
+ String argsDigest,
+ byte[] resultPayload,
+ byte[] exceptionPayload,
+ Status status) {
this.functionId = functionId;
this.argsDigest = argsDigest;
this.resultPayload = resultPayload;
this.exceptionPayload = exceptionPayload;
+ this.status = status;
}
/**
- * Creates a CallResult for a failed function call.
+ * Creates a CallResult for an in-flight durable call whose terminal
result has not yet been
+ * persisted.
*
* @param functionId the function identifier
* @param argsDigest the digest of serialized arguments
- * @param exceptionPayload the serialized exception
- * @return a new CallResult representing a failed call
+ * @return a new CallResult representing a pending call
*/
- public static CallResult ofException(
- String functionId, String argsDigest, byte[] exceptionPayload) {
- return new CallResult(functionId, argsDigest, null, exceptionPayload);
+ public static CallResult pending(String functionId, String argsDigest) {
+ return new CallResult(functionId, argsDigest, null, null,
Status.PENDING);
}
public String getFunctionId() {
@@ -113,6 +150,18 @@ public class CallResult {
return exceptionPayload;
}
+ /**
+ * Validates if this CallResult matches the given function identifier and
arguments digest.
+ *
+ * @param functionId the function identifier to match
+ * @param argsDigest the arguments digest to match
+ * @return true if both functionId and argsDigest match, false otherwise
+ */
+ public boolean matches(String functionId, String argsDigest) {
+ return Objects.equals(this.functionId, functionId)
+ && Objects.equals(this.argsDigest, argsDigest);
+ }
+
/**
* Checks if this call result represents a successful execution.
*
@@ -120,19 +169,43 @@ public class CallResult {
*/
@JsonIgnore
public boolean isSuccess() {
- return exceptionPayload == null;
+ return getEffectiveStatus() == Status.SUCCEEDED;
+ }
+
+ /** Checks if this call result represents a failed execution. */
+ @JsonIgnore
+ public boolean isFailure() {
+ return getEffectiveStatus() == Status.FAILED;
+ }
+
+ /** Checks if this call result represents an in-flight execution. */
+ @JsonIgnore
+ public boolean isPending() {
+ return getEffectiveStatus() == Status.PENDING;
}
/**
- * Validates if this CallResult matches the given function identifier and
arguments digest.
+ * Creates a CallResult matching legacy persisted data where {@code
status} was absent.
*
- * @param functionId the function identifier to match
- * @param argsDigest the arguments digest to match
- * @return true if both functionId and argsDigest match, false otherwise
+ * <p>Used by backward-compatibility tests for legacy serialized state.
*/
- public boolean matches(String functionId, String argsDigest) {
- return Objects.equals(this.functionId, functionId)
- && Objects.equals(this.argsDigest, argsDigest);
+ static CallResult ofNullStatus(
+ String functionId, String argsDigest, byte[] resultPayload, byte[]
exceptionPayload) {
+ return new CallResult(functionId, argsDigest, resultPayload,
exceptionPayload, null);
+ }
+
+ /**
+ * Returns the effective status of the call result.
+ *
+ * <p>For legacy states written before {@code status} existed, the
effective status is inferred
+ * from {@code exceptionPayload}.
+ */
+ @JsonIgnore
+ private Status getEffectiveStatus() {
+ if (status != null) {
+ return status;
+ }
+ return exceptionPayload == null ? Status.SUCCEEDED : Status.FAILED;
}
@Override
@@ -147,12 +220,13 @@ public class CallResult {
return Objects.equals(functionId, that.functionId)
&& Objects.equals(argsDigest, that.argsDigest)
&& Arrays.equals(resultPayload, that.resultPayload)
- && Arrays.equals(exceptionPayload, that.exceptionPayload);
+ && Arrays.equals(exceptionPayload, that.exceptionPayload)
+ && status == that.status;
}
@Override
public int hashCode() {
- int result = Objects.hash(functionId, argsDigest);
+ int result = Objects.hash(functionId, argsDigest, status);
result = 31 * result + Arrays.hashCode(resultPayload);
result = 31 * result + Arrays.hashCode(exceptionPayload);
return result;
@@ -171,6 +245,8 @@ public class CallResult {
+ (resultPayload != null ? resultPayload.length + " bytes" :
"null")
+ ", exceptionPayload="
+ (exceptionPayload != null ? exceptionPayload.length + "
bytes" : "null")
+ + ", status="
+ + status
+ '}';
}
}
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/JavaRunnerContextImpl.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/JavaRunnerContextImpl.java
index ce5f3f11..ae0661ea 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/JavaRunnerContextImpl.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/JavaRunnerContextImpl.java
@@ -24,6 +24,7 @@ import
org.apache.flink.agents.runtime.async.ContinuationActionExecutor;
import org.apache.flink.agents.runtime.async.ContinuationContext;
import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
+import java.util.concurrent.Callable;
import java.util.function.Supplier;
/**
@@ -31,7 +32,6 @@ import java.util.function.Supplier;
* execution support.
*/
public class JavaRunnerContextImpl extends RunnerContextImpl {
-
private final ContinuationActionExecutor continuationExecutor;
private ContinuationContext continuationContext;
@@ -60,14 +60,22 @@ public class JavaRunnerContextImpl extends
RunnerContextImpl {
@Override
public <T> T durableExecuteAsync(DurableCallable<T> callable) throws
Exception {
- String functionId = callable.getId();
- String argsDigest = "";
-
- java.util.Optional<T> cachedResult =
- tryGetCachedResult(functionId, argsDigest,
callable.getResultClass());
- if (cachedResult.isPresent()) {
- return cachedResult.get();
+ if (durableExecutionContext != null) {
+ Callable<T> reconcileCallable = callable.reconciler();
+ if (reconcileCallable != null) {
+ return durableExecuteAsyncWithReconcile(callable,
reconcileCallable);
+ }
}
+ return durableExecuteCompletionOnly(callable, () ->
executeAsyncCallable(callable));
+ }
+
+ private <T> T durableExecuteAsyncWithReconcile(
+ DurableCallable<T> callable, Callable<T> reconcileCallable) throws
Exception {
+ return durableExecuteWithReconcile(
+ callable, reconcileCallable, () ->
executeAsyncCallable(callable));
+ }
+
+ private <T> T executeAsyncCallable(DurableCallable<T> callable) throws
Exception {
Supplier<T> wrappedSupplier =
() -> {
@@ -85,23 +93,14 @@ public class JavaRunnerContextImpl extends
RunnerContextImpl {
return innerResult;
};
- T result = null;
- Exception originalException = null;
try {
if (continuationExecutor == null || continuationContext == null) {
- result = wrappedSupplier.get();
+ return wrappedSupplier.get();
} else {
- result =
continuationExecutor.executeAsync(continuationContext, wrappedSupplier);
+ return continuationExecutor.executeAsync(continuationContext,
wrappedSupplier);
}
} catch (DurableExecutionRuntimeException e) {
- originalException = (Exception) e.getCause();
- }
-
- recordDurableCompletion(functionId, argsDigest, result,
originalException);
-
- if (originalException != null) {
- throw originalException;
+ throw (Exception) e.getCause();
}
- return result;
}
}
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
index dcba1555..f4d28d3b 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
@@ -52,7 +52,7 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
-import java.util.function.Supplier;
+import java.util.concurrent.Callable;
/**
* The implementation class of {@link RunnerContext}, which serves as the
execution context for
@@ -250,49 +250,41 @@ public class RunnerContextImpl implements RunnerContext {
return agentPlan.getActionConfigValue(actionName, key);
}
- protected <T> Optional<T> tryGetCachedResult(
- String functionId, String argsDigest, Class<T> resultClass) throws
Exception {
- Object[] cached = matchNextOrClearSubsequentCallResult(functionId,
argsDigest);
- if (cached != null && (Boolean) cached[0]) {
- byte[] resultPayload = (byte[]) cached[1];
- byte[] exceptionPayload = (byte[]) cached[2];
-
- if (exceptionPayload != null) {
- DurableExecutionException cachedException =
- OBJECT_MAPPER.readValue(exceptionPayload,
DurableExecutionException.class);
- throw cachedException.toException();
- } else if (resultPayload != null) {
- return Optional.of(OBJECT_MAPPER.readValue(resultPayload,
resultClass));
- } else {
- return Optional.of(null);
+ @Override
+ public <T> T durableExecute(DurableCallable<T> callable) throws Exception {
+ if (durableExecutionContext != null) {
+ Callable<T> reconcileCallable = callable.reconciler();
+ if (reconcileCallable != null) {
+ return durableExecuteSyncWithReconcile(callable,
reconcileCallable);
}
}
- return Optional.empty();
+ return durableExecuteCompletionOnly(callable, callable::call);
}
- protected void recordDurableCompletion(
- String functionId, String argsDigest, Object result, Exception
exception)
- throws Exception {
- byte[] resultPayload = null;
- byte[] exceptionPayload = null;
- if (exception != null) {
- exceptionPayload =
- OBJECT_MAPPER.writeValueAsBytes(
-
DurableExecutionException.fromException(exception));
- } else if (result != null) {
- resultPayload = OBJECT_MAPPER.writeValueAsBytes(result);
- }
- recordCallCompletion(functionId, argsDigest, resultPayload,
exceptionPayload);
+ @Override
+ public <T> T durableExecuteAsync(DurableCallable<T> callable) throws
Exception {
+ LOG.debug(
+ "Async durable execution is not supported in
RunnerContextImpl; falling back to durableExecute for {}",
+ callable.getId());
+ return durableExecute(callable);
}
- @Override
- public <T> T durableExecute(DurableCallable<T> callable) throws Exception {
- String functionId = callable.getId();
+ /**
+ * Executes a durable call using the completion-only state machine.
+ *
+ * @param durableCallable durable call that provides the durable execution
identity and result
+ * metadata
+ * @param executionCallable concrete execution boundary for the current
path, such as direct
+ * sync execution or Java-specific async execution
+ */
+ protected <T> T durableExecuteCompletionOnly(
+ DurableCallable<T> durableCallable, Callable<T> executionCallable)
throws Exception {
+ String functionId = durableCallable.getId();
// argsDigest is empty because DurableCallable encapsulates all
arguments internally
String argsDigest = "";
Optional<T> cachedResult =
- tryGetCachedResult(functionId, argsDigest,
callable.getResultClass());
+ tryGetCachedResult(functionId, argsDigest,
durableCallable.getResultClass());
if (cachedResult.isPresent()) {
return cachedResult.get();
}
@@ -300,7 +292,7 @@ public class RunnerContextImpl implements RunnerContext {
T result = null;
Exception exception = null;
try {
- result = callable.call();
+ result = executionCallable.call();
} catch (Exception e) {
exception = e;
}
@@ -313,54 +305,9 @@ public class RunnerContextImpl implements RunnerContext {
return result;
}
- @Override
- public <T> T durableExecuteAsync(DurableCallable<T> callable) throws
Exception {
- String functionId = callable.getId();
- // argsDigest is empty because DurableCallable encapsulates all
arguments internally
- String argsDigest = "";
-
- Optional<T> cachedResult =
- tryGetCachedResult(functionId, argsDigest,
callable.getResultClass());
- if (cachedResult.isPresent()) {
- return cachedResult.get();
- }
-
- Supplier<T> wrappedSupplier =
- () -> {
- T innerResult = null;
- Exception innerException = null;
- try {
- innerResult = callable.call();
- } catch (Exception e) {
- innerException = e;
- }
-
- if (innerException != null) {
- throw new
DurableExecutionRuntimeException(innerException);
- }
- return innerResult;
- };
-
- T result = null;
- Exception originalException = null;
- try {
- result = wrappedSupplier.get();
- } catch (DurableExecutionRuntimeException e) {
- originalException = (Exception) e.getCause();
- }
-
- recordDurableCompletion(functionId, argsDigest, result,
originalException);
-
- if (originalException != null) {
- throw originalException;
- }
- return result;
- }
-
- protected static class DurableExecutionRuntimeException extends
RuntimeException {
- DurableExecutionRuntimeException(Throwable cause) {
- super(cause);
- }
+ private <T> T durableExecuteSyncWithReconcile(
+ DurableCallable<T> callable, Callable<T> reconcileCallable) throws
Exception {
+ return durableExecuteWithReconcile(callable, reconcileCallable,
callable::call);
}
/** Serializable exception info for durable execution persistence. */
@@ -466,6 +413,163 @@ public class RunnerContextImpl implements RunnerContext {
}
}
+ /** Appends a pending durable call slot at the current call index. */
+ public void appendPendingCall(String functionId, String argsDigest) {
+ mailboxThreadChecker.run();
+ if (durableExecutionContext != null) {
+ durableExecutionContext.appendPendingCall(functionId, argsDigest);
+ }
+ }
+
+ /** Finalizes the pending durable call slot at the current call index. */
+ public void finalizeCurrentCall(
+ String functionId, String argsDigest, byte[] resultPayload, byte[]
exceptionPayload) {
+ mailboxThreadChecker.run();
+ if (durableExecutionContext != null) {
+ durableExecutionContext.finalizeCurrentCall(
+ functionId, argsDigest, resultPayload, exceptionPayload);
+ }
+ }
+
+ /**
+ * Clears persisted call results from the current call index onward and
persists immediately.
+ */
+ public void clearCallResultsFromCurrentIndexAndPersist() {
+ mailboxThreadChecker.run();
+ if (durableExecutionContext != null) {
+
durableExecutionContext.clearCallResultsFromCurrentIndexAndPersist();
+ }
+ }
+
+ protected CallResult getCurrentCallResult() {
+ mailboxThreadChecker.run();
+ if (durableExecutionContext != null) {
+ return durableExecutionContext.getCurrentCallResult();
+ }
+ return null;
+ }
+
+ protected <T> Optional<T> tryGetCachedResult(
+ String functionId, String argsDigest, Class<T> resultClass) throws
Exception {
+ Object[] cached = matchNextOrClearSubsequentCallResult(functionId,
argsDigest);
+ if (cached != null && (Boolean) cached[0]) {
+ byte[] resultPayload = (byte[]) cached[1];
+ byte[] exceptionPayload = (byte[]) cached[2];
+
+ if (exceptionPayload != null) {
+ DurableExecutionException cachedException =
+ OBJECT_MAPPER.readValue(exceptionPayload,
DurableExecutionException.class);
+ throw cachedException.toException();
+ } else if (resultPayload != null) {
+ return Optional.of(OBJECT_MAPPER.readValue(resultPayload,
resultClass));
+ } else {
+ return Optional.of(null);
+ }
+ }
+ return Optional.empty();
+ }
+
+ protected void recordDurableCompletion(
+ String functionId, String argsDigest, Object result, Exception
exception)
+ throws Exception {
+ byte[] resultPayload = serializeDurableResult(result);
+ byte[] exceptionPayload = serializeDurableException(exception);
+ recordCallCompletion(functionId, argsDigest, resultPayload,
exceptionPayload);
+ }
+
+ /**
+ * Executes a durable call using the reconcile-enabled state machine.
+ *
+ * @param durableCallable durable call that provides the durable execution
identity and result
+ * metadata
+ * @param reconcileCallable reconcile boundary used to recover a
successful outcome from a
+ * pending durable call
+ * @param executionCallable concrete execution boundary for the current
path when recovery
+ * starts or restarts the original durable call
+ */
+ protected <T> T durableExecuteWithReconcile(
+ DurableCallable<T> durableCallable,
+ Callable<T> reconcileCallable,
+ Callable<T> executionCallable)
+ throws Exception {
+ String functionId = durableCallable.getId();
+ String argsDigest = "";
+ Preconditions.checkState(
+ durableExecutionContext != null, "durableExecutionContext must
not be null");
+
+ CallResult current = getCurrentCallResult();
+
+ if (current == null) {
+ appendPendingCall(functionId, argsDigest);
+ return executeAndFinalizeCurrentCall(functionId, argsDigest,
executionCallable);
+ }
+
+ if (!current.matches(functionId, argsDigest)) {
+ clearCallResultsFromCurrentIndexAndPersist();
+ appendPendingCall(functionId, argsDigest);
+ return executeAndFinalizeCurrentCall(functionId, argsDigest,
executionCallable);
+ }
+
+ if (!current.isPending()) {
+ Optional<T> cachedResult =
+ tryGetCachedResult(functionId, argsDigest,
durableCallable.getResultClass());
+ if (cachedResult.isPresent()) {
+ return cachedResult.get();
+ }
+ throw new IllegalStateException(
+ String.format(
+ "Expected a terminal durable call result at index
%s for "
+ + "functionId=%s, argsDigest=%s",
+ durableExecutionContext.getCurrentCallIndex(),
functionId, argsDigest));
+ }
+
+ T reconcileResult = reconcileCallable.call();
+ finalizeCurrentCall(functionId, argsDigest,
serializeDurableResult(reconcileResult), null);
+ return reconcileResult;
+ }
+
+ protected <T> T executeAndFinalizeCurrentCall(
+ String functionId, String argsDigest, Callable<T> callSupplier)
throws Exception {
+ T result = null;
+ Exception exception = null;
+ try {
+ result = callSupplier.call();
+ } catch (Exception e) {
+ exception = e;
+ }
+
+ finalizeCurrentCall(
+ functionId,
+ argsDigest,
+ serializeDurableResult(result),
+ serializeDurableException(exception));
+
+ if (exception != null) {
+ throw exception;
+ }
+ return result;
+ }
+
+ protected byte[] serializeDurableResult(Object result) throws
JsonProcessingException {
+ if (result == null) {
+ return null;
+ }
+ return OBJECT_MAPPER.writeValueAsBytes(result);
+ }
+
+ protected byte[] serializeDurableException(Exception exception) throws
JsonProcessingException {
+ if (exception == null) {
+ return null;
+ }
+ return
OBJECT_MAPPER.writeValueAsBytes(DurableExecutionException.fromException(exception));
+ }
+
+ protected static class DurableExecutionRuntimeException extends
RuntimeException {
+ DurableExecutionRuntimeException(Throwable cause) {
+ super(cause);
+ }
+ }
+
/**
* Context for fine-grained durable execution within an action.
*
@@ -516,6 +620,17 @@ public class RunnerContextImpl implements RunnerContext {
return actionState;
}
+ /**
+ * Returns the call result at the current call index, or null if the
current index does not
+ * yet have a persisted slot.
+ */
+ public CallResult getCurrentCallResult() {
+ if (currentCallIndex < recoveryCallResults.size()) {
+ return recoveryCallResults.get(currentCallIndex);
+ }
+ return null;
+ }
+
/**
* Matches the next call result for recovery, or clears subsequent
results if mismatch
* detected.
@@ -571,7 +686,8 @@ public class RunnerContextImpl implements RunnerContext {
new CallResult(functionId, argsDigest, resultPayload,
exceptionPayload);
actionState.addCallResult(callResult);
- persister.persist(key, sequenceNumber, action, event, actionState);
+ recoveryCallResults.add(callResult);
+ persistActionState();
LOG.debug(
"Recorded and persisted CallResult at index {}:
functionId={}, argsDigest={}",
@@ -582,11 +698,100 @@ public class RunnerContextImpl implements RunnerContext {
currentCallIndex++;
}
+ /**
+ * Appends and persists a pending slot for the current call index.
+ *
+ * <p>This reserves the current slot for a reconcilable durable call
but does not advance
+ * {@code currentCallIndex}.
+ */
+ public void appendPendingCall(String functionId, String argsDigest) {
+ if (currentCallIndex != recoveryCallResults.size()) {
+ throw new IllegalStateException(
+ String.format(
+ "Cannot append pending call at index %s when a
persisted slot "
+ + "already exists",
+ currentCallIndex));
+ }
+
+ CallResult pending = CallResult.pending(functionId, argsDigest);
+ actionState.addCallResult(pending);
+ recoveryCallResults.add(pending);
+ persistActionState();
+
+ LOG.debug(
+ "Recorded and persisted pending CallResult at index {}:
functionId={}, "
+ + "argsDigest={}",
+ currentCallIndex,
+ functionId,
+ argsDigest);
+ }
+
+ /**
+ * Replaces the current persisted slot with a terminal call result and
advances the current
+ * call index.
+ */
+ public void finalizeCurrentCall(
+ String functionId,
+ String argsDigest,
+ byte[] resultPayload,
+ byte[] exceptionPayload) {
+ CallResult current = getCurrentCallResult();
+ if (current == null) {
+ throw new IllegalStateException(
+ String.format(
+ "Cannot finalize current call at index %s
because no persisted "
+ + "slot exists",
+ currentCallIndex));
+ }
+ if (!current.matches(functionId, argsDigest)) {
+ throw new IllegalStateException(
+ String.format(
+ "Cannot finalize current call at index %s
because the persisted "
+ + "slot does not match functionId=%s,
argsDigest=%s",
+ currentCallIndex, functionId, argsDigest));
+ }
+ if (!current.isPending()) {
+ throw new IllegalStateException(
+ String.format(
+ "Cannot finalize current call at index %s
because the persisted "
+ + "slot is not pending",
+ currentCallIndex));
+ }
+
+ CallResult terminal =
+ new CallResult(functionId, argsDigest, resultPayload,
exceptionPayload);
+ actionState.replaceCallResult(currentCallIndex, terminal);
+ recoveryCallResults.set(currentCallIndex, terminal);
+ persistActionState();
+
+ LOG.debug(
+ "Finalized and persisted CallResult at index {}:
functionId={}, argsDigest={}",
+ currentCallIndex,
+ functionId,
+ argsDigest);
+
+ currentCallIndex++;
+ }
+
+ /**
+ * Clears persisted call results from the current index onward and
persists the truncated
+ * state immediately.
+ */
+ public void clearCallResultsFromCurrentIndexAndPersist() {
+ clearCallResultsFromCurrentIndex();
+ persistActionState();
+ }
+
private void clearCallResultsFromCurrentIndex() {
actionState.clearCallResultsFrom(currentCallIndex);
recoveryCallResults =
- recoveryCallResults.subList(
- 0, Math.min(currentCallIndex,
recoveryCallResults.size()));
+ new ArrayList<>(
+ recoveryCallResults.subList(
+ 0, Math.min(currentCallIndex,
recoveryCallResults.size())));
+ }
+
+ private void persistActionState() {
+ persister.persist(key, sequenceNumber, action, event, actionState);
}
}
}
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
index 74181d0f..ca510332 100644
---
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
@@ -23,7 +23,9 @@ import org.apache.flink.agents.api.OutputEvent;
import org.apache.flink.agents.api.context.MemoryUpdate;
import org.junit.jupiter.api.Test;
+import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
+import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -150,7 +152,7 @@ public class ActionStateSerdeTest {
// Add call results
CallResult result1 = new CallResult("module.func1", "digest1",
"result1".getBytes());
CallResult result2 =
- CallResult.ofException("module.func2", "digest2",
"exception".getBytes());
+ new CallResult("module.func2", "digest2", null,
"exception".getBytes());
originalState.addCallResult(result1);
originalState.addCallResult(result2);
@@ -175,7 +177,25 @@ public class ActionStateSerdeTest {
assertEquals("digest2", deserializedResult2.getArgsDigest());
assertNull(deserializedResult2.getResultPayload());
assertArrayEquals("exception".getBytes(),
deserializedResult2.getExceptionPayload());
- assertFalse(deserializedResult2.isSuccess());
+ assertTrue(deserializedResult2.isFailure());
+ }
+
+ @Test
+ public void testActionStateWithPendingCallResult() throws Exception {
+ InputEvent inputEvent = new InputEvent("test input");
+ ActionState originalState = new ActionState(inputEvent);
+ originalState.addCallResult(CallResult.pending("module.func",
"digest"));
+
+ ActionStateKafkaSeder seder = new ActionStateKafkaSeder();
+
+ byte[] serialized = seder.serialize("test-topic", originalState);
+ ActionState deserializedState = seder.deserialize("test-topic",
serialized);
+
+ assertEquals(1, deserializedState.getCallResultCount());
+ CallResult result = deserializedState.getCallResult(0);
+ assertTrue(result.isPending());
+ assertNull(result.getResultPayload());
+ assertNull(result.getExceptionPayload());
}
@Test
@@ -252,5 +272,58 @@ public class ActionStateSerdeTest {
assertEquals("digest", result.getArgsDigest());
assertNull(result.getResultPayload());
assertNull(result.getExceptionPayload());
+ assertTrue(result.isSuccess());
+ }
+
+ @Test
+ public void testDeserializeLegacyCallResultWithoutStatus() throws
Exception {
+ // Legacy JSON sample: unlike current serializer output, CallResult
entries do not include
+ // `status`.
+ String legacySuccessPayload =
+
Base64.getEncoder().encodeToString("result".getBytes(StandardCharsets.UTF_8));
+ String legacyFailurePayload =
+
Base64.getEncoder().encodeToString("exception".getBytes(StandardCharsets.UTF_8));
+ String json =
+ "{"
+ + "\"taskEvent\":null,"
+ + "\"sensoryMemoryUpdates\":[],"
+ + "\"shortTermMemoryUpdates\":[],"
+ + "\"outputEvents\":[],"
+ + "\"callResults\":["
+ + "{"
+ + "\"functionId\":\"legacy.success\","
+ + "\"argsDigest\":\"digest-success\","
+ + "\"resultPayload\":\""
+ + legacySuccessPayload
+ + "\","
+ + "\"exceptionPayload\":null"
+ + "},"
+ + "{"
+ + "\"functionId\":\"legacy.failure\","
+ + "\"argsDigest\":\"digest-failure\","
+ + "\"resultPayload\":null,"
+ + "\"exceptionPayload\":\""
+ + legacyFailurePayload
+ + "\""
+ + "}"
+ + "],"
+ + "\"completed\":false"
+ + "}";
+
+ ActionStateKafkaSeder seder = new ActionStateKafkaSeder();
+ ActionState deserializedState =
+ seder.deserialize("test-topic",
json.getBytes(StandardCharsets.UTF_8));
+
+ assertEquals(2, deserializedState.getCallResultCount());
+
+ CallResult legacySuccess = deserializedState.getCallResult(0);
+ assertTrue(legacySuccess.isSuccess());
+ assertArrayEquals(
+ "result".getBytes(StandardCharsets.UTF_8),
legacySuccess.getResultPayload());
+
+ CallResult legacyFailure = deserializedState.getCallResult(1);
+ assertTrue(legacyFailure.isFailure());
+ assertArrayEquals(
+ "exception".getBytes(StandardCharsets.UTF_8),
legacyFailure.getExceptionPayload());
}
}
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateTest.java
index aa00d119..17296c1f 100644
---
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateTest.java
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateTest.java
@@ -90,6 +90,20 @@ public class ActionStateTest {
assertEquals(result2, state.getCallResult(1));
}
+ @Test
+ public void testReplaceCallResult() {
+ ActionState state = new ActionState(new InputEvent("test"));
+ CallResult original = new CallResult("func1", "digest1",
"result1".getBytes());
+ CallResult replacement = CallResult.pending("func1", "digest1");
+
+ state.addCallResult(original);
+ state.replaceCallResult(0, replacement);
+
+ assertEquals(1, state.getCallResultCount());
+ assertEquals(replacement, state.getCallResult(0));
+ assertTrue(state.getCallResult(0).isPending());
+ }
+
@Test
public void testGetCallResultOutOfBounds() {
ActionState state = new ActionState(new InputEvent("test"));
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/CallResultTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/CallResultTest.java
index 11d8eb14..8e3a1af1 100644
---
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/CallResultTest.java
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/CallResultTest.java
@@ -45,13 +45,13 @@ public class CallResultTest {
String argsDigest = "abc123";
byte[] exceptionPayload = "exception".getBytes();
- CallResult result = CallResult.ofException(functionId, argsDigest,
exceptionPayload);
+ CallResult result = new CallResult(functionId, argsDigest, null,
exceptionPayload);
assertEquals(functionId, result.getFunctionId());
assertEquals(argsDigest, result.getArgsDigest());
assertNull(result.getResultPayload());
assertArrayEquals(exceptionPayload, result.getExceptionPayload());
- assertFalse(result.isSuccess());
+ assertTrue(result.isFailure());
}
@Test
@@ -70,6 +70,28 @@ public class CallResultTest {
assertTrue(result.isSuccess());
}
+ @Test
+ public void testPendingCallResult() {
+ CallResult result = CallResult.pending("my_module.my_function",
"abc123");
+
+ assertNull(result.getResultPayload());
+ assertNull(result.getExceptionPayload());
+ assertTrue(result.isPending());
+ }
+
+ @Test
+ public void testLegacyStatusInference() {
+ CallResult success =
+ CallResult.ofNullStatus(
+ "my_module.my_function", "abc123",
"result".getBytes(), null);
+ CallResult failure =
+ CallResult.ofNullStatus(
+ "my_module.my_function", "abc123", null,
"exception".getBytes());
+
+ assertTrue(success.isSuccess());
+ assertTrue(failure.isFailure());
+ }
+
@Test
public void testMatches() {
String functionId = "my_module.my_function";
@@ -142,6 +164,7 @@ public class CallResultTest {
String str = result.toString();
assertTrue(str.contains("null"));
+ assertTrue(str.contains("SUCCEEDED"));
}
@Test
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java
index 9e7c2de0..b659a77e 100644
---
a/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java
@@ -180,7 +180,7 @@ class DurableExecutionContextTest {
@Test
void testRecoveryWithExceptionPayload() {
byte[] exceptionPayload =
"exception_data".getBytes(StandardCharsets.UTF_8);
- actionState.addCallResult(CallResult.ofException("funcA", "digestA",
exceptionPayload));
+ actionState.addCallResult(new CallResult("funcA", "digestA", null,
exceptionPayload));
RunnerContextImpl.DurableExecutionContext context = createContext();
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/context/JavaRunnerContextImplDurableExecuteAsyncTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/JavaRunnerContextImplDurableExecuteAsyncTest.java
new file mode 100644
index 00000000..ef3023ff
--- /dev/null
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/JavaRunnerContextImplDurableExecuteAsyncTest.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.runtime.context;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.plan.AgentPlan;
+import org.apache.flink.agents.plan.actions.Action;
+import org.apache.flink.agents.runtime.actionstate.ActionState;
+import org.apache.flink.agents.runtime.actionstate.CallResult;
+import org.apache.flink.agents.runtime.async.ContinuationActionExecutor;
+import org.apache.flink.agents.runtime.async.ContinuationContext;
+import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
+import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.HashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Supplier;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.Mockito.mock;
+
+/** Unit tests for async durable execution in {@link JavaRunnerContextImpl}. */
+class JavaRunnerContextImplDurableExecuteAsyncTest {
+
+ private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+
+ private FlinkAgentsMetricGroupImpl metricGroup;
+ private AtomicInteger persistCallCount;
+ private ActionState lastPersistedState;
+
+ @BeforeEach
+ void setUp() {
+ metricGroup =
+ new FlinkAgentsMetricGroupImpl(
+
UnregisteredMetricGroups.createUnregisteredOperatorMetricGroup());
+ persistCallCount = new AtomicInteger();
+ lastPersistedState = null;
+ }
+
+ @Test
+ void testDurableExecuteAsyncLegacyCall() throws Exception {
+ InspectingContinuationActionExecutor executor = new
InspectingContinuationActionExecutor();
+ JavaRunnerContextImpl context = createContext(new ActionState(null),
executor);
+ TestDurableCallable<String> callable =
+ new TestDurableCallable<>("legacy-async", String.class, () ->
"ok");
+
+ String result = context.durableExecuteAsync(callable);
+
+ assertEquals("ok", result);
+ assertEquals(1, callable.getCallCount());
+ assertEquals(1, executor.getExecuteAsyncCallCount());
+ assertEquals(1, persistCallCount.get());
+ CallResult persisted =
+
context.getDurableExecutionContext().getActionState().getCallResults().get(0);
+ assertTrue(persisted.isSuccess());
+ assertSame(context.getDurableExecutionContext().getActionState(),
lastPersistedState);
+ }
+
+ @Test
+ void testDurableExecuteAsyncReconcilableSuccessCall() throws Exception {
+ InspectingContinuationActionExecutor executor = new
InspectingContinuationActionExecutor();
+ JavaRunnerContextImpl context = createContext(new ActionState(null),
executor);
+ executor.setBeforeExecute(
+ () -> {
+ CallResult current =
+
context.getDurableExecutionContext().getCurrentCallResult();
+ assertNotNull(current);
+ assertTrue(current.isPending());
+ assertEquals(0,
context.getDurableExecutionContext().getCurrentCallIndex());
+ assertEquals(1, persistCallCount.get());
+ });
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-async",
+ String.class,
+ () -> "ok",
+ () -> fail("reconcile should not be called on initial
async execution"));
+
+ String result = context.durableExecuteAsync(callable);
+
+ assertEquals("ok", result);
+ assertEquals(1, callable.getCallCount());
+ assertEquals(0, callable.getReconcileCount());
+ assertEquals(1, executor.getExecuteAsyncCallCount());
+ assertEquals(2, persistCallCount.get());
+ CallResult persisted =
+
context.getDurableExecutionContext().getActionState().getCallResults().get(0);
+ assertTrue(persisted.isSuccess());
+ }
+
+ @Test
+ void testDurableExecuteAsyncReconcilableReplaySuccess() throws Exception {
+ InspectingContinuationActionExecutor executor = new
InspectingContinuationActionExecutor();
+ ActionState actionState = new ActionState(null);
+ actionState.addCallResult(
+ new CallResult("recon-async", "",
OBJECT_MAPPER.writeValueAsBytes("cached")));
+ JavaRunnerContextImpl context = createContext(actionState, executor);
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-async",
+ String.class,
+ () -> fail("call should not be executed"),
+ () -> fail("reconcile should not be called for
terminal slot"));
+
+ String result = context.durableExecuteAsync(callable);
+
+ assertEquals("cached", result);
+ assertEquals(0, callable.getCallCount());
+ assertEquals(0, callable.getReconcileCount());
+ assertEquals(0, executor.getExecuteAsyncCallCount());
+ assertEquals(0, persistCallCount.get());
+ }
+
+ @Test
+ void testDurableExecuteAsyncReconcilableReconcileSuccess() throws
Exception {
+ InspectingContinuationActionExecutor executor = new
InspectingContinuationActionExecutor();
+ ActionState actionState = new ActionState(null);
+ actionState.addCallResult(CallResult.pending("recon-async", ""));
+ JavaRunnerContextImpl context = createContext(actionState, executor);
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-async",
+ String.class,
+ () -> fail("call should not be executed"),
+ () -> "recovered");
+
+ String result = context.durableExecuteAsync(callable);
+
+ assertEquals("recovered", result);
+ assertEquals(0, callable.getCallCount());
+ assertEquals(1, callable.getReconcileCount());
+ assertEquals(0, executor.getExecuteAsyncCallCount());
+ assertEquals(1, persistCallCount.get());
+ CallResult persisted =
+
context.getDurableExecutionContext().getActionState().getCallResults().get(0);
+ assertTrue(persisted.isSuccess());
+ }
+
+ @Test
+ void testDurableExecuteAsyncReconcilableReconcileExceptionPropagates()
throws Exception {
+ InspectingContinuationActionExecutor executor = new
InspectingContinuationActionExecutor();
+ ActionState actionState = new ActionState(null);
+ actionState.addCallResult(CallResult.pending("recon-async", ""));
+ JavaRunnerContextImpl context = createContext(actionState, executor);
+ executor.setBeforeExecute(
+ () -> {
+ CallResult current =
+
context.getDurableExecutionContext().getCurrentCallResult();
+ assertNotNull(current);
+ assertTrue(current.isPending());
+ assertEquals(0, persistCallCount.get());
+ });
+ IllegalArgumentException failure = new
IllegalArgumentException("reconcile unavailable");
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-async",
+ String.class,
+ () -> fail("call should not be executed"),
+ () -> {
+ throw failure;
+ });
+
+ IllegalArgumentException thrown =
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> context.durableExecuteAsync(callable));
+
+ assertSame(failure, thrown);
+ assertEquals(0, callable.getCallCount());
+ assertEquals(1, callable.getReconcileCount());
+ assertEquals(0, executor.getExecuteAsyncCallCount());
+ assertEquals(0, persistCallCount.get());
+ CallResult persisted =
+
context.getDurableExecutionContext().getActionState().getCallResults().get(0);
+ assertTrue(persisted.isPending());
+ assertEquals(0,
context.getDurableExecutionContext().getCurrentCallIndex());
+ }
+
+ private JavaRunnerContextImpl createContext(
+ ActionState actionState, ContinuationActionExecutor executor) {
+ JavaRunnerContextImpl context =
+ new JavaRunnerContextImpl(
+ metricGroup,
+ () -> {},
+ new AgentPlan(new HashMap<>(), new HashMap<>()),
+ null,
+ "test-job",
+ executor);
+ context.setContinuationContext(new ContinuationContext());
+ ActionStatePersister persister =
+ (key, sequenceNumber, action, event, state) -> {
+ persistCallCount.incrementAndGet();
+ lastPersistedState = state;
+ };
+ context.setDurableExecutionContext(
+ new RunnerContextImpl.DurableExecutionContext(
+ "test-key",
+ 1L,
+ mock(Action.class),
+ mock(Event.class),
+ actionState,
+ persister));
+ return context;
+ }
+
+ private static final class InspectingContinuationActionExecutor
+ extends ContinuationActionExecutor {
+ private Runnable beforeExecute;
+ private int executeAsyncCallCount;
+
+ private InspectingContinuationActionExecutor() {
+ super(1);
+ }
+
+ @Override
+ public <T> T executeAsync(ContinuationContext context, Supplier<T>
supplier) {
+ executeAsyncCallCount++;
+ if (beforeExecute != null) {
+ beforeExecute.run();
+ }
+ return supplier.get();
+ }
+
+ private void setBeforeExecute(Runnable beforeExecute) {
+ this.beforeExecute = beforeExecute;
+ }
+
+ private int getExecuteAsyncCallCount() {
+ return executeAsyncCallCount;
+ }
+ }
+}
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/context/RunnerContextImplDurableExecuteTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/RunnerContextImplDurableExecuteTest.java
new file mode 100644
index 00000000..132a7a6e
--- /dev/null
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/RunnerContextImplDurableExecuteTest.java
@@ -0,0 +1,295 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.runtime.context;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.plan.AgentPlan;
+import org.apache.flink.agents.plan.actions.Action;
+import org.apache.flink.agents.runtime.actionstate.ActionState;
+import org.apache.flink.agents.runtime.actionstate.CallResult;
+import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
+import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.HashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.Mockito.mock;
+
+/** Unit tests for durable execution in {@link RunnerContextImpl}. */
+class RunnerContextImplDurableExecuteTest {
+
+ private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+
+ private FlinkAgentsMetricGroupImpl metricGroup;
+ private AtomicInteger persistCallCount;
+ private ActionState lastPersistedState;
+
+ @BeforeEach
+ void setUp() {
+ metricGroup =
+ new FlinkAgentsMetricGroupImpl(
+
UnregisteredMetricGroups.createUnregisteredOperatorMetricGroup());
+ persistCallCount = new AtomicInteger();
+ lastPersistedState = null;
+ }
+
+ @Test
+ void testDurableExecuteLegacyCall() throws Exception {
+ RunnerContextImpl context = createContext(new ActionState(null));
+ TestDurableCallable<String> callable =
+ new TestDurableCallable<>("legacy-call", String.class, () ->
"ok");
+
+ String result = context.durableExecute(callable);
+
+ assertEquals("ok", result);
+ assertEquals(1, callable.getCallCount());
+ assertEquals(1, persistCallCount.get());
+ assertEquals(1,
context.getDurableExecutionContext().getCurrentCallIndex());
+ assertEquals(1,
context.getDurableExecutionContext().getActionState().getCallResultCount());
+ CallResult persisted =
+
context.getDurableExecutionContext().getActionState().getCallResults().get(0);
+ assertTrue(persisted.isSuccess());
+ assertSame(context.getDurableExecutionContext().getActionState(),
lastPersistedState);
+ }
+
+ @Test
+ void testDurableExecuteReconcilableSuccessCall() throws Exception {
+ RunnerContextImpl context = createContext(new ActionState(null));
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-call",
+ String.class,
+ () -> "ok",
+ () -> fail("reconcile should not be called on initial
execution"));
+
+ String result = context.durableExecute(callable);
+
+ assertEquals("ok", result);
+ assertEquals(1, callable.getCallCount());
+ assertEquals(0, callable.getReconcileCount());
+ assertEquals(2, persistCallCount.get());
+ assertEquals(1,
context.getDurableExecutionContext().getCurrentCallIndex());
+ CallResult persisted =
+
context.getDurableExecutionContext().getActionState().getCallResults().get(0);
+ assertTrue(persisted.isSuccess());
+ }
+
+ @Test
+ void testDurableExecuteAsyncFallsBackToSyncExecution() throws Exception {
+ RunnerContextImpl context = createContext(new ActionState(null));
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-async",
+ String.class,
+ () -> "ok",
+ () -> fail("reconcile should not be called on initial
async fallback"));
+
+ String result = context.durableExecuteAsync(callable);
+
+ assertEquals("ok", result);
+ assertEquals(1, callable.getCallCount());
+ assertEquals(0, callable.getReconcileCount());
+ assertEquals(2, persistCallCount.get());
+ assertEquals(1,
context.getDurableExecutionContext().getCurrentCallIndex());
+ CallResult persisted =
+
context.getDurableExecutionContext().getActionState().getCallResults().get(0);
+ assertTrue(persisted.isSuccess());
+ }
+
+ @Test
+ void testDurableExecuteReconcilableFailCall() {
+ RunnerContextImpl context = createContext(new ActionState(null));
+ RuntimeException failure = new RuntimeException("call failed");
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-call",
+ String.class,
+ () -> {
+ throw failure;
+ },
+ () -> fail("reconcile should not be called on initial
execution"));
+
+ RuntimeException thrown =
+ assertThrows(RuntimeException.class, () ->
context.durableExecute(callable));
+
+ assertSame(failure, thrown);
+ assertEquals(1, callable.getCallCount());
+ assertEquals(0, callable.getReconcileCount());
+ assertEquals(2, persistCallCount.get());
+ CallResult persisted =
+
context.getDurableExecutionContext().getActionState().getCallResults().get(0);
+ assertTrue(persisted.isFailure());
+ }
+
+ @Test
+ void testDurableExecuteReconcilableReplaySuccess() throws Exception {
+ ActionState actionState = new ActionState(null);
+ actionState.addCallResult(
+ new CallResult("recon-call", "",
OBJECT_MAPPER.writeValueAsBytes("cached")));
+ RunnerContextImpl context = createContext(actionState);
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-call",
+ String.class,
+ () -> fail("call should not be re-executed"),
+ () -> fail("reconcile should not be called for
terminal slot"));
+
+ String result = context.durableExecute(callable);
+
+ assertEquals("cached", result);
+ assertEquals(0, callable.getCallCount());
+ assertEquals(0, callable.getReconcileCount());
+ assertEquals(0, persistCallCount.get());
+ assertEquals(1,
context.getDurableExecutionContext().getCurrentCallIndex());
+ }
+
+ @Test
+ void testDurableExecuteReconcilableReplayFailure() throws Exception {
+ ActionState actionState = new ActionState(null);
+ actionState.addCallResult(
+ new CallResult(
+ "recon-call",
+ "",
+ null,
+ OBJECT_MAPPER.writeValueAsBytes(
+
RunnerContextImpl.DurableExecutionException.fromException(
+ new IllegalStateException("cached
failure")))));
+ RunnerContextImpl context = createContext(actionState);
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-call",
+ String.class,
+ () -> fail("call should not be re-executed"),
+ () -> fail("reconcile should not be called for
terminal slot"));
+
+ Exception thrown = assertThrows(Exception.class, () ->
context.durableExecute(callable));
+
+ assertTrue(thrown.getMessage().contains("IllegalStateException"));
+ assertTrue(thrown.getMessage().contains("cached failure"));
+ assertEquals(0, callable.getCallCount());
+ assertEquals(0, callable.getReconcileCount());
+ assertEquals(0, persistCallCount.get());
+ assertEquals(1,
context.getDurableExecutionContext().getCurrentCallIndex());
+ }
+
+ @Test
+ void testDurableExecuteReconcilableReconcileSuccess() throws Exception {
+ ActionState actionState = new ActionState(null);
+ actionState.addCallResult(CallResult.pending("recon-call", ""));
+ RunnerContextImpl context = createContext(actionState);
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-call",
+ String.class,
+ () -> fail("call should not be re-executed"),
+ () -> "recovered");
+
+ String result = context.durableExecute(callable);
+
+ assertEquals("recovered", result);
+ assertEquals(0, callable.getCallCount());
+ assertEquals(1, callable.getReconcileCount());
+ assertEquals(1, persistCallCount.get());
+ CallResult persisted =
+
context.getDurableExecutionContext().getActionState().getCallResults().get(0);
+ assertTrue(persisted.isSuccess());
+ assertEquals(1,
context.getDurableExecutionContext().getCurrentCallIndex());
+ }
+
+ @Test
+ void testDurableExecuteReconcilableReconcileExceptionPropagates() throws
Exception {
+ ActionState actionState = new ActionState(null);
+ actionState.addCallResult(CallResult.pending("recon-call", ""));
+ RunnerContextImpl context = createContext(actionState);
+ IllegalStateException failure = new IllegalStateException("reconcile
unavailable");
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-call",
+ String.class,
+ () -> fail("call should not be re-executed"),
+ () -> {
+ throw failure;
+ });
+
+ IllegalStateException thrown =
+ assertThrows(IllegalStateException.class, () ->
context.durableExecute(callable));
+
+ assertSame(failure, thrown);
+ assertEquals(0, callable.getCallCount());
+ assertEquals(1, callable.getReconcileCount());
+ assertEquals(0, persistCallCount.get());
+ CallResult persisted =
+
context.getDurableExecutionContext().getActionState().getCallResults().get(0);
+ assertTrue(persisted.isPending());
+ assertEquals(0,
context.getDurableExecutionContext().getCurrentCallIndex());
+ }
+
+ @Test
+ void testDurableExecuteReconcilableMismatchStartsNewCall() throws
Exception {
+ ActionState actionState = new ActionState(null);
+ actionState.addCallResult(CallResult.pending("stale-call", ""));
+ RunnerContextImpl context = createContext(actionState);
+ TestReconcilableCallable<String> callable =
+ new TestReconcilableCallable<>(
+ "recon-call",
+ String.class,
+ () -> "fresh",
+ () -> fail("reconcile should not be called for
mismatched slot"));
+
+ String result = context.durableExecute(callable);
+
+ assertEquals("fresh", result);
+ assertEquals(1, callable.getCallCount());
+ assertEquals(0, callable.getReconcileCount());
+ assertEquals(3, persistCallCount.get());
+ assertEquals(1,
context.getDurableExecutionContext().getActionState().getCallResultCount());
+ CallResult persisted =
+
context.getDurableExecutionContext().getActionState().getCallResults().get(0);
+ assertEquals("recon-call", persisted.getFunctionId());
+ assertTrue(persisted.isSuccess());
+ }
+
+ private RunnerContextImpl createContext(ActionState actionState) {
+ RunnerContextImpl context =
+ new RunnerContextImpl(
+ metricGroup,
+ () -> {},
+ new AgentPlan(new HashMap<>(), new HashMap<>()),
+ null,
+ "test-job");
+ ActionStatePersister persister =
+ (key, sequenceNumber, action, event, state) -> {
+ persistCallCount.incrementAndGet();
+ lastPersistedState = state;
+ };
+ context.durableExecutionContext =
+ new RunnerContextImpl.DurableExecutionContext(
+ "test-key",
+ 1L,
+ mock(Action.class),
+ mock(Event.class),
+ actionState,
+ persister);
+ return context;
+ }
+}
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/context/TestDurableCallables.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/TestDurableCallables.java
new file mode 100644
index 00000000..46ad2ced
--- /dev/null
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/TestDurableCallables.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.runtime.context;
+
+import org.apache.flink.agents.api.context.DurableCallable;
+
+import java.util.concurrent.Callable;
+
+/** Shared durable callable fixtures for runner context tests. */
+class TestDurableCallable<T> implements DurableCallable<T> {
+ private final String id;
+ private final Class<T> resultClass;
+ private final Callable<T> callSupplier;
+ private int callCount;
+
+ TestDurableCallable(String id, Class<T> resultClass, Callable<T>
callSupplier) {
+ this.id = id;
+ this.resultClass = resultClass;
+ this.callSupplier = callSupplier;
+ }
+
+ @Override
+ public String getId() {
+ return id;
+ }
+
+ @Override
+ public Class<T> getResultClass() {
+ return resultClass;
+ }
+
+ @Override
+ public T call() throws Exception {
+ callCount++;
+ return callSupplier.call();
+ }
+
+ int getCallCount() {
+ return callCount;
+ }
+}
+
+/** Shared reconciliable durable callable fixture for runner context tests. */
+class TestReconcilableCallable<T> implements DurableCallable<T> {
+ private final String id;
+ private final Class<T> resultClass;
+ private final Callable<T> callSupplier;
+ private final Callable<T> reconcileSupplier;
+ private int callCount;
+ private int reconcileCount;
+
+ TestReconcilableCallable(
+ String id,
+ Class<T> resultClass,
+ Callable<T> callSupplier,
+ Callable<T> reconcileSupplier) {
+ this.id = id;
+ this.resultClass = resultClass;
+ this.callSupplier = callSupplier;
+ this.reconcileSupplier = reconcileSupplier;
+ }
+
+ @Override
+ public String getId() {
+ return id;
+ }
+
+ @Override
+ public Class<T> getResultClass() {
+ return resultClass;
+ }
+
+ @Override
+ public T call() throws Exception {
+ callCount++;
+ return callSupplier.call();
+ }
+
+ @Override
+ public Callable<T> reconciler() {
+ return () -> {
+ reconcileCount++;
+ return reconcileSupplier.call();
+ };
+ }
+
+ int getCallCount() {
+ return callCount;
+ }
+
+ int getReconcileCount() {
+ return reconcileCount;
+ }
+}
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
index a45a52d4..72a62509 100644
---
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
@@ -17,6 +17,7 @@
*/
package org.apache.flink.agents.runtime.operator;
+import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.agents.api.Event;
import org.apache.flink.agents.api.InputEvent;
import org.apache.flink.agents.api.OutputEvent;
@@ -29,6 +30,7 @@ import org.apache.flink.agents.plan.AgentPlan;
import org.apache.flink.agents.plan.JavaFunction;
import org.apache.flink.agents.plan.actions.Action;
import org.apache.flink.agents.runtime.actionstate.ActionState;
+import org.apache.flink.agents.runtime.actionstate.CallResult;
import org.apache.flink.agents.runtime.actionstate.InMemoryActionStateStore;
import org.apache.flink.agents.runtime.eventlog.FileEventLogger;
import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -41,6 +43,7 @@ import
org.apache.flink.streaming.runtime.tasks.mailbox.TaskMailbox;
import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
import org.apache.flink.util.ExceptionUtils;
+import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.lang.reflect.Field;
@@ -48,6 +51,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
@@ -59,6 +63,14 @@ import static
org.assertj.core.api.Assertions.assertThatThrownBy;
/** Tests for {@link ActionExecutionOperator}. */
public class ActionExecutionOperatorTest {
+ private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+
+ @BeforeEach
+ void resetReconcilableFixtures() {
+ TestAgent.resetReconcilableRecoveryFixture();
+ TestAgent.resetMixedRecoveryFixture();
+ }
+
@Test
void testExecuteAgent() throws Exception {
try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object>
testHarness =
@@ -1029,6 +1041,188 @@ public class ActionExecutionOperatorTest {
.contains("Async operation failed");
}
+ @Test
+ void testDurableExecuteReconcilableRecoverySuccess() throws Exception {
+ AgentPlan agentPlan = TestAgent.getDurableReconcilableAgentPlan();
+ InMemoryActionStateStore actionStateStore = new
InMemoryActionStateStore(false);
+ long key = 1L;
+ long input = 1L;
+ TestAgent.RECONCILABLE_RECOVERY_BEHAVIOR =
TestAgent.ReconcileBehavior.SUCCESS;
+ TestAgent.RECONCILABLE_RECOVERY_RESULT = 42L;
+
+ seedActionState(
+ actionStateStore,
+ key,
+ input,
+ agentPlan,
+ "durableReconcilableAction",
+
actionStateWithCallResults(CallResult.pending("reconcilable-call", "")));
+
+ try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object>
testHarness =
+ new KeyedOneInputStreamOperatorTestHarness<>(
+ new ActionExecutionOperatorFactory<>(agentPlan, true,
actionStateStore),
+ (KeySelector<Long, Long>) value -> value,
+ TypeInformation.of(Long.class))) {
+ testHarness.open();
+ ActionExecutionOperator<Long, Object> operator =
+ (ActionExecutionOperator<Long, Object>)
testHarness.getOperator();
+
+ testHarness.processElement(new StreamRecord<>(input));
+ operator.waitInFlightEventsFinished();
+
+ List<StreamRecord<Object>> recordOutput =
+ (List<StreamRecord<Object>>) testHarness.getRecordOutput();
+ assertThat(recordOutput).hasSize(1);
+ assertThat(recordOutput.get(0).getValue()).isEqualTo(42L);
+ }
+
+ assertThat(TestAgent.RECONCILABLE_CALL_COUNTER.get()).isZero();
+
assertThat(TestAgent.RECONCILABLE_RECONCILE_COUNTER.get()).isEqualTo(1);
+
+ ActionState persistedState =
+ getStoredActionState(
+ actionStateStore, key, input, agentPlan,
"durableReconcilableAction");
+ assertThat(persistedState.isCompleted()).isTrue();
+ assertThat(persistedState.getCallResults()).isEmpty();
+ }
+
+ @Test
+ void testDurableExecuteReconcilableRecoveryException() throws Exception {
+ AgentPlan agentPlan = TestAgent.getDurableReconcilableAgentPlan();
+ InMemoryActionStateStore actionStateStore = new
InMemoryActionStateStore(false);
+ long key = 2L;
+ long input = 2L;
+ TestAgent.RECONCILABLE_RECOVERY_BEHAVIOR =
TestAgent.ReconcileBehavior.EXCEPTION;
+ TestAgent.RECONCILABLE_EXCEPTION_MESSAGE = "reconcile unavailable";
+
+ seedActionState(
+ actionStateStore,
+ key,
+ input,
+ agentPlan,
+ "durableReconcilableAction",
+
actionStateWithCallResults(CallResult.pending("reconcilable-call", "")));
+
+ try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object>
testHarness =
+ new KeyedOneInputStreamOperatorTestHarness<>(
+ new ActionExecutionOperatorFactory<>(agentPlan, true,
actionStateStore),
+ (KeySelector<Long, Long>) value -> value,
+ TypeInformation.of(Long.class))) {
+ testHarness.open();
+ ActionExecutionOperator<Long, Object> operator =
+ (ActionExecutionOperator<Long, Object>)
testHarness.getOperator();
+
+ testHarness.processElement(new StreamRecord<>(input));
+ operator.waitInFlightEventsFinished();
+
+ List<StreamRecord<Object>> recordOutput =
+ (List<StreamRecord<Object>>) testHarness.getRecordOutput();
+ assertThat(recordOutput).hasSize(1);
+
assertThat(recordOutput.get(0).getValue()).isEqualTo("ERROR:reconcile
unavailable");
+ }
+
+ assertThat(TestAgent.RECONCILABLE_CALL_COUNTER.get()).isZero();
+
assertThat(TestAgent.RECONCILABLE_RECONCILE_COUNTER.get()).isEqualTo(1);
+
+ ActionState persistedState =
+ getStoredActionState(
+ actionStateStore, key, input, agentPlan,
"durableReconcilableAction");
+ assertThat(persistedState.isCompleted()).isTrue();
+ assertThat(persistedState.getCallResults()).isEmpty();
+ }
+
+ @Test
+ void testDurableExecuteReconcilableRecoveryMismatchStartsNewCall() throws
Exception {
+ AgentPlan agentPlan = TestAgent.getDurableReconcilableAgentPlan();
+ InMemoryActionStateStore actionStateStore = new
InMemoryActionStateStore(false);
+ long key = 4L;
+ long input = 4L;
+
+ seedActionState(
+ actionStateStore,
+ key,
+ input,
+ agentPlan,
+ "durableReconcilableAction",
+ actionStateWithCallResults(CallResult.pending("stale-call",
"")));
+
+ try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object>
testHarness =
+ new KeyedOneInputStreamOperatorTestHarness<>(
+ new ActionExecutionOperatorFactory<>(agentPlan, true,
actionStateStore),
+ (KeySelector<Long, Long>) value -> value,
+ TypeInformation.of(Long.class))) {
+ testHarness.open();
+ ActionExecutionOperator<Long, Object> operator =
+ (ActionExecutionOperator<Long, Object>)
testHarness.getOperator();
+
+ testHarness.processElement(new StreamRecord<>(input));
+ operator.waitInFlightEventsFinished();
+
+ List<StreamRecord<Object>> recordOutput =
+ (List<StreamRecord<Object>>) testHarness.getRecordOutput();
+ assertThat(recordOutput).hasSize(1);
+ assertThat(recordOutput.get(0).getValue()).isEqualTo(20L);
+ }
+
+ assertThat(TestAgent.RECONCILABLE_CALL_COUNTER.get()).isEqualTo(1);
+ assertThat(TestAgent.RECONCILABLE_RECONCILE_COUNTER.get()).isZero();
+
+ ActionState persistedState =
+ getStoredActionState(
+ actionStateStore, key, input, agentPlan,
"durableReconcilableAction");
+ assertThat(persistedState.isCompleted()).isTrue();
+ assertThat(persistedState.getCallResults()).isEmpty();
+ }
+
+ @Test
+ void testDurableExecuteRecoveryMixedCompletionOnlyAndReconcilableCalls()
throws Exception {
+ AgentPlan agentPlan = TestAgent.getDurableMixedRecoveryAgentPlan();
+ InMemoryActionStateStore actionStateStore = new
InMemoryActionStateStore(false);
+ long key = 1L;
+ long input = 1L;
+ TestAgent.MIXED_RECONCILE_BEHAVIOR =
TestAgent.ReconcileBehavior.SUCCESS;
+ TestAgent.MIXED_RECONCILE_RESULT = 50L;
+
+ seedActionState(
+ actionStateStore,
+ key,
+ input,
+ agentPlan,
+ "durableMixedRecoveryAction",
+ actionStateWithCallResults(
+ new CallResult(
+ "mixed-legacy-call", "",
OBJECT_MAPPER.writeValueAsBytes(11L)),
+ CallResult.pending("mixed-reconcilable-call", "")));
+
+ try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object>
testHarness =
+ new KeyedOneInputStreamOperatorTestHarness<>(
+ new ActionExecutionOperatorFactory<>(agentPlan, true,
actionStateStore),
+ (KeySelector<Long, Long>) value -> value,
+ TypeInformation.of(Long.class))) {
+ testHarness.open();
+ ActionExecutionOperator<Long, Object> operator =
+ (ActionExecutionOperator<Long, Object>)
testHarness.getOperator();
+
+ testHarness.processElement(new StreamRecord<>(input));
+ operator.waitInFlightEventsFinished();
+
+ List<StreamRecord<Object>> recordOutput =
+ (List<StreamRecord<Object>>) testHarness.getRecordOutput();
+ assertThat(recordOutput).hasSize(1);
+ assertThat(recordOutput.get(0).getValue()).isEqualTo(61L);
+ }
+
+ assertThat(TestAgent.MIXED_LEGACY_CALL_COUNTER.get()).isZero();
+ assertThat(TestAgent.MIXED_RECONCILABLE_CALL_COUNTER.get()).isZero();
+ assertThat(TestAgent.MIXED_RECONCILE_COUNTER.get()).isEqualTo(1);
+
+ ActionState persistedState =
+ getStoredActionState(
+ actionStateStore, key, input, agentPlan,
"durableMixedRecoveryAction");
+ assertThat(persistedState.isCompleted()).isTrue();
+ assertThat(persistedState.getCallResults()).isEmpty();
+ }
+
public static class TestAgent {
/** Counter to track how many times the durable supplier is executed.
*/
@@ -1083,32 +1277,70 @@ public class ActionExecutionOperatorTest {
}
}
+ private static <T> DurableCallable<T> durableCallable(
+ String id, Class<T> resultClass, Callable<T> callSupplier) {
+ return new DurableCallable<T>() {
+ @Override
+ public String getId() {
+ return id;
+ }
+
+ @Override
+ public Class<T> getResultClass() {
+ return resultClass;
+ }
+
+ @Override
+ public T call() throws Exception {
+ return callSupplier.call();
+ }
+ };
+ }
+
+ private static <T> DurableCallable<T> reconcilableDurableCallable(
+ String id,
+ Class<T> resultClass,
+ Callable<T> callSupplier,
+ Callable<T> reconcileSupplier) {
+ return new DurableCallable<T>() {
+ @Override
+ public String getId() {
+ return id;
+ }
+
+ @Override
+ public Class<T> getResultClass() {
+ return resultClass;
+ }
+
+ @Override
+ public T call() throws Exception {
+ return callSupplier.call();
+ }
+
+ @Override
+ public Callable<T> reconciler() {
+ return reconcileSupplier;
+ }
+ };
+ }
+
public static void asyncAction1(InputEvent event, RunnerContext
context) {
Long inputData = (Long) event.getInput();
try {
Long result =
context.durableExecuteAsync(
- new DurableCallable<Long>() {
- @Override
- public String getId() {
- return "async-multiply";
- }
-
- @Override
- public Class<Long> getResultClass() {
- return Long.class;
- }
-
- @Override
- public Long call() {
- try {
- Thread.sleep(50);
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- }
- return inputData * 10;
- }
- });
+ durableCallable(
+ "async-multiply",
+ Long.class,
+ () -> {
+ try {
+ Thread.sleep(50);
+ } catch (InterruptedException e) {
+
Thread.currentThread().interrupt();
+ }
+ return inputData * 10;
+ }));
MemoryObject mem = context.getShortTermMemory();
mem.set("tmp", result);
@@ -1123,51 +1355,31 @@ public class ActionExecutionOperatorTest {
try {
Long result1 =
context.durableExecuteAsync(
- new DurableCallable<Long>() {
- @Override
- public String getId() {
- return "async-add";
- }
-
- @Override
- public Class<Long> getResultClass() {
- return Long.class;
- }
-
- @Override
- public Long call() {
- try {
- Thread.sleep(30);
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- }
- return inputData + 100;
- }
- });
+ durableCallable(
+ "async-add",
+ Long.class,
+ () -> {
+ try {
+ Thread.sleep(30);
+ } catch (InterruptedException e) {
+
Thread.currentThread().interrupt();
+ }
+ return inputData + 100;
+ }));
Long result2 =
context.durableExecuteAsync(
- new DurableCallable<Long>() {
- @Override
- public String getId() {
- return "async-multiply";
- }
-
- @Override
- public Class<Long> getResultClass() {
- return Long.class;
- }
-
- @Override
- public Long call() {
- try {
- Thread.sleep(30);
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- }
- return result1 * 2;
- }
- });
+ durableCallable(
+ "async-multiply",
+ Long.class,
+ () -> {
+ try {
+ Thread.sleep(30);
+ } catch (InterruptedException e) {
+
Thread.currentThread().interrupt();
+ }
+ return result1 * 2;
+ }));
MemoryObject mem = context.getShortTermMemory();
mem.set("multiAsyncResult", result2);
@@ -1182,23 +1394,13 @@ public class ActionExecutionOperatorTest {
try {
Long result =
context.durableExecute(
- new DurableCallable<Long>() {
- @Override
- public String getId() {
- return "sync-compute";
- }
-
- @Override
- public Class<Long> getResultClass() {
- return Long.class;
- }
-
- @Override
- public Long call() {
- DURABLE_CALL_COUNTER.incrementAndGet();
- return inputData * 3;
- }
- });
+ durableCallable(
+ "sync-compute",
+ Long.class,
+ () -> {
+
DURABLE_CALL_COUNTER.incrementAndGet();
+ return inputData * 3;
+ }));
context.sendEvent(new OutputEvent(result));
} catch (Exception e) {
@@ -1209,31 +1411,131 @@ public class ActionExecutionOperatorTest {
public static final java.util.concurrent.atomic.AtomicInteger
EXCEPTION_CALL_COUNTER =
new java.util.concurrent.atomic.AtomicInteger(0);
+ public enum ReconcileBehavior {
+ SUCCESS,
+ EXCEPTION
+ }
+
+ public static final java.util.concurrent.atomic.AtomicInteger
RECONCILABLE_CALL_COUNTER =
+ new java.util.concurrent.atomic.AtomicInteger(0);
+ public static final java.util.concurrent.atomic.AtomicInteger
+ RECONCILABLE_RECONCILE_COUNTER = new
java.util.concurrent.atomic.AtomicInteger(0);
+ public static volatile ReconcileBehavior
RECONCILABLE_RECOVERY_BEHAVIOR =
+ ReconcileBehavior.SUCCESS;
+ public static volatile long RECONCILABLE_RECOVERY_RESULT = 42L;
+ public static volatile String RECONCILABLE_EXCEPTION_MESSAGE =
"reconcile unavailable";
+
+ public static final java.util.concurrent.atomic.AtomicInteger
MIXED_LEGACY_CALL_COUNTER =
+ new java.util.concurrent.atomic.AtomicInteger(0);
+ public static final java.util.concurrent.atomic.AtomicInteger
+ MIXED_RECONCILABLE_CALL_COUNTER = new
java.util.concurrent.atomic.AtomicInteger(0);
+ public static final java.util.concurrent.atomic.AtomicInteger
MIXED_RECONCILE_COUNTER =
+ new java.util.concurrent.atomic.AtomicInteger(0);
+ public static volatile ReconcileBehavior MIXED_RECONCILE_BEHAVIOR =
+ ReconcileBehavior.SUCCESS;
+ public static volatile long MIXED_RECONCILE_RESULT = 50L;
+
public static void durableExceptionAction(InputEvent event,
RunnerContext context) {
try {
context.durableExecute(
- new DurableCallable<String>() {
- @Override
- public String getId() {
- return "exception-action";
- }
-
- @Override
- public Class<String> getResultClass() {
- return String.class;
- }
-
- @Override
- public String call() {
- EXCEPTION_CALL_COUNTER.incrementAndGet();
- throw new RuntimeException("Test exception
from durableExecute");
- }
- });
+ durableCallable(
+ "exception-action",
+ String.class,
+ () -> {
+ EXCEPTION_CALL_COUNTER.incrementAndGet();
+ throw new RuntimeException(
+ "Test exception from
durableExecute");
+ }));
} catch (Exception e) {
context.sendEvent(new OutputEvent("ERROR:" + e.getMessage()));
}
}
+ public static void durableReconcilableAction(InputEvent event,
RunnerContext context) {
+ Long inputData = (Long) event.getInput();
+ try {
+ Long result =
+ context.durableExecute(
+ reconcilableDurableCallable(
+ "reconcilable-call",
+ Long.class,
+ () -> {
+
RECONCILABLE_CALL_COUNTER.incrementAndGet();
+ return inputData * 5;
+ },
+ () -> {
+
RECONCILABLE_RECONCILE_COUNTER.incrementAndGet();
+ switch
(RECONCILABLE_RECOVERY_BEHAVIOR) {
+ case SUCCESS:
+ return
RECONCILABLE_RECOVERY_RESULT;
+ case EXCEPTION:
+ throw new
IllegalStateException(
+
RECONCILABLE_EXCEPTION_MESSAGE);
+ }
+ throw new IllegalStateException(
+ "Unsupported reconcile
behavior");
+ }));
+ context.sendEvent(new OutputEvent(result));
+ } catch (Exception e) {
+ context.sendEvent(new OutputEvent("ERROR:" + e.getMessage()));
+ }
+ }
+
+ public static void durableMixedRecoveryAction(InputEvent event,
RunnerContext context) {
+ Long inputData = (Long) event.getInput();
+ try {
+ Long firstResult =
+ context.durableExecute(
+ durableCallable(
+ "mixed-legacy-call",
+ Long.class,
+ () -> {
+
MIXED_LEGACY_CALL_COUNTER.incrementAndGet();
+ return inputData + 10;
+ }));
+ Long secondResult =
+ context.durableExecute(
+ reconcilableDurableCallable(
+ "mixed-reconcilable-call",
+ Long.class,
+ () -> {
+
MIXED_RECONCILABLE_CALL_COUNTER.incrementAndGet();
+ return firstResult * 2;
+ },
+ () -> {
+
MIXED_RECONCILE_COUNTER.incrementAndGet();
+ switch (MIXED_RECONCILE_BEHAVIOR) {
+ case SUCCESS:
+ return
MIXED_RECONCILE_RESULT;
+ case EXCEPTION:
+ throw new
IllegalStateException(
+ "mixed reconcile
failed");
+ }
+ throw new IllegalStateException(
+ "Unsupported reconcile
behavior");
+ }));
+ context.sendEvent(new OutputEvent(firstResult + secondResult));
+ } catch (Exception e) {
+ ExceptionUtils.rethrow(e);
+ }
+ }
+
+ public static void resetReconcilableRecoveryFixture() {
+ RECONCILABLE_CALL_COUNTER.set(0);
+ RECONCILABLE_RECONCILE_COUNTER.set(0);
+ RECONCILABLE_RECOVERY_BEHAVIOR = ReconcileBehavior.SUCCESS;
+ RECONCILABLE_RECOVERY_RESULT = 42L;
+ RECONCILABLE_EXCEPTION_MESSAGE = "reconcile unavailable";
+ }
+
+ public static void resetMixedRecoveryFixture() {
+ MIXED_LEGACY_CALL_COUNTER.set(0);
+ MIXED_RECONCILABLE_CALL_COUNTER.set(0);
+ MIXED_RECONCILE_COUNTER.set(0);
+ MIXED_RECONCILE_BEHAVIOR = ReconcileBehavior.SUCCESS;
+ MIXED_RECONCILE_RESULT = 50L;
+ }
+
public static AgentPlan getAgentPlan(boolean
testMemoryAccessOutOfMailbox) {
return getAgentPlanWithConfig(new AgentConfiguration(),
testMemoryAccessOutOfMailbox);
}
@@ -1375,6 +1677,54 @@ public class ActionExecutionOperatorTest {
return null;
}
+ public static AgentPlan getDurableReconcilableAgentPlan() {
+ try {
+ Map<String, List<Action>> actionsByEvent = new HashMap<>();
+ Map<String, Action> actions = new HashMap<>();
+
+ Action reconcilableAction =
+ new Action(
+ "durableReconcilableAction",
+ new JavaFunction(
+ TestAgent.class,
+ "durableReconcilableAction",
+ new Class<?>[] {InputEvent.class,
RunnerContext.class}),
+
Collections.singletonList(InputEvent.class.getName()));
+ actionsByEvent.put(
+ InputEvent.class.getName(),
Collections.singletonList(reconcilableAction));
+ actions.put(reconcilableAction.getName(), reconcilableAction);
+
+ return new AgentPlan(actions, actionsByEvent, new HashMap<>());
+ } catch (Exception e) {
+ ExceptionUtils.rethrow(e);
+ }
+ return null;
+ }
+
+ public static AgentPlan getDurableMixedRecoveryAgentPlan() {
+ try {
+ Map<String, List<Action>> actionsByEvent = new HashMap<>();
+ Map<String, Action> actions = new HashMap<>();
+
+ Action mixedRecoveryAction =
+ new Action(
+ "durableMixedRecoveryAction",
+ new JavaFunction(
+ TestAgent.class,
+ "durableMixedRecoveryAction",
+ new Class<?>[] {InputEvent.class,
RunnerContext.class}),
+
Collections.singletonList(InputEvent.class.getName()));
+ actionsByEvent.put(
+ InputEvent.class.getName(),
Collections.singletonList(mixedRecoveryAction));
+ actions.put(mixedRecoveryAction.getName(),
mixedRecoveryAction);
+
+ return new AgentPlan(actions, actionsByEvent, new HashMap<>());
+ } catch (Exception e) {
+ ExceptionUtils.rethrow(e);
+ }
+ return null;
+ }
+
public static AgentPlan getDurableExceptionAgentPlan() {
try {
Map<String, List<Action>> actionsByEvent = new HashMap<>();
@@ -1415,24 +1765,14 @@ public class ActionExecutionOperatorTest {
public static void durableExceptionUncaughtAction(InputEvent event,
RunnerContext context) {
try {
context.durableExecute(
- new DurableCallable<String>() {
- @Override
- public String getId() {
- return "uncaught-exception-action";
- }
-
- @Override
- public Class<String> getResultClass() {
- return String.class;
- }
-
- @Override
- public String call() {
-
UNCAUGHT_EXCEPTION_CALL_COUNTER.incrementAndGet();
- throw new IllegalStateException(
- "Simulated LLM failure: Connection
timeout");
- }
- });
+ durableCallable(
+ "uncaught-exception-action",
+ String.class,
+ () -> {
+
UNCAUGHT_EXCEPTION_CALL_COUNTER.incrementAndGet();
+ throw new IllegalStateException(
+ "Simulated LLM failure: Connection
timeout");
+ }));
} catch (Exception e) {
// Re-throw without wrapping - simulates built-in action
behavior
ExceptionUtils.rethrow(e);
@@ -1453,23 +1793,13 @@ public class ActionExecutionOperatorTest {
public static void durableAsyncExceptionAction(InputEvent event,
RunnerContext context) {
try {
context.durableExecuteAsync(
- new DurableCallable<String>() {
- @Override
- public String getId() {
- return "async-exception-action";
- }
-
- @Override
- public Class<String> getResultClass() {
- return String.class;
- }
-
- @Override
- public String call() {
- ASYNC_EXCEPTION_CALL_COUNTER.incrementAndGet();
- throw new RuntimeException("Async operation
failed: API error");
- }
- });
+ durableCallable(
+ "async-exception-action",
+ String.class,
+ () -> {
+
ASYNC_EXCEPTION_CALL_COUNTER.incrementAndGet();
+ throw new RuntimeException("Async
operation failed: API error");
+ }));
} catch (Exception e) {
ExceptionUtils.rethrow(e);
}
@@ -1524,6 +1854,39 @@ public class ActionExecutionOperatorTest {
}
}
+ private static ActionState actionStateWithCallResults(CallResult...
callResults) {
+ ActionState actionState = new ActionState(null);
+ for (CallResult callResult : callResults) {
+ actionState.addCallResult(callResult);
+ }
+ return actionState;
+ }
+
+ private static void seedActionState(
+ InMemoryActionStateStore actionStateStore,
+ long key,
+ long input,
+ AgentPlan agentPlan,
+ String actionName,
+ ActionState actionState)
+ throws Exception {
+ InputEvent event = new InputEvent(input);
+ Action action = agentPlan.getActions().get(actionName);
+ actionStateStore.put(key, 0L, action, event, actionState);
+ }
+
+ private static ActionState getStoredActionState(
+ InMemoryActionStateStore actionStateStore,
+ long key,
+ long input,
+ AgentPlan agentPlan,
+ String actionName)
+ throws Exception {
+ InputEvent event = new InputEvent(input);
+ Action action = agentPlan.getActions().get(actionName);
+ return actionStateStore.get(key, 0L, action, event);
+ }
+
private static void assertMailboxSizeAndRun(TaskMailbox mailbox, int
expectedSize)
throws Exception {
assertThat(mailbox.size()).isEqualTo(expectedSize);