This is an automated email from the ASF dual-hosted git repository.

sxnan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git


The following commit(s) were added to refs/heads/main by this push:
     new a2fef7e  [runtime] Only persist memory changes to flink state when 
action is finished (#219)
a2fef7e is described below

commit a2fef7efa4d8bd9e71e7d34360d81eea008cfe9e
Author: Xuannan <[email protected]>
AuthorDate: Fri Sep 26 09:30:50 2025 +0800

    [runtime] Only persist memory changes to flink state when action is 
finished (#219)
---
 .../agents/runtime/context/RunnerContextImpl.java  | 10 ++-
 .../agents/runtime/memory/CachedMemoryStore.java   | 60 +++++++++++++++
 .../agents/runtime/memory/MemoryObjectImpl.java    |  8 +-
 .../flink/agents/runtime/memory/MemoryStore.java   | 48 ++++++++++++
 .../runtime/operator/ActionExecutionOperator.java  | 44 +++++++++--
 .../python/context/PythonRunnerContextImpl.java    |  5 +-
 .../runtime/memory/CachedMemoryStoreTest.java      | 90 ++++++++++++++++++++++
 .../runtime/memory/ForTestMemoryMapState.java      | 85 ++++++++++++++++++++
 .../agents/runtime/memory/MemoryObjectTest.java    | 66 +---------------
 .../flink/agents/runtime/memory/MemoryRefTest.java |  6 +-
 10 files changed, 339 insertions(+), 83 deletions(-)

diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
index 653c5cc..e8ca999 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
@@ -26,9 +26,9 @@ import org.apache.flink.agents.api.resource.Resource;
 import org.apache.flink.agents.api.resource.ResourceType;
 import org.apache.flink.agents.plan.AgentPlan;
 import org.apache.flink.agents.plan.utils.JsonUtils;
+import org.apache.flink.agents.runtime.memory.CachedMemoryStore;
 import org.apache.flink.agents.runtime.memory.MemoryObjectImpl;
 import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
-import org.apache.flink.api.common.state.MapState;
 import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
 import org.apache.flink.util.Preconditions;
 
@@ -44,7 +44,7 @@ import java.util.Map;
 public class RunnerContextImpl implements RunnerContext {
 
     protected final List<Event> pendingEvents = new ArrayList<>();
-    protected final MapState<String, MemoryObjectImpl.MemoryItem> store;
+    protected final CachedMemoryStore store;
     protected final FlinkAgentsMetricGroupImpl agentMetricGroup;
     protected final Runnable mailboxThreadChecker;
     protected final AgentPlan agentPlan;
@@ -52,7 +52,7 @@ public class RunnerContextImpl implements RunnerContext {
     protected String actionName;
 
     public RunnerContextImpl(
-            MapState<String, MemoryObjectImpl.MemoryItem> store,
+            CachedMemoryStore store,
             FlinkAgentsMetricGroupImpl agentMetricGroup,
             Runnable mailboxThreadChecker,
             AgentPlan agentPlan) {
@@ -147,4 +147,8 @@ public class RunnerContextImpl implements RunnerContext {
     public String getActionName() {
         return actionName;
     }
+
+    public void persistMemory() throws Exception {
+        store.persistCache();
+    }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java
new file mode 100644
index 0000000..71eb8d2
--- /dev/null
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.memory;
+
+import org.apache.flink.api.common.state.MapState;
+
+import java.util.HashMap;
+import java.util.Map;
+
+public class CachedMemoryStore implements MemoryStore {
+
+    private final Map<String, MemoryObjectImpl.MemoryItem> cache;
+    private final MapState<String, MemoryObjectImpl.MemoryItem> store;
+
+    public CachedMemoryStore(MapState<String, MemoryObjectImpl.MemoryItem> 
store) {
+        this.store = store;
+        this.cache = new HashMap<>();
+    }
+
+    @Override
+    public MemoryObjectImpl.MemoryItem get(String key) throws Exception {
+        if (cache.containsKey(key)) {
+            return cache.get(key);
+        }
+
+        return store.get(key);
+    }
+
+    @Override
+    public void put(String key, MemoryObjectImpl.MemoryItem value) throws 
Exception {
+        cache.put(key, value);
+    }
+
+    @Override
+    public boolean contains(String key) throws Exception {
+        return cache.containsKey(key) || store.contains(key);
+    }
+
+    public void persistCache() throws Exception {
+        for (Map.Entry<String, MemoryObjectImpl.MemoryItem> entry : 
cache.entrySet()) {
+            store.put(entry.getKey(), entry.getValue());
+        }
+        cache.clear();
+    }
+}
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
index 70856eb..a8fa1cb 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
@@ -20,7 +20,6 @@ package org.apache.flink.agents.runtime.memory;
 import org.apache.flink.agents.api.context.MemoryObject;
 import org.apache.flink.agents.api.context.MemoryRef;
 import org.apache.flink.agents.api.context.MemoryUpdate;
-import org.apache.flink.api.common.state.MapState;
 
 import java.io.Serializable;
 import java.util.ArrayList;
@@ -41,19 +40,18 @@ public class MemoryObjectImpl implements MemoryObject {
     public static final String ROOT_KEY = "";
     private static final String SEPARATOR = ".";
 
-    private final MapState<String, MemoryItem> store;
+    private final MemoryStore store;
     private final List<MemoryUpdate> memoryUpdates;
     private final String prefix;
     private final Runnable mailboxThreadChecker;
 
-    public MemoryObjectImpl(
-            MapState<String, MemoryItem> store, String prefix, 
List<MemoryUpdate> memoryUpdates)
+    public MemoryObjectImpl(MemoryStore store, String prefix, 
List<MemoryUpdate> memoryUpdates)
             throws Exception {
         this(store, prefix, () -> {}, memoryUpdates);
     }
 
     public MemoryObjectImpl(
-            MapState<String, MemoryItem> store,
+            MemoryStore store,
             String prefix,
             Runnable mailboxThreadChecker,
             List<MemoryUpdate> memoryUpdates)
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java
new file mode 100644
index 0000000..f466750
--- /dev/null
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.memory;
+
+import org.apache.flink.agents.runtime.memory.MemoryObjectImpl.MemoryItem;
+
+/** MemoryStore to put and get MemoryItems. */
+public interface MemoryStore {
+
+    /**
+     * Get a MemoryItem by key.
+     *
+     * @param key the key of the MemoryItem
+     * @return the MemoryItem
+     */
+    MemoryItem get(String key) throws Exception;
+
+    /**
+     * Put a MemoryItem by key.
+     *
+     * @param key the key of the MemoryItem
+     * @param value the MemoryItem
+     */
+    void put(String key, MemoryItem value) throws Exception;
+
+    /**
+     * Check if a MemoryItem exists by key.
+     *
+     * @param key the key of the MemoryItem
+     * @return true if the MemoryItem exists, false otherwise
+     */
+    boolean contains(String key) throws Exception;
+}
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
index 0683f2e..b91fa52 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
@@ -36,6 +36,7 @@ import 
org.apache.flink.agents.runtime.actionstate.ActionStateStore;
 import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore;
 import org.apache.flink.agents.runtime.context.RunnerContextImpl;
 import org.apache.flink.agents.runtime.env.PythonEnvironmentManager;
+import org.apache.flink.agents.runtime.memory.CachedMemoryStore;
 import org.apache.flink.agents.runtime.memory.MemoryObjectImpl;
 import org.apache.flink.agents.runtime.metrics.BuiltInMetrics;
 import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
@@ -83,7 +84,6 @@ import java.util.Optional;
 import static 
org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND;
 import static 
org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA;
 import static org.apache.flink.agents.runtime.utils.StateUtil.*;
-import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
@@ -151,6 +151,10 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
     private transient ListState<Object> recoveryMarkerOpState;
     private transient Map<Long, Map<Object, Long>> checkpointIdToSeqNums;
 
+    // This in memory map keep track of the runner context for the async 
action task that having
+    // been finished
+    private final transient Map<ActionTask, RunnerContextImpl> 
actionTaskRunnerContexts;
+
     public ActionExecutionOperator(
             AgentPlan agentPlan,
             Boolean inputIsJava,
@@ -166,6 +170,7 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         this.eventListeners = new ArrayList<>();
         this.actionStateStore = actionStateStore;
         this.checkpointIdToSeqNums = new HashMap<>();
+        this.actionTaskRunnerContexts = new HashMap<>();
     }
 
     @Override
@@ -374,6 +379,11 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         } else {
             maybeInitActionState(key, sequenceNumber, actionTask.action, 
actionTask.event);
             ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke();
+
+            // We remove the RunnerContext of the action task from the map 
after it is finished. The
+            // RunnerContext will be added later if the action task has a 
generated action task,
+            // meaning it is not finished.
+            actionTaskRunnerContexts.remove(actionTask);
             maybePersistTaskResult(
                     key,
                     sequenceNumber,
@@ -394,13 +404,23 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         if (isFinished) {
             builtInMetrics.markActionExecuted(actionTask.action.getName());
             currentInputEventFinished = !currentKeyHasMoreActionTask();
+
+            // Persist memory to the Flink state when the action task is 
finished.
+            actionTask.getRunnerContext().persistMemory();
         } else {
+            checkState(
+                    generatedActionTaskOpt.isPresent(),
+                    "ActionTask not finished, but the generated action task is 
null.");
+
             // If the action task is not finished, we should get a new action 
task to continue the
             // execution.
-            checkNotNull(
-                    generatedActionTaskOpt.get(),
-                    "ActionTask not finished, but the generated action task is 
null.");
-            actionTasksKState.add(generatedActionTaskOpt.get());
+            ActionTask generatedActionTask = generatedActionTaskOpt.get();
+
+            // If the action task is not finished, we keep the runner context 
in the memory for the
+            // next generated ActionTask to be invoked.
+            actionTaskRunnerContexts.put(generatedActionTask, 
actionTask.getRunnerContext());
+
+            actionTasksKState.add(generatedActionTask);
         }
 
         // 3. Process the next InputEvent or next action task
@@ -602,14 +622,22 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         }
 
         RunnerContextImpl runnerContext;
-        if (actionTask.action.getExec() instanceof JavaFunction) {
+        if (actionTaskRunnerContexts.containsKey(actionTask)) {
+            runnerContext = actionTaskRunnerContexts.get(actionTask);
+        } else if (actionTask.action.getExec() instanceof JavaFunction) {
             runnerContext =
                     new RunnerContextImpl(
-                            shortTermMemState, metricGroup, 
this::checkMailboxThread, agentPlan);
+                            new CachedMemoryStore(shortTermMemState),
+                            metricGroup,
+                            this::checkMailboxThread,
+                            agentPlan);
         } else if (actionTask.action.getExec() instanceof PythonFunction) {
             runnerContext =
                     new PythonRunnerContextImpl(
-                            shortTermMemState, metricGroup, 
this::checkMailboxThread, agentPlan);
+                            new CachedMemoryStore(shortTermMemState),
+                            metricGroup,
+                            this::checkMailboxThread,
+                            agentPlan);
         } else {
             throw new IllegalStateException(
                     "Unsupported action type: " + 
actionTask.action.getExec().getClass());
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
index b787a6a..43e3f62 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
@@ -21,10 +21,9 @@ import org.apache.flink.agents.api.Event;
 import org.apache.flink.agents.api.context.RunnerContext;
 import org.apache.flink.agents.plan.AgentPlan;
 import org.apache.flink.agents.runtime.context.RunnerContextImpl;
-import org.apache.flink.agents.runtime.memory.MemoryObjectImpl;
+import org.apache.flink.agents.runtime.memory.CachedMemoryStore;
 import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
 import org.apache.flink.agents.runtime.python.event.PythonEvent;
-import org.apache.flink.api.common.state.MapState;
 import org.apache.flink.util.Preconditions;
 
 import javax.annotation.concurrent.NotThreadSafe;
@@ -34,7 +33,7 @@ import javax.annotation.concurrent.NotThreadSafe;
 public class PythonRunnerContextImpl extends RunnerContextImpl {
 
     public PythonRunnerContextImpl(
-            MapState<String, MemoryObjectImpl.MemoryItem> store,
+            CachedMemoryStore store,
             FlinkAgentsMetricGroupImpl agentMetricGroup,
             Runnable mailboxThreadChecker,
             AgentPlan agentPlan) {
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/CachedMemoryStoreTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/CachedMemoryStoreTest.java
new file mode 100644
index 0000000..51c55b9
--- /dev/null
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/CachedMemoryStoreTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.memory;
+
+import org.apache.flink.agents.runtime.memory.MemoryObjectImpl.MemoryItem;
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+class CachedMemoryStoreTest {
+
+    @Test
+    void testPutAndGet() throws Exception {
+        ForTestMemoryMapState<MemoryItem> store = new 
ForTestMemoryMapState<>();
+        CachedMemoryStore cachedStore = new CachedMemoryStore(store);
+
+        MemoryItem v1 = new MemoryItem();
+        MemoryItem v2 = new MemoryItem(10);
+        cachedStore.put("k1", v1);
+        cachedStore.put("k2", v2);
+
+        assertThat(cachedStore.get("k1")).isEqualTo(v1);
+        assertThat(cachedStore.get("k2")).isEqualTo(v2);
+    }
+
+    @Test
+    void testPutToCache() throws Exception {
+        ForTestMemoryMapState<MemoryItem> store = new 
ForTestMemoryMapState<>();
+        MemoryItem v1 = new MemoryItem();
+        MemoryItem v2 = new MemoryItem();
+        store.put("k1", v1);
+        store.put("k2", v2);
+
+        CachedMemoryStore cachedStore = new CachedMemoryStore(store);
+
+        MemoryItem v11 = new MemoryItem();
+        cachedStore.put("k1", v11);
+
+        assertThat(cachedStore.get("k1")).isEqualTo(v11);
+        assertThat(cachedStore.get("k2")).isEqualTo(v2);
+    }
+
+    @Test
+    void testContains() throws Exception {
+        ForTestMemoryMapState<MemoryItem> store = new 
ForTestMemoryMapState<>();
+        MemoryItem v1 = new MemoryItem();
+        store.put("k1", v1);
+
+        CachedMemoryStore cachedStore = new CachedMemoryStore(store);
+        assertThat(cachedStore.contains("k1")).isTrue();
+        assertThat(cachedStore.contains("k2")).isFalse();
+
+        MemoryItem v2 = new MemoryItem();
+        cachedStore.put("k2", v2);
+        assertThat(cachedStore.contains("k2")).isTrue();
+    }
+
+    @Test
+    void testPersistCache() throws Exception {
+        ForTestMemoryMapState<MemoryItem> store = new 
ForTestMemoryMapState<>();
+        MemoryItem v1 = new MemoryItem();
+        MemoryItem v2 = new MemoryItem(10);
+        store.put("k1", v1);
+        store.put("k2", v2);
+
+        CachedMemoryStore cachedStore = new CachedMemoryStore(store);
+        MemoryItem v11 = new MemoryItem();
+        cachedStore.put("k1", v11);
+
+        cachedStore.persistCache();
+
+        assertThat(store.get("k1")).isEqualTo(v11);
+        assertThat(store.get("k2")).isEqualTo(v2);
+    }
+}
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/ForTestMemoryMapState.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/ForTestMemoryMapState.java
new file mode 100644
index 0000000..916d68c
--- /dev/null
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/ForTestMemoryMapState.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.memory;
+
+import org.apache.flink.api.common.state.MapState;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+/** Simple, non-serialized HashMap implementation. */
+class ForTestMemoryMapState<V> implements MapState<String, V> {
+
+    private final Map<String, V> fortest = new HashMap<>();
+
+    @Override
+    public V get(String key) {
+        return fortest.get(key);
+    }
+
+    @Override
+    public void put(String key, V value) {
+        fortest.put(key, value);
+    }
+
+    @Override
+    public void putAll(Map<String, V> map) {
+        fortest.putAll(map);
+    }
+
+    @Override
+    public void remove(String key) {
+        fortest.remove(key);
+    }
+
+    @Override
+    public boolean contains(String key) {
+        return fortest.containsKey(key);
+    }
+
+    @Override
+    public Iterable<Map.Entry<String, V>> entries() {
+        return fortest.entrySet();
+    }
+
+    @Override
+    public Iterable<String> keys() {
+        return fortest.keySet();
+    }
+
+    @Override
+    public Iterable<V> values() {
+        return fortest.values();
+    }
+
+    @Override
+    public Iterator<Map.Entry<String, V>> iterator() {
+        return fortest.entrySet().iterator();
+    }
+
+    @Override
+    public boolean isEmpty() {
+        return fortest.isEmpty();
+    }
+
+    @Override
+    public void clear() {
+        fortest.clear();
+    }
+}
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java
index c6942c9..970fe32 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java
@@ -19,7 +19,6 @@ package org.apache.flink.agents.runtime.memory;
 
 import org.apache.flink.agents.api.context.MemoryObject;
 import org.apache.flink.agents.api.context.MemoryUpdate;
-import org.apache.flink.api.common.state.MapState;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
@@ -62,7 +61,9 @@ public class MemoryObjectTest {
     void setUp() throws Exception {
         ForTestMemoryMapState<MemoryObjectImpl.MemoryItem> mapState = new 
ForTestMemoryMapState<>();
         memoryUpdates = new LinkedList<>();
-        memory = new MemoryObjectImpl(mapState, MemoryObjectImpl.ROOT_KEY, 
memoryUpdates);
+        memory =
+                new MemoryObjectImpl(
+                        new CachedMemoryStore(mapState), 
MemoryObjectImpl.ROOT_KEY, memoryUpdates);
     }
 
     @Test
@@ -180,64 +181,3 @@ public class MemoryObjectTest {
                         new MemoryUpdate("str.new_str.str", "world"));
     }
 }
-
-/** Simple, non-serialized HashMap implementation. */
-class ForTestMemoryMapState<V> implements MapState<String, V> {
-
-    private final Map<String, V> fortest = new HashMap<>();
-
-    @Override
-    public V get(String key) {
-        return fortest.get(key);
-    }
-
-    @Override
-    public void put(String key, V value) {
-        fortest.put(key, value);
-    }
-
-    @Override
-    public void putAll(Map<String, V> map) {
-        fortest.putAll(map);
-    }
-
-    @Override
-    public void remove(String key) {
-        fortest.remove(key);
-    }
-
-    @Override
-    public boolean contains(String key) {
-        return fortest.containsKey(key);
-    }
-
-    @Override
-    public Iterable<Map.Entry<String, V>> entries() {
-        return fortest.entrySet();
-    }
-
-    @Override
-    public Iterable<String> keys() {
-        return fortest.keySet();
-    }
-
-    @Override
-    public Iterable<V> values() {
-        return fortest.values();
-    }
-
-    @Override
-    public Iterator<Map.Entry<String, V>> iterator() {
-        return fortest.entrySet().iterator();
-    }
-
-    @Override
-    public boolean isEmpty() {
-        return fortest.isEmpty();
-    }
-
-    @Override
-    public void clear() {
-        fortest.clear();
-    }
-}
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
index a997a32..780784f 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
@@ -110,7 +110,11 @@ public class MemoryRefTest {
     @BeforeEach
     void setUp() throws Exception {
         ForTestMemoryMapState<MemoryObjectImpl.MemoryItem> mapState = new 
ForTestMemoryMapState<>();
-        memory = new MemoryObjectImpl(mapState, MemoryObjectImpl.ROOT_KEY, new 
LinkedList<>());
+        memory =
+                new MemoryObjectImpl(
+                        new CachedMemoryStore(mapState),
+                        MemoryObjectImpl.ROOT_KEY,
+                        new LinkedList<>());
     }
 
     @Test

Reply via email to