This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git


The following commit(s) were added to refs/heads/main by this push:
     new d1b2601  [runtime] Avoid create new runner context each time execute 
action. (#392)
d1b2601 is described below

commit d1b2601a316ef4c33c3b3ae762f0fcac414ab26a
Author: Wenjin Xie <[email protected]>
AuthorDate: Thu Dec 25 10:06:29 2025 +0800

    [runtime] Avoid create new runner context each time execute action. (#392)
---
 .../flink_agents/runtime/flink_runner_context.py   | 29 ++++----
 .../agents/runtime/context/RunnerContextImpl.java  | 67 +++++++++++++------
 .../runtime/operator/ActionExecutionOperator.java  | 77 +++++++++++++++-------
 .../flink/agents/runtime/operator/ActionTask.java  |  4 +-
 .../agents/runtime/operator/JavaActionTask.java    |  4 +-
 .../python/context/PythonRunnerContextImpl.java    | 22 +------
 .../runtime/python/operator/PythonActionTask.java  | 19 ++----
 .../python/operator/PythonGeneratorActionTask.java |  5 +-
 .../runtime/python/utils/PythonActionExecutor.java | 45 ++++++++-----
 9 files changed, 156 insertions(+), 116 deletions(-)

diff --git a/python/flink_agents/runtime/flink_runner_context.py 
b/python/flink_agents/runtime/flink_runner_context.py
index 4319074..75ba9c2 100644
--- a/python/flink_agents/runtime/flink_runner_context.py
+++ b/python/flink_agents/runtime/flink_runner_context.py
@@ -219,13 +219,23 @@ def create_flink_runner_context(
     agent_plan_json: str,
     executor: ThreadPoolExecutor,
     j_resource_adapter: Any,
-    job_identifier: str,
-    key: int,
 ) -> FlinkRunnerContext:
     """Used to create a FlinkRunnerContext Python object in Pemja 
environment."""
-    ctx = FlinkRunnerContext(
+    return FlinkRunnerContext(
         j_runner_context, agent_plan_json, executor, j_resource_adapter
     )
+
+
+def flink_runner_context_switch_action_context(
+    ctx: FlinkRunnerContext,
+    job_identifier: str,
+    key: int,
+) -> None:
+    """Switch the context of the flink runner context.
+
+    The ctx is reused across keyed partitions, the context related to
+    specific key should be switched when process new action.
+    """
     backend = ctx.config.get(LongTermMemoryOptions.BACKEND)
     # use external vector store based long term memory
     if backend == LongTermMemoryBackend.EXTERNAL_VECTOR_STORE:
@@ -240,19 +250,6 @@ def create_flink_runner_context(
                 key=str(key),
             )
         )
-    return ctx
-
-
-def create_long_term_memory(
-    j_runner_context: Any,
-    agent_plan_json: str,
-    executor: ThreadPoolExecutor,
-    j_resource_adapter: Any,
-) -> FlinkRunnerContext:
-    """Used to create a FlinkRunnerContext Python object in Pemja 
environment."""
-    return FlinkRunnerContext(
-        j_runner_context, agent_plan_json, executor, j_resource_adapter
-    )
 
 
 def create_async_thread_pool() -> ThreadPoolExecutor:
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
index 6321d98..743b4aa 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
@@ -42,34 +42,61 @@ import java.util.Map;
  * actions.
  */
 public class RunnerContextImpl implements RunnerContext {
+    public static class MemoryContext {
+        private final CachedMemoryStore sensoryMemStore;
+        private final CachedMemoryStore shortTermMemStore;
+        private final List<MemoryUpdate> sensoryMemoryUpdates;
+        private final List<MemoryUpdate> shortTermMemoryUpdates;
+
+        public MemoryContext(
+                CachedMemoryStore sensoryMemStore, CachedMemoryStore 
shortTermMemStore) {
+            this.sensoryMemStore = sensoryMemStore;
+            this.shortTermMemStore = shortTermMemStore;
+            this.sensoryMemoryUpdates = new LinkedList<>();
+            this.shortTermMemoryUpdates = new LinkedList<>();
+        }
+
+        public List<MemoryUpdate> getShortTermMemoryUpdates() {
+            return shortTermMemoryUpdates;
+        }
+
+        public List<MemoryUpdate> getSensoryMemoryUpdates() {
+            return sensoryMemoryUpdates;
+        }
+
+        public CachedMemoryStore getShortTermMemStore() {
+            return shortTermMemStore;
+        }
+
+        public CachedMemoryStore getSensoryMemStore() {
+            return sensoryMemStore;
+        }
+    }
 
     protected final List<Event> pendingEvents = new ArrayList<>();
-    protected final CachedMemoryStore sensoryMemStore;
-    protected final CachedMemoryStore shortTermMemStore;
     protected final FlinkAgentsMetricGroupImpl agentMetricGroup;
     protected final Runnable mailboxThreadChecker;
     protected final AgentPlan agentPlan;
-    protected final List<MemoryUpdate> sensoryMemoryUpdates;
-    protected final List<MemoryUpdate> shortTermMemoryUpdates;
+
+    protected MemoryContext memoryContext;
     protected String actionName;
 
     public RunnerContextImpl(
-            CachedMemoryStore sensoryMemStore,
-            CachedMemoryStore shortTermMemStore,
             FlinkAgentsMetricGroupImpl agentMetricGroup,
             Runnable mailboxThreadChecker,
             AgentPlan agentPlan) {
-        this.sensoryMemStore = sensoryMemStore;
-        this.shortTermMemStore = shortTermMemStore;
         this.agentMetricGroup = agentMetricGroup;
         this.mailboxThreadChecker = mailboxThreadChecker;
         this.agentPlan = agentPlan;
-        this.sensoryMemoryUpdates = new LinkedList<>();
-        this.shortTermMemoryUpdates = new LinkedList<>();
     }
 
-    public void setActionName(String actionName) {
+    public void switchActionContext(String actionName, MemoryContext 
memoryContext) {
         this.actionName = actionName;
+        this.memoryContext = memoryContext;
+    }
+
+    public MemoryContext getMemoryContext() {
+        return memoryContext;
     }
 
     @Override
@@ -112,7 +139,7 @@ public class RunnerContextImpl implements RunnerContext {
 
     public List<MemoryUpdate> getSensoryMemoryUpdates() {
         mailboxThreadChecker.run();
-        return List.copyOf(sensoryMemoryUpdates);
+        return List.copyOf(memoryContext.getSensoryMemoryUpdates());
     }
 
     /**
@@ -124,7 +151,7 @@ public class RunnerContextImpl implements RunnerContext {
      */
     public List<MemoryUpdate> getShortTermMemoryUpdates() {
         mailboxThreadChecker.run();
-        return List.copyOf(shortTermMemoryUpdates);
+        return List.copyOf(memoryContext.getShortTermMemoryUpdates());
     }
 
     @Override
@@ -132,10 +159,10 @@ public class RunnerContextImpl implements RunnerContext {
         mailboxThreadChecker.run();
         return new MemoryObjectImpl(
                 MemoryObject.MemoryType.SENSORY,
-                sensoryMemStore,
+                memoryContext.getSensoryMemStore(),
                 MemoryObjectImpl.ROOT_KEY,
                 mailboxThreadChecker,
-                sensoryMemoryUpdates);
+                memoryContext.getSensoryMemoryUpdates());
     }
 
     @Override
@@ -143,10 +170,10 @@ public class RunnerContextImpl implements RunnerContext {
         mailboxThreadChecker.run();
         return new MemoryObjectImpl(
                 MemoryObject.MemoryType.SHORT_TERM,
-                shortTermMemStore,
+                memoryContext.getShortTermMemStore(),
                 MemoryObjectImpl.ROOT_KEY,
                 mailboxThreadChecker,
-                shortTermMemoryUpdates);
+                memoryContext.getShortTermMemoryUpdates());
     }
 
     @Override
@@ -177,11 +204,11 @@ public class RunnerContextImpl implements RunnerContext {
     }
 
     public void persistMemory() throws Exception {
-        sensoryMemStore.persistCache();
-        shortTermMemStore.persistCache();
+        memoryContext.getSensoryMemStore().persistCache();
+        memoryContext.getShortTermMemStore().persistCache();
     }
 
     public void clearSensoryMemory() throws Exception {
-        sensoryMemStore.clear();
+        memoryContext.getSensoryMemStore().clear();
     }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
index 98ff144..d827eb6 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
@@ -133,6 +133,9 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
     // PythonActionExecutor for Python actions
     private transient PythonActionExecutor pythonActionExecutor;
 
+    // RunnerContext for Python actions
+    private transient PythonRunnerContextImpl pythonRunnerContext;
+
     // PythonResourceAdapter for Python resources in Java actions
     private transient PythonResourceAdapterImpl pythonResourceAdapter;
 
@@ -144,6 +147,9 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
     private final transient MailboxExecutor mailboxExecutor;
 
+    // RunnerContext for Java Actions
+    private transient RunnerContextImpl runnerContext;
+
     // We need to check whether the current thread is the mailbox thread using 
the mailbox
     // processor.
     // TODO: This is a temporary workaround. In the future, we should add an 
interface in
@@ -174,7 +180,8 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
     // This in memory map keep track of the runner context for the async 
action task that having
     // been finished
-    private final transient Map<ActionTask, RunnerContextImpl> 
actionTaskRunnerContexts;
+    private final transient Map<ActionTask, RunnerContextImpl.MemoryContext>
+            actionTaskMemoryContexts;
 
     // Each job can only have one identifier and this identifier must be 
consistent across restarts.
     // We cannot use job id as the identifier here because user may change job 
id by
@@ -198,7 +205,7 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         this.eventListeners = new ArrayList<>();
         this.actionStateStore = actionStateStore;
         this.checkpointIdToSeqNums = new HashMap<>();
-        this.actionTaskRunnerContexts = new HashMap<>();
+        this.actionTaskMemoryContexts = new HashMap<>();
     }
 
     @Override
@@ -443,12 +450,14 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         } else {
             maybeInitActionState(key, sequenceNumber, actionTask.action, 
actionTask.event);
             ActionTask.ActionTaskResult actionTaskResult =
-                    
actionTask.invoke(getRuntimeContext().getUserCodeClassLoader());
+                    actionTask.invoke(
+                            getRuntimeContext().getUserCodeClassLoader(),
+                            this.pythonActionExecutor);
 
             // We remove the RunnerContext of the action task from the map 
after it is finished. The
             // RunnerContext will be added later if the action task has a 
generated action task,
             // meaning it is not finished.
-            actionTaskRunnerContexts.remove(actionTask);
+            actionTaskMemoryContexts.remove(actionTask);
             maybePersistTaskResult(
                     key,
                     sequenceNumber,
@@ -483,7 +492,8 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
             // If the action task is not finished, we keep the runner context 
in the memory for the
             // next generated ActionTask to be invoked.
-            actionTaskRunnerContexts.put(generatedActionTask, 
actionTask.getRunnerContext());
+            actionTaskMemoryContexts.put(
+                    generatedActionTask, 
actionTask.getRunnerContext().getMemoryContext());
 
             actionTasksKState.add(generatedActionTask);
         }
@@ -552,6 +562,9 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
             pythonEnvironmentManager.open();
             EmbeddedPythonEnvironment env = 
pythonEnvironmentManager.createEnvironment();
             pythonInterpreter = env.getInterpreter();
+            pythonRunnerContext =
+                    new PythonRunnerContextImpl(
+                            this.metricGroup, this::checkMailboxThread, 
this.agentPlan);
             if (containPythonAction) {
                 initPythonActionExecutor();
             } else {
@@ -568,6 +581,7 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
                         pythonInterpreter,
                         new ObjectMapper().writeValueAsString(agentPlan),
                         javaResourceAdapter,
+                        pythonRunnerContext,
                         jobIdentifier);
         pythonActionExecutor.open();
     }
@@ -752,31 +766,28 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         }
 
         RunnerContextImpl runnerContext;
-        if (actionTaskRunnerContexts.containsKey(actionTask)) {
-            runnerContext = actionTaskRunnerContexts.get(actionTask);
-        } else if (actionTask.action.getExec() instanceof JavaFunction) {
-            runnerContext =
-                    new RunnerContextImpl(
-                            new CachedMemoryStore(sensoryMemState),
-                            new CachedMemoryStore(shortTermMemState),
-                            metricGroup,
-                            this::checkMailboxThread,
-                            agentPlan);
+        if (actionTask.action.getExec() instanceof JavaFunction) {
+            runnerContext = createOrGetRunnerContext(true);
         } else if (actionTask.action.getExec() instanceof PythonFunction) {
-            runnerContext =
-                    new PythonRunnerContextImpl(
-                            new CachedMemoryStore(sensoryMemState),
-                            new CachedMemoryStore(shortTermMemState),
-                            metricGroup,
-                            this::checkMailboxThread,
-                            agentPlan,
-                            pythonActionExecutor);
+            runnerContext = createOrGetRunnerContext(false);
         } else {
             throw new IllegalStateException(
                     "Unsupported action type: " + 
actionTask.action.getExec().getClass());
         }
 
-        runnerContext.setActionName(actionTask.action.getName());
+        RunnerContextImpl.MemoryContext memoryContext;
+        if (actionTaskMemoryContexts.containsKey(actionTask)) {
+            // action task for async execution action, should retrieve 
intermediate results from
+            // map.
+            memoryContext = actionTaskMemoryContexts.get(actionTask);
+        } else {
+            memoryContext =
+                    new RunnerContextImpl.MemoryContext(
+                            new CachedMemoryStore(sensoryMemState),
+                            new CachedMemoryStore(shortTermMemState));
+        }
+
+        runnerContext.switchActionContext(actionTask.action.getName(), 
memoryContext);
         actionTask.setRunnerContext(runnerContext);
     }
 
@@ -883,6 +894,24 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         }
     }
 
+    private RunnerContextImpl createOrGetRunnerContext(Boolean isJava) {
+        if (isJava) {
+            if (runnerContext == null) {
+                runnerContext =
+                        new RunnerContextImpl(
+                                this.metricGroup, this::checkMailboxThread, 
this.agentPlan);
+            }
+            return runnerContext;
+        } else {
+            if (pythonRunnerContext == null) {
+                pythonRunnerContext =
+                        new PythonRunnerContextImpl(
+                                this.metricGroup, this::checkMailboxThread, 
this.agentPlan);
+            }
+            return pythonRunnerContext;
+        }
+    }
+
     /** Failed to execute Action task. */
     public static class ActionTaskExecutionException extends Exception {
         public ActionTaskExecutionException(String message, Throwable cause) {
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTask.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTask.java
index 34f7850..053b9d9 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTask.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTask.java
@@ -20,6 +20,7 @@ package org.apache.flink.agents.runtime.operator;
 import org.apache.flink.agents.api.Event;
 import org.apache.flink.agents.plan.actions.Action;
 import org.apache.flink.agents.runtime.context.RunnerContextImpl;
+import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -87,7 +88,8 @@ public abstract class ActionTask {
     }
 
     /** Invokes the action task. */
-    public abstract ActionTaskResult invoke(ClassLoader userCodeClassLoader) 
throws Exception;
+    public abstract ActionTaskResult invoke(
+            ClassLoader userCodeClassLoader, PythonActionExecutor executor) 
throws Exception;
 
     public class ActionTaskResult {
         private final boolean finished;
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java
index 9fc641b..65d8ef4 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java
@@ -20,6 +20,7 @@ package org.apache.flink.agents.runtime.operator;
 import org.apache.flink.agents.api.Event;
 import org.apache.flink.agents.plan.JavaFunction;
 import org.apache.flink.agents.plan.actions.Action;
+import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
 
 import static org.apache.flink.util.Preconditions.checkState;
 
@@ -36,7 +37,8 @@ public class JavaActionTask extends ActionTask {
     }
 
     @Override
-    public ActionTaskResult invoke(ClassLoader userCodeClassLoader) throws 
Exception {
+    public ActionTaskResult invoke(ClassLoader userCodeClassLoader, 
PythonActionExecutor executor)
+            throws Exception {
         LOG.debug(
                 "Try execute java action {} for event {} with key {}.",
                 action.getName(),
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
index 26d5a79..690d412 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
@@ -21,10 +21,8 @@ import org.apache.flink.agents.api.Event;
 import org.apache.flink.agents.api.context.RunnerContext;
 import org.apache.flink.agents.plan.AgentPlan;
 import org.apache.flink.agents.runtime.context.RunnerContextImpl;
-import org.apache.flink.agents.runtime.memory.CachedMemoryStore;
 import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
 import org.apache.flink.agents.runtime.python.event.PythonEvent;
-import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
 import org.apache.flink.util.Preconditions;
 
 import javax.annotation.concurrent.NotThreadSafe;
@@ -32,23 +30,11 @@ import javax.annotation.concurrent.NotThreadSafe;
 /** A specialized {@link RunnerContext} that is specifically used when 
executing Python actions. */
 @NotThreadSafe
 public class PythonRunnerContextImpl extends RunnerContextImpl {
-
-    private final PythonActionExecutor pythonActionExecutor;
-
     public PythonRunnerContextImpl(
-            CachedMemoryStore sensoryMemStore,
-            CachedMemoryStore shortTermMemStore,
             FlinkAgentsMetricGroupImpl agentMetricGroup,
             Runnable mailboxThreadChecker,
-            AgentPlan agentPlan,
-            PythonActionExecutor pythonActionExecutor) {
-        super(
-                sensoryMemStore,
-                shortTermMemStore,
-                agentMetricGroup,
-                mailboxThreadChecker,
-                agentPlan);
-        this.pythonActionExecutor = pythonActionExecutor;
+            AgentPlan agentPlan) {
+        super(agentMetricGroup, mailboxThreadChecker, agentPlan);
     }
 
     @Override
@@ -62,8 +48,4 @@ public class PythonRunnerContextImpl extends 
RunnerContextImpl {
         // this method will be invoked by PythonActionExecutor's python 
interpreter.
         sendEvent(new PythonEvent(event, type, eventString));
     }
-
-    public PythonActionExecutor getPythonActionExecutor() {
-        return pythonActionExecutor;
-    }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java
index a03c4c8..8f1e2d5 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java
@@ -21,7 +21,6 @@ import org.apache.flink.agents.api.Event;
 import org.apache.flink.agents.plan.PythonFunction;
 import org.apache.flink.agents.plan.actions.Action;
 import org.apache.flink.agents.runtime.operator.ActionTask;
-import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl;
 import org.apache.flink.agents.runtime.python.event.PythonEvent;
 import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
 
@@ -43,7 +42,8 @@ public class PythonActionTask extends ActionTask {
                 "Python action only accept python event, but got " + event);
     }
 
-    public ActionTaskResult invoke(ClassLoader userCodeClassLoader) throws 
Exception {
+    public ActionTaskResult invoke(ClassLoader userCodeClassLoader, 
PythonActionExecutor executor)
+            throws Exception {
         LOG.debug(
                 "Try execute python action {} for event {} with key {}.",
                 action.getName(),
@@ -51,13 +51,9 @@ public class PythonActionTask extends ActionTask {
                 key);
         runnerContext.checkNoPendingEvents();
 
-        PythonActionExecutor pythonActionExecutor = getPythonActionExecutor();
         String pythonGeneratorRef =
-                pythonActionExecutor.executePythonFunction(
-                        (PythonFunction) action.getExec(),
-                        (PythonEvent) event,
-                        runnerContext,
-                        key.hashCode());
+                executor.executePythonFunction(
+                        (PythonFunction) action.getExec(), (PythonEvent) 
event, key.hashCode());
         // If a user-defined action uses an interface to submit asynchronous 
tasks, it will return a
         // Python generator object instance upon its first execution. 
Otherwise, it means that no
         // asynchronous tasks were submitted and the action has already 
completed.
@@ -67,14 +63,9 @@ public class PythonActionTask extends ActionTask {
             ActionTask tempGeneratedActionTask =
                     new PythonGeneratorActionTask(key, event, action, 
pythonGeneratorRef);
             tempGeneratedActionTask.setRunnerContext(runnerContext);
-            return tempGeneratedActionTask.invoke(userCodeClassLoader);
+            return tempGeneratedActionTask.invoke(userCodeClassLoader, 
executor);
         }
         return new ActionTaskResult(
                 true, runnerContext.drainEvents(event.getSourceTimestamp()), 
null);
     }
-
-    protected PythonActionExecutor getPythonActionExecutor() {
-        checkState(runnerContext != null && runnerContext instanceof 
PythonRunnerContextImpl);
-        return ((PythonRunnerContextImpl) 
runnerContext).getPythonActionExecutor();
-    }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonGeneratorActionTask.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonGeneratorActionTask.java
index 96afa19..969cb8f 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonGeneratorActionTask.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonGeneratorActionTask.java
@@ -20,6 +20,7 @@ package org.apache.flink.agents.runtime.python.operator;
 import org.apache.flink.agents.api.Event;
 import org.apache.flink.agents.plan.actions.Action;
 import org.apache.flink.agents.runtime.operator.ActionTask;
+import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
 
 /** An {@link ActionTask} wrapper a Python Generator to represent a code block 
in Python action. */
 public class PythonGeneratorActionTask extends PythonActionTask {
@@ -32,13 +33,13 @@ public class PythonGeneratorActionTask extends 
PythonActionTask {
     }
 
     @Override
-    public ActionTaskResult invoke(ClassLoader userCodeClassLoader) {
+    public ActionTaskResult invoke(ClassLoader userCodeClassLoader, 
PythonActionExecutor executor) {
         LOG.debug(
                 "Try execute python generator action {} for event {} with key 
{}.",
                 action.getName(),
                 event,
                 key);
-        boolean finished = 
getPythonActionExecutor().callPythonGenerator(pythonGeneratorRef);
+        boolean finished = executor.callPythonGenerator(pythonGeneratorRef);
         ActionTask generatedActionTask = finished ? null : this;
         return new ActionTaskResult(
                 finished,
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
index 4c08789..e0108fb 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
@@ -18,10 +18,11 @@
 package org.apache.flink.agents.runtime.python.utils;
 
 import org.apache.flink.agents.plan.PythonFunction;
-import org.apache.flink.agents.runtime.context.RunnerContextImpl;
+import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl;
 import org.apache.flink.agents.runtime.python.event.PythonEvent;
 import org.apache.flink.agents.runtime.utils.EventUtil;
 import pemja.core.PythonInterpreter;
+import pemja.core.object.PyObject;
 
 import java.util.concurrent.atomic.AtomicLong;
 
@@ -39,6 +40,9 @@ public class PythonActionExecutor {
     private static final String CREATE_FLINK_RUNNER_CONTEXT =
             "flink_runner_context.create_flink_runner_context";
 
+    private static final String FLINK_RUNNER_CONTEXT_SWITCH_ACTION_CONTEXT =
+            "flink_runner_context.flink_runner_context_switch_action_context";
+
     // ========== ASYNC THREAD POOL ===========
     private static final String CREATE_ASYNC_THREAD_POOL =
             "flink_runner_context.create_async_thread_pool";
@@ -59,17 +63,21 @@ public class PythonActionExecutor {
 
     private final PythonInterpreter interpreter;
     private final String agentPlanJson;
+    private final PythonRunnerContextImpl runnerContext;
     private final JavaResourceAdapter javaResourceAdapter;
     private final String jobIdentifier;
-    private Object pythonAsyncThreadPool;
+    private PyObject pythonAsyncThreadPool;
+    private PyObject pythonRunnerContext;
 
     public PythonActionExecutor(
             PythonInterpreter interpreter,
             String agentPlanJson,
             JavaResourceAdapter javaResourceAdapter,
+            PythonRunnerContextImpl runnerContext,
             String jobIdentifier) {
         this.interpreter = interpreter;
         this.agentPlanJson = agentPlanJson;
+        this.runnerContext = runnerContext;
         this.javaResourceAdapter = javaResourceAdapter;
         this.jobIdentifier = jobIdentifier;
     }
@@ -77,7 +85,16 @@ public class PythonActionExecutor {
     public void open() throws Exception {
         interpreter.exec(PYTHON_IMPORTS);
 
-        pythonAsyncThreadPool = interpreter.invoke(CREATE_ASYNC_THREAD_POOL);
+        pythonAsyncThreadPool = (PyObject) 
interpreter.invoke(CREATE_ASYNC_THREAD_POOL);
+
+        pythonRunnerContext =
+                (PyObject)
+                        interpreter.invoke(
+                                CREATE_FLINK_RUNNER_CONTEXT,
+                                runnerContext,
+                                agentPlanJson,
+                                pythonAsyncThreadPool,
+                                javaResourceAdapter);
     }
 
     /**
@@ -90,29 +107,21 @@ public class PythonActionExecutor {
      * @return The name of the Python generator variable. It may be null if 
the Python function does
      *     not return a generator.
      */
-    public String executePythonFunction(
-            PythonFunction function,
-            PythonEvent event,
-            RunnerContextImpl runnerContext,
-            int hashOfKey)
+    public String executePythonFunction(PythonFunction function, PythonEvent 
event, int hashOfKey)
             throws Exception {
         runnerContext.checkNoPendingEvents();
         function.setInterpreter(interpreter);
 
-        Object pythonRunnerContextObject =
-                interpreter.invoke(
-                        CREATE_FLINK_RUNNER_CONTEXT,
-                        runnerContext,
-                        agentPlanJson,
-                        pythonAsyncThreadPool,
-                        javaResourceAdapter,
-                        jobIdentifier,
-                        hashOfKey);
+        interpreter.invoke(
+                FLINK_RUNNER_CONTEXT_SWITCH_ACTION_CONTEXT,
+                pythonRunnerContext,
+                jobIdentifier,
+                hashOfKey);
 
         Object pythonEventObject = 
interpreter.invoke(CONVERT_TO_PYTHON_OBJECT, event.getEvent());
 
         try {
-            Object calledResult = function.call(pythonEventObject, 
pythonRunnerContextObject);
+            Object calledResult = function.call(pythonEventObject, 
pythonRunnerContext);
             if (calledResult == null) {
                 return null;
             } else {

Reply via email to