This is an automated email from the ASF dual-hosted git repository. sxnan pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit b5338b4d0c2f46f92aa37ceb96729aef611c88a3 Author: sxnan <[email protected]> AuthorDate: Thu Jan 8 17:59:26 2026 +0800 [api] Introduce execute to runner_context.py --- python/flink_agents/api/runner_context.py | 70 ++++++- .../e2e_tests_integration/execute_test.py | 205 +++++++++++++++++++++ .../e2e_tests_integration/execute_test_agent.py | 118 ++++++++++++ .../flink_integration_agent.py | 2 +- .../e2e_tests/long_term_memory_test.py | 4 +- .../resources/execute_test_input/input_data.txt | 3 + .../resources/ground_truth/test_execute_basic.txt | 3 + .../ground_truth/test_execute_multiple.txt | 3 + .../ground_truth/test_execute_with_async.txt | 3 + python/flink_agents/plan/function.py | 2 +- .../flink_agents/runtime/flink_runner_context.py | 17 ++ python/flink_agents/runtime/local_runner.py | 25 ++- .../tests/test_local_execution_environment.py | 2 +- .../runtime/tests/test_runner_context_execute.py | 190 +++++++++++++++++++ 14 files changed, 629 insertions(+), 18 deletions(-) diff --git a/python/flink_agents/api/runner_context.py b/python/flink_agents/api/runner_context.py index 5461e7fc..8b44efeb 100644 --- a/python/flink_agents/api/runner_context.py +++ b/python/flink_agents/api/runner_context.py @@ -32,9 +32,9 @@ class AsyncExecutionResult: """This class wraps an asynchronous task that will be submitted to a thread pool only when awaited. This ensures lazy submission and serial execution semantics. - Note: Only `await ctx.execute_async(...)` is supported. asyncio functions like - `asyncio.gather`, `asyncio.wait`, `asyncio.create_task`, and `asyncio.sleep` - are NOT supported because there is no asyncio event loop. + Note: Only `await ctx.durable_execute_async(...)` is supported. asyncio + functions like `asyncio.gather`, `asyncio.wait`, `asyncio.create_task`, + and `asyncio.sleep` are NOT supported because there is no asyncio event loop. """ def __init__(self, executor: Any, func: Callable, args: tuple, kwargs: dict) -> None: @@ -187,24 +187,74 @@ class RunnerContext(ABC): """ @abstractmethod - def execute_async( + def durable_execute( + self, + func: Callable[[Any], Any], + *args: Any, + **kwargs: Any, + ) -> Any: + """Synchronously execute the provided function with durable execution support. + Access to memory is prohibited within the function. + + The result of the function will be stored and returned when the same + durable_execute call is made again during job recovery. The arguments and the + result must be serializable. + + The function is executed synchronously in the current thread, blocking + the operator until completion. + + The action that calls this API should be deterministic, meaning that it + will always make the durable_execute call with the same arguments and in the + same order during job recovery. Otherwise, the behavior is undefined. + + Usage:: + + def my_action(event, ctx): + result = ctx.durable_execute(slow_function, arg1, arg2) + ctx.send_event(OutputEvent(output=result)) + + Parameters + ---------- + func : Callable + The function to be executed. + *args : Any + Positional arguments to pass to the function. + **kwargs : Any + Keyword arguments to pass to the function. + + Returns: + ------- + Any + The result of the function. + """ + + @abstractmethod + def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, **kwargs: Any, ) -> "AsyncExecutionResult": - """Asynchronously execute the provided function. Access to memory - is prohibited within the function. + """Asynchronously execute the provided function with durable execution support. + Access to memory is prohibited within the function. + + The result of the function will be stored and returned when the same + durable_execute_async call is made again during job recovery. The arguments + and the result must be serializable. + + The action that calls this API should be deterministic, meaning that it + will always make the durable_execute_async call with the same arguments and in + the same order during job recovery. Otherwise, the behavior is undefined. Usage:: async def my_action(event, ctx): - result = await ctx.execute_async(slow_function, arg1, arg2) + result = await ctx.durable_execute_async(slow_function, arg1, arg2) ctx.send_event(OutputEvent(output=result)) - Note: Only `await ctx.execute_async(...)` is supported. asyncio functions - like `asyncio.gather`, `asyncio.wait`, `asyncio.create_task`, and - `asyncio.sleep` are NOT supported. + Note: Only `await ctx.durable_execute_async(...)` is supported. + asyncio functions like `asyncio.gather`, `asyncio.wait`, + `asyncio.create_task`, and `asyncio.sleep` are NOT supported. Parameters ---------- diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test.py new file mode 100644 index 00000000..8118f15b --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test.py @@ -0,0 +1,205 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""E2E tests for durable_execute() method in Flink execution environment.""" + +import os +import sysconfig +from pathlib import Path + +from pyflink.common import Configuration, Encoder, WatermarkStrategy +from pyflink.common.typeinfo import Types +from pyflink.datastream import ( + RuntimeExecutionMode, + StreamExecutionEnvironment, +) +from pyflink.datastream.connectors.file_system import ( + FileSource, + StreamFormat, + StreamingFileSink, +) + +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.e2e_tests.e2e_tests_integration.execute_test_agent import ( + ExecuteMultipleTestAgent, + ExecuteTestAgent, + ExecuteTestData, + ExecuteTestKeySelector, + ExecuteWithAsyncTestAgent, +) +from flink_agents.e2e_tests.test_utils import check_result + +current_dir = Path(__file__).parent + +os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] + + +def test_durable_execute_basic_flink(tmp_path: Path) -> None: + """Test basic synchronous durable_execute() functionality in Flink environment.""" + config = Configuration() + config.set_string("state.backend.type", "rocksdb") + config.set_string("checkpointing.interval", "1s") + config.set_string("restart-strategy.type", "disable") + env = StreamExecutionEnvironment.get_execution_environment(config) + env.set_runtime_mode(RuntimeExecutionMode.STREAMING) + env.set_parallelism(1) + + input_datastream = env.from_source( + source=FileSource.for_record_stream_format( + StreamFormat.text_line_format(), + f"file:///{current_dir}/../resources/execute_test_input", + ).build(), + watermark_strategy=WatermarkStrategy.no_watermarks(), + source_name="execute_test_source", + ) + + deserialize_datastream = input_datastream.map( + lambda x: ExecuteTestData.model_validate_json(x) + ) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + output_datastream = ( + agents_env.from_datastream( + input=deserialize_datastream, key_selector=ExecuteTestKeySelector() + ) + .apply(ExecuteTestAgent()) + .to_datastream() + ) + + result_dir = tmp_path / "results" + result_dir.mkdir(parents=True, exist_ok=True) + + output_datastream.map(lambda x: x.model_dump_json(), Types.STRING()).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + ) + + agents_env.execute() + + check_result( + result_dir=result_dir, + groud_truth_dir=Path( + f"{current_dir}/../resources/ground_truth/test_execute_basic.txt" + ), + ) + + +def test_durable_execute_multiple_calls_flink(tmp_path: Path) -> None: + """Test multiple durable_execute() calls in Flink environment.""" + config = Configuration() + config.set_string("state.backend.type", "rocksdb") + config.set_string("checkpointing.interval", "1s") + config.set_string("restart-strategy.type", "disable") + env = StreamExecutionEnvironment.get_execution_environment(config) + env.set_runtime_mode(RuntimeExecutionMode.STREAMING) + env.set_parallelism(1) + + input_datastream = env.from_source( + source=FileSource.for_record_stream_format( + StreamFormat.text_line_format(), + f"file:///{current_dir}/../resources/execute_test_input", + ).build(), + watermark_strategy=WatermarkStrategy.no_watermarks(), + source_name="execute_test_source", + ) + + deserialize_datastream = input_datastream.map( + lambda x: ExecuteTestData.model_validate_json(x) + ) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + output_datastream = ( + agents_env.from_datastream( + input=deserialize_datastream, key_selector=ExecuteTestKeySelector() + ) + .apply(ExecuteMultipleTestAgent()) + .to_datastream() + ) + + result_dir = tmp_path / "results" + result_dir.mkdir(parents=True, exist_ok=True) + + output_datastream.map(lambda x: x.model_dump_json(), Types.STRING()).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + ) + + agents_env.execute() + + check_result( + result_dir=result_dir, + groud_truth_dir=Path( + f"{current_dir}/../resources/ground_truth/test_execute_multiple.txt" + ), + ) + + +def test_durable_execute_with_async_flink(tmp_path: Path) -> None: + """Test durable_execute() and durable_execute_async() in Flink environment.""" + config = Configuration() + config.set_string("state.backend.type", "rocksdb") + config.set_string("checkpointing.interval", "1s") + config.set_string("restart-strategy.type", "disable") + env = StreamExecutionEnvironment.get_execution_environment(config) + env.set_runtime_mode(RuntimeExecutionMode.STREAMING) + env.set_parallelism(1) + + input_datastream = env.from_source( + source=FileSource.for_record_stream_format( + StreamFormat.text_line_format(), + f"file:///{current_dir}/../resources/execute_test_input", + ).build(), + watermark_strategy=WatermarkStrategy.no_watermarks(), + source_name="execute_test_source", + ) + + deserialize_datastream = input_datastream.map( + lambda x: ExecuteTestData.model_validate_json(x) + ) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + output_datastream = ( + agents_env.from_datastream( + input=deserialize_datastream, key_selector=ExecuteTestKeySelector() + ) + .apply(ExecuteWithAsyncTestAgent()) + .to_datastream() + ) + + result_dir = tmp_path / "results" + result_dir.mkdir(parents=True, exist_ok=True) + + output_datastream.map(lambda x: x.model_dump_json(), Types.STRING()).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + ) + + agents_env.execute() + + check_result( + result_dir=result_dir, + groud_truth_dir=Path( + f"{current_dir}/../resources/ground_truth/test_execute_with_async.txt" + ), + ) + diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test_agent.py b/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test_agent.py new file mode 100644 index 00000000..a0b39477 --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/execute_test_agent.py @@ -0,0 +1,118 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Agent definitions for testing durable_execute() in Flink execution environment.""" + +from pydantic import BaseModel +from pyflink.datastream import KeySelector + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.decorators import action +from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.runner_context import RunnerContext + + +class ExecuteTestData(BaseModel): + """Data model for testing durable execute method. + + Attributes: + ---------- + id : int + Unique identifier of the item + value : int + The input value for computation + """ + + id: int + value: int + + +class ExecuteTestOutput(BaseModel): + """Output data model for durable execute test. + + Attributes: + ---------- + id : int + Unique identifier of the item + result : int + The computed result + """ + + id: int + result: int + + +class ExecuteTestKeySelector(KeySelector): + """KeySelector for extracting key from ExecuteTestData.""" + + def get_key(self, value: ExecuteTestData) -> int: + """Extract key from ExecuteTestData.""" + return value.id + + +def compute_value(x: int, y: int) -> int: + """A function that performs computation.""" + return x + y + + +def multiply_value(x: int, y: int) -> int: + """A function that multiplies two values.""" + return x * y + + +class ExecuteTestAgent(Agent): + """Agent that uses synchronous durable_execute() method for testing.""" + + @action(InputEvent) + @staticmethod + def process(event: Event, ctx: RunnerContext) -> None: + """Process an event using durable_execute().""" + input_data: ExecuteTestData = event.input + # Use synchronous durable execute + result = ctx.durable_execute(compute_value, input_data.value, 10) + ctx.send_event(OutputEvent(output=ExecuteTestOutput(id=input_data.id, result=result))) + + +class ExecuteMultipleTestAgent(Agent): + """Agent that makes multiple durable_execute() calls.""" + + @action(InputEvent) + @staticmethod + def process(event: Event, ctx: RunnerContext) -> None: + """Process an event with multiple durable_execute() calls.""" + input_data: ExecuteTestData = event.input + result1 = ctx.durable_execute(compute_value, input_data.value, 5) + result2 = ctx.durable_execute(multiply_value, result1, 2) + ctx.send_event(OutputEvent(output=ExecuteTestOutput(id=input_data.id, result=result2))) + + +class ExecuteWithAsyncTestAgent(Agent): + """Agent that uses both durable_execute() and durable_execute_async().""" + + @action(InputEvent) + @staticmethod + async def process(event: Event, ctx: RunnerContext) -> None: + """Process an event using both durable_execute() and durable_execute_async().""" + input_data: ExecuteTestData = event.input + # Use synchronous durable execute + sync_result = ctx.durable_execute(compute_value, input_data.value, 5) + # Use async durable execute + async_result = await ctx.durable_execute_async(multiply_value, sync_result, 3) + ctx.send_event( + OutputEvent(output=ExecuteTestOutput(id=input_data.id, result=async_result)) + ) + diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py index 5e6249d1..989cc0ec 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py @@ -103,7 +103,7 @@ class DataStreamAgent(Agent): total = current_total + 1 stm.set("status.total_reviews", total) - log_success = await ctx.execute_async(log_to_stdout, input_data, total) + log_success = await ctx.durable_execute_async(log_to_stdout, input_data, total) content = copy.deepcopy(input_data) content.review += " first action, log success=" + str(log_success) + "," diff --git a/python/flink_agents/e2e_tests/long_term_memory_test.py b/python/flink_agents/e2e_tests/long_term_memory_test.py index 58ec7ccc..8c068e35 100644 --- a/python/flink_agents/e2e_tests/long_term_memory_test.py +++ b/python/flink_agents/e2e_tests/long_term_memory_test.py @@ -181,7 +181,7 @@ class LongTermMemoryAgent(Agent): capacity=5, compaction_strategy=SummarizationStrategy(model="ollama_qwen3"), ) - await ctx.execute_async(memory_set.add, items=input_data.review) + await ctx.durable_execute_async(memory_set.add, items=input_data.review) timestamp_after_add = datetime.now(timezone.utc).isoformat() stm = ctx.short_term_memory @@ -205,7 +205,7 @@ class LongTermMemoryAgent(Agent): record: Record = event.value record.timestamp_second_action = datetime.now(timezone.utc).isoformat() memory_set = ctx.long_term_memory.get_memory_set(name="test_ltm") - items = await ctx.execute_async(memory_set.get) + items = await ctx.durable_execute_async(memory_set.get) if ( (record.id == 1 and record.count == 3) or (record.id == 2 and record.count == 5) diff --git a/python/flink_agents/e2e_tests/resources/execute_test_input/input_data.txt b/python/flink_agents/e2e_tests/resources/execute_test_input/input_data.txt new file mode 100644 index 00000000..6148564b --- /dev/null +++ b/python/flink_agents/e2e_tests/resources/execute_test_input/input_data.txt @@ -0,0 +1,3 @@ +{"id": 1, "value": 5} +{"id": 2, "value": 15} +{"id": 3, "value": 10} \ No newline at end of file diff --git a/python/flink_agents/e2e_tests/resources/ground_truth/test_execute_basic.txt b/python/flink_agents/e2e_tests/resources/ground_truth/test_execute_basic.txt new file mode 100644 index 00000000..b99963fa --- /dev/null +++ b/python/flink_agents/e2e_tests/resources/ground_truth/test_execute_basic.txt @@ -0,0 +1,3 @@ +{"id":1,"result":15} +{"id":2,"result":25} +{"id":3,"result":20} \ No newline at end of file diff --git a/python/flink_agents/e2e_tests/resources/ground_truth/test_execute_multiple.txt b/python/flink_agents/e2e_tests/resources/ground_truth/test_execute_multiple.txt new file mode 100644 index 00000000..0f21d9d1 --- /dev/null +++ b/python/flink_agents/e2e_tests/resources/ground_truth/test_execute_multiple.txt @@ -0,0 +1,3 @@ +{"id":1,"result":20} +{"id":2,"result":40} +{"id":3,"result":30} \ No newline at end of file diff --git a/python/flink_agents/e2e_tests/resources/ground_truth/test_execute_with_async.txt b/python/flink_agents/e2e_tests/resources/ground_truth/test_execute_with_async.txt new file mode 100644 index 00000000..cd2df454 --- /dev/null +++ b/python/flink_agents/e2e_tests/resources/ground_truth/test_execute_with_async.txt @@ -0,0 +1,3 @@ +{"id":1,"result":30} +{"id":2,"result":60} +{"id":3,"result":45} \ No newline at end of file diff --git a/python/flink_agents/plan/function.py b/python/flink_agents/plan/function.py index 82b6f47f..bb5fa5a1 100644 --- a/python/flink_agents/plan/function.py +++ b/python/flink_agents/plan/function.py @@ -332,7 +332,7 @@ def get_python_function_cache_keys() -> List[Tuple[str, str]]: _ASYNCIO_ERROR_MESSAGE = ( "asyncio functions (gather/wait/create_task/sleep) are not supported " - "in Flink Agents. Only 'await ctx.execute_async(...)' is supported." + "in Flink Agents. Only 'await ctx.durable_execute_async(...)' is supported." ) diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index d9ac02a4..257d22d4 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -185,6 +185,22 @@ class FlinkRunnerContext(RunnerContext): """ return FlinkMetricGroup(self._j_runner_context.getActionMetricGroup()) + @override + def execute( + self, + func: Callable[[Any], Any], + *args: Any, + **kwargs: Any, + ) -> Any: + """Synchronously execute the provided function. Access to memory + is prohibited within the function. + + The function is executed synchronously in the current thread, blocking + the operator until completion. + """ + # TODO: Add durable execution support (persist result for recovery) + return func(*args, **kwargs) + @override def execute_async( self, @@ -195,6 +211,7 @@ class FlinkRunnerContext(RunnerContext): """Asynchronously execute the provided function. Access to memory is prohibited within the function. """ + # TODO: Add durable execution support (persist result for recovery) return AsyncExecutionResult(self.executor, func, args, kwargs) @property diff --git a/python/flink_agents/runtime/local_runner.py b/python/flink_agents/runtime/local_runner.py index 5ed6157d..64ef4cc6 100644 --- a/python/flink_agents/runtime/local_runner.py +++ b/python/flink_agents/runtime/local_runner.py @@ -179,7 +179,26 @@ class LocalRunnerContext(RunnerContext): err_msg = "Metric mechanism is not supported for local agent execution yet." raise NotImplementedError(err_msg) - def execute_async( + @override + def durable_execute( + self, + func: Callable[[Any], Any], + *args: Any, + **kwargs: Any, + ) -> Any: + """Synchronously execute the provided function. Access to memory + is prohibited within the function. + + Note: Local runner does not support durable execution, so recovery + is not available. + """ + logger.warning( + "Local runner does not support durable execution; recovery is not available." + ) + return func(*args, **kwargs) + + @override + def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, @@ -189,10 +208,10 @@ class LocalRunnerContext(RunnerContext): is prohibited within the function. Note: Local runner executes synchronously but returns an AsyncExecutionResult - for API consistency. + for API consistency. Durable execution is not supported. """ logger.warning( - "Local runner does not support asynchronous execution; falling back to synchronous execution." + "Local runner does not support durable execution; recovery is not available." ) # Execute synchronously and wrap the result in a completed Future future: Future = Future() diff --git a/python/flink_agents/runtime/tests/test_local_execution_environment.py b/python/flink_agents/runtime/tests/test_local_execution_environment.py index 22010160..a34492c6 100644 --- a/python/flink_agents/runtime/tests/test_local_execution_environment.py +++ b/python/flink_agents/runtime/tests/test_local_execution_environment.py @@ -44,7 +44,7 @@ class Agent1WithAsync(Agent): # noqa: D101 return value + 1 input = event.input - value = await ctx.execute_async(my_func, input) + value = await ctx.durable_execute_async(my_func, input) ctx.send_event(OutputEvent(output=value)) diff --git a/python/flink_agents/runtime/tests/test_runner_context_execute.py b/python/flink_agents/runtime/tests/test_runner_context_execute.py new file mode 100644 index 00000000..fd3fef34 --- /dev/null +++ b/python/flink_agents/runtime/tests/test_runner_context_execute.py @@ -0,0 +1,190 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Tests for RunnerContext durable_execute() method.""" + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.decorators import action +from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.api.runner_context import RunnerContext + + +def slow_computation(x: int, y: int) -> int: + """A sample function that simulates slow computation.""" + return x + y + + +def multiply(x: int, y: int) -> int: + """A sample function that multiplies two numbers.""" + return x * y + + +def raise_exception(msg: str) -> None: + """A sample function that raises an exception.""" + raise ValueError(msg) + + +class AgentWithDurableExecute(Agent): + """Agent that uses synchronous durable_execute() method.""" + + @action(InputEvent) + @staticmethod + def process(event: Event, ctx: RunnerContext) -> None: + """Process an event using durable_execute().""" + input_val = event.input + # Use synchronous durable execute + result = ctx.durable_execute(slow_computation, input_val, 10) + ctx.send_event(OutputEvent(output=result)) + + +class AgentWithMultipleDurableExecute(Agent): + """Agent that makes multiple durable_execute() calls.""" + + @action(InputEvent) + @staticmethod + def process(event: Event, ctx: RunnerContext) -> None: + """Process an event with multiple durable_execute() calls.""" + input_val = event.input + result1 = ctx.durable_execute(slow_computation, input_val, 5) + result2 = ctx.durable_execute(multiply, result1, 2) + ctx.send_event(OutputEvent(output=result2)) + + +class AgentWithDurableExecuteAndAsync(Agent): + """Agent that uses both durable_execute() and durable_execute_async().""" + + @action(InputEvent) + @staticmethod + async def process(event: Event, ctx: RunnerContext) -> None: + """Process an event using both durable_execute() and durable_execute_async().""" + input_val = event.input + # Use synchronous durable execute + sync_result = ctx.durable_execute(slow_computation, input_val, 5) + # Use async durable execute + async_result = await ctx.durable_execute_async(multiply, sync_result, 3) + ctx.send_event(OutputEvent(output=async_result)) + + +class AgentWithDurableExecuteException(Agent): + """Agent that uses durable_execute() with a function that raises an exception.""" + + @action(InputEvent) + @staticmethod + def process(event: Event, ctx: RunnerContext) -> None: + """Process an event where durable_execute() raises an exception.""" + input_val = event.input + try: + ctx.durable_execute(raise_exception, f"Test error: {input_val}") + except ValueError as e: + ctx.send_event(OutputEvent(output=f"Caught: {e}")) + + +class AgentWithKwargs(Agent): + """Agent that uses durable_execute() with keyword arguments.""" + + @action(InputEvent) + @staticmethod + def process(event: Event, ctx: RunnerContext) -> None: + """Process an event using durable_execute() with kwargs.""" + input_val = event.input + result = ctx.durable_execute(slow_computation, x=input_val, y=20) + ctx.send_event(OutputEvent(output=result)) + + +def test_durable_execute_basic() -> None: + """Test basic synchronous durable_execute() functionality.""" + env = AgentsExecutionEnvironment.get_execution_environment() + + input_list = [] + agent = AgentWithDurableExecute() + + output_list = env.from_list(input_list).apply(agent).to_list() + + input_list.append({"key": "alice", "value": 5}) + input_list.append({"key": "bob", "value": 15}) + + env.execute() + + assert output_list == [{"alice": 15}, {"bob": 25}] + + +def test_durable_execute_multiple_calls() -> None: + """Test multiple durable_execute() calls in a single action.""" + env = AgentsExecutionEnvironment.get_execution_environment() + + input_list = [] + agent = AgentWithMultipleDurableExecute() + + output_list = env.from_list(input_list).apply(agent).to_list() + + input_list.append({"key": "alice", "value": 10}) + + env.execute() + + # (10 + 5) * 2 = 30 + assert output_list == [{"alice": 30}] + + +def test_durable_execute_with_async() -> None: + """Test durable_execute() and durable_execute_async() in the same action.""" + env = AgentsExecutionEnvironment.get_execution_environment() + + input_list = [] + agent = AgentWithDurableExecuteAndAsync() + + output_list = env.from_list(input_list).apply(agent).to_list() + + input_list.append({"key": "alice", "value": 7}) + + env.execute() + + # (7 + 5) * 3 = 36 + assert output_list == [{"alice": 36}] + + +def test_durable_execute_exception_handling() -> None: + """Test that exceptions from durable_execute() can be caught.""" + env = AgentsExecutionEnvironment.get_execution_environment() + + input_list = [] + agent = AgentWithDurableExecuteException() + + output_list = env.from_list(input_list).apply(agent).to_list() + + input_list.append({"key": "alice", "value": "test"}) + + env.execute() + + assert output_list == [{"alice": "Caught: Test error: test"}] + + +def test_durable_execute_with_kwargs() -> None: + """Test durable_execute() with keyword arguments.""" + env = AgentsExecutionEnvironment.get_execution_environment() + + input_list = [] + agent = AgentWithKwargs() + + output_list = env.from_list(input_list).apply(agent).to_list() + + input_list.append({"key": "alice", "value": 5}) + + env.execute() + + assert output_list == [{"alice": 25}] +
