This is an automated email from the ASF dual-hosted git repository.

sxnan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 15b1976087762f9c810c1ef37c4500f2a082d1f4
Author: sxnan <[email protected]>
AuthorDate: Thu Jan 8 15:45:10 2026 +0800

    [runtime] Introduce CallRecord to ActionState
---
 .../agents/runtime/actionstate/ActionState.java    | 126 ++++++++-
 .../agents/runtime/actionstate/CallResult.java     | 176 ++++++++++++
 .../runtime/actionstate/ActionStateSerdeTest.java  | 117 +++++++-
 .../runtime/actionstate/ActionStateTest.java       | 310 +++++++++++++++++++++
 .../agents/runtime/actionstate/CallResultTest.java | 157 +++++++++++
 .../actionstate/KafkaActionStateStoreTest.java     |   7 +-
 6 files changed, 884 insertions(+), 9 deletions(-)

diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
index 34eefb35..031928ad 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
@@ -17,6 +17,7 @@
  */
 package org.apache.flink.agents.runtime.actionstate;
 
+import com.fasterxml.jackson.annotation.JsonIgnore;
 import org.apache.flink.agents.api.Event;
 import org.apache.flink.agents.api.context.MemoryUpdate;
 
@@ -30,19 +31,32 @@ public class ActionState {
     private final List<MemoryUpdate> shortTermMemoryUpdates;
     private final List<Event> outputEvents;
 
-    /** Constructs a new TaskActionState instance. */
-    public ActionState(final Event taskEvent) {
-        this.taskEvent = taskEvent;
+    /**
+     * Records of completed durable_execute/durable_execute_async calls for 
fine-grained recovery.
+     */
+    private final List<CallResult> callResults;
+
+    /** Indicates whether the action has completed execution. */
+    private boolean completed;
+
+    /** Default constructor for Jackson deserialization. */
+    private ActionState() {
+        this.taskEvent = null;
         this.sensoryMemoryUpdates = new ArrayList<>();
         this.shortTermMemoryUpdates = new ArrayList<>();
         this.outputEvents = new ArrayList<>();
+        this.callResults = new ArrayList<>();
+        this.completed = false;
     }
 
-    public ActionState() {
-        this.taskEvent = null;
+    /** Constructs a new TaskActionState instance. */
+    public ActionState(final Event taskEvent) {
+        this.taskEvent = taskEvent;
         this.sensoryMemoryUpdates = new ArrayList<>();
         this.shortTermMemoryUpdates = new ArrayList<>();
         this.outputEvents = new ArrayList<>();
+        this.callResults = new ArrayList<>();
+        this.completed = false;
     }
 
     /** Constructor for deserialization purposes. */
@@ -50,13 +64,17 @@ public class ActionState {
             Event taskEvent,
             List<MemoryUpdate> sensoryMemoryUpdates,
             List<MemoryUpdate> shortTermMemoryUpdates,
-            List<Event> outputEvents) {
+            List<Event> outputEvents,
+            List<CallResult> callResults,
+            boolean completed) {
         this.taskEvent = taskEvent;
         this.sensoryMemoryUpdates =
                 sensoryMemoryUpdates != null ? sensoryMemoryUpdates : new 
ArrayList<>();
         this.shortTermMemoryUpdates =
                 shortTermMemoryUpdates != null ? shortTermMemoryUpdates : new 
ArrayList<>();
         this.outputEvents = outputEvents != null ? outputEvents : new 
ArrayList<>();
+        this.callResults = callResults != null ? callResults : new 
ArrayList<>();
+        this.completed = completed;
     }
 
     /** Getters for the fields */
@@ -90,6 +108,77 @@ public class ActionState {
         outputEvents.add(event);
     }
 
+    /** Gets the list of call results for fine-grained durable execution. */
+    public List<CallResult> getCallResults() {
+        return callResults;
+    }
+
+    /**
+     * Adds a call result for a completed 
durable_execute/durable_execute_async call.
+     *
+     * @param callResult the call result to add
+     */
+    public void addCallResult(CallResult callResult) {
+        callResults.add(callResult);
+    }
+
+    /**
+     * Gets the call result at the specified index.
+     *
+     * @param index the index of the call result
+     * @return the call result at the specified index, or null if index is out 
of bounds
+     */
+    public CallResult getCallResult(int index) {
+        if (index >= 0 && index < callResults.size()) {
+            return callResults.get(index);
+        }
+        return null;
+    }
+
+    /**
+     * Gets the number of call results.
+     *
+     * @return the number of call results
+     */
+    @JsonIgnore
+    public int getCallResultCount() {
+        return callResults.size();
+    }
+
+    /**
+     * Clears all call results. This should be called when the action 
completes to reduce storage
+     * overhead.
+     */
+    public void clearCallResults() {
+        callResults.clear();
+    }
+
+    /**
+     * Clears call results from the specified index onwards. This is used when 
a non-deterministic
+     * call order is detected during recovery.
+     *
+     * @param fromIndex the index from which to clear results (inclusive)
+     */
+    public void clearCallResultsFrom(int fromIndex) {
+        if (fromIndex >= 0 && fromIndex < callResults.size()) {
+            callResults.subList(fromIndex, callResults.size()).clear();
+        }
+    }
+
+    /** Returns whether the action has completed execution. */
+    public boolean isCompleted() {
+        return completed;
+    }
+
+    /**
+     * Marks the action as completed and clears call results. This should be 
called when the action
+     * finishes execution to indicate that recovery should skip the entire 
action.
+     */
+    public void markCompleted() {
+        this.completed = true;
+        this.callResults.clear();
+    }
+
     @Override
     public int hashCode() {
         int result = taskEvent != null ? taskEvent.hashCode() : 0;
@@ -102,12 +191,31 @@ public class ActionState {
                                 ? 0
                                 : shortTermMemoryUpdates.hashCode());
         result = 31 * result + (outputEvents.isEmpty() ? 0 : 
outputEvents.hashCode());
+        result = 31 * result + (callResults.isEmpty() ? 0 : 
callResults.hashCode());
+        result = 31 * result + (completed ? 1 : 0);
         return result;
     }
 
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        ActionState that = (ActionState) o;
+        return completed == that.completed
+                && java.util.Objects.equals(taskEvent, that.taskEvent)
+                && java.util.Objects.equals(sensoryMemoryUpdates, 
that.sensoryMemoryUpdates)
+                && java.util.Objects.equals(shortTermMemoryUpdates, 
that.shortTermMemoryUpdates)
+                && java.util.Objects.equals(outputEvents, that.outputEvents)
+                && java.util.Objects.equals(callResults, that.callResults);
+    }
+
     @Override
     public String toString() {
-        return "TaskActionState{"
+        return "ActionState{"
                 + "taskEvent="
                 + taskEvent
                 + ", sensoryMemoryUpdates="
@@ -116,6 +224,10 @@ public class ActionState {
                 + shortTermMemoryUpdates
                 + ", outputEvents="
                 + outputEvents
+                + ", callResults="
+                + callResults
+                + ", completed="
+                + completed
                 + '}';
     }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallResult.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallResult.java
new file mode 100644
index 00000000..cb9c5338
--- /dev/null
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallResult.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.actionstate;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+
+import java.util.Arrays;
+import java.util.Objects;
+
+/**
+ * Represents a result of a function call execution for fine-grained durable 
execution.
+ *
+ * <p>This class stores the execution result of a single {@code 
durable_execute} or {@code
+ * durable_execute_async} call, enabling recovery without re-execution when 
the same call is
+ * encountered during job recovery.
+ *
+ * <p>During recovery, the success or failure of the original call is 
determined by checking whether
+ * {@code exceptionPayload} is null.
+ */
+public class CallResult {
+
+    /** Function identifier: module+qualname for Python, or method signature 
for Java. */
+    private final String functionId;
+
+    /** Stable digest of the serialized arguments for validation during 
recovery. */
+    private final String argsDigest;
+
+    /** Serialized return value of the function call (null if the call threw 
an exception). */
+    private final byte[] resultPayload;
+
+    /** Serialized exception info if the call failed (null if the call 
succeeded). */
+    private final byte[] exceptionPayload;
+
+    /** Default constructor for deserialization. */
+    public CallResult() {
+        this.functionId = null;
+        this.argsDigest = null;
+        this.resultPayload = null;
+        this.exceptionPayload = null;
+    }
+
+    /**
+     * Constructs a CallResult for a successful function call.
+     *
+     * @param functionId the function identifier
+     * @param argsDigest the digest of serialized arguments
+     * @param resultPayload the serialized return value
+     */
+    public CallResult(String functionId, String argsDigest, byte[] 
resultPayload) {
+        this.functionId = functionId;
+        this.argsDigest = argsDigest;
+        this.resultPayload = resultPayload;
+        this.exceptionPayload = null;
+    }
+
+    /**
+     * Constructs a CallResult with explicit result and exception payloads.
+     *
+     * @param functionId the function identifier
+     * @param argsDigest the digest of serialized arguments
+     * @param resultPayload the serialized return value (null if exception 
occurred)
+     * @param exceptionPayload the serialized exception (null if call 
succeeded)
+     */
+    public CallResult(
+            String functionId, String argsDigest, byte[] resultPayload, byte[] 
exceptionPayload) {
+        this.functionId = functionId;
+        this.argsDigest = argsDigest;
+        this.resultPayload = resultPayload;
+        this.exceptionPayload = exceptionPayload;
+    }
+
+    /**
+     * Creates a CallResult for a failed function call.
+     *
+     * @param functionId the function identifier
+     * @param argsDigest the digest of serialized arguments
+     * @param exceptionPayload the serialized exception
+     * @return a new CallResult representing a failed call
+     */
+    public static CallResult ofException(
+            String functionId, String argsDigest, byte[] exceptionPayload) {
+        return new CallResult(functionId, argsDigest, null, exceptionPayload);
+    }
+
+    public String getFunctionId() {
+        return functionId;
+    }
+
+    public String getArgsDigest() {
+        return argsDigest;
+    }
+
+    public byte[] getResultPayload() {
+        return resultPayload;
+    }
+
+    public byte[] getExceptionPayload() {
+        return exceptionPayload;
+    }
+
+    /**
+     * Checks if this call result represents a successful execution.
+     *
+     * @return true if the call succeeded (no exception), false otherwise
+     */
+    @JsonIgnore
+    public boolean isSuccess() {
+        return exceptionPayload == null;
+    }
+
+    /**
+     * Validates if this CallResult matches the given function identifier and 
arguments digest.
+     *
+     * @param functionId the function identifier to match
+     * @param argsDigest the arguments digest to match
+     * @return true if both functionId and argsDigest match, false otherwise
+     */
+    public boolean matches(String functionId, String argsDigest) {
+        return Objects.equals(this.functionId, functionId)
+                && Objects.equals(this.argsDigest, argsDigest);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        CallResult that = (CallResult) o;
+        return Objects.equals(functionId, that.functionId)
+                && Objects.equals(argsDigest, that.argsDigest)
+                && Arrays.equals(resultPayload, that.resultPayload)
+                && Arrays.equals(exceptionPayload, that.exceptionPayload);
+    }
+
+    @Override
+    public int hashCode() {
+        int result = Objects.hash(functionId, argsDigest);
+        result = 31 * result + Arrays.hashCode(resultPayload);
+        result = 31 * result + Arrays.hashCode(exceptionPayload);
+        return result;
+    }
+
+    @Override
+    public String toString() {
+        return "CallResult{"
+                + "functionId='"
+                + functionId
+                + '\''
+                + ", argsDigest='"
+                + argsDigest
+                + '\''
+                + ", resultPayload="
+                + (resultPayload != null ? resultPayload.length + " bytes" : 
"null")
+                + ", exceptionPayload="
+                + (exceptionPayload != null ? exceptionPayload.length + " 
bytes" : "null")
+                + '}';
+    }
+}
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
index eac53d2e..74181d0f 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
@@ -23,7 +23,9 @@ import org.apache.flink.agents.api.OutputEvent;
 import org.apache.flink.agents.api.context.MemoryUpdate;
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 import static org.junit.jupiter.api.Assertions.*;
@@ -92,7 +94,7 @@ public class ActionStateSerdeTest {
     @Test
     public void testActionStateWithNullTaskEvent() throws Exception {
         // Create ActionState with null taskEvent
-        ActionState originalState = new ActionState();
+        ActionState originalState = new ActionState(null, null, null, null, 
null, false);
         MemoryUpdate memoryUpdate = new MemoryUpdate("test.path", "test 
value");
         originalState.addShortTermMemoryUpdate(memoryUpdate);
         originalState.addSensoryMemoryUpdate(memoryUpdate);
@@ -138,4 +140,117 @@ public class ActionStateSerdeTest {
         assertEquals("value", deserializedComplexAttr.get("nested"));
         assertEquals(42, deserializedComplexAttr.get("number"));
     }
+
+    @Test
+    public void testActionStateWithCallResults() throws Exception {
+        // Create ActionState with call results
+        InputEvent inputEvent = new InputEvent("test input");
+        ActionState originalState = new ActionState(inputEvent);
+
+        // Add call results
+        CallResult result1 = new CallResult("module.func1", "digest1", 
"result1".getBytes());
+        CallResult result2 =
+                CallResult.ofException("module.func2", "digest2", 
"exception".getBytes());
+        originalState.addCallResult(result1);
+        originalState.addCallResult(result2);
+
+        // Test serialization/deserialization
+        ActionStateKafkaSeder seder = new ActionStateKafkaSeder();
+
+        byte[] serialized = seder.serialize("test-topic", originalState);
+        ActionState deserializedState = seder.deserialize("test-topic", 
serialized);
+
+        // Verify call results
+        assertEquals(2, deserializedState.getCallResultCount());
+
+        CallResult deserializedResult1 = deserializedState.getCallResult(0);
+        assertEquals("module.func1", deserializedResult1.getFunctionId());
+        assertEquals("digest1", deserializedResult1.getArgsDigest());
+        assertArrayEquals("result1".getBytes(), 
deserializedResult1.getResultPayload());
+        assertNull(deserializedResult1.getExceptionPayload());
+        assertTrue(deserializedResult1.isSuccess());
+
+        CallResult deserializedResult2 = deserializedState.getCallResult(1);
+        assertEquals("module.func2", deserializedResult2.getFunctionId());
+        assertEquals("digest2", deserializedResult2.getArgsDigest());
+        assertNull(deserializedResult2.getResultPayload());
+        assertArrayEquals("exception".getBytes(), 
deserializedResult2.getExceptionPayload());
+        assertFalse(deserializedResult2.isSuccess());
+    }
+
+    @Test
+    public void testActionStateWithCompletedFlag() throws Exception {
+        // Create completed ActionState
+        InputEvent inputEvent = new InputEvent("test input");
+        List<MemoryUpdate> sensoryUpdates = new ArrayList<>();
+        sensoryUpdates.add(new MemoryUpdate("sm.path", "value"));
+        List<MemoryUpdate> shortTermUpdates = new ArrayList<>();
+        shortTermUpdates.add(new MemoryUpdate("stm.path", "value"));
+        List<Event> outputEvents = new ArrayList<>();
+        outputEvents.add(new OutputEvent("output"));
+
+        // Create with completed = true and empty callResults (simulating 
markCompleted)
+        ActionState originalState =
+                new ActionState(
+                        inputEvent, sensoryUpdates, shortTermUpdates, 
outputEvents, null, true);
+
+        // Test serialization/deserialization
+        ActionStateKafkaSeder seder = new ActionStateKafkaSeder();
+
+        byte[] serialized = seder.serialize("test-topic", originalState);
+        ActionState deserializedState = seder.deserialize("test-topic", 
serialized);
+
+        // Verify completed flag
+        assertTrue(deserializedState.isCompleted());
+        assertEquals(0, deserializedState.getCallResultCount());
+
+        // Verify other fields preserved
+        assertEquals(1, deserializedState.getSensoryMemoryUpdates().size());
+        assertEquals(1, deserializedState.getShortTermMemoryUpdates().size());
+        assertEquals(1, deserializedState.getOutputEvents().size());
+    }
+
+    @Test
+    public void testActionStateInProgressWithCallResults() throws Exception {
+        // Create in-progress ActionState with call results (simulating 
partial execution)
+        InputEvent inputEvent = new InputEvent("test input");
+        List<CallResult> callResults = new ArrayList<>();
+        callResults.add(new CallResult("func1", "hash1", 
"result1".getBytes()));
+        callResults.add(new CallResult("func2", "hash2", 
"result2".getBytes()));
+
+        ActionState originalState =
+                new ActionState(inputEvent, null, null, null, callResults, 
false);
+
+        // Test serialization/deserialization
+        ActionStateKafkaSeder seder = new ActionStateKafkaSeder();
+
+        byte[] serialized = seder.serialize("test-topic", originalState);
+        ActionState deserializedState = seder.deserialize("test-topic", 
serialized);
+
+        // Verify state
+        assertFalse(deserializedState.isCompleted());
+        assertEquals(2, deserializedState.getCallResultCount());
+        assertTrue(deserializedState.getCallResult(0).matches("func1", 
"hash1"));
+        assertTrue(deserializedState.getCallResult(1).matches("func2", 
"hash2"));
+    }
+
+    @Test
+    public void testCallResultWithNullPayloads() throws Exception {
+        // Test CallResult with null payloads
+        InputEvent inputEvent = new InputEvent("test");
+        ActionState originalState = new ActionState(inputEvent);
+        originalState.addCallResult(new CallResult("func", "digest", null, 
null));
+
+        ActionStateKafkaSeder seder = new ActionStateKafkaSeder();
+
+        byte[] serialized = seder.serialize("test-topic", originalState);
+        ActionState deserializedState = seder.deserialize("test-topic", 
serialized);
+
+        assertEquals(1, deserializedState.getCallResultCount());
+        CallResult result = deserializedState.getCallResult(0);
+        assertEquals("func", result.getFunctionId());
+        assertEquals("digest", result.getArgsDigest());
+        assertNull(result.getResultPayload());
+        assertNull(result.getExceptionPayload());
+    }
 }
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateTest.java
new file mode 100644
index 00000000..aa00d119
--- /dev/null
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateTest.java
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.actionstate;
+
+import org.apache.flink.agents.api.InputEvent;
+import org.apache.flink.agents.api.OutputEvent;
+import org.apache.flink.agents.api.context.MemoryUpdate;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+/** Unit tests for {@link ActionState} with focus on fine-grained durable 
execution fields. */
+public class ActionStateTest {
+
+    @Test
+    public void testConstructorWithEvent() {
+        InputEvent event = new InputEvent("test");
+        ActionState state = new ActionState(event);
+
+        assertEquals(event, state.getTaskEvent());
+        assertTrue(state.getSensoryMemoryUpdates().isEmpty());
+        assertTrue(state.getShortTermMemoryUpdates().isEmpty());
+        assertTrue(state.getOutputEvents().isEmpty());
+        assertTrue(state.getCallResults().isEmpty());
+        assertFalse(state.isCompleted());
+    }
+
+    @Test
+    public void testFullConstructorWithCallResults() {
+        InputEvent taskEvent = new InputEvent("test");
+        List<MemoryUpdate> sensoryUpdates = new ArrayList<>();
+        sensoryUpdates.add(new MemoryUpdate("sm.path", "value"));
+        List<MemoryUpdate> shortTermUpdates = new ArrayList<>();
+        shortTermUpdates.add(new MemoryUpdate("stm.path", "value"));
+        List<org.apache.flink.agents.api.Event> outputEvents = new 
ArrayList<>();
+        outputEvents.add(new OutputEvent("output"));
+        List<CallResult> callResults = new ArrayList<>();
+        callResults.add(new CallResult("func1", "digest1", 
"result1".getBytes()));
+        callResults.add(new CallResult("func2", "digest2", 
"result2".getBytes()));
+        boolean completed = true;
+
+        ActionState state =
+                new ActionState(
+                        taskEvent,
+                        sensoryUpdates,
+                        shortTermUpdates,
+                        outputEvents,
+                        callResults,
+                        completed);
+
+        assertEquals(taskEvent, state.getTaskEvent());
+        assertEquals(1, state.getSensoryMemoryUpdates().size());
+        assertEquals(1, state.getShortTermMemoryUpdates().size());
+        assertEquals(1, state.getOutputEvents().size());
+        assertEquals(2, state.getCallResults().size());
+        assertTrue(state.isCompleted());
+    }
+
+    @Test
+    public void testAddCallResult() {
+        ActionState state = new ActionState(new InputEvent("test"));
+
+        CallResult result1 = new CallResult("func1", "digest1", 
"result1".getBytes());
+        CallResult result2 = new CallResult("func2", "digest2", 
"result2".getBytes());
+
+        state.addCallResult(result1);
+        assertEquals(1, state.getCallResultCount());
+        assertEquals(result1, state.getCallResult(0));
+
+        state.addCallResult(result2);
+        assertEquals(2, state.getCallResultCount());
+        assertEquals(result2, state.getCallResult(1));
+    }
+
+    @Test
+    public void testGetCallResultOutOfBounds() {
+        ActionState state = new ActionState(new InputEvent("test"));
+
+        assertNull(state.getCallResult(-1));
+        assertNull(state.getCallResult(0));
+        assertNull(state.getCallResult(100));
+
+        state.addCallResult(new CallResult("func", "digest", 
"result".getBytes()));
+        assertNull(state.getCallResult(1));
+        assertNotNull(state.getCallResult(0));
+    }
+
+    @Test
+    public void testClearCallResults() {
+        ActionState state = new ActionState(new InputEvent("test"));
+        state.addCallResult(new CallResult("func1", "digest1", 
"result1".getBytes()));
+        state.addCallResult(new CallResult("func2", "digest2", 
"result2".getBytes()));
+        assertEquals(2, state.getCallResultCount());
+
+        state.clearCallResults();
+        assertEquals(0, state.getCallResultCount());
+        assertTrue(state.getCallResults().isEmpty());
+    }
+
+    @Test
+    public void testClearCallResultsFrom() {
+        ActionState state = new ActionState(new InputEvent("test"));
+        state.addCallResult(new CallResult("func0", "digest0", 
"result0".getBytes()));
+        state.addCallResult(new CallResult("func1", "digest1", 
"result1".getBytes()));
+        state.addCallResult(new CallResult("func2", "digest2", 
"result2".getBytes()));
+        state.addCallResult(new CallResult("func3", "digest3", 
"result3".getBytes()));
+        assertEquals(4, state.getCallResultCount());
+
+        // Clear from index 2 onwards (keep func0, func1)
+        state.clearCallResultsFrom(2);
+
+        assertEquals(2, state.getCallResultCount());
+        assertEquals("func0", state.getCallResult(0).getFunctionId());
+        assertEquals("func1", state.getCallResult(1).getFunctionId());
+    }
+
+    @Test
+    public void testClearCallResultsFromInvalidIndex() {
+        ActionState state = new ActionState(new InputEvent("test"));
+        state.addCallResult(new CallResult("func", "digest", 
"result".getBytes()));
+
+        // Negative index - should do nothing
+        state.clearCallResultsFrom(-1);
+        assertEquals(1, state.getCallResultCount());
+
+        // Out of bounds index - should do nothing
+        state.clearCallResultsFrom(10);
+        assertEquals(1, state.getCallResultCount());
+    }
+
+    @Test
+    public void testClearCallResultsFromZero() {
+        ActionState state = new ActionState(new InputEvent("test"));
+        state.addCallResult(new CallResult("func1", "digest1", 
"result1".getBytes()));
+        state.addCallResult(new CallResult("func2", "digest2", 
"result2".getBytes()));
+
+        // Clear from index 0 - should clear all
+        state.clearCallResultsFrom(0);
+        assertEquals(0, state.getCallResultCount());
+    }
+
+    @Test
+    public void testMarkCompleted() {
+        ActionState state = new ActionState(new InputEvent("test"));
+        state.addCallResult(new CallResult("func1", "digest1", 
"result1".getBytes()));
+        state.addCallResult(new CallResult("func2", "digest2", 
"result2".getBytes()));
+
+        assertFalse(state.isCompleted());
+        assertEquals(2, state.getCallResultCount());
+
+        state.markCompleted();
+
+        assertTrue(state.isCompleted());
+        assertEquals(0, state.getCallResultCount());
+    }
+
+    @Test
+    public void testEqualsWithCallResultsAndCompleted() {
+        InputEvent event = new InputEvent("test");
+        List<CallResult> callResults1 = new ArrayList<>();
+        callResults1.add(new CallResult("func", "digest", 
"result".getBytes()));
+
+        List<CallResult> callResults2 = new ArrayList<>();
+        callResults2.add(new CallResult("func", "digest", 
"result".getBytes()));
+
+        ActionState state1 = new ActionState(event, null, null, null, 
callResults1, true);
+        ActionState state2 = new ActionState(event, null, null, null, 
callResults2, true);
+        ActionState state3 = new ActionState(event, null, null, null, 
callResults1, false);
+
+        assertEquals(state1, state2);
+        assertNotEquals(state1, state3); // Different completed flag
+    }
+
+    @Test
+    public void testHashCodeWithCallResultsAndCompleted() {
+        InputEvent event = new InputEvent("test");
+        List<CallResult> callResults = new ArrayList<>();
+        callResults.add(new CallResult("func", "digest", "result".getBytes()));
+
+        ActionState state1 = new ActionState(event, null, null, null, 
callResults, true);
+        ActionState state2 =
+                new ActionState(event, null, null, null, new 
ArrayList<>(callResults), true);
+
+        assertEquals(state1.hashCode(), state2.hashCode());
+    }
+
+    @Test
+    public void testToStringIncludesNewFields() {
+        ActionState state = new ActionState(new InputEvent("test"));
+        state.addCallResult(new CallResult("func", "digest", 
"result".getBytes()));
+        state.markCompleted();
+
+        String str = state.toString();
+
+        assertTrue(str.contains("callResults"));
+        assertTrue(str.contains("completed=true"));
+    }
+
+    @Test
+    public void testNullListsInFullConstructor() {
+        ActionState state = new ActionState(null, null, null, null, null, 
false);
+
+        assertNull(state.getTaskEvent());
+        assertNotNull(state.getSensoryMemoryUpdates());
+        assertNotNull(state.getShortTermMemoryUpdates());
+        assertNotNull(state.getOutputEvents());
+        assertNotNull(state.getCallResults());
+        assertTrue(state.getSensoryMemoryUpdates().isEmpty());
+        assertTrue(state.getShortTermMemoryUpdates().isEmpty());
+        assertTrue(state.getOutputEvents().isEmpty());
+        assertTrue(state.getCallResults().isEmpty());
+    }
+
+    @Test
+    public void testIntegrationScenario() {
+        // Simulate a typical fine-grained durable execution flow
+
+        // 1. Create initial state
+        ActionState state = new ActionState(new InputEvent("test"));
+        assertFalse(state.isCompleted());
+        assertEquals(0, state.getCallResultCount());
+
+        // 2. First code block completes
+        CallResult result1 = new CallResult("llm.call", "hash1", 
"response1".getBytes());
+        state.addCallResult(result1);
+        assertEquals(1, state.getCallResultCount());
+        assertFalse(state.isCompleted());
+
+        // 3. Second code block completes
+        CallResult result2 = new CallResult("db.query", "hash2", 
"data".getBytes());
+        state.addCallResult(result2);
+        assertEquals(2, state.getCallResultCount());
+
+        // 4. Action completes - mark completed and clear results
+        state.addSensoryMemoryUpdate(new MemoryUpdate("sm.key", "value"));
+        state.addShortTermMemoryUpdate(new MemoryUpdate("stm.key", "value"));
+        state.addEvent(new OutputEvent("final_output"));
+        state.markCompleted();
+
+        assertTrue(state.isCompleted());
+        assertEquals(0, state.getCallResultCount()); // Results cleared
+        assertEquals(1, state.getSensoryMemoryUpdates().size()); // Memory 
preserved
+        assertEquals(1, state.getShortTermMemoryUpdates().size());
+        assertEquals(1, state.getOutputEvents().size()); // Events preserved
+    }
+
+    @Test
+    public void testRecoveryScenario() {
+        // Simulate recovery scenario where we need to check call results
+
+        // State from before failure (with 2 completed code blocks)
+        ActionState recoveredState = new ActionState(new InputEvent("test"));
+        recoveredState.addCallResult(new CallResult("func1", "digest1", 
"result1".getBytes()));
+        recoveredState.addCallResult(new CallResult("func2", "digest2", 
"result2".getBytes()));
+
+        // Check if action is completed
+        assertFalse(recoveredState.isCompleted());
+
+        // During re-execution, check if call result matches
+        CallResult result0 = recoveredState.getCallResult(0);
+        assertTrue(result0.matches("func1", "digest1"));
+        assertTrue(result0.isSuccess());
+
+        CallResult result1 = recoveredState.getCallResult(1);
+        assertTrue(result1.matches("func2", "digest2"));
+
+        // Third call is new (not in results)
+        assertNull(recoveredState.getCallResult(2));
+    }
+
+    @Test
+    public void testNonDeterministicRecovery() {
+        // Simulate detection of non-deterministic call order
+        ActionState state = new ActionState(new InputEvent("test"));
+        state.addCallResult(new CallResult("func1", "digest1", 
"result1".getBytes()));
+        state.addCallResult(new CallResult("func2", "digest2", 
"result2".getBytes()));
+        state.addCallResult(new CallResult("func3", "digest3", 
"result3".getBytes()));
+
+        // During recovery, call 1 matches
+        CallResult result0 = state.getCallResult(0);
+        assertTrue(result0.matches("func1", "digest1"));
+
+        // Call 2 doesn't match (different function called)
+        CallResult result1 = state.getCallResult(1);
+        assertFalse(result1.matches("different_func", "digest2"));
+
+        // Clear results from index 1 onwards
+        state.clearCallResultsFrom(1);
+        assertEquals(1, state.getCallResultCount());
+        assertEquals("func1", state.getCallResult(0).getFunctionId());
+    }
+}
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/CallResultTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/CallResultTest.java
new file mode 100644
index 00000000..11d8eb14
--- /dev/null
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/CallResultTest.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.actionstate;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+/** Unit tests for {@link CallResult}. */
+public class CallResultTest {
+
+    @Test
+    public void testSuccessfulCallResult() {
+        String functionId = "my_module.my_function";
+        String argsDigest = "abc123";
+        byte[] resultPayload = "result".getBytes();
+
+        CallResult result = new CallResult(functionId, argsDigest, 
resultPayload);
+
+        assertEquals(functionId, result.getFunctionId());
+        assertEquals(argsDigest, result.getArgsDigest());
+        assertArrayEquals(resultPayload, result.getResultPayload());
+        assertNull(result.getExceptionPayload());
+        assertTrue(result.isSuccess());
+    }
+
+    @Test
+    public void testFailedCallResult() {
+        String functionId = "my_module.my_function";
+        String argsDigest = "abc123";
+        byte[] exceptionPayload = "exception".getBytes();
+
+        CallResult result = CallResult.ofException(functionId, argsDigest, 
exceptionPayload);
+
+        assertEquals(functionId, result.getFunctionId());
+        assertEquals(argsDigest, result.getArgsDigest());
+        assertNull(result.getResultPayload());
+        assertArrayEquals(exceptionPayload, result.getExceptionPayload());
+        assertFalse(result.isSuccess());
+    }
+
+    @Test
+    public void testFullConstructor() {
+        String functionId = "my_module.my_function";
+        String argsDigest = "abc123";
+        byte[] resultPayload = "result".getBytes();
+        byte[] exceptionPayload = null;
+
+        CallResult result = new CallResult(functionId, argsDigest, 
resultPayload, exceptionPayload);
+
+        assertEquals(functionId, result.getFunctionId());
+        assertEquals(argsDigest, result.getArgsDigest());
+        assertArrayEquals(resultPayload, result.getResultPayload());
+        assertNull(result.getExceptionPayload());
+        assertTrue(result.isSuccess());
+    }
+
+    @Test
+    public void testMatches() {
+        String functionId = "my_module.my_function";
+        String argsDigest = "abc123";
+        byte[] resultPayload = "result".getBytes();
+
+        CallResult result = new CallResult(functionId, argsDigest, 
resultPayload);
+
+        assertTrue(result.matches(functionId, argsDigest));
+        assertFalse(result.matches("other_function", argsDigest));
+        assertFalse(result.matches(functionId, "other_digest"));
+        assertFalse(result.matches("other_function", "other_digest"));
+    }
+
+    @Test
+    public void testMatchesWithNullValues() {
+        CallResult result = new CallResult();
+
+        assertTrue(result.matches(null, null));
+        assertFalse(result.matches("function", null));
+        assertFalse(result.matches(null, "digest"));
+    }
+
+    @Test
+    public void testEquals() {
+        String functionId = "my_module.my_function";
+        String argsDigest = "abc123";
+        byte[] resultPayload = "result".getBytes();
+
+        CallResult result1 = new CallResult(functionId, argsDigest, 
resultPayload);
+        CallResult result2 = new CallResult(functionId, argsDigest, 
resultPayload);
+        CallResult result3 = new CallResult("other", argsDigest, 
resultPayload);
+
+        assertEquals(result1, result2);
+        assertNotEquals(result1, result3);
+        assertNotEquals(result1, null);
+        assertNotEquals(result1, "string");
+        assertEquals(result1, result1);
+    }
+
+    @Test
+    public void testHashCode() {
+        String functionId = "my_module.my_function";
+        String argsDigest = "abc123";
+        byte[] resultPayload = "result".getBytes();
+
+        CallResult result1 = new CallResult(functionId, argsDigest, 
resultPayload);
+        CallResult result2 = new CallResult(functionId, argsDigest, 
resultPayload);
+
+        assertEquals(result1.hashCode(), result2.hashCode());
+    }
+
+    @Test
+    public void testToString() {
+        String functionId = "my_module.my_function";
+        String argsDigest = "abc123";
+        byte[] resultPayload = "result".getBytes();
+
+        CallResult result = new CallResult(functionId, argsDigest, 
resultPayload);
+        String str = result.toString();
+
+        assertTrue(str.contains(functionId));
+        assertTrue(str.contains(argsDigest));
+        assertTrue(str.contains("bytes"));
+    }
+
+    @Test
+    public void testToStringWithNullPayloads() {
+        CallResult result = new CallResult("func", "digest", null, null);
+        String str = result.toString();
+
+        assertTrue(str.contains("null"));
+    }
+
+    @Test
+    public void testDefaultConstructor() {
+        CallResult result = new CallResult();
+
+        assertNull(result.getFunctionId());
+        assertNull(result.getArgsDigest());
+        assertNull(result.getResultPayload());
+        assertNull(result.getExceptionPayload());
+        assertTrue(result.isSuccess()); // exceptionPayload is null
+    }
+}
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/KafkaActionStateStoreTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/KafkaActionStateStoreTest.java
index c285adf3..cd32524b 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/KafkaActionStateStoreTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/KafkaActionStateStoreTest.java
@@ -159,7 +159,12 @@ public class KafkaActionStateStoreTest {
                         3L));
         for (int i = 0; i < 5; i++) {
             mockConsumer.addRecord(
-                    new ConsumerRecord<>(TEST_TOPIC, 0, i++, "key", new 
ActionState()));
+                    new ConsumerRecord<>(
+                            TEST_TOPIC,
+                            0,
+                            i++,
+                            "key",
+                            new ActionState(null, null, null, null, null, 
false)));
         }
         // Test getting recovery marker after putting state
         Object secondMarker = actionStateStore.getRecoveryMarker();


Reply via email to