This is an automated email from the ASF dual-hosted git repository.

sxnan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 53d76a75b0c7c0d34dbee902303311e3a57a8d7d
Author: sxnan <[email protected]>
AuthorDate: Thu Jan 8 18:00:37 2026 +0800

    [runtime] Implement CallRecord persistent and restore
---
 .../flink_agents/runtime/flink_runner_context.py   | 282 ++++++++++++++++++++-
 .../runtime/tests/test_durable_execution.py        | 150 +++++++++++
 .../runtime/context/ActionStatePersister.java      |  44 ++++
 .../agents/runtime/context/RunnerContextImpl.java  | 188 ++++++++++++++
 .../runtime/operator/ActionExecutionOperator.java  | 101 +++++++-
 .../python/context/PythonRunnerContextImpl.java    |   1 +
 .../context/DurableExecutionContextTest.java       | 206 +++++++++++++++
 7 files changed, 954 insertions(+), 18 deletions(-)

diff --git a/python/flink_agents/runtime/flink_runner_context.py 
b/python/flink_agents/runtime/flink_runner_context.py
index 257d22d4..dffdaf73 100644
--- a/python/flink_agents/runtime/flink_runner_context.py
+++ b/python/flink_agents/runtime/flink_runner_context.py
@@ -15,6 +15,8 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
#################################################################################
+import hashlib
+import logging
 import os
 from concurrent.futures import ThreadPoolExecutor
 from typing import Any, Callable, Dict
@@ -39,11 +41,136 @@ from 
flink_agents.runtime.memory.vector_store_long_term_memory import (
     VectorStoreLongTermMemory,
 )
 
+logger = logging.getLogger(__name__)
+
+
+class _DurableExecutionResult:
+    """Wrapper that holds result and triggers recording when unwrapped."""
+
+    def __init__(
+        self,
+        func: Callable,
+        args: tuple,
+        kwargs: dict,
+        result: Any,
+        record_callback: Callable,
+    ) -> None:
+        self.func = func
+        self.args = args
+        self.kwargs = kwargs
+        self.result = result
+        self.record_callback = record_callback
+        self._recorded = False
+
+    def get_result(self) -> Any:
+        """Get the result and record completion if not already recorded."""
+        if not self._recorded:
+            self.record_callback(self.func, self.args, self.kwargs, 
self.result, None)
+            self._recorded = True
+        return self.result
+
+
+class _DurableExecutionException(Exception):
+    """Wrapper exception that holds exception info and triggers recording."""
+
+    def __init__(
+        self,
+        func: Callable,
+        args: tuple,
+        kwargs: dict,
+        result: Any,
+        exception: BaseException,
+        record_callback: Callable,
+    ) -> None:
+        super().__init__(str(exception))
+        self.func = func
+        self.args = args
+        self.kwargs = kwargs
+        self.original_exception = exception
+        self.record_callback = record_callback
+        self._recorded = False
+
+    def record_and_raise(self) -> None:
+        """Record completion and raise the original exception."""
+        if not self._recorded:
+            self.record_callback(
+                self.func, self.args, self.kwargs, None, 
self.original_exception
+            )
+            self._recorded = True
+        raise self.original_exception from None
+
+
+class _CachedAsyncExecutionResult(AsyncExecutionResult):
+    """An AsyncExecutionResult that returns a cached value immediately."""
+
+    def __init__(self, cached_result: Any) -> None:
+        # Don't call super().__init__ as we don't need 
executor/func/args/kwargs
+        self._cached_result = cached_result
+
+    def __await__(self) -> Any:
+        """Return the cached result immediately.
+
+        This is a generator that yields nothing and returns the cached result.
+        """
+        if False:
+            yield  # Make this a generator function
+        return self._cached_result
+
+
+class _DurableAsyncExecutionResult(AsyncExecutionResult):
+    """An AsyncExecutionResult that records completion after execution."""
+
+    def __init__(
+        self, executor: Any, func: Callable, args: tuple, kwargs: dict
+    ) -> None:
+        super().__init__(executor, func, args, kwargs)
+
+    def __await__(self) -> Any:
+        """Execute and record completion when awaited."""
+        future = self._executor.submit(self._func, *self._args, **self._kwargs)
+        while not future.done():
+            yield
+
+        result = future.result()
+
+        # Handle the wrapped result/exception
+        if isinstance(result, _DurableExecutionResult):
+            return result.get_result()
+        elif isinstance(result, _DurableExecutionException):
+            result.record_and_raise()
+        else:
+            return result
+
+
+def _compute_function_id(func: Callable) -> str:
+    """Compute a stable function identifier from a callable.
+
+    Returns module.qualname for functions/methods.
+    """
+    module = getattr(func, "__module__", "<unknown>")
+    qualname = getattr(func, "__qualname__", getattr(func, "__name__", 
"<unknown>"))
+    return f"{module}.{qualname}"
+
+
+def _compute_args_digest(args: tuple, kwargs: dict) -> str:
+    """Compute a stable digest of the serialized arguments.
+
+    The digest is used to validate that the same arguments are passed
+    during recovery as during the original execution.
+    """
+    try:
+        serialized = cloudpickle.dumps((args, kwargs))
+        return hashlib.sha256(serialized).hexdigest()[:16]
+    except Exception:
+        # If serialization fails, return a fallback digest
+        return hashlib.sha256(str((args, kwargs)).encode()).hexdigest()[:16]
+
 
 class FlinkRunnerContext(RunnerContext):
     """Providing context for agent execution in Flink Environment.
 
-    This context allows access to event handling.
+    This context allows access to event handling and provides fine-grained
+    durable execution support through execute() and execute_async() methods.
     """
 
     __agent_plan: AgentPlan | None
@@ -185,34 +312,167 @@ class FlinkRunnerContext(RunnerContext):
         """
         return FlinkMetricGroup(self._j_runner_context.getActionMetricGroup())
 
+    def _try_get_cached_result(
+        self, func: Callable, args: tuple, kwargs: dict
+    ) -> tuple[bool, Any]:
+        """Try to get a cached result from a previous execution.
+
+        Returns:
+        -------
+        tuple[bool, Any]
+            A tuple of (is_hit, result_or_exception). If is_hit is True,
+            the second element is the cached result or an exception to 
re-raise.
+        """
+        function_id = _compute_function_id(func)
+        args_digest = _compute_args_digest(args, kwargs)
+
+        cached_exception: BaseException | None = None
+        try:
+            cached = 
self._j_runner_context.matchNextOrClearSubsequentCallResult(
+                function_id, args_digest
+            )
+            if cached is not None:
+                is_hit, result_payload, exception_payload = cached
+                if is_hit:
+                    if exception_payload is not None:
+                        # Store cached exception to re-raise outside try block
+                        cached_exception = 
cloudpickle.loads(bytes(exception_payload))
+                    elif result_payload is not None:
+                        return True, cloudpickle.loads(bytes(result_payload))
+                    else:
+                        return True, None
+        except Exception as e:
+            # If Java method doesn't exist (not supported), fall through to 
execute
+            if "matchNextOrClearSubsequentCallResult" in str(e):
+                logger.debug("Durable execution not supported, executing 
directly")
+            else:
+                raise
+
+        # Re-raise cached exception outside try block
+        if cached_exception is not None:
+            raise cached_exception
+
+        return False, None
+
+    def _record_call_completion(
+        self,
+        func: Callable,
+        args: tuple,
+        kwargs: dict,
+        result: Any,
+        exception: BaseException | None,
+    ) -> None:
+        """Record the completion of a call for durable execution.
+
+        Parameters
+        ----------
+        func : Callable
+            The function that was executed.
+        args : tuple
+            Positional arguments passed to the function.
+        kwargs : dict
+            Keyword arguments passed to the function.
+        result : Any
+            The result of the function (None if exception occurred).
+        exception : BaseException | None
+            The exception raised by the function (None if successful).
+        """
+        function_id = _compute_function_id(func)
+        args_digest = _compute_args_digest(args, kwargs)
+
+        try:
+            result_payload = None if exception else cloudpickle.dumps(result)
+            exception_payload = cloudpickle.dumps(exception) if exception else 
None
+
+            self._j_runner_context.recordCallCompletion(
+                function_id, args_digest, result_payload, exception_payload
+            )
+        except Exception as e:
+            # If Java method doesn't exist, silently ignore
+            if "recordCallCompletion" not in str(e):
+                logger.warning("Failed to record call completion: %s", e)
+
     @override
-    def execute(
+    def durable_execute(
         self,
         func: Callable[[Any], Any],
         *args: Any,
         **kwargs: Any,
     ) -> Any:
-        """Synchronously execute the provided function. Access to memory
-        is prohibited within the function.
+        """Synchronously execute the provided function with durable execution 
support.
+        Access to memory is prohibited within the function.
+
+        The result of the function will be stored and returned when the same
+        durable_execute call is made again during job recovery. The arguments 
and the
+        result must be serializable.
 
         The function is executed synchronously in the current thread, blocking
         the operator until completion.
         """
-        # TODO: Add durable execution support (persist result for recovery)
-        return func(*args, **kwargs)
+        # Try to get cached result for recovery
+        is_hit, cached_result = self._try_get_cached_result(func, args, kwargs)
+        if is_hit:
+            return cached_result
+
+        # Execute the function
+        exception = None
+        result = None
+        try:
+            result = func(*args, **kwargs)
+        except BaseException as e:
+            exception = e
+
+        # Record the completion
+        self._record_call_completion(func, args, kwargs, result, exception)
+
+        if exception:
+            raise exception
+        return result
 
     @override
-    def execute_async(
+    def durable_execute_async(
         self,
         func: Callable[[Any], Any],
         *args: Any,
         **kwargs: Any,
     ) -> AsyncExecutionResult:
-        """Asynchronously execute the provided function. Access to memory
-        is prohibited within the function.
+        """Asynchronously execute the provided function with durable execution 
support.
+        Access to memory is prohibited within the function.
+
+        The result of the function will be stored and returned when the same
+        durable_execute_async call is made again during job recovery. The 
arguments
+        and the result must be serializable.
+
+        Important: The result is only recorded when the returned 
AsyncExecutionResult
+        is awaited. Fire-and-forget calls (not awaiting the result) will NOT be
+        recorded and cannot be recovered.
         """
-        # TODO: Add durable execution support (persist result for recovery)
-        return AsyncExecutionResult(self.executor, func, args, kwargs)
+        # Try to get cached result for recovery
+        is_hit, cached_result = self._try_get_cached_result(func, args, kwargs)
+        if is_hit:
+            # Return a pre-completed AsyncExecutionResult
+            return _CachedAsyncExecutionResult(cached_result)
+
+        # Create a wrapper function that records completion
+        def wrapped_func(*a: Any, **kw: Any) -> Any:
+            exception = None
+            result = None
+            try:
+                result = func(*a, **kw)
+            except BaseException as e:
+                exception = e
+
+            # Note: This runs in a thread pool, so we need to be careful
+            # The actual recording will happen when the result is awaited
+            if exception:
+                raise _DurableExecutionException(
+                    func, args, kwargs, result, exception, 
self._record_call_completion
+                )
+            return _DurableExecutionResult(
+                func, args, kwargs, result, self._record_call_completion
+            )
+
+        return _DurableAsyncExecutionResult(self.executor, wrapped_func, args, 
kwargs)
 
     @property
     @override
diff --git a/python/flink_agents/runtime/tests/test_durable_execution.py 
b/python/flink_agents/runtime/tests/test_durable_execution.py
new file mode 100644
index 00000000..e59e54cd
--- /dev/null
+++ b/python/flink_agents/runtime/tests/test_durable_execution.py
@@ -0,0 +1,150 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+"""Tests for durable execution helper functions."""
+
+import cloudpickle
+
+from flink_agents.runtime.flink_runner_context import (
+    _compute_args_digest,
+    _compute_function_id,
+)
+
+
+def sample_function(x: int, y: int) -> int:
+    """A sample function for testing."""
+    return x + y
+
+
+class SampleClass:
+    """A sample class for testing method function IDs."""
+
+    def instance_method(self, x: int) -> int:
+        """An instance method."""
+        return x * 2
+
+    @staticmethod
+    def static_method(x: int) -> int:
+        """A static method."""
+        return x * 3
+
+    @classmethod
+    def class_method(cls, x: int) -> int:
+        """A class method."""
+        return x * 4
+
+
+def test_compute_function_id_for_function() -> None:
+    """Test function ID computation for regular functions."""
+    func_id = _compute_function_id(sample_function)
+    assert "sample_function" in func_id
+    assert "test_durable_execution" in func_id
+
+
+def test_compute_function_id_for_lambda() -> None:
+    """Test function ID computation for lambda functions."""
+    lambda_func = lambda x: x + 1  # noqa: E731
+    func_id = _compute_function_id(lambda_func)
+    assert "<lambda>" in func_id
+
+
+def test_compute_function_id_for_method() -> None:
+    """Test function ID computation for instance methods."""
+    obj = SampleClass()
+    func_id = _compute_function_id(obj.instance_method)
+    assert "instance_method" in func_id
+    assert "SampleClass" in func_id
+
+
+def test_compute_function_id_for_static_method() -> None:
+    """Test function ID computation for static methods."""
+    func_id = _compute_function_id(SampleClass.static_method)
+    assert "static_method" in func_id
+
+
+def test_compute_function_id_for_class_method() -> None:
+    """Test function ID computation for class methods."""
+    func_id = _compute_function_id(SampleClass.class_method)
+    assert "class_method" in func_id
+
+
+def test_compute_args_digest_basic() -> None:
+    """Test args digest computation for basic types."""
+    digest1 = _compute_args_digest((1, 2), {"key": "value"})
+    digest2 = _compute_args_digest((1, 2), {"key": "value"})
+    # Same arguments should produce same digest
+    assert digest1 == digest2
+
+    # Different arguments should produce different digest
+    digest3 = _compute_args_digest((1, 3), {"key": "value"})
+    assert digest1 != digest3
+
+
+def test_compute_args_digest_empty() -> None:
+    """Test args digest computation for empty arguments."""
+    digest = _compute_args_digest((), {})
+    assert len(digest) == 16  # SHA256 truncated to 16 chars
+
+
+def test_compute_args_digest_complex_types() -> None:
+    """Test args digest computation for complex types."""
+    complex_args = (
+        {"nested": {"key": [1, 2, 3]}},
+        [1, 2, {"inner": "value"}],
+    )
+    complex_kwargs = {"data": {"x": 1, "y": 2}}
+
+    digest1 = _compute_args_digest(complex_args, complex_kwargs)
+    digest2 = _compute_args_digest(complex_args, complex_kwargs)
+    assert digest1 == digest2
+
+
+def test_compute_args_digest_order_matters() -> None:
+    """Test that argument order affects the digest."""
+    digest1 = _compute_args_digest((1, 2), {})
+    digest2 = _compute_args_digest((2, 1), {})
+    assert digest1 != digest2
+
+
+def test_compute_args_digest_kwargs_vs_args() -> None:
+    """Test that kwargs and args produce different digests."""
+    digest1 = _compute_args_digest((1,), {"y": 2})
+    digest2 = _compute_args_digest((1, 2), {})
+    assert digest1 != digest2
+
+
+def test_cloudpickle_serialization() -> None:
+    """Test that results can be serialized and deserialized with 
cloudpickle."""
+    # Test basic types
+    original = {"key": "value", "number": 42, "list": [1, 2, 3]}
+    serialized = cloudpickle.dumps(original)
+    deserialized = cloudpickle.loads(serialized)
+    assert deserialized == original
+
+    # Test exception
+    def raise_test_error() -> None:
+        error_message = "test error"
+        raise ValueError(error_message)
+
+    try:
+        raise_test_error()
+    except ValueError as e:
+        serialized_exc = cloudpickle.dumps(e)
+        deserialized_exc = cloudpickle.loads(serialized_exc)
+        assert str(deserialized_exc) == "test error"
+        assert isinstance(deserialized_exc, ValueError)
+
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/ActionStatePersister.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/ActionStatePersister.java
new file mode 100644
index 00000000..529098a8
--- /dev/null
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/ActionStatePersister.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.runtime.context;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.plan.actions.Action;
+import org.apache.flink.agents.runtime.actionstate.ActionState;
+
+/**
+ * Interface for persisting {@link ActionState}.
+ *
+ * <p>This interface decouples the {@link 
RunnerContextImpl.DurableExecutionContext} from the
+ * storage layer.
+ */
+public interface ActionStatePersister {
+
+    /**
+     * Persists the given ActionState.
+     *
+     * @param key the key for the action
+     * @param sequenceNumber the sequence number for ordering
+     * @param action the action being executed
+     * @param event the event that triggered the action
+     * @param actionState the ActionState to persist
+     */
+    void persist(
+            Object key, long sequenceNumber, Action action, Event event, 
ActionState actionState);
+}
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
index 2b8d8dc5..8a946f2d 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
@@ -28,13 +28,20 @@ import 
org.apache.flink.agents.api.memory.LongTermMemoryOptions;
 import org.apache.flink.agents.api.resource.Resource;
 import org.apache.flink.agents.api.resource.ResourceType;
 import org.apache.flink.agents.plan.AgentPlan;
+import org.apache.flink.agents.plan.actions.Action;
 import org.apache.flink.agents.plan.utils.JsonUtils;
+import org.apache.flink.agents.runtime.actionstate.ActionState;
+import org.apache.flink.agents.runtime.actionstate.CallResult;
 import org.apache.flink.agents.runtime.memory.CachedMemoryStore;
 import org.apache.flink.agents.runtime.memory.InteranlBaseLongTermMemory;
 import org.apache.flink.agents.runtime.memory.MemoryObjectImpl;
 import org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory;
 import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
 import org.apache.flink.util.Preconditions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
 
 import java.util.ArrayList;
 import java.util.LinkedList;
@@ -77,6 +84,8 @@ public class RunnerContextImpl implements RunnerContext {
         }
     }
 
+    private static final Logger LOG = 
LoggerFactory.getLogger(RunnerContextImpl.class);
+
     protected final List<Event> pendingEvents = new ArrayList<>();
     protected final FlinkAgentsMetricGroupImpl agentMetricGroup;
     protected final Runnable mailboxThreadChecker;
@@ -86,6 +95,9 @@ public class RunnerContextImpl implements RunnerContext {
     protected String actionName;
     protected InteranlBaseLongTermMemory ltm;
 
+    /** Context for fine-grained durable execution, may be null if not 
enabled. */
+    @Nullable protected DurableExecutionContext durableExecutionContext;
+
     public RunnerContextImpl(
             FlinkAgentsMetricGroupImpl agentMetricGroup,
             Runnable mailboxThreadChecker,
@@ -247,4 +259,180 @@ public class RunnerContextImpl implements RunnerContext {
     public void clearSensoryMemory() throws Exception {
         memoryContext.getSensoryMemStore().clear();
     }
+
+    public void setDurableExecutionContext(
+            @Nullable DurableExecutionContext durableExecutionContext) {
+        this.durableExecutionContext = durableExecutionContext;
+    }
+
+    @Nullable
+    public DurableExecutionContext getDurableExecutionContext() {
+        return durableExecutionContext;
+    }
+
+    public void clearDurableExecutionContext() {
+        this.durableExecutionContext = null;
+    }
+
+    /**
+     * Matches the next call result for recovery, or clears subsequent results 
if mismatch detected.
+     *
+     * <p>This method delegates to the {@link DurableExecutionContext} if 
present.
+     *
+     * @param functionId the function identifier
+     * @param argsDigest the digest of serialized arguments
+     * @return array containing [isHit (boolean), resultPayload (byte[]), 
exceptionPayload
+     *     (byte[])], or null if miss or durable execution is not enabled
+     */
+    public Object[] matchNextOrClearSubsequentCallResult(String functionId, 
String argsDigest) {
+        mailboxThreadChecker.run();
+        if (durableExecutionContext != null) {
+            return 
durableExecutionContext.matchNextOrClearSubsequentCallResult(
+                    functionId, argsDigest);
+        }
+        return null;
+    }
+
+    /**
+     * Records a completed call and persists the ActionState.
+     *
+     * <p>This method delegates to the {@link DurableExecutionContext} if 
present.
+     *
+     * @param functionId the function identifier
+     * @param argsDigest the digest of serialized arguments
+     * @param resultPayload the serialized result (null if exception)
+     * @param exceptionPayload the serialized exception (null if success)
+     */
+    public void recordCallCompletion(
+            String functionId, String argsDigest, byte[] resultPayload, byte[] 
exceptionPayload) {
+        mailboxThreadChecker.run();
+        if (durableExecutionContext != null) {
+            durableExecutionContext.recordCallCompletion(
+                    functionId, argsDigest, resultPayload, exceptionPayload);
+        }
+    }
+
+    /**
+     * Context for fine-grained durable execution within an action.
+     *
+     * <p>This class encapsulates all state needed for {@code 
durable_execute}/{@code
+     * durable_execute_async} recovery. During normal execution, each call is 
recorded as a {@link
+     * CallResult}. During recovery, these results are used to skip 
re-execution of already
+     * completed calls.
+     */
+    public static class DurableExecutionContext {
+        private final Object key;
+        private final long sequenceNumber;
+        private final Action action;
+        private final Event event;
+        private final ActionState actionState;
+        private final ActionStatePersister persister;
+
+        /** Current call index within the action, used for matching 
CallResults during recovery. */
+        private int currentCallIndex;
+
+        /** Snapshot of CallResults loaded during recovery. */
+        private List<CallResult> recoveryCallResults;
+
+        public DurableExecutionContext(
+                Object key,
+                long sequenceNumber,
+                Action action,
+                Event event,
+                ActionState actionState,
+                ActionStatePersister persister) {
+            this.key = key;
+            this.sequenceNumber = sequenceNumber;
+            this.action = action;
+            this.event = event;
+            this.actionState = actionState;
+            this.persister = persister;
+            this.currentCallIndex = 0;
+            this.recoveryCallResults =
+                    actionState.getCallResults() != null
+                            ? new ArrayList<>(actionState.getCallResults())
+                            : new ArrayList<>();
+        }
+
+        public int getCurrentCallIndex() {
+            return currentCallIndex;
+        }
+
+        public ActionState getActionState() {
+            return actionState;
+        }
+
+        /**
+         * Matches the next call result for recovery, or clears subsequent 
results if mismatch
+         * detected.
+         *
+         * @param functionId the function identifier
+         * @param argsDigest the digest of serialized arguments
+         * @return array containing [isHit, resultPayload, exceptionPayload], 
or null if miss
+         */
+        public Object[] matchNextOrClearSubsequentCallResult(String 
functionId, String argsDigest) {
+            if (currentCallIndex < recoveryCallResults.size()) {
+                CallResult result = recoveryCallResults.get(currentCallIndex);
+
+                if (result.matches(functionId, argsDigest)) {
+                    LOG.debug(
+                            "CallResult hit at index {}: functionId={}, 
argsDigest={}",
+                            currentCallIndex,
+                            functionId,
+                            argsDigest);
+                    currentCallIndex++;
+                    return new Object[] {
+                        true, result.getResultPayload(), 
result.getExceptionPayload()
+                    };
+                } else {
+                    LOG.warn(
+                            "Non-deterministic call detected at index {}: 
expected functionId={}, "
+                                    + "argsDigest={}, but got functionId={}, 
argsDigest={}. "
+                                    + "Clearing subsequent results.",
+                            currentCallIndex,
+                            result.getFunctionId(),
+                            result.getArgsDigest(),
+                            functionId,
+                            argsDigest);
+                    clearCallResultsFromCurrentIndex();
+                }
+            }
+            return null;
+        }
+
+        /**
+         * Records a completed call and persists the ActionState.
+         *
+         * @param functionId the function identifier
+         * @param argsDigest the digest of serialized arguments
+         * @param resultPayload the serialized result (null if exception)
+         * @param exceptionPayload the serialized exception (null if success)
+         */
+        public void recordCallCompletion(
+                String functionId,
+                String argsDigest,
+                byte[] resultPayload,
+                byte[] exceptionPayload) {
+            CallResult callResult =
+                    new CallResult(functionId, argsDigest, resultPayload, 
exceptionPayload);
+
+            actionState.addCallResult(callResult);
+            persister.persist(key, sequenceNumber, action, event, actionState);
+
+            LOG.debug(
+                    "Recorded and persisted CallResult at index {}: 
functionId={}, argsDigest={}",
+                    currentCallIndex,
+                    functionId,
+                    argsDigest);
+
+            currentCallIndex++;
+        }
+
+        private void clearCallResultsFromCurrentIndex() {
+            actionState.clearCallResultsFrom(currentCallIndex);
+            recoveryCallResults =
+                    recoveryCallResults.subList(
+                            0, Math.min(currentCallIndex, 
recoveryCallResults.size()));
+        }
+    }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
index 1b569ac8..60b7e329 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
@@ -38,6 +38,7 @@ import 
org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider;
 import org.apache.flink.agents.runtime.actionstate.ActionState;
 import org.apache.flink.agents.runtime.actionstate.ActionStateStore;
 import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore;
+import org.apache.flink.agents.runtime.context.ActionStatePersister;
 import org.apache.flink.agents.runtime.context.RunnerContextImpl;
 import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment;
 import org.apache.flink.agents.runtime.env.PythonEnvironmentManager;
@@ -110,7 +111,7 @@ import static 
org.apache.flink.util.Preconditions.checkState;
  * and the resulting output event is collected for further processing.
  */
 public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT>
-        implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
+        implements OneInputStreamOperator<IN, OUT>, BoundedOneInput, 
ActionStatePersister {
 
     private static final long serialVersionUID = 1L;
 
@@ -190,6 +191,11 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
     private final transient Map<ActionTask, RunnerContextImpl.MemoryContext>
             actionTaskMemoryContexts;
 
+    // This in memory map keeps track of the durable execution context for 
async action tasks
+    // that have not been finished, allowing recovery of currentCallIndex 
across invocations
+    private final transient Map<ActionTask, 
RunnerContextImpl.DurableExecutionContext>
+            actionTaskDurableContexts;
+
     // Each job can only have one identifier and this identifier must be 
consistent across restarts.
     // We cannot use job id as the identifier here because user may change job 
id by
     // creating a savepoint, stop the job and then resume from savepoint.
@@ -212,6 +218,7 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         this.actionStateStore = actionStateStore;
         this.checkpointIdToSeqNums = new HashMap<>();
         this.actionTaskMemoryContexts = new HashMap<>();
+        this.actionTaskDurableContexts = new HashMap<>();
         OperatorUtils.setChainStrategy(this, ChainingStrategy.ALWAYS);
     }
 
@@ -446,7 +453,10 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         Optional<ActionTask> generatedActionTaskOpt = Optional.empty();
         ActionState actionState =
                 maybeGetActionState(key, sequenceNumber, actionTask.action, 
actionTask.event);
-        if (actionState != null) {
+
+        // Check if action is already completed
+        if (actionState != null && actionState.isCompleted()) {
+            // Action has completed, skip execution and replay memory/events
             isFinished = true;
             outputEvents = actionState.getOutputEvents();
             for (MemoryUpdate memoryUpdate : 
actionState.getShortTermMemoryUpdates()) {
@@ -463,16 +473,27 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
                         .set(memoryUpdate.getPath(), memoryUpdate.getValue());
             }
         } else {
-            maybeInitActionState(key, sequenceNumber, actionTask.action, 
actionTask.event);
+            // Initialize ActionState if not exists, or use existing one for 
recovery
+            if (actionState == null) {
+                maybeInitActionState(key, sequenceNumber, actionTask.action, 
actionTask.event);
+                actionState =
+                        maybeGetActionState(
+                                key, sequenceNumber, actionTask.action, 
actionTask.event);
+            }
+
+            // Set up durable execution context for fine-grained recovery
+            setupDurableExecutionContext(actionTask, actionState);
+
             ActionTask.ActionTaskResult actionTaskResult =
                     actionTask.invoke(
                             getRuntimeContext().getUserCodeClassLoader(),
                             this.pythonActionExecutor);
 
-            // We remove the RunnerContext of the action task from the map 
after it is finished. The
-            // RunnerContext will be added later if the action task has a 
generated action task,
-            // meaning it is not finished.
+            // We remove the contexts from the map after the task is 
processed. They will be added
+            // back later if the action task has a generated action task, 
meaning it is not
+            // finished.
             actionTaskMemoryContexts.remove(actionTask);
+            actionTaskDurableContexts.remove(actionTask);
             maybePersistTaskResult(
                     key,
                     sequenceNumber,
@@ -505,10 +526,15 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
             // execution.
             ActionTask generatedActionTask = generatedActionTaskOpt.get();
 
-            // If the action task is not finished, we keep the runner context 
in the memory for the
+            // If the action task is not finished, we keep the contexts in 
memory for the
             // next generated ActionTask to be invoked.
             actionTaskMemoryContexts.put(
                     generatedActionTask, 
actionTask.getRunnerContext().getMemoryContext());
+            RunnerContextImpl.DurableExecutionContext durableContext =
+                    actionTask.getRunnerContext().getDurableExecutionContext();
+            if (durableContext != null) {
+                actionTaskDurableContexts.put(generatedActionTask, 
durableContext);
+            }
 
             actionTasksKState.add(generatedActionTask);
         }
@@ -916,7 +942,68 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         for (Event outputEvent : actionTaskResult.getOutputEvents()) {
             actionState.addEvent(outputEvent);
         }
+
+        // Mark the action as completed and clear call records
+        // This indicates that recovery should skip the entire action
+        actionState.markCompleted();
+
         actionStateStore.put(key, sequenceNum, action, event, actionState);
+
+        // Clear durable execution context
+        context.clearDurableExecutionContext();
+    }
+
+    /**
+     * Sets up the durable execution context for fine-grained recovery.
+     *
+     * <p>This method initializes the runner context with a {@link
+     * RunnerContextImpl.DurableExecutionContext}, which enables 
execute/execute_async calls to:
+     *
+     * <ul>
+     *   <li>Skip re-execution for already completed calls during recovery
+     *   <li>Persist CallRecords after each code block completion
+     * </ul>
+     */
+    private void setupDurableExecutionContext(ActionTask actionTask, 
ActionState actionState) {
+        if (actionStateStore == null) {
+            return;
+        }
+
+        RunnerContextImpl.DurableExecutionContext durableContext;
+        if (actionTaskDurableContexts.containsKey(actionTask)) {
+            // Reuse existing context for async action continuation
+            durableContext = actionTaskDurableContexts.get(actionTask);
+        } else {
+            // Create new context for first invocation
+            final long sequenceNumber;
+            try {
+                sequenceNumber = sequenceNumberKState.value();
+            } catch (Exception e) {
+                throw new RuntimeException("Failed to get sequence number from 
state", e);
+            }
+
+            durableContext =
+                    new RunnerContextImpl.DurableExecutionContext(
+                            actionTask.getKey(),
+                            sequenceNumber,
+                            actionTask.action,
+                            actionTask.event,
+                            actionState,
+                            this);
+        }
+
+        
actionTask.getRunnerContext().setDurableExecutionContext(durableContext);
+    }
+
+    @Override
+    public void persist(
+            Object key, long sequenceNumber, Action action, Event event, 
ActionState actionState) {
+        try {
+            actionStateStore.put(key, sequenceNumber, action, event, 
actionState);
+        } catch (Exception e) {
+            LOG.error("Failed to persist ActionState", e);
+            throw new RuntimeException("Failed to persist ActionState", e);
+        }
     }
 
     private void maybePruneState(Object key, long sequenceNum) throws 
Exception {
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
index 7df56e5e..ddabf503 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
@@ -15,6 +15,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+
 package org.apache.flink.agents.runtime.python.context;
 
 import org.apache.flink.agents.api.Event;
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java
new file mode 100644
index 00000000..f2701e50
--- /dev/null
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.runtime.context;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.plan.actions.Action;
+import org.apache.flink.agents.runtime.actionstate.ActionState;
+import org.apache.flink.agents.runtime.actionstate.CallResult;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.nio.charset.StandardCharsets;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.Mockito.mock;
+
+/** Unit tests for {@link RunnerContextImpl.DurableExecutionContext}. */
+class DurableExecutionContextTest {
+
+    private ActionState actionState;
+    private AtomicInteger persistCallCount;
+    private ActionState lastPersistedState;
+    private Object testKey;
+    private long testSequenceNumber;
+    private Action mockAction;
+    private Event mockEvent;
+
+    @BeforeEach
+    void setUp() {
+        actionState = new ActionState(null);
+        persistCallCount = new AtomicInteger(0);
+        lastPersistedState = null;
+        testKey = "testKey";
+        testSequenceNumber = 1L;
+        mockAction = mock(Action.class);
+        mockEvent = mock(Event.class);
+    }
+
+    private RunnerContextImpl.DurableExecutionContext createContext() {
+        ActionStatePersister persister =
+                (key, seqNum, action, event, state) -> {
+                    persistCallCount.incrementAndGet();
+                    lastPersistedState = state;
+                };
+        return new RunnerContextImpl.DurableExecutionContext(
+                testKey, testSequenceNumber, mockAction, mockEvent, 
actionState, persister);
+    }
+
+    @Test
+    void testInitialization() {
+        actionState.addCallResult(
+                new CallResult("funcA", "digestA", 
"resultA".getBytes(StandardCharsets.UTF_8)));
+        actionState.addCallResult(
+                new CallResult("funcB", "digestB", 
"resultB".getBytes(StandardCharsets.UTF_8)));
+
+        RunnerContextImpl.DurableExecutionContext context = createContext();
+
+        assertEquals(0, context.getCurrentCallIndex());
+        assertSame(actionState, context.getActionState());
+    }
+
+    @Test
+    void testMatchNextOrClearSubsequentCallResultHit() {
+        byte[] expectedResult = 
"cached_result".getBytes(StandardCharsets.UTF_8);
+        actionState.addCallResult(new CallResult("funcA", "digestA", 
expectedResult));
+
+        RunnerContextImpl.DurableExecutionContext context = createContext();
+
+        Object[] result = 
context.matchNextOrClearSubsequentCallResult("funcA", "digestA");
+
+        assertNotNull(result);
+        assertEquals(3, result.length);
+        assertTrue((Boolean) result[0]); // isHit
+        assertArrayEquals(expectedResult, (byte[]) result[1]); // resultPayload
+        assertNull(result[2]); // exceptionPayload
+        assertEquals(1, context.getCurrentCallIndex());
+    }
+
+    @Test
+    void testMatchNextOrClearSubsequentCallResultMiss() {
+        RunnerContextImpl.DurableExecutionContext context = createContext();
+
+        Object[] result = 
context.matchNextOrClearSubsequentCallResult("funcA", "digestA");
+
+        assertNull(result);
+        assertEquals(0, context.getCurrentCallIndex());
+    }
+
+    @Test
+    void testMatchNextOrClearSubsequentCallResultMismatch() {
+        actionState.addCallResult(new CallResult("funcA", "digestA", 
"result".getBytes()));
+        actionState.addCallResult(new CallResult("funcB", "digestB", 
"result".getBytes()));
+
+        RunnerContextImpl.DurableExecutionContext context = createContext();
+
+        // Call with mismatched functionId - should clear subsequent results 
and return null
+        Object[] result = 
context.matchNextOrClearSubsequentCallResult("funcX", "digestX");
+
+        assertNull(result);
+        // ActionState should have results cleared from index 0
+        assertEquals(0, actionState.getCallResultCount());
+        // Persist is not called here - it will be called in 
recordCallCompletion
+        assertEquals(0, persistCallCount.get());
+    }
+
+    @Test
+    void testRecordCallCompletionSuccess() {
+        RunnerContextImpl.DurableExecutionContext context = createContext();
+
+        byte[] resultPayload = 
"success_result".getBytes(StandardCharsets.UTF_8);
+        context.recordCallCompletion("funcA", "digestA", resultPayload, null);
+
+        assertEquals(1, context.getCurrentCallIndex());
+        assertEquals(1, actionState.getCallResults().size());
+        assertEquals("funcA", 
actionState.getCallResults().get(0).getFunctionId());
+        // Verify persister was called
+        assertEquals(1, persistCallCount.get());
+        assertSame(actionState, lastPersistedState);
+    }
+
+    @Test
+    void testRecordCallCompletionException() {
+        RunnerContextImpl.DurableExecutionContext context = createContext();
+
+        byte[] exceptionPayload = 
"exception_data".getBytes(StandardCharsets.UTF_8);
+        context.recordCallCompletion("funcA", "digestA", null, 
exceptionPayload);
+
+        assertEquals(1, context.getCurrentCallIndex());
+        CallResult recorded = actionState.getCallResults().get(0);
+        assertNull(recorded.getResultPayload());
+        assertArrayEquals(exceptionPayload, recorded.getExceptionPayload());
+        assertEquals(1, persistCallCount.get());
+    }
+
+    @Test
+    void testMultipleCallResultRecovery() {
+        byte[] result1 = "result1".getBytes(StandardCharsets.UTF_8);
+        byte[] result2 = "result2".getBytes(StandardCharsets.UTF_8);
+        actionState.addCallResult(new CallResult("func1", "digest1", result1));
+        actionState.addCallResult(new CallResult("func2", "digest2", result2));
+
+        RunnerContextImpl.DurableExecutionContext context = createContext();
+
+        // First call should hit
+        Object[] hit1 = context.matchNextOrClearSubsequentCallResult("func1", 
"digest1");
+        assertNotNull(hit1);
+        assertTrue((Boolean) hit1[0]);
+        assertArrayEquals(result1, (byte[]) hit1[1]);
+
+        // Second call should hit
+        Object[] hit2 = context.matchNextOrClearSubsequentCallResult("func2", 
"digest2");
+        assertNotNull(hit2);
+        assertTrue((Boolean) hit2[0]);
+        assertArrayEquals(result2, (byte[]) hit2[1]);
+
+        // Third call should miss (no more results)
+        Object[] miss = context.matchNextOrClearSubsequentCallResult("func3", 
"digest3");
+        assertNull(miss);
+    }
+
+    @Test
+    void testRecoveryWithExceptionPayload() {
+        byte[] exceptionPayload = 
"exception_data".getBytes(StandardCharsets.UTF_8);
+        actionState.addCallResult(CallResult.ofException("funcA", "digestA", 
exceptionPayload));
+
+        RunnerContextImpl.DurableExecutionContext context = createContext();
+
+        Object[] result = 
context.matchNextOrClearSubsequentCallResult("funcA", "digestA");
+
+        assertNotNull(result);
+        assertTrue((Boolean) result[0]); // isHit
+        assertNull(result[1]); // resultPayload should be null
+        assertArrayEquals(exceptionPayload, (byte[]) result[2]); // 
exceptionPayload
+    }
+
+    @Test
+    void testMultiplePersistCalls() {
+        RunnerContextImpl.DurableExecutionContext context = createContext();
+
+        // Record multiple completions
+        context.recordCallCompletion("func1", "digest1", "result1".getBytes(), 
null);
+        context.recordCallCompletion("func2", "digest2", "result2".getBytes(), 
null);
+        context.recordCallCompletion("func3", "digest3", "result3".getBytes(), 
null);
+
+        // Each call should trigger persistence
+        assertEquals(3, persistCallCount.get());
+        assertEquals(3, actionState.getCallResults().size());
+    }
+}

Reply via email to