This is an automated email from the ASF dual-hosted git repository.
xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new 824b093 Revert "Introduce sensory memory in python."
824b093 is described below
commit 824b09365f20a9b593e867b581456a08d5d4b335
Author: Xintong Song <[email protected]>
AuthorDate: Fri Nov 28 16:35:49 2025 +0800
Revert "Introduce sensory memory in python."
The PR was still in review and I accidentally merged it by mistake.
This reverts commits
901a96d97e9a8b8f0b74db6df781e300b9a5df68^..68e37e1a0f97389f10c0c7c1c835279deda48083.
---
.../flink/agents/api/context/MemoryObject.java | 4 -
.../apache/flink/agents/api/context/MemoryRef.java | 21 +----
.../flink/agents/api/context/RunnerContext.java | 11 ---
.../agents/integration/test/MemoryObjectAgent.java | 91 ++++++----------------
python/flink_agents/api/memory_object.py | 6 --
python/flink_agents/api/memory_reference.py | 17 +---
python/flink_agents/api/runner_context.py | 15 ----
python/flink_agents/runtime/flink_memory_object.py | 14 ++--
.../flink_agents/runtime/flink_runner_context.py | 17 ----
python/flink_agents/runtime/local_memory_object.py | 12 ++-
python/flink_agents/runtime/local_runner.py | 31 +-------
.../runtime/tests/test_local_memory_object.py | 3 +-
.../runtime/tests/test_memory_reference.py | 9 +--
.../agents/runtime/actionstate/ActionState.java | 53 ++++---------
.../agents/runtime/context/RunnerContextImpl.java | 48 +++---------
.../agents/runtime/memory/CachedMemoryStore.java | 5 --
.../agents/runtime/memory/MemoryObjectImpl.java | 15 ++--
.../flink/agents/runtime/memory/MemoryStore.java | 3 -
.../runtime/operator/ActionExecutionOperator.java | 31 +-------
.../python/context/PythonRunnerContextImpl.java | 10 +--
.../runtime/actionstate/ActionStateSerdeTest.java | 26 ++-----
.../agents/runtime/memory/MemoryObjectTest.java | 5 +-
.../flink/agents/runtime/memory/MemoryRefTest.java | 6 --
.../operator/ActionExecutionOperatorTest.java | 7 +-
24 files changed, 92 insertions(+), 368 deletions(-)
diff --git
a/api/src/main/java/org/apache/flink/agents/api/context/MemoryObject.java
b/api/src/main/java/org/apache/flink/agents/api/context/MemoryObject.java
index 78c325e..6d6af4b 100644
--- a/api/src/main/java/org/apache/flink/agents/api/context/MemoryObject.java
+++ b/api/src/main/java/org/apache/flink/agents/api/context/MemoryObject.java
@@ -27,10 +27,6 @@ import java.util.Map;
* nested object.Fields can be accessed using an absolute or relative path.
*/
public interface MemoryObject {
- enum MemoryType {
- SENSORY,
- SHORT_TERM
- }
/**
* Returns a MemoryObject that represents the given path.
*
diff --git
a/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java
b/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java
index f3ff87e..8f0a133 100644
--- a/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java
+++ b/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java
@@ -17,8 +17,6 @@
*/
package org.apache.flink.agents.api.context;
-import org.apache.flink.annotation.VisibleForTesting;
-
import java.io.Serializable;
import java.util.Objects;
@@ -30,15 +28,9 @@ import java.util.Objects;
public final class MemoryRef implements Serializable {
private static final long serialVersionUID = 1L;
- private final MemoryObject.MemoryType type;
private final String path;
private MemoryRef(String path) {
- this(MemoryObject.MemoryType.SHORT_TERM, path);
- }
-
- private MemoryRef(MemoryObject.MemoryType type, String path) {
- this.type = type;
this.path = path;
}
@@ -48,11 +40,6 @@ public final class MemoryRef implements Serializable {
* @param path The absolute path of the data in Short-Term Memory.
* @return A new MemoryRef instance.
*/
- public static MemoryRef create(MemoryObject.MemoryType type, String path) {
- return new MemoryRef(type, path);
- }
-
- @VisibleForTesting
public static MemoryRef create(String path) {
return new MemoryRef(path);
}
@@ -65,13 +52,7 @@ public final class MemoryRef implements Serializable {
* @throws Exception if the memory cannot be accessed or the data cannot
be resolved.
*/
public MemoryObject resolve(RunnerContext ctx) throws Exception {
- if (type.equals(MemoryObject.MemoryType.SENSORY)) {
- return ctx.getSensoryMemory().get(this);
- } else if (type.equals(MemoryObject.MemoryType.SHORT_TERM)) {
- return ctx.getShortTermMemory().get(this);
- } else {
- throw new RuntimeException(String.format("Unknown memory type %s",
type));
- }
+ return ctx.getShortTermMemory().get(this);
}
public String getPath() {
diff --git
a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
index 6c1bd02..5960124 100644
--- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
+++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
@@ -37,17 +37,6 @@ public interface RunnerContext {
*/
void sendEvent(Event event);
- /**
- * Gets the sensory memory.
- *
- * <p>Sensory memory is similar to short-term memory, but will be auto
cleared after agent run
- * finished. User could use it to store data that does not need to be
shared across agent runs.
- *
- * @return MemoryObject the root of the sensory memory
- * @throws Exception if the underlying state backend cannot be accessed
- */
- MemoryObject getSensoryMemory() throws Exception;
-
/**
* Gets the short-term memory.
*
diff --git
a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java
b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java
index da1a03a..a54316c 100644
---
a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java
+++
b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java
@@ -30,17 +30,6 @@ import java.util.*;
/** An example agent that tests usages of MemoryObject. */
public class MemoryObjectAgent extends Agent {
- public static class MyEvent extends Event {
- private final String value;
-
- public MyEvent(String value) {
- this.value = value;
- }
-
- public String getValue() {
- return value;
- }
- }
/** A custom POJO for testing serialization. */
public static class Person implements Serializable {
@@ -76,8 +65,6 @@ public class MemoryObjectAgent extends Agent {
@Action(listenEvents = {InputEvent.class})
public static void testMemoryObject(Event event, RunnerContext ctx) throws
Exception {
MemoryObject stm = ctx.getShortTermMemory();
- MemoryObject sm = ctx.getSensoryMemory();
-
Integer key = (Integer) ((InputEvent) event).getInput();
int visitCount = 1;
@@ -86,70 +73,42 @@ public class MemoryObjectAgent extends Agent {
}
stm.set("visit_count", visitCount);
+ // isExist
+ stm.set("existing.path", true);
+ assertEquals(stm.isExist("existing.path"), true);
+ assertEquals(stm.isExist("non.existing.path"), false);
+
+ // getFieldNames and getFields
+ MemoryObject fieldsTestObj = stm.newObject("fieldsTest", true);
+ fieldsTestObj.set("x", 1);
+ fieldsTestObj.set("y", 2);
+ fieldsTestObj.newObject("obj", false);
+ List<String> names = fieldsTestObj.getFieldNames();
+ assertEquals(new HashSet<>(names).containsAll(Arrays.asList("x", "y",
"obj")), true);
+ Map<String, Object> fields = fieldsTestObj.getFields();
+ assertEquals(1, ((Number) fields.get("x")).intValue());
+ assertEquals("NestedObject", fields.get("obj"));
+
+ // List
List<String> tags = Arrays.asList("gamer", "developer", "flink-user");
+ stm.set("list", tags);
+ assertEquals(tags, stm.get("list").getValue());
+ // Map
Map<String, Integer> inventory = new HashMap<>();
inventory.put("potion", 10);
inventory.put("gold", 500);
+ stm.set("map", inventory);
+ assertEquals(inventory, stm.get("map").getValue());
+ // Custom POJO
Person person = new Person("Bob", 22);
-
- if (visitCount == 1) {
- // Test sensory memory
- sm.set("existing.path", true);
- assertEquals(sm.isExist("existing"), true);
- assertEquals(sm.isExist("existing.path"), true);
-
- // Test short-term memory
- // exist
- stm.set("existing.path", true);
-
- // getFieldNames and getFields
- MemoryObject fieldsTestObj = stm.newObject("fieldsTest", true);
- fieldsTestObj.set("x", 1);
- fieldsTestObj.set("y", 2);
- fieldsTestObj.newObject("obj", false);
-
- // List
- stm.set("list", tags);
-
- // Map
- stm.set("map", inventory);
-
- // Custom POJO
- stm.set("person", person);
- } else {
- // Test sensory memory
- assertEquals(sm.isExist("existing"), false);
- assertEquals(sm.isExist("existing.path"), false);
-
- // Test short-term memory
- // exist
- assertEquals(stm.isExist("existing.path"), true);
- assertEquals(stm.isExist("non.existing.path"), false);
-
- // getFieldNames and getFields
- MemoryObject fieldsTestObj = stm.get("fieldsTest");
- List<String> names = fieldsTestObj.getFieldNames();
- assertEquals(new HashSet<>(names).containsAll(Arrays.asList("x",
"y", "obj")), true);
- Map<String, Object> fields = fieldsTestObj.getFields();
- assertEquals(1, ((Number) fields.get("x")).intValue());
- assertEquals("NestedObject", fields.get("obj"));
-
- // List
- assertEquals(tags, stm.get("list").getValue());
-
- // Map
- assertEquals(inventory, stm.get("map").getValue());
-
- // Custom POJO
- assertEquals(person, stm.get("person").getValue());
- }
+ stm.set("person", person);
+ assertEquals(person, stm.get("person").getValue());
String result =
String.format("All assertions passed for key: %d (visit #%d)",
key, visitCount);
String output = result + " [Agent Complete]";
-
ctx.sendEvent(new OutputEvent(output));
}
diff --git a/python/flink_agents/api/memory_object.py
b/python/flink_agents/api/memory_object.py
index 5e700b0..a250199 100644
--- a/python/flink_agents/api/memory_object.py
+++ b/python/flink_agents/api/memory_object.py
@@ -16,7 +16,6 @@
# limitations under the License.
#################################################################################
from abc import ABC, abstractmethod
-from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Union
from pydantic import BaseModel
@@ -24,11 +23,6 @@ from pydantic import BaseModel
if TYPE_CHECKING:
from flink_agents.api.memory_reference import MemoryRef
-class MemoryType(Enum):
- """Memory types based on MemoryObject."""
- SENSORY = "sensory",
- SHORT_TERM = "short_term"
-
class MemoryObject(BaseModel, ABC):
"""Representation of an object in the short-term memory.
diff --git a/python/flink_agents/api/memory_reference.py
b/python/flink_agents/api/memory_reference.py
index 5c37793..952f7b5 100644
--- a/python/flink_agents/api/memory_reference.py
+++ b/python/flink_agents/api/memory_reference.py
@@ -21,8 +21,6 @@ from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict
-from flink_agents.api.memory_object import MemoryType
-
if TYPE_CHECKING:
from flink_agents.api.runner_context import RunnerContext
@@ -30,28 +28,25 @@ if TYPE_CHECKING:
class MemoryRef(BaseModel):
"""Reference to a specific data item in the Short-Term Memory."""
- memory_type: MemoryType = MemoryType.SHORT_TERM
path: str
model_config = ConfigDict(frozen=True)
@staticmethod
- def create(memory_type: MemoryType, path: str) -> MemoryRef:
+ def create(path: str) -> MemoryRef:
"""Create a new MemoryRef instance based on the given path.
Parameters
----------
path: str
The absolute path of the data in the Short-Term Memory.
- memory_type:
- The type of the memory object this reference points to.
Returns:
-------
MemoryRef
A new MemoryRef instance.
"""
- return MemoryRef(memory_type=memory_type, path=path)
+ return MemoryRef(path=path)
def resolve(self, ctx: RunnerContext) -> Any:
"""Resolve the reference to get the actual data.
@@ -66,10 +61,4 @@ class MemoryRef(BaseModel):
Any
The deserialized, original data object.
"""
- if self.memory_type == MemoryType.SENSORY:
- return ctx.sensory_memory.get(self)
- elif self.memory_type == MemoryType.SHORT_TERM:
- return ctx.short_term_memory.get(self)
- else:
- msg = f"Unknown memory type: {self.memory_type}"
- raise RuntimeError(msg)
+ return ctx.short_term_memory.get(self)
diff --git a/python/flink_agents/api/runner_context.py
b/python/flink_agents/api/runner_context.py
index 8f3f7a7..0a3eb01 100644
--- a/python/flink_agents/api/runner_context.py
+++ b/python/flink_agents/api/runner_context.py
@@ -81,21 +81,6 @@ class RunnerContext(ABC):
The config option value.
"""
- @property
- @abstractmethod
- def sensory_memory(self) -> "MemoryObject":
- """Get the sensory memory.
-
- Sensory memory is similar to short-term memory, but will be auto
cleared
- after agent run finished. User could use it to store data that does
not need
- to be shared across agent runs.
-
- Returns:
- -------
- MemoryObject
- The root object of the sensory memory.
- """
-
@property
@abstractmethod
def short_term_memory(self) -> "MemoryObject":
diff --git a/python/flink_agents/runtime/flink_memory_object.py
b/python/flink_agents/runtime/flink_memory_object.py
index f0a8e9d..0b166a1 100644
--- a/python/flink_agents/runtime/flink_memory_object.py
+++ b/python/flink_agents/runtime/flink_memory_object.py
@@ -17,7 +17,7 @@
#################################################################################
from typing import Any, Dict, List
-from flink_agents.api.memory_object import MemoryObject, MemoryType
+from flink_agents.api.memory_object import MemoryObject
from flink_agents.api.memory_reference import MemoryRef
@@ -29,13 +29,9 @@ class FlinkMemoryObject(MemoryObject):
memory implemented in Java.
"""
- __type: MemoryType
-
- def __init__(self, type: MemoryType, j_memory_object: Any, /, **data: Any)
-> None:
+ def __init__(self, j_memory_object: Any) -> None:
"""Initialize with a Java MemoryObject instance."""
- super().__init__(**data)
self._j_memory_object = j_memory_object
- self.__type = type
def get(self, path_or_ref: str | MemoryRef) -> Any:
"""Get a nested object or value by path or MemoryRef.
@@ -55,7 +51,7 @@ class FlinkMemoryObject(MemoryObject):
if j_result is None:
return None
if j_result.isNestedObject():
- return FlinkMemoryObject(self.__type, j_result)
+ return FlinkMemoryObject(j_result)
else:
return j_result.getValue()
except Exception as e:
@@ -66,7 +62,7 @@ class FlinkMemoryObject(MemoryObject):
"""Set a value at the given path. Creates intermediate objects if
needed."""
try:
j_ref = self._j_memory_object.set(path, value)
- return MemoryRef.create(memory_type=self.__type,
path=j_ref.getPath())
+ return MemoryRef(path=j_ref.getPath())
except Exception as e:
msg = f"Failed to set value at path '{path}'"
raise MemoryObjectError(msg) from e
@@ -74,7 +70,7 @@ class FlinkMemoryObject(MemoryObject):
def new_object(self, path: str, *, overwrite: bool = False) ->
"FlinkMemoryObject":
"""Create a new object at the given path."""
try:
- return FlinkMemoryObject(self.__type,
self._j_memory_object.newObject(path, overwrite))
+ return FlinkMemoryObject(self._j_memory_object.newObject(path,
overwrite))
except Exception as e:
msg = f"Failed to create new object at path '{path}'"
raise MemoryObjectError(msg) from e
diff --git a/python/flink_agents/runtime/flink_runner_context.py
b/python/flink_agents/runtime/flink_runner_context.py
index 5e8331b..8f3c323 100644
--- a/python/flink_agents/runtime/flink_runner_context.py
+++ b/python/flink_agents/runtime/flink_runner_context.py
@@ -88,23 +88,6 @@ class FlinkRunnerContext(RunnerContext):
action_name=self._j_runner_context.getActionName(), key=key
)
- @property
- @override
- def sensory_memory(self) -> FlinkMemoryObject:
- """Get the sensory memory object associated with this context.
-
- Returns:
- -------
- MemoryObject
- The sensory memory object that can be used to access and modify
- temporary state data.
- """
- try:
- return FlinkMemoryObject(self._j_runner_context.getSensoryMemory())
- except Exception as e:
- err_msg = "Failed to get sensory memory of runner context"
- raise RuntimeError(err_msg) from e
-
@property
@override
def short_term_memory(self) -> FlinkMemoryObject:
diff --git a/python/flink_agents/runtime/local_memory_object.py
b/python/flink_agents/runtime/local_memory_object.py
index 20e03f6..a1b4c1b 100644
--- a/python/flink_agents/runtime/local_memory_object.py
+++ b/python/flink_agents/runtime/local_memory_object.py
@@ -17,7 +17,7 @@
#################################################################################
from typing import Any, ClassVar, Dict, List
-from flink_agents.api.memory_object import MemoryObject, MemoryType
+from flink_agents.api.memory_object import MemoryObject
from flink_agents.api.memory_reference import MemoryRef
@@ -33,11 +33,10 @@ class LocalMemoryObject(MemoryObject):
__SEPARATOR: ClassVar[str] = "."
__NESTED_MARK: ClassVar[str] = "NestedObject"
- __type: MemoryType
__store: dict[str, Any]
__prefix: str
- def __init__(self, type: MemoryType, store: Dict[str, Any], prefix: str =
ROOT_KEY) -> None:
+ def __init__(self, store: Dict[str, Any], prefix: str = ROOT_KEY) -> None:
"""Initialize a LocalMemoryObject.
Parameters
@@ -49,7 +48,6 @@ class LocalMemoryObject(MemoryObject):
shared store.
"""
super().__init__()
- self.__type = type
self.__store = store if store is not None else {}
self.__prefix = prefix
@@ -81,7 +79,7 @@ class LocalMemoryObject(MemoryObject):
if abs_path in self.__store:
value = self.__store[abs_path]
if self._is_nested_object(value):
- return LocalMemoryObject(self.__type, self.__store, abs_path)
+ return LocalMemoryObject(self.__store, abs_path)
return value
return None
@@ -117,7 +115,7 @@ class LocalMemoryObject(MemoryObject):
self._add_subfield(parent_path, parts[-1])
self.__store[abs_path] = value
- return MemoryRef(memory_type=self.__type, path=abs_path)
+ return MemoryRef(path=abs_path)
def new_object(self, path: str, *, overwrite: bool = False) ->
"LocalMemoryObject":
"""Create a new object as the value of an indirect field in the object.
@@ -148,7 +146,7 @@ class LocalMemoryObject(MemoryObject):
raise ValueError(msg)
self.__store[abs_path] = _ObjMarker()
- return LocalMemoryObject(self.__type, self.__store, abs_path)
+ return LocalMemoryObject(self.__store, abs_path)
def is_exist(self, path: str) -> bool:
"""Check whether a (direct or indirect) field exist in the object.
diff --git a/python/flink_agents/runtime/local_runner.py
b/python/flink_agents/runtime/local_runner.py
index 6b5f50b..8f0441c 100644
--- a/python/flink_agents/runtime/local_runner.py
+++ b/python/flink_agents/runtime/local_runner.py
@@ -24,7 +24,7 @@ from typing_extensions import override
from flink_agents.api.agent import Agent
from flink_agents.api.events.event import Event, InputEvent, OutputEvent
-from flink_agents.api.memory_object import MemoryObject, MemoryType
+from flink_agents.api.memory_object import MemoryObject
from flink_agents.api.metric_group import MetricGroup
from flink_agents.api.resource import Resource, ResourceType
from flink_agents.api.runner_context import RunnerContext
@@ -58,9 +58,7 @@ class LocalRunnerContext(RunnerContext):
__key: Any
events: deque[Event]
action_name: str
- _sensory_mem_store: dict[str, Any]
- _short_term_mem_store: dict[str, Any]
- _sensory_memory: MemoryObject
+ _store: dict[str, Any]
_short_term_memory: MemoryObject
_config: AgentConfiguration
@@ -78,13 +76,9 @@ class LocalRunnerContext(RunnerContext):
self.__agent_plan = agent_plan
self.__key = key
self.events = deque()
- self._sensory_mem_store = {}
- self._short_term_mem_store = {}
- self._sensory_memory = LocalMemoryObject(
- MemoryType.SENSORY, self._sensory_mem_store,
LocalMemoryObject.ROOT_KEY
- )
+ self._store = {}
self._short_term_memory = LocalMemoryObject(
- MemoryType.SHORT_TERM, self._short_term_mem_store,
LocalMemoryObject.ROOT_KEY
+ self._store, LocalMemoryObject.ROOT_KEY
)
self._config = config
@@ -128,18 +122,6 @@ class LocalRunnerContext(RunnerContext):
action_name=self.action_name, key=key
)
- @property
- @override
- def sensory_memory(self) -> MemoryObject:
- """Get the sensory memory object associated with this context.
-
- Returns:
- -------
- MemoryObject
- The root object of the short-term memory.
- """
- return self._sensory_memory
-
@property
@override
def short_term_memory(self) -> MemoryObject:
@@ -187,10 +169,6 @@ class LocalRunnerContext(RunnerContext):
def config(self) -> AgentConfiguration:
return self._config
- def clear_sensory_memory(self) -> None:
- """Clean up sensory memory."""
- self._sensory_mem_store.clear()
-
class LocalRunner(AgentRunner):
"""Agent runner implementation for local execution, which is
@@ -250,7 +228,6 @@ class LocalRunner(AgentRunner):
if key not in self.__keyed_contexts:
self.__keyed_contexts[key] = LocalRunnerContext(self.__agent_plan,
key, self.__config)
context = self.__keyed_contexts[key]
- context.clear_sensory_memory()
if "value" in data:
input_event = InputEvent(input=data["value"])
diff --git a/python/flink_agents/runtime/tests/test_local_memory_object.py
b/python/flink_agents/runtime/tests/test_local_memory_object.py
index 305aba2..0ab6f75 100644
--- a/python/flink_agents/runtime/tests/test_local_memory_object.py
+++ b/python/flink_agents/runtime/tests/test_local_memory_object.py
@@ -17,13 +17,12 @@
#################################################################################
from typing import Dict, List, Set
-from flink_agents.api.memory_object import MemoryType
from flink_agents.runtime.local_memory_object import LocalMemoryObject
def create_memory() -> LocalMemoryObject:
"""Return a MemoryObject for every test case."""
- return LocalMemoryObject(MemoryType.SHORT_TERM, {})
+ return LocalMemoryObject({})
class User: # noqa: D101
diff --git a/python/flink_agents/runtime/tests/test_memory_reference.py
b/python/flink_agents/runtime/tests/test_memory_reference.py
index 497de1a..d6f157b 100644
--- a/python/flink_agents/runtime/tests/test_memory_reference.py
+++ b/python/flink_agents/runtime/tests/test_memory_reference.py
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
-from flink_agents.api.memory_object import MemoryType
from flink_agents.api.memory_reference import MemoryRef
from flink_agents.runtime.local_memory_object import LocalMemoryObject
@@ -32,7 +31,7 @@ class MockRunnerContext: # noqa D101
def create_memory() -> LocalMemoryObject:
"""Return a MemoryObject for every test case."""
- return LocalMemoryObject(MemoryType.SHORT_TERM, {})
+ return LocalMemoryObject({})
class User: # noqa: D101
@@ -74,7 +73,7 @@ def test_set_get_involved_ref() -> None: # noqa: D103
def test_memory_ref_create() -> None: # noqa: D103
path = "a.b.c"
- ref = MemoryRef.create(MemoryType.SHORT_TERM, path)
+ ref = MemoryRef.create(path)
assert isinstance(ref, MemoryRef)
assert ref.path == path
@@ -105,7 +104,7 @@ def test_get_with_ref_to_nested_object() -> None: # noqa:
D103
obj = mem.new_object("a.b")
obj.set("c", 10)
- ref = MemoryRef.create(MemoryType.SHORT_TERM, "a")
+ ref = MemoryRef.create("a")
resolved_obj = mem.get(ref)
assert isinstance(resolved_obj, LocalMemoryObject)
@@ -115,7 +114,7 @@ def test_get_with_ref_to_nested_object() -> None: # noqa:
D103
def test_get_with_non_existent_ref() -> None: # noqa: D103
mem = create_memory()
- non_existent_ref = MemoryRef.create(MemoryType.SHORT_TERM,
"this.path.does.not.exist")
+ non_existent_ref = MemoryRef.create("this.path.does.not.exist")
assert mem.get(non_existent_ref) is None
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
index 34eefb3..5a7b1ff 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
@@ -26,36 +26,27 @@ import java.util.List;
/** Class representing the state of an action after processing an event. */
public class ActionState {
private final Event taskEvent;
- private final List<MemoryUpdate> sensoryMemoryUpdates;
- private final List<MemoryUpdate> shortTermMemoryUpdates;
+ private final List<MemoryUpdate> memoryUpdates;
private final List<Event> outputEvents;
/** Constructs a new TaskActionState instance. */
public ActionState(final Event taskEvent) {
this.taskEvent = taskEvent;
- this.sensoryMemoryUpdates = new ArrayList<>();
- this.shortTermMemoryUpdates = new ArrayList<>();
+ this.memoryUpdates = new ArrayList<>();
this.outputEvents = new ArrayList<>();
}
public ActionState() {
this.taskEvent = null;
- this.sensoryMemoryUpdates = new ArrayList<>();
- this.shortTermMemoryUpdates = new ArrayList<>();
+ this.memoryUpdates = new ArrayList<>();
this.outputEvents = new ArrayList<>();
}
/** Constructor for deserialization purposes. */
public ActionState(
- Event taskEvent,
- List<MemoryUpdate> sensoryMemoryUpdates,
- List<MemoryUpdate> shortTermMemoryUpdates,
- List<Event> outputEvents) {
+ Event taskEvent, List<MemoryUpdate> memoryUpdates, List<Event>
outputEvents) {
this.taskEvent = taskEvent;
- this.sensoryMemoryUpdates =
- sensoryMemoryUpdates != null ? sensoryMemoryUpdates : new
ArrayList<>();
- this.shortTermMemoryUpdates =
- shortTermMemoryUpdates != null ? shortTermMemoryUpdates : new
ArrayList<>();
+ this.memoryUpdates = memoryUpdates != null ? memoryUpdates : new
ArrayList<>();
this.outputEvents = outputEvents != null ? outputEvents : new
ArrayList<>();
}
@@ -64,12 +55,8 @@ public class ActionState {
return taskEvent;
}
- public List<MemoryUpdate> getSensoryMemoryUpdates() {
- return sensoryMemoryUpdates;
- }
-
- public List<MemoryUpdate> getShortTermMemoryUpdates() {
- return shortTermMemoryUpdates;
+ public List<MemoryUpdate> getMemoryUpdates() {
+ return memoryUpdates;
}
public List<Event> getOutputEvents() {
@@ -77,13 +64,8 @@ public class ActionState {
}
/** Setters for the fields */
- public void addSensoryMemoryUpdate(MemoryUpdate memoryUpdate) {
- sensoryMemoryUpdates.add(memoryUpdate);
- }
-
- /** Setters for the fields */
- public void addShortTermMemoryUpdate(MemoryUpdate memoryUpdate) {
- shortTermMemoryUpdates.add(memoryUpdate);
+ public void addMemoryUpdate(MemoryUpdate memoryUpdate) {
+ memoryUpdates.add(memoryUpdate);
}
public void addEvent(Event event) {
@@ -93,15 +75,8 @@ public class ActionState {
@Override
public int hashCode() {
int result = taskEvent != null ? taskEvent.hashCode() : 0;
- result =
- 31 * result
- + (sensoryMemoryUpdates.isEmpty() ? 0 :
sensoryMemoryUpdates.hashCode());
- result =
- 31 * result
- + (shortTermMemoryUpdates.isEmpty()
- ? 0
- : shortTermMemoryUpdates.hashCode());
- result = 31 * result + (outputEvents.isEmpty() ? 0 :
outputEvents.hashCode());
+ result = 31 * result + (memoryUpdates != null ?
memoryUpdates.hashCode() : 0);
+ result = 31 * result + (outputEvents != null ? outputEvents.hashCode()
: 0);
return result;
}
@@ -110,10 +85,8 @@ public class ActionState {
return "TaskActionState{"
+ "taskEvent="
+ taskEvent
- + ", sensoryMemoryUpdates="
- + sensoryMemoryUpdates
- + ", shortTermMemoryUpdates="
- + shortTermMemoryUpdates
+ + ", memoryUpdates="
+ + memoryUpdates
+ ", outputEvents="
+ outputEvents
+ '}';
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
index 6321d98..d998af6 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
@@ -44,28 +44,23 @@ import java.util.Map;
public class RunnerContextImpl implements RunnerContext {
protected final List<Event> pendingEvents = new ArrayList<>();
- protected final CachedMemoryStore sensoryMemStore;
- protected final CachedMemoryStore shortTermMemStore;
+ protected final CachedMemoryStore store;
protected final FlinkAgentsMetricGroupImpl agentMetricGroup;
protected final Runnable mailboxThreadChecker;
protected final AgentPlan agentPlan;
- protected final List<MemoryUpdate> sensoryMemoryUpdates;
- protected final List<MemoryUpdate> shortTermMemoryUpdates;
+ protected final List<MemoryUpdate> memoryUpdates;
protected String actionName;
public RunnerContextImpl(
- CachedMemoryStore sensoryMemStore,
- CachedMemoryStore shortTermMemStore,
+ CachedMemoryStore store,
FlinkAgentsMetricGroupImpl agentMetricGroup,
Runnable mailboxThreadChecker,
AgentPlan agentPlan) {
- this.sensoryMemStore = sensoryMemStore;
- this.shortTermMemStore = shortTermMemStore;
+ this.store = store;
this.agentMetricGroup = agentMetricGroup;
this.mailboxThreadChecker = mailboxThreadChecker;
this.agentPlan = agentPlan;
- this.sensoryMemoryUpdates = new LinkedList<>();
- this.shortTermMemoryUpdates = new LinkedList<>();
+ this.memoryUpdates = new LinkedList<>();
}
public void setActionName(String actionName) {
@@ -110,11 +105,6 @@ public class RunnerContextImpl implements RunnerContext {
this.pendingEvents.isEmpty(), "There are pending events
remaining in the context.");
}
- public List<MemoryUpdate> getSensoryMemoryUpdates() {
- mailboxThreadChecker.run();
- return List.copyOf(sensoryMemoryUpdates);
- }
-
/**
* Gets all the updates made to this MemoryObject since it was created or
the last time this
* method was called. This method lives here because it is internally used
by the ActionTask to
@@ -122,31 +112,16 @@ public class RunnerContextImpl implements RunnerContext {
*
* @return list of memory updates
*/
- public List<MemoryUpdate> getShortTermMemoryUpdates() {
- mailboxThreadChecker.run();
- return List.copyOf(shortTermMemoryUpdates);
- }
-
- @Override
- public MemoryObject getSensoryMemory() throws Exception {
+ public List<MemoryUpdate> getAllMemoryUpdates() {
mailboxThreadChecker.run();
- return new MemoryObjectImpl(
- MemoryObject.MemoryType.SENSORY,
- sensoryMemStore,
- MemoryObjectImpl.ROOT_KEY,
- mailboxThreadChecker,
- sensoryMemoryUpdates);
+ return List.copyOf(memoryUpdates);
}
@Override
public MemoryObject getShortTermMemory() throws Exception {
mailboxThreadChecker.run();
return new MemoryObjectImpl(
- MemoryObject.MemoryType.SHORT_TERM,
- shortTermMemStore,
- MemoryObjectImpl.ROOT_KEY,
- mailboxThreadChecker,
- shortTermMemoryUpdates);
+ store, MemoryObjectImpl.ROOT_KEY, mailboxThreadChecker,
memoryUpdates);
}
@Override
@@ -177,11 +152,6 @@ public class RunnerContextImpl implements RunnerContext {
}
public void persistMemory() throws Exception {
- sensoryMemStore.persistCache();
- shortTermMemStore.persistCache();
- }
-
- public void clearSensoryMemory() throws Exception {
- sensoryMemStore.clear();
+ store.persistCache();
}
}
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java
index 36360cd..71eb8d2 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java
@@ -57,9 +57,4 @@ public class CachedMemoryStore implements MemoryStore {
}
cache.clear();
}
-
- public void clear() throws Exception {
- cache.clear();
- store.clear();
- }
}
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
index b9550a6..a8fa1cb 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
@@ -40,27 +40,22 @@ public class MemoryObjectImpl implements MemoryObject {
public static final String ROOT_KEY = "";
private static final String SEPARATOR = ".";
- private final MemoryType type;
-
private final MemoryStore store;
private final List<MemoryUpdate> memoryUpdates;
private final String prefix;
private final Runnable mailboxThreadChecker;
- public MemoryObjectImpl(
- MemoryType type, MemoryStore store, String prefix,
List<MemoryUpdate> memoryUpdates)
+ public MemoryObjectImpl(MemoryStore store, String prefix,
List<MemoryUpdate> memoryUpdates)
throws Exception {
- this(type, store, prefix, () -> {}, memoryUpdates);
+ this(store, prefix, () -> {}, memoryUpdates);
}
public MemoryObjectImpl(
- MemoryType type,
MemoryStore store,
String prefix,
Runnable mailboxThreadChecker,
List<MemoryUpdate> memoryUpdates)
throws Exception {
- this.type = type;
this.store = store;
this.prefix = prefix;
this.mailboxThreadChecker = mailboxThreadChecker;
@@ -75,7 +70,7 @@ public class MemoryObjectImpl implements MemoryObject {
mailboxThreadChecker.run();
String absPath = fullPath(path);
if (store.contains(absPath)) {
- return new MemoryObjectImpl(type, store, absPath, memoryUpdates);
+ return new MemoryObjectImpl(store, absPath, memoryUpdates);
}
return null;
}
@@ -109,7 +104,7 @@ public class MemoryObjectImpl implements MemoryObject {
store.put(absPath, val);
memoryUpdates.add(new MemoryUpdate(absPath, value));
- return MemoryRef.create(type, absPath);
+ return MemoryRef.create(absPath);
}
@Override
@@ -142,7 +137,7 @@ public class MemoryObjectImpl implements MemoryObject {
parentItem.getSubKeys().add(parts[parts.length - 1]);
store.put(parent, parentItem);
- return new MemoryObjectImpl(type, store, absPath, memoryUpdates);
+ return new MemoryObjectImpl(store, absPath, memoryUpdates);
}
@Override
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java
index 28793fb..f466750 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java
@@ -45,7 +45,4 @@ public interface MemoryStore {
* @return true if the MemoryItem exists, false otherwise
*/
boolean contains(String key) throws Exception;
-
- /** Remove all the MemoryItem. */
- void clear() throws Exception;
}
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
index 2a54563..95b991b 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
@@ -115,8 +115,6 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
private transient StreamRecord<OUT> reusedStreamRecord;
- private transient MapState<String, MemoryObjectImpl.MemoryItem>
sensoryMemState;
-
private transient MapState<String, MemoryObjectImpl.MemoryItem>
shortTermMemState;
// PythonActionExecutor for Python actions
@@ -184,13 +182,6 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
public void open() throws Exception {
super.open();
reusedStreamRecord = new StreamRecord<>(null);
- // init sensoryMemState
- MapStateDescriptor<String, MemoryObjectImpl.MemoryItem>
sensoryMemStateDescriptor =
- new MapStateDescriptor<>(
- "sensoryMemory",
- TypeInformation.of(String.class),
- TypeInformation.of(MemoryObjectImpl.MemoryItem.class));
- sensoryMemState =
getRuntimeContext().getMapState(sensoryMemStateDescriptor);
// init shortTermMemState
MapStateDescriptor<String, MemoryObjectImpl.MemoryItem>
shortTermMemStateDescriptor =
@@ -406,19 +397,12 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
if (actionState != null) {
isFinished = true;
outputEvents = actionState.getOutputEvents();
- for (MemoryUpdate memoryUpdate :
actionState.getShortTermMemoryUpdates()) {
+ for (MemoryUpdate memoryUpdate : actionState.getMemoryUpdates()) {
actionTask
.getRunnerContext()
.getShortTermMemory()
.set(memoryUpdate.getPath(), memoryUpdate.getValue());
}
-
- for (MemoryUpdate memoryUpdate :
actionState.getSensoryMemoryUpdates()) {
- actionTask
- .getRunnerContext()
- .getSensoryMemory()
- .set(memoryUpdate.getPath(), memoryUpdate.getValue());
- }
} else {
maybeInitActionState(key, sequenceNumber, actionTask.action,
actionTask.event);
ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke();
@@ -468,9 +452,6 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
// 3. Process the next InputEvent or next action task
if (currentInputEventFinished) {
- // Clean up sensory memory when a single run finished.
- actionTask.getRunnerContext().clearSensoryMemory();
-
// Once all sub-events and actions related to the current
InputEvent are completed,
// we can proceed to process the next InputEvent.
int removedCount =
removeFromListState(currentProcessingKeysOpState, key);
@@ -678,7 +659,6 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
} else if (actionTask.action.getExec() instanceof JavaFunction) {
runnerContext =
new RunnerContextImpl(
- new CachedMemoryStore(sensoryMemState),
new CachedMemoryStore(shortTermMemState),
metricGroup,
this::checkMailboxThread,
@@ -686,7 +666,6 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
} else if (actionTask.action.getExec() instanceof PythonFunction) {
runnerContext =
new PythonRunnerContextImpl(
- new CachedMemoryStore(sensoryMemState),
new CachedMemoryStore(shortTermMemState),
metricGroup,
this::checkMailboxThread,
@@ -776,12 +755,8 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
ActionState actionState = actionStateStore.get(key, sequenceNum,
action, event);
- for (MemoryUpdate memoryUpdate : context.getSensoryMemoryUpdates()) {
- actionState.addSensoryMemoryUpdate(memoryUpdate);
- }
-
- for (MemoryUpdate memoryUpdate : context.getShortTermMemoryUpdates()) {
- actionState.addShortTermMemoryUpdate(memoryUpdate);
+ for (MemoryUpdate memoryUpdate : context.getAllMemoryUpdates()) {
+ actionState.addMemoryUpdate(memoryUpdate);
}
for (Event outputEvent : actionTaskResult.getOutputEvents()) {
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
index 4bdb8d8..89f741d 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
@@ -36,18 +36,12 @@ public class PythonRunnerContextImpl extends
RunnerContextImpl {
private final PythonActionExecutor pythonActionExecutor;
public PythonRunnerContextImpl(
- CachedMemoryStore sensoryMemStore,
- CachedMemoryStore shortTermMemStore,
+ CachedMemoryStore store,
FlinkAgentsMetricGroupImpl agentMetricGroup,
Runnable mailboxThreadChecker,
AgentPlan agentPlan,
PythonActionExecutor pythonActionExecutor) {
- super(
- sensoryMemStore,
- shortTermMemStore,
- agentMetricGroup,
- mailboxThreadChecker,
- agentPlan);
+ super(store, agentMetricGroup, mailboxThreadChecker, agentPlan);
this.pythonActionExecutor = pythonActionExecutor;
}
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
index eac53d2..4f9ca04 100644
---
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
@@ -40,13 +40,11 @@ public class ActionStateSerdeTest {
OutputEvent outputEvent = new OutputEvent("test output");
outputEvent.setAttr("outputAttr", 123);
- MemoryUpdate sensoryMemoryUpdate = new MemoryUpdate("sm.test.path",
"sm test value");
- MemoryUpdate shortTermMemoryUpdate = new MemoryUpdate("stm.test.path",
"stm test value");
+ MemoryUpdate memoryUpdate = new MemoryUpdate("test.path", "test
value");
// Create ActionState
ActionState originalState = new ActionState(inputEvent);
- originalState.addSensoryMemoryUpdate(sensoryMemoryUpdate);
- originalState.addShortTermMemoryUpdate(shortTermMemoryUpdate);
+ originalState.addMemoryUpdate(memoryUpdate);
originalState.addEvent(outputEvent);
// Test Kafka seder/deserializer
@@ -69,16 +67,10 @@ public class ActionStateSerdeTest {
assertEquals("testValue", deserializedInputEvent.getAttr("testAttr"));
// Verify memoryUpdates
- assertEquals(1, deserializedState.getSensoryMemoryUpdates().size());
- MemoryUpdate deserializedSensoryMemoryUpdate =
- deserializedState.getSensoryMemoryUpdates().get(0);
- assertEquals("sm.test.path",
deserializedSensoryMemoryUpdate.getPath());
- assertEquals("sm test value",
deserializedSensoryMemoryUpdate.getValue());
- assertEquals(1, deserializedState.getShortTermMemoryUpdates().size());
- MemoryUpdate deserializedShortTermMemoryUpdate =
- deserializedState.getShortTermMemoryUpdates().get(0);
- assertEquals("stm.test.path",
deserializedShortTermMemoryUpdate.getPath());
- assertEquals("stm test value",
deserializedShortTermMemoryUpdate.getValue());
+ assertEquals(1, deserializedState.getMemoryUpdates().size());
+ MemoryUpdate deserializedMemoryUpdate =
deserializedState.getMemoryUpdates().get(0);
+ assertEquals("test.path", deserializedMemoryUpdate.getPath());
+ assertEquals("test value", deserializedMemoryUpdate.getValue());
// Verify outputEvents
assertEquals(1, deserializedState.getOutputEvents().size());
@@ -94,8 +86,7 @@ public class ActionStateSerdeTest {
// Create ActionState with null taskEvent
ActionState originalState = new ActionState();
MemoryUpdate memoryUpdate = new MemoryUpdate("test.path", "test
value");
- originalState.addShortTermMemoryUpdate(memoryUpdate);
- originalState.addSensoryMemoryUpdate(memoryUpdate);
+ originalState.addMemoryUpdate(memoryUpdate);
// Test serialization/deserialization
ActionStateKafkaSeder seder = new ActionStateKafkaSeder();
@@ -107,8 +98,7 @@ public class ActionStateSerdeTest {
assertNull(deserializedState.getTaskEvent());
// Verify other fields
- assertEquals(1, deserializedState.getSensoryMemoryUpdates().size());
- assertEquals(1, deserializedState.getShortTermMemoryUpdates().size());
+ assertEquals(1, deserializedState.getMemoryUpdates().size());
assertEquals(0, deserializedState.getOutputEvents().size());
}
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java
index 07949a6..970fe32 100644
---
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java
@@ -63,10 +63,7 @@ public class MemoryObjectTest {
memoryUpdates = new LinkedList<>();
memory =
new MemoryObjectImpl(
- MemoryObject.MemoryType.SHORT_TERM,
- new CachedMemoryStore(mapState),
- MemoryObjectImpl.ROOT_KEY,
- memoryUpdates);
+ new CachedMemoryStore(mapState),
MemoryObjectImpl.ROOT_KEY, memoryUpdates);
}
@Test
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
index 8a0e02e..780784f 100644
---
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
@@ -68,11 +68,6 @@ public class MemoryRefTest {
this.memoryObject = memoryObject;
}
- @Override
- public MemoryObject getSensoryMemory() throws Exception {
- return memoryObject;
- }
-
@Override
public MemoryObject getShortTermMemory() {
return memoryObject;
@@ -117,7 +112,6 @@ public class MemoryRefTest {
ForTestMemoryMapState<MemoryObjectImpl.MemoryItem> mapState = new
ForTestMemoryMapState<>();
memory =
new MemoryObjectImpl(
- MemoryObject.MemoryType.SHORT_TERM,
new CachedMemoryStore(mapState),
MemoryObjectImpl.ROOT_KEY,
new LinkedList<>());
diff --git
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
index 4c58f98..4027589 100644
---
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
+++
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
@@ -267,11 +267,10 @@ public class ActionExecutionOperatorTest {
assertThat(taskEvent).isNotNull();
// Verify memory updates contain expected data
- if (!state.getShortTermMemoryUpdates().isEmpty()) {
+ if (!state.getMemoryUpdates().isEmpty()) {
// For action1, memory should contain input + 1
-
assertThat(state.getShortTermMemoryUpdates().get(0).getPath())
- .isEqualTo("tmp");
-
assertThat(state.getShortTermMemoryUpdates().get(0).getValue())
+
assertThat(state.getMemoryUpdates().get(0).getPath()).isEqualTo("tmp");
+ assertThat(state.getMemoryUpdates().get(0).getValue())
.isEqualTo(inputValue + 1);
}