This is an automated email from the ASF dual-hosted git repository. skrawcz pushed a commit to branch stefan/interceptor in repository https://gitbox.apache.org/repos/asf/burr.git
commit a0f7d2e18165154dd102305019206c378b51b8fe Author: Stefan Krawczyk <[email protected]> AuthorDate: Wed Nov 19 22:04:50 2025 -0800 WIP: run a burr action remotely --- burr/core/application.py | 148 +++++++- burr/lifecycle/__init__.py | 24 ++ burr/lifecycle/base.py | 352 ++++++++++++++++++ burr/lifecycle/internal.py | 76 +++- examples/remote-execution-ray/README.md | 209 +++++++++++ examples/remote-execution-ray/__init__.py | 6 + examples/remote-execution-ray/application.py | 268 ++++++++++++++ examples/remote-execution-ray/notebook.ipynb | 483 +++++++++++++++++++++++++ examples/remote-execution-ray/requirements.txt | 2 + tests/core/test_action_interceptor.py | 343 ++++++++++++++++++ 10 files changed, 1900 insertions(+), 11 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index 55f98acf..b552df31 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -158,7 +158,13 @@ def _remap_dunder_parameters( return inputs -def _run_function(function: Function, state: State, inputs: Dict[str, Any], name: str) -> dict: +def _run_function( + function: Function, + state: State, + inputs: Dict[str, Any], + name: str, + adapter_set: Optional["LifecycleAdapterSet"] = None, +) -> dict: """Runs a function, returning the result of running the function. Note this restricts the keys in the state to only those that the function reads. @@ -166,6 +172,8 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name :param function: Function to run :param state: State at time of execution :param inputs: Inputs to the function + :param name: Name of the action (for error messages) + :param adapter_set: Optional lifecycle adapter set for checking interceptors :return: """ if function.is_async(): @@ -174,6 +182,21 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name "in non-async context. Use astep()/aiterate()/arun() " "instead...)" ) + + # Check for execution interceptors + if adapter_set: + interceptor = adapter_set.get_first_matching_hook( + "intercept_action_execution", lambda hook: hook.should_intercept(action=function) + ) + if interceptor: + worker_adapter_set = adapter_set.get_worker_adapter_set() + result = interceptor.intercept_run( + action=function, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + _validate_result(result, name) + return result + + # Normal execution path state_to_use = state.subset(*function.reads) function.validate_inputs(inputs) if "__context" in inputs or "__tracer" in inputs: @@ -185,10 +208,30 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name async def _arun_function( - function: Function, state: State, inputs: Dict[str, Any], name: str + function: Function, + state: State, + inputs: Dict[str, Any], + name: str, + adapter_set: Optional["LifecycleAdapterSet"] = None, ) -> dict: """Runs a function, returning the result of running the function. Async version of the above.""" + + # Check for execution interceptors + if adapter_set: + interceptor = adapter_set.get_first_matching_hook( + "intercept_action_execution", + lambda hook: hook.should_intercept(action=function) and hasattr(hook, "intercept_run"), + ) + if interceptor and inspect.iscoroutinefunction(interceptor.intercept_run): + worker_adapter_set = adapter_set.get_worker_adapter_set() + result = await interceptor.intercept_run( + action=function, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + _validate_result(result, name) + return result + + # Normal execution path state_to_use = state.subset(*function.reads) function.validate_inputs(inputs) result = await function.run(state_to_use, **inputs) @@ -299,7 +342,10 @@ def _format_BASE_ERROR_MESSAGE(action: Action, input_state: State, inputs: dict) def _run_single_step_action( - action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] + action: SingleStepAction, + state: State, + inputs: Optional[Dict[str, Any]], + adapter_set: Optional["LifecycleAdapterSet"] = None, ) -> Tuple[Dict[str, Any], State]: """Runs a single step action. This API is internal-facing and a bit in flux, but it corresponds to the SingleStepAction class. @@ -307,9 +353,33 @@ def _run_single_step_action( :param action: Action to run :param state: State to run with :param inputs: Inputs to pass directly to the action + :param adapter_set: Optional lifecycle adapter set for checking interceptors :return: The result of running the action, and the new state """ - # TODO -- guard all reads/writes with a subset of the state + # Check for execution interceptors + if adapter_set: + interceptor = adapter_set.get_first_matching_hook( + "intercept_action_execution", lambda hook: hook.should_intercept(action=action) + ) + if interceptor: + worker_adapter_set = adapter_set.get_worker_adapter_set() + result = interceptor.intercept_run( + action=action, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + # Check if interceptor returned state via special key (for single-step actions) + if "__INTERCEPTOR_NEW_STATE__" in result: + new_state = result.pop("__INTERCEPTOR_NEW_STATE__") + else: + # For multi-step actions or if state wasn't provided + # we need to compute it + new_state = action.update(result, state) + + _validate_result(result, action.name, action.schema) + out = result, _state_update(state, new_state) + _validate_reducer_writes(action, new_state, action.name) + return out + + # Normal execution path action.validate_inputs(inputs) result, new_state = _adjust_single_step_output( action.run_and_update(state, **inputs), action.name, action.schema @@ -334,7 +404,18 @@ def _run_single_step_streaming_action( action.validate_inputs(inputs) stream_initialize_time = system.now() first_stream_start_time = None - generator = action.stream_run_and_update(state, **inputs) + + # Check for streaming action interceptors + interceptor = lifecycle_adapters.get_first_matching_hook( + "intercept_streaming_action", lambda hook: hook.should_intercept(action=action) + ) + if interceptor: + worker_adapter_set = lifecycle_adapters.get_worker_adapter_set() + generator = interceptor.intercept_stream_run_and_update( + action=action, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + else: + generator = action.stream_run_and_update(state, **inputs) result = None state_update = None count = 0 @@ -387,7 +468,20 @@ async def _arun_single_step_streaming_action( action.validate_inputs(inputs) stream_initialize_time = system.now() first_stream_start_time = None - generator = action.stream_run_and_update(state, **inputs) + + # Check for streaming action interceptors + interceptor = lifecycle_adapters.get_first_matching_hook( + "intercept_streaming_action", + lambda hook: hook.should_intercept(action=action) + and hasattr(hook, "intercept_stream_run_and_update"), + ) + if interceptor and inspect.isasyncgenfunction(interceptor.intercept_stream_run_and_update): + worker_adapter_set = lifecycle_adapters.get_worker_adapter_set() + generator = interceptor.intercept_stream_run_and_update( + action=action, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + else: + generator = action.stream_run_and_update(state, **inputs) result = None state_update = None count = 0 @@ -523,9 +617,35 @@ async def _arun_multi_step_streaming_action( async def _arun_single_step_action( - action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] + action: SingleStepAction, + state: State, + inputs: Optional[Dict[str, Any]], + adapter_set: Optional["LifecycleAdapterSet"] = None, ) -> Tuple[dict, State]: """Runs a single step action in async. See the synchronous version for more details.""" + # Check for execution interceptors + if adapter_set: + interceptor = adapter_set.get_first_matching_hook( + "intercept_action_execution", + lambda hook: hook.should_intercept(action=action) and hasattr(hook, "intercept_run"), + ) + if interceptor and inspect.iscoroutinefunction(interceptor.intercept_run): + worker_adapter_set = adapter_set.get_worker_adapter_set() + result = await interceptor.intercept_run( + action=action, state=state, inputs=inputs, worker_adapter_set=worker_adapter_set + ) + # Check if interceptor returned state via special key (for single-step actions) + if "__INTERCEPTOR_NEW_STATE__" in result: + new_state = result.pop("__INTERCEPTOR_NEW_STATE__") + else: + # For multi-step actions or if state wasn't provided + new_state = action.update(result, state) + + _validate_result(result, action.name, action.schema) + _validate_reducer_writes(action, new_state, action.name) + return result, _state_update(state, new_state) + + # Normal execution path state_to_use = state action.validate_inputs(inputs) result, new_state = _adjust_single_step_output( @@ -915,11 +1035,15 @@ class Application(Generic[ApplicationStateType]): try: if next_action.single_step: result, new_state = _run_single_step_action( - next_action, self._state, action_inputs + next_action, self._state, action_inputs, adapter_set=self._adapter_set ) else: result = _run_function( - next_action, self._state, action_inputs, name=next_action.name + next_action, + self._state, + action_inputs, + name=next_action.name, + adapter_set=self._adapter_set, ) new_state = _run_reducer(next_action, self._state, result, next_action.name) @@ -1065,7 +1189,10 @@ class Application(Generic[ApplicationStateType]): action_inputs = self._process_inputs(inputs, next_action) if next_action.single_step: result, new_state = await _arun_single_step_action( - next_action, self._state, inputs=action_inputs + next_action, + self._state, + inputs=action_inputs, + adapter_set=self._adapter_set, ) else: result = await _arun_function( @@ -1073,6 +1200,7 @@ class Application(Generic[ApplicationStateType]): self._state, inputs=action_inputs, name=next_action.name, + adapter_set=self._adapter_set, ) new_state = _run_reducer(next_action, self._state, result, next_action.name) new_state = self._update_internal_state_value(new_state, next_action) diff --git a/burr/lifecycle/__init__.py b/burr/lifecycle/__init__.py index 4ae24073..991ee984 100644 --- a/burr/lifecycle/__init__.py +++ b/burr/lifecycle/__init__.py @@ -16,18 +16,30 @@ # under the License. from burr.lifecycle.base import ( + ActionExecutionInterceptorHook, + ActionExecutionInterceptorHookAsync, LifecycleAdapter, PostApplicationCreateHook, PostApplicationExecuteCallHook, PostApplicationExecuteCallHookAsync, PostEndSpanHook, + PostEndStreamHookWorker, + PostEndStreamHookWorkerAsync, PostRunStepHook, PostRunStepHookAsync, + PostRunStepHookWorker, + PostRunStepHookWorkerAsync, PreApplicationExecuteCallHook, PreApplicationExecuteCallHookAsync, PreRunStepHook, PreRunStepHookAsync, + PreRunStepHookWorker, + PreRunStepHookWorkerAsync, PreStartSpanHook, + PreStartStreamHookWorker, + PreStartStreamHookWorkerAsync, + StreamingActionInterceptorHook, + StreamingActionInterceptorHookAsync, ) from burr.lifecycle.default import StateAndResultsFullLogger @@ -45,4 +57,16 @@ __all__ = [ "PostApplicationCreateHook", "PostEndSpanHook", "PreStartSpanHook", + "PreRunStepHookWorker", + "PreRunStepHookWorkerAsync", + "PostRunStepHookWorker", + "PostRunStepHookWorkerAsync", + "PreStartStreamHookWorker", + "PreStartStreamHookWorkerAsync", + "PostEndStreamHookWorker", + "PostEndStreamHookWorkerAsync", + "ActionExecutionInterceptorHook", + "ActionExecutionInterceptorHookAsync", + "StreamingActionInterceptorHook", + "StreamingActionInterceptorHookAsync", ] diff --git a/burr/lifecycle/base.py b/burr/lifecycle/base.py index 66d8bd7e..2117a7f5 100644 --- a/burr/lifecycle/base.py +++ b/burr/lifecycle/base.py @@ -492,6 +492,346 @@ class PostEndStreamHookAsync(abc.ABC): pass [email protected]_hook("pre_run_step_worker") +class PreRunStepHookWorker(abc.ABC): + """Hook that runs on the worker (e.g., Ray/Temporal) before action execution. + This hook is designed to be called by execution interceptors on remote workers, + as opposed to PreRunStepHook which always runs on the main orchestrator process.""" + + @abc.abstractmethod + def pre_run_step_worker( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + """Run before a step is executed on the worker. + + :param action: Action to be executed + :param state: State prior to step execution + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments + """ + pass + + [email protected]_hook("pre_run_step_worker") +class PreRunStepHookWorkerAsync(abc.ABC): + """Async hook that runs on the worker before action execution.""" + + @abc.abstractmethod + async def pre_run_step_worker( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + """Async run before a step is executed on the worker. + + :param action: Action to be executed + :param state: State prior to step execution + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments + """ + pass + + [email protected]_hook("post_run_step_worker") +class PostRunStepHookWorker(abc.ABC): + """Hook that runs on the worker after action execution. + This hook is designed to be called by execution interceptors on remote workers, + as opposed to PostRunStepHook which always runs on the main orchestrator process.""" + + @abc.abstractmethod + def post_run_step_worker( + self, + *, + action: "Action", + state: "State", + result: Optional[Dict[str, Any]], + exception: Exception, + **future_kwargs: Any, + ): + """Run after a step is executed on the worker. + + :param action: Action that was executed + :param state: State after step execution + :param result: Result of the action + :param exception: Exception that was raised + :param future_kwargs: Future keyword arguments + """ + pass + + [email protected]_hook("post_run_step_worker") +class PostRunStepHookWorkerAsync(abc.ABC): + """Async hook that runs on the worker after action execution.""" + + @abc.abstractmethod + async def post_run_step_worker( + self, + *, + action: "Action", + state: "State", + result: Optional[dict], + exception: Exception, + **future_kwargs: Any, + ): + """Async run after a step is executed on the worker. + + :param action: Action that was executed + :param state: State after step execution + :param result: Result of the action + :param exception: Exception that was raised + :param future_kwargs: Future keyword arguments + """ + pass + + [email protected]_hook("pre_start_stream_worker") +class PreStartStreamHookWorker(abc.ABC): + """Hook that runs on the worker after a stream is started.""" + + @abc.abstractmethod + def pre_start_stream_worker( + self, + *, + action: str, + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + pass + + [email protected]_hook("pre_start_stream_worker") +class PreStartStreamHookWorkerAsync(abc.ABC): + """Async hook that runs on the worker after a stream is started.""" + + @abc.abstractmethod + async def pre_start_stream_worker( + self, + *, + action: str, + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + pass + + [email protected]_hook("post_end_stream_worker") +class PostEndStreamHookWorker(abc.ABC): + """Hook that runs on the worker after a stream is ended.""" + + @abc.abstractmethod + def post_end_stream_worker( + self, + *, + action: str, + result: Optional[Dict[str, Any]], + exception: Exception, + **future_kwargs: Any, + ): + pass + + [email protected]_hook("post_end_stream_worker") +class PostEndStreamHookWorkerAsync(abc.ABC): + """Async hook that runs on the worker after a stream is ended.""" + + @abc.abstractmethod + async def post_end_stream_worker( + self, + *, + action: str, + result: Optional[Dict[str, Any]], + exception: Exception, + **future_kwargs: Any, + ): + pass + + +class ActionExecutionInterceptorHook(abc.ABC): + """Hook that can wrap/replace action execution (e.g., for Ray/Temporal). + This hook allows you to intercept the execution of an action and run it + on a different execution backend while maintaining the same interface. + + The interceptor receives a worker_adapter_set containing only worker hooks + (PreRunStepHookWorker, PostRunStepHookWorker, etc.) that can be called + on the remote execution environment. + + Note: Interceptors don't use the @lifecycle.base_hook decorator because they + have multiple methods and special handling logic.""" + + @abc.abstractmethod + def should_intercept( + self, + *, + action: "Action", + **future_kwargs: Any, + ) -> bool: + """Determine if this action should be intercepted. + + :param action: Action to potentially intercept + :param future_kwargs: Future keyword arguments + :return: True if this hook should intercept execution + """ + pass + + @abc.abstractmethod + def intercept_run( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ) -> dict: + """Replace the action.run() call with custom execution. + + Note: The state passed here is the FULL state, not subsetted. + You are responsible for subsetting it to action.reads if needed. + + :param action: Action to execute + :param state: Current state (FULL state, not subsetted) + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments (includes worker_adapter_set) + :return: Result dictionary from running the action + """ + pass + + +class ActionExecutionInterceptorHookAsync(abc.ABC): + """Async version of ActionExecutionInterceptorHook for intercepting async actions.""" + + @abc.abstractmethod + async def should_intercept( + self, + *, + action: "Action", + **future_kwargs: Any, + ) -> bool: + """Determine if this action should be intercepted. + + :param action: Action to potentially intercept + :param future_kwargs: Future keyword arguments + :return: True if this hook should intercept execution + """ + pass + + @abc.abstractmethod + async def intercept_run( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ) -> dict: + """Replace the action.run() call with custom execution. + + Note: The state passed here is the FULL state, not subsetted. + You are responsible for subsetting it to action.reads if needed. + + :param action: Action to execute + :param state: Current state (FULL state, not subsetted) + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments (includes worker_adapter_set) + :return: Result dictionary from running the action + """ + pass + + +class StreamingActionInterceptorHook(abc.ABC): + """Hook to intercept streaming action execution (e.g., for Ray/Temporal). + This hook allows you to wrap streaming actions to execute on different backends. + + The interceptor receives a worker_adapter_set containing only worker hooks + that can be called on the remote execution environment. + + Note: Interceptors don't use the @lifecycle.base_hook decorator because they + have multiple methods and special handling logic.""" + + @abc.abstractmethod + def should_intercept( + self, + *, + action: "Action", + **future_kwargs: Any, + ) -> bool: + """Determine if this streaming action should be intercepted. + + :param action: Streaming action to potentially intercept + :param future_kwargs: Future keyword arguments + :return: True if this hook should intercept execution + """ + pass + + @abc.abstractmethod + def intercept_stream_run_and_update( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + """Replace stream_run_and_update with custom execution. + Must be a generator that yields (result_dict, optional_state) tuples. + + :param action: Streaming action to execute + :param state: Current state + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments (includes worker_adapter_set) + :return: Generator yielding (dict, Optional[State]) tuples + """ + pass + + +class StreamingActionInterceptorHookAsync(abc.ABC): + """Async version for intercepting async streaming actions.""" + + @abc.abstractmethod + async def should_intercept( + self, + *, + action: "Action", + **future_kwargs: Any, + ) -> bool: + """Determine if this streaming action should be intercepted. + + :param action: Streaming action to potentially intercept + :param future_kwargs: Future keyword arguments + :return: True if this hook should intercept execution + """ + pass + + @abc.abstractmethod + def intercept_stream_run_and_update( + self, + *, + action: "Action", + state: "State", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + """Replace stream_run_and_update with custom execution. + Must be an async generator that yields (result_dict, optional_state) tuples. + + :param action: Streaming action to execute + :param state: Current state + :param inputs: Inputs to the action + :param future_kwargs: Future keyword arguments (includes worker_adapter_set) + :return: Async generator yielding (dict, Optional[State]) tuples + """ + pass + + # strictly for typing -- this conflicts a bit with the lifecycle decorator above, but its fine for now # This makes IDE completion/type-hinting easier LifecycleAdapter = Union[ @@ -515,4 +855,16 @@ LifecycleAdapter = Union[ PreStartStreamHookAsync, PostStreamItemHookAsync, PostEndStreamHookAsync, + PreRunStepHookWorker, + PreRunStepHookWorkerAsync, + PostRunStepHookWorker, + PostRunStepHookWorkerAsync, + PreStartStreamHookWorker, + PreStartStreamHookWorkerAsync, + PostEndStreamHookWorker, + PostEndStreamHookWorkerAsync, + ActionExecutionInterceptorHook, + ActionExecutionInterceptorHookAsync, + StreamingActionInterceptorHook, + StreamingActionInterceptorHookAsync, ] diff --git a/burr/lifecycle/internal.py b/burr/lifecycle/internal.py index 1043bd0a..1c7eb2df 100644 --- a/burr/lifecycle/internal.py +++ b/burr/lifecycle/internal.py @@ -119,7 +119,15 @@ class LifecycleAdapterSet: :param adapters: Adapters to group together """ self._adapters = list(adapters) - self.sync_hooks, self.async_hooks = self._get_lifecycle_hooks() + self._sync_hooks, self._async_hooks = self._get_lifecycle_hooks() + + @property + def sync_hooks(self): + return self._sync_hooks + + @property + def async_hooks(self): + return self._async_hooks def with_new_adapters(self, *adapters: "LifecycleAdapter") -> "LifecycleAdapterSet": """Adds new adapters to the set. @@ -212,3 +220,69 @@ class LifecycleAdapterSet: :return: A list of adapters """ return self._adapters + + def get_first_matching_hook( + self, hook_name: str, predicate: Callable[["LifecycleAdapter"], bool] + ): + """Get first hook of given type that matches predicate. + + For interceptor hooks (intercept_action_execution, intercept_streaming_action), + this searches all adapters for instances of the interceptor classes, since + interceptors don't use the standard hook registration system. + + :param hook_name: Name of the hook to search for (or interceptor type) + :param predicate: Function that takes a hook and returns True if it matches + :return: The first matching hook, or None if no match found + """ + # Special handling for interceptors - they don't use the registration system + if hook_name in ("intercept_action_execution", "intercept_streaming_action"): + # Import here to avoid circular dependency + from burr.lifecycle.base import ( + ActionExecutionInterceptorHook, + ActionExecutionInterceptorHookAsync, + StreamingActionInterceptorHook, + StreamingActionInterceptorHookAsync, + ) + + interceptor_classes = ( + ActionExecutionInterceptorHook, + ActionExecutionInterceptorHookAsync, + StreamingActionInterceptorHook, + StreamingActionInterceptorHookAsync, + ) + + for adapter in self.adapters: + if isinstance(adapter, interceptor_classes): + if predicate(adapter): + return adapter + return None + + # Standard hook lookup for registered hooks + hooks = self.sync_hooks.get(hook_name, []) + self.async_hooks.get(hook_name, []) + for hook in hooks: + if predicate(hook): + return hook + return None + + def get_worker_adapter_set(self) -> "LifecycleAdapterSet": + """Create a new LifecycleAdapterSet containing only worker hooks. + Worker hooks are those with names ending in '_worker' and are designed + to be called on remote execution environments (Ray/Temporal workers). + + :return: A new LifecycleAdapterSet with only worker hooks + """ + worker_hooks = [] + for adapter in self.adapters: + # Check if this adapter is a worker hook by looking at its registered hooks + is_worker = False + for cls in inspect.getmro(adapter.__class__): + sync_hook = getattr(cls, SYNC_HOOK, None) + async_hook = getattr(cls, ASYNC_HOOK, None) + if (sync_hook and sync_hook.endswith("_worker")) or ( + async_hook and async_hook.endswith("_worker") + ): + is_worker = True + break + if is_worker: + worker_hooks.append(adapter) + return LifecycleAdapterSet(*worker_hooks) diff --git a/examples/remote-execution-ray/README.md b/examples/remote-execution-ray/README.md new file mode 100644 index 00000000..c855aee8 --- /dev/null +++ b/examples/remote-execution-ray/README.md @@ -0,0 +1,209 @@ +# Remote Execution with Ray + +This example demonstrates how to use Burr's **Action Execution Interceptors** to run specific actions on Ray workers while keeping orchestration on the main process. + +## Overview + +Burr's lifecycle hook system includes **interceptors** that can wrap action execution and redirect it to different execution backends like Ray, Temporal, or custom distributed systems. + +This example shows: +- ✅ Selective interception (only actions tagged with `ray` run remotely) +- ✅ Orchestrator hooks (run on main process) +- ✅ Worker hooks (run on Ray workers) +- ✅ Seamless mixing of local and remote execution +- ✅ State management across distributed execution + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Main Process (Orchestrator) │ +│ │ +│ ┌──────────────────────────────────────────────┐ │ +│ │ Burr Application │ │ +│ │ │ │ +│ │ PreRunStepHook (Orchestrator) ────┐ │ │ +│ │ ↓ │ │ +│ │ RayActionInterceptor │ │ │ +│ │ - should_intercept() │ │ │ +│ │ - intercept_run() ───────────────┼─────────┼─────┐ │ +│ │ │ │ │ │ +│ │ PostRunStepHook (Orchestrator) ←──┘ │ │ │ +│ └──────────────────────────────────────────────┘ │ │ +└────────────────────────────────────────────────────────┼─────┘ + │ + Ray Remote Call │ + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ Ray Worker │ +│ │ +│ PreRunStepHookWorker ────┐ │ +│ ↓ │ +│ Action.run_and_update() (actual execution) │ +│ │ │ +│ PostRunStepHookWorker ←───┘ │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Key Concepts + +### 1. Two-Tier Hook System + +**Orchestrator Hooks** (run on main process): +- `PreRunStepHook` - runs before any action (local or remote) +- `PostRunStepHook` - runs after any action completes +- These hooks see all actions but don't know about execution details + +**Worker Hooks** (run on Ray workers): +- `PreRunStepHookWorker` - runs on the worker before execution +- `PostRunStepHookWorker` - runs on the worker after execution +- Only called for intercepted actions +- Must be serializable (picklable) + +### 2. Action Execution Interceptor + +The interceptor has two methods: + +```python +def should_intercept(self, *, action: Action, **kwargs) -> bool: + """Decide if this action should be intercepted""" + return "ray" in action.tags + +def intercept_run(self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs) -> dict: + """Execute the action on Ray and return the result""" + # Get worker hooks to pass to Ray worker + worker_adapter_set = kwargs.get("worker_adapter_set") + + # Execute on Ray with worker hooks + @ray.remote + def execute_on_ray(): + # Call worker hooks + # Execute action + # Return result + + return ray.get(execute_on_ray.remote()) +``` + +### 3. Selective Execution + +Actions are tagged to control where they run: + +```python +@action(reads=["count"], writes=["count"], tags=["local"]) +def local_task(state: State): + # Runs on main process + ... + +@action(reads=["count"], writes=["count"], tags=["ray"]) +def remote_task(state: State): + # Runs on Ray worker + ... +``` + +## Installation + +```bash +pip install -r requirements.txt +``` + +## Running the Example + +### Python Script + +```bash +python application.py +``` + +Expected output: +``` +================================================================================ +Burr + Ray Remote Execution Example +================================================================================ + +[Main Process] Initializing Ray... + +================================================================================ +Step 1: Local execution (increment_local) +================================================================================ +[Main Process] About to execute action: increment_local +[Main Process] Finished executing action: increment_local +Result: count=1, operation=increment_local + +================================================================================ +Step 2: Ray execution (heavy_computation) +================================================================================ +[Main Process] About to execute action: heavy_computation +[Main Process] Dispatching heavy_computation to Ray... +[Ray Worker] Starting action: heavy_computation on Ray worker +[Ray Worker] Running heavy computation with multiplier=3 +[Ray Worker] Completed action: heavy_computation on Ray worker +[Main Process] Received result from Ray for heavy_computation +[Main Process] Finished executing action: heavy_computation +Result: count=3, operation=heavy_computation(x3) + +... +``` + +### Jupyter Notebook + +```bash +jupyter notebook notebook.ipynb +``` + +## Use Cases + +This pattern is useful for: + +1. **Compute-Intensive Operations**: Offload heavy computations to Ray clusters +2. **GPU Workloads**: Run ML inference/training on GPU workers +3. **Scalability**: Distribute work across multiple machines +4. **Resource Isolation**: Keep heavy operations away from orchestrator +5. **Hybrid Workflows**: Mix local control flow with distributed execution + +## Extending to Other Backends + +The same pattern works for other execution backends: + +### Temporal + +```python +class TemporalActionInterceptor(ActionExecutionInterceptorHook): + def should_intercept(self, *, action, **kwargs): + return "temporal" in action.tags + + def intercept_run(self, *, action, state, inputs, **kwargs): + # Execute as Temporal activity + return await workflow.execute_activity( + action.run_and_update, + state, + **inputs + ) +``` + +### Custom Distributed System + +```python +class CustomBackendInterceptor(ActionExecutionInterceptorHook): + def should_intercept(self, *, action, **kwargs): + return "distributed" in action.tags + + def intercept_run(self, *, action, state, inputs, **kwargs): + # Submit to your custom backend + job_id = backend.submit_job(action, state, inputs) + result = backend.wait_for_completion(job_id) + return result +``` + +## Important Notes + +1. **State Serialization**: State must be serializable to pass to workers +2. **Worker Hooks**: Must be picklable (avoid closures with local variables) +3. **Error Handling**: Exceptions on workers propagate back to orchestrator +4. **Performance**: Ray overhead ~100ms per task; use for tasks >1s + +## Related Documentation + +- [Burr Lifecycle Hooks](https://burr.dagworks.io/concepts/hooks/) +- [Ray Core API](https://docs.ray.io/en/latest/ray-core/walkthrough.html) +- [Temporal Workflows](https://docs.temporal.io/) diff --git a/examples/remote-execution-ray/__init__.py b/examples/remote-execution-ray/__init__.py new file mode 100644 index 00000000..cd242bc2 --- /dev/null +++ b/examples/remote-execution-ray/__init__.py @@ -0,0 +1,6 @@ +""" +Remote Execution with Ray + +This example demonstrates how to use Burr's Action Execution Interceptors +to run actions on Ray workers while maintaining orchestration on the main process. +""" diff --git a/examples/remote-execution-ray/application.py b/examples/remote-execution-ray/application.py new file mode 100644 index 00000000..ad44291a --- /dev/null +++ b/examples/remote-execution-ray/application.py @@ -0,0 +1,268 @@ +""" +Example demonstrating how to use Burr's action execution interceptors to run +actions remotely on Ray workers. + +This example shows: +1. How to create a RayActionInterceptor to execute actions on Ray +2. How worker hooks run on the remote Ray worker +3. How to mix local and remote execution based on action tags +""" + +import time +from typing import Any, Dict, Optional + +import ray + +from burr.core import Action, ApplicationBuilder, State, action +from burr.lifecycle import ( + ActionExecutionInterceptorHook, + PostRunStepHook, + PostRunStepHookWorker, + PreRunStepHook, + PreRunStepHookWorker, +) + + +# Define some example actions +@action(reads=["count"], writes=["count", "last_operation"], tags=["local"]) +def increment_local(state: State) -> tuple: + """Increment counter locally (not on Ray)""" + result = { + "count": state["count"] + 1, + "last_operation": "increment_local", + } + return result, state.update(**result) + + +@action(reads=["count"], writes=["count", "last_operation"], tags=["ray"]) +def heavy_computation(state: State, multiplier: int = 2) -> tuple: + """Simulate heavy computation that should run on Ray""" + print(f"[Ray Worker] Running heavy computation with multiplier={multiplier}") + time.sleep(0.5) # Simulate work + result = { + "count": state["count"] * multiplier, + "last_operation": f"heavy_computation(x{multiplier})", + } + return result, state.update(**result) + + +@action(reads=["count"], writes=["count", "last_operation"], tags=["ray"]) +def another_ray_task(state: State) -> tuple: + """Another task that runs on Ray""" + print("[Ray Worker] Running another Ray task") + time.sleep(0.3) # Simulate work + result = { + "count": state["count"] + 10, + "last_operation": "another_ray_task(+10)", + } + return result, state.update(**result) + + +# Orchestrator hooks (run on main process) +class OrchestratorPreHook(PreRunStepHook): + """Hook that runs on the main process before action execution""" + + def pre_run_step(self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs): + print(f"[Main Process] About to execute action: {action.name}") + + +class OrchestratorPostHook(PostRunStepHook): + """Hook that runs on the main process after action execution""" + + def post_run_step( + self, + *, + action: Action, + state: State, + result: Optional[Dict[str, Any]], + exception: Exception, + **kwargs, + ): + print(f"[Main Process] Finished executing action: {action.name}") + + +# Worker hooks (run on Ray workers) +class WorkerPreHook(PreRunStepHookWorker): + """Hook that runs on the Ray worker before action execution""" + + def pre_run_step_worker( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ): + print(f"[Ray Worker] Starting action: {action.name} on Ray worker") + + +class WorkerPostHook(PostRunStepHookWorker): + """Hook that runs on the Ray worker after action execution""" + + def post_run_step_worker( + self, + *, + action: Action, + state: State, + result: Optional[Dict[str, Any]], + exception: Exception, + **kwargs, + ): + print(f"[Ray Worker] Completed action: {action.name} on Ray worker") + + +# Ray Execution Interceptor +class RayActionInterceptor(ActionExecutionInterceptorHook): + """Interceptor that executes actions tagged with 'ray' on Ray workers""" + + def __init__(self): + self.ray_initialized = False + + def _ensure_ray_initialized(self): + """Initialize Ray if not already initialized""" + if not self.ray_initialized: + if not ray.is_initialized(): + print("[Main Process] Initializing Ray...") + ray.init(ignore_reinit_error=True) + self.ray_initialized = True + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + """Intercept actions tagged with 'ray'""" + return "ray" in action.tags + + def intercept_run( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ) -> dict: + """Execute the action on a Ray worker""" + self._ensure_ray_initialized() + + print(f"[Main Process] Dispatching {action.name} to Ray...") + + # Extract worker hooks + worker_adapter_set = kwargs.get("worker_adapter_set") + + # Create a Ray remote function that executes the action + @ray.remote + def execute_on_ray(): + """Execute action on Ray worker with worker hooks""" + # Call pre-worker hooks + if worker_adapter_set: + worker_adapter_set.call_all_lifecycle_hooks_sync( + "pre_run_step_worker", + action=action, + state=state, + inputs=inputs, + ) + + # Execute the action + if hasattr(action, "single_step") and action.single_step: + result, new_state = action.run_and_update(state, **inputs) + else: + state_to_use = state.subset(*action.reads) + result = action.run(state_to_use, **inputs) + new_state = None + + # Call post-worker hooks + if worker_adapter_set: + worker_adapter_set.call_all_lifecycle_hooks_sync( + "post_run_step_worker", + action=action, + state=state, + result=result, + exception=None, + ) + + return result, new_state + + # Execute remotely and wait for result + result_ref = execute_on_ray.remote() + result, new_state = ray.get(result_ref) + + print(f"[Main Process] Received result from Ray for {action.name}") + + # For single-step actions, include the new state + if new_state is not None: + result_with_state = result.copy() + result_with_state["__INTERCEPTOR_NEW_STATE__"] = new_state + return result_with_state + + return result + + +def main(): + """Run the example application""" + print("=" * 80) + print("Burr + Ray Remote Execution Example") + print("=" * 80) + print() + + # Create interceptor and hooks + ray_interceptor = RayActionInterceptor() + orchestrator_pre = OrchestratorPreHook() + orchestrator_post = OrchestratorPostHook() + worker_pre = WorkerPreHook() + worker_post = WorkerPostHook() + + # Build the application + app = ( + ApplicationBuilder() + .with_state(count=0) + .with_actions( + increment_local, + heavy_computation, + another_ray_task, + ) + .with_transitions( + ("increment_local", "heavy_computation"), + ("heavy_computation", "another_ray_task"), + ("another_ray_task", "increment_local"), + ) + .with_entrypoint("increment_local") + .with_hooks( + ray_interceptor, + orchestrator_pre, + orchestrator_post, + worker_pre, + worker_post, + ) + .build() + ) + + # Execute steps + print("\n" + "=" * 80) + print("Step 1: Local execution (increment_local)") + print("=" * 80) + action, result, state = app.step() + print(f"Result: count={state['count']}, operation={state['last_operation']}") + + print("\n" + "=" * 80) + print("Step 2: Ray execution (heavy_computation)") + print("=" * 80) + action, result, state = app.step(inputs={"multiplier": 3}) + print(f"Result: count={state['count']}, operation={state['last_operation']}") + + print("\n" + "=" * 80) + print("Step 3: Ray execution (another_ray_task)") + print("=" * 80) + action, result, state = app.step() + print(f"Result: count={state['count']}, operation={state['last_operation']}") + + print("\n" + "=" * 80) + print("Step 4: Back to local execution (increment_local)") + print("=" * 80) + action, result, state = app.step() + print(f"Result: count={state['count']}, operation={state['last_operation']}") + + print("\n" + "=" * 80) + print("Final State:") + print("=" * 80) + print(f"Count: {state['count']}") + print(f"Last Operation: {state['last_operation']}") + + # Shutdown Ray + if ray.is_initialized(): + print("\n[Main Process] Shutting down Ray...") + ray.shutdown() + + print("\n" + "=" * 80) + print("Example completed successfully!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/remote-execution-ray/notebook.ipynb b/examples/remote-execution-ray/notebook.ipynb new file mode 100644 index 00000000..fed2a1d0 --- /dev/null +++ b/examples/remote-execution-ray/notebook.ipynb @@ -0,0 +1,483 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Remote Execution with Ray - Interactive Demo\n", + "\n", + "This notebook demonstrates how to use **Burr's Action Execution Interceptors** to run actions on Ray workers.\n", + "\n", + "## What You'll Learn\n", + "\n", + "1. How to create a Ray interceptor\n", + "2. How to define orchestrator vs. worker hooks\n", + "3. How to selectively run actions locally vs. remotely\n", + "4. How state flows between main process and Ray workers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "from typing import Dict, Any, Optional\n", + "\n", + "import ray\n", + "from burr.core import Action, State, ApplicationBuilder, action\n", + "from burr.lifecycle import (\n", + " ActionExecutionInterceptorHook,\n", + " PreRunStepHookWorker,\n", + " PostRunStepHookWorker,\n", + " PreRunStepHook,\n", + " PostRunStepHook,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Define Actions\n", + "\n", + "We'll create three actions:\n", + "- `increment_local` - runs locally (no `ray` tag)\n", + "- `heavy_computation` - runs on Ray (tagged with `ray`)\n", + "- `another_ray_task` - also runs on Ray" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@action(reads=[\"count\"], writes=[\"count\", \"last_operation\"], tags=[\"local\"])\n", + "def increment_local(state: State) -> tuple:\n", + " \"\"\"Increment counter locally (not on Ray)\"\"\"\n", + " result = {\n", + " \"count\": state[\"count\"] + 1,\n", + " \"last_operation\": \"increment_local\",\n", + " }\n", + " return result, state.update(**result)\n", + "\n", + "\n", + "@action(reads=[\"count\"], writes=[\"count\", \"last_operation\"], tags=[\"ray\"])\n", + "def heavy_computation(state: State, multiplier: int = 2) -> tuple:\n", + " \"\"\"Simulate heavy computation that should run on Ray\"\"\"\n", + " print(f\"🔧 [Ray Worker] Running heavy computation with multiplier={multiplier}\")\n", + " time.sleep(0.5) # Simulate work\n", + " result = {\n", + " \"count\": state[\"count\"] * multiplier,\n", + " \"last_operation\": f\"heavy_computation(x{multiplier})\",\n", + " }\n", + " return result, state.update(**result)\n", + "\n", + "\n", + "@action(reads=[\"count\"], writes=[\"count\", \"last_operation\"], tags=[\"ray\"])\n", + "def another_ray_task(state: State) -> tuple:\n", + " \"\"\"Another task that runs on Ray\"\"\"\n", + " print(\"🔧 [Ray Worker] Running another Ray task\")\n", + " time.sleep(0.3) # Simulate work\n", + " result = {\n", + " \"count\": state[\"count\"] + 10,\n", + " \"last_operation\": \"another_ray_task(+10)\",\n", + " }\n", + " return result, state.update(**result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Define Hooks\n", + "\n", + "We define two types of hooks:\n", + "1. **Orchestrator hooks** - run on the main process\n", + "2. **Worker hooks** - run on Ray workers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Orchestrator hooks (run on main process)\n", + "class OrchestratorPreHook(PreRunStepHook):\n", + " def pre_run_step(self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs):\n", + " print(f\"📋 [Main Process] About to execute: {action.name}\")\n", + "\n", + "\n", + "class OrchestratorPostHook(PostRunStepHook):\n", + " def post_run_step(\n", + " self, *, action: Action, state: State, result: Optional[Dict[str, Any]], exception, **kwargs\n", + " ):\n", + " print(f\"✅ [Main Process] Finished: {action.name}\")\n", + "\n", + "\n", + "# Worker hooks (run on Ray workers)\n", + "class WorkerPreHook(PreRunStepHookWorker):\n", + " def pre_run_step_worker(self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs):\n", + " print(f\"⚙️ [Ray Worker] Starting: {action.name}\")\n", + "\n", + "\n", + "class WorkerPostHook(PostRunStepHookWorker):\n", + " def post_run_step_worker(\n", + " self, *, action: Action, state: State, result: Optional[Dict[str, Any]], exception, **kwargs\n", + " ):\n", + " print(f\"✨ [Ray Worker] Completed: {action.name}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Create the Ray Interceptor\n", + "\n", + "The interceptor decides which actions to run on Ray and handles the remote execution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class RayActionInterceptor(ActionExecutionInterceptorHook):\n", + " \"\"\"Interceptor that executes actions tagged with 'ray' on Ray workers\"\"\"\n", + "\n", + " def __init__(self):\n", + " self.ray_initialized = False\n", + "\n", + " def _ensure_ray_initialized(self):\n", + " if not self.ray_initialized:\n", + " if not ray.is_initialized():\n", + " print(\"🚀 [Main Process] Initializing Ray...\")\n", + " ray.init(ignore_reinit_error=True)\n", + " self.ray_initialized = True\n", + "\n", + " def should_intercept(self, *, action: Action, **kwargs) -> bool:\n", + " \"\"\"Intercept actions tagged with 'ray'\"\"\"\n", + " return \"ray\" in action.tags\n", + "\n", + " def intercept_run(\n", + " self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs\n", + " ) -> dict:\n", + " \"\"\"Execute the action on a Ray worker\"\"\"\n", + " self._ensure_ray_initialized()\n", + "\n", + " print(f\"📤 [Main Process] Dispatching {action.name} to Ray...\")\n", + "\n", + " # Extract worker hooks\n", + " worker_adapter_set = kwargs.get(\"worker_adapter_set\")\n", + "\n", + " # Create a Ray remote function\n", + " @ray.remote\n", + " def execute_on_ray():\n", + " # Call pre-worker hooks\n", + " if worker_adapter_set:\n", + " worker_adapter_set.call_all_lifecycle_hooks_sync(\n", + " \"pre_run_step_worker\",\n", + " action=action,\n", + " state=state,\n", + " inputs=inputs,\n", + " )\n", + "\n", + " # Execute the action\n", + " if hasattr(action, \"single_step\") and action.single_step:\n", + " result, new_state = action.run_and_update(state, **inputs)\n", + " else:\n", + " state_to_use = state.subset(*action.reads)\n", + " result = action.run(state_to_use, **inputs)\n", + " new_state = None\n", + "\n", + " # Call post-worker hooks\n", + " if worker_adapter_set:\n", + " worker_adapter_set.call_all_lifecycle_hooks_sync(\n", + " \"post_run_step_worker\",\n", + " action=action,\n", + " state=state,\n", + " result=result,\n", + " exception=None,\n", + " )\n", + "\n", + " return result, new_state\n", + "\n", + " # Execute remotely and wait for result\n", + " result_ref = execute_on_ray.remote()\n", + " result, new_state = ray.get(result_ref)\n", + "\n", + " print(f\"📥 [Main Process] Received result from Ray for {action.name}\")\n", + "\n", + " # For single-step actions, include the new state\n", + " if new_state is not None:\n", + " result_with_state = result.copy()\n", + " result_with_state[\"__INTERCEPTOR_NEW_STATE__\"] = new_state\n", + " return result_with_state\n", + "\n", + " return result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Build the Application\n", + "\n", + "Now we put it all together!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create interceptor and hooks\n", + "ray_interceptor = RayActionInterceptor()\n", + "orchestrator_pre = OrchestratorPreHook()\n", + "orchestrator_post = OrchestratorPostHook()\n", + "worker_pre = WorkerPreHook()\n", + "worker_post = WorkerPostHook()\n", + "\n", + "# Build the application\n", + "app = (\n", + " ApplicationBuilder()\n", + " .with_state(count=0)\n", + " .with_actions(\n", + " increment_local,\n", + " heavy_computation,\n", + " another_ray_task,\n", + " )\n", + " .with_transitions(\n", + " (\"increment_local\", \"heavy_computation\"),\n", + " (\"heavy_computation\", \"another_ray_task\"),\n", + " (\"another_ray_task\", \"increment_local\"),\n", + " )\n", + " .with_entrypoint(\"increment_local\")\n", + " .with_hooks(\n", + " ray_interceptor,\n", + " orchestrator_pre,\n", + " orchestrator_post,\n", + " worker_pre,\n", + " worker_post,\n", + " )\n", + " .build()\n", + ")\n", + "\n", + "print(\"✨ Application built successfully!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Execute Actions\n", + "\n", + "Let's run through the workflow step by step." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 5a: Local Execution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\"*60)\n", + "print(\"STEP 1: Local Execution (increment_local)\")\n", + "print(\"=\"*60)\n", + "\n", + "action, result, state = app.step()\n", + "\n", + "print(f\"\\n📊 Result: count={state['count']}, operation={state['last_operation']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice:\n", + "- ✅ Orchestrator hooks run\n", + "- ❌ Worker hooks DON'T run (action not intercepted)\n", + "- ❌ No Ray dispatch messages" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 5b: Ray Execution #1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\"*60)\n", + "print(\"STEP 2: Ray Execution (heavy_computation)\")\n", + "print(\"=\"*60)\n", + "\n", + "action, result, state = app.step(inputs={\"multiplier\": 3})\n", + "\n", + "print(f\"\\n📊 Result: count={state['count']}, operation={state['last_operation']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice:\n", + "- ✅ Orchestrator hooks run (on main process)\n", + "- ✅ Worker hooks run (on Ray worker!)\n", + "- ✅ Ray dispatch and receive messages\n", + "- ✅ Actual computation happens on Ray worker" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 5c: Ray Execution #2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\"*60)\n", + "print(\"STEP 3: Ray Execution (another_ray_task)\")\n", + "print(\"=\"*60)\n", + "\n", + "action, result, state = app.step()\n", + "\n", + "print(f\"\\n📊 Result: count={state['count']}, operation={state['last_operation']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 5d: Back to Local" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\"*60)\n", + "print(\"STEP 4: Back to Local Execution (increment_local)\")\n", + "print(\"=\"*60)\n", + "\n", + "action, result, state = app.step()\n", + "\n", + "print(f\"\\n📊 Result: count={state['count']}, operation={state['last_operation']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: View Final State" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\"*60)\n", + "print(\"FINAL STATE\")\n", + "print(\"=\"*60)\n", + "print(f\"Count: {state['count']}\")\n", + "print(f\"Last Operation: {state['last_operation']}\")\n", + "print(\"\\nWorkflow: 0 → +1 (local) → x3 (ray) → +10 (ray) → +1 (local) = 4\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleanup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if ray.is_initialized():\n", + " print(\"🛑 Shutting down Ray...\")\n", + " ray.shutdown()\n", + " print(\"✅ Ray shutdown complete\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Key Takeaways\n", + "\n", + "1. **Selective Execution**: Actions can run locally or remotely based on tags\n", + "2. **Two-Tier Hooks**: Orchestrator hooks always run; worker hooks only run for intercepted actions\n", + "3. **Seamless Integration**: State flows naturally between main process and workers\n", + "4. **Transparent to Actions**: Actions don't know they're running on Ray\n", + "5. **Flexible**: Easy to add more actions or change execution backend\n", + "\n", + "## Next Steps\n", + "\n", + "Try modifying this notebook:\n", + "- Add your own actions with different tags\n", + "- Create a more complex workflow\n", + "- Add custom logging in the hooks\n", + "- Experiment with async actions\n", + "- Try with actual compute-intensive operations" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/remote-execution-ray/requirements.txt b/examples/remote-execution-ray/requirements.txt new file mode 100644 index 00000000..b0dbf87d --- /dev/null +++ b/examples/remote-execution-ray/requirements.txt @@ -0,0 +1,2 @@ +burr +ray>=2.0.0 diff --git a/tests/core/test_action_interceptor.py b/tests/core/test_action_interceptor.py new file mode 100644 index 00000000..97249b31 --- /dev/null +++ b/tests/core/test_action_interceptor.py @@ -0,0 +1,343 @@ +# Tests for action execution interceptor hooks +from typing import Any, Dict, Generator, Optional, Tuple + +import pytest + +from burr.core import Action, ApplicationBuilder, State, action +from burr.core.action import streaming_action +from burr.lifecycle import ( + ActionExecutionInterceptorHook, + PostRunStepHookWorker, + PreRunStepHookWorker, + StreamingActionInterceptorHook, +) + + +# Test actions +@action(reads=["x"], writes=["y"]) +def add_one(state: State) -> Tuple[dict, State]: + result = {"y": state["x"] + 1} + return result, state.update(**result) + + +@action(reads=["x"], writes=["z"], tags=["intercepted"]) +def multiply_by_two(state: State) -> Tuple[dict, State]: + result = {"z": state["x"] * 2} + return result, state.update(**result) + + +@streaming_action(reads=["prompt"], writes=["response"], tags=["streaming_intercepted"]) +def streaming_responder(state: State) -> Generator[Tuple[dict, Optional[State]], None, None]: + """Simple streaming action for testing""" + tokens = ["Hello", " ", "World", "!"] + buffer = [] + for token in tokens: + buffer.append(token) + yield {"response": token}, None + full_response = "".join(buffer) + yield {"response": full_response}, state.update(response=full_response) + + +# Mock interceptor that captures execution +class MockActionInterceptor(ActionExecutionInterceptorHook): + """Test interceptor that tracks which actions were intercepted""" + + def __init__(self): + self.intercepted_actions = [] + self.worker_hooks_called = [] + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + # Intercept actions with the "intercepted" tag + return "intercepted" in action.tags + + def intercept_run( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ) -> dict: + self.intercepted_actions.append(action.name) + + # Extract worker_adapter_set if provided + worker_adapter_set = kwargs.get("worker_adapter_set") + + # Call worker pre-hooks if they exist + if worker_adapter_set: + worker_adapter_set.call_all_lifecycle_hooks_sync( + "pre_run_step_worker", + action=action, + state=state, + inputs=inputs, + ) + + # Simulate "remote" execution - check if it's a single-step action + # For single-step actions, we need to call run_and_update and handle both result and state + if hasattr(action, "single_step") and action.single_step: + # Store the new state in a special key that _run_single_step_action will extract + result, new_state = action.run_and_update(state, **inputs) + # Store state in result for extraction + result_with_state = result.copy() + result_with_state["__INTERCEPTOR_NEW_STATE__"] = new_state + result = result_with_state + else: + # For multi-step actions, call run + state_to_use = state.subset(*action.reads) + action.validate_inputs(inputs) + result = action.run(state_to_use, **inputs) + + # Call worker post-hooks if they exist + if worker_adapter_set: + worker_adapter_set.call_all_lifecycle_hooks_sync( + "post_run_step_worker", + action=action, + state=state, + result=result, + exception=None, + ) + + return result + + +class MockStreamingInterceptor(StreamingActionInterceptorHook): + """Test interceptor for streaming actions""" + + def __init__(self): + self.intercepted_actions = [] + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + return "streaming_intercepted" in action.tags + + def intercept_stream_run_and_update( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ): + self.intercepted_actions.append(action.name) + + # Extract worker_adapter_set if provided + worker_adapter_set = kwargs.get("worker_adapter_set") + + # Call worker pre-stream-hooks if they exist + if worker_adapter_set: + worker_adapter_set.call_all_lifecycle_hooks_sync( + "pre_start_stream_worker", + action=action.name, + state=state, + inputs=inputs, + ) + + # Run the streaming action normally (simulating remote execution) + generator = action.stream_run_and_update(state, **inputs) + result = None + for item in generator: + result = item + yield item + + # Call worker post-stream-hooks if they exist + if worker_adapter_set and result: + worker_adapter_set.call_all_lifecycle_hooks_sync( + "post_end_stream_worker", + action=action.name, + result=result[0] if result else None, + exception=None, + ) + + +class WorkerPreHook(PreRunStepHookWorker): + """Test worker hook that runs before action execution""" + + def __init__(self): + self.called_actions = [] + + def pre_run_step_worker( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ): + self.called_actions.append(("pre", action.name)) + + +class WorkerPostHook(PostRunStepHookWorker): + """Test worker hook that runs after action execution""" + + def __init__(self): + self.called_actions = [] + + def post_run_step_worker( + self, + *, + action: Action, + state: State, + result: Optional[Dict[str, Any]], + exception: Exception, + **kwargs, + ): + self.called_actions.append(("post", action.name)) + + +def test_interceptor_intercepts_tagged_action(): + """Test that interceptor only intercepts actions with specific tags""" + interceptor = MockActionInterceptor() + + app = ( + ApplicationBuilder() + .with_state(x=5) + .with_actions(add_one, multiply_by_two) + .with_transitions( + ("add_one", "multiply_by_two"), + ("multiply_by_two", "add_one"), + ) + .with_entrypoint("add_one") + .with_hooks(interceptor) + .build() + ) + + # Run add_one (not intercepted) + action, result, state = app.step() + assert action.name == "add_one" + assert state["y"] == 6 + assert "add_one" not in interceptor.intercepted_actions + + # Run multiply_by_two (intercepted) + action, result, state = app.step() + assert action.name == "multiply_by_two" + assert state["z"] == 10 # 5 * 2, using original x value + assert "multiply_by_two" in interceptor.intercepted_actions + + +def test_interceptor_calls_worker_hooks(): + """Test that interceptor properly calls worker hooks""" + interceptor = MockActionInterceptor() + worker_pre = WorkerPreHook() + worker_post = WorkerPostHook() + + app = ( + ApplicationBuilder() + .with_state(x=10) + .with_actions(multiply_by_two) + .with_entrypoint("multiply_by_two") + .with_hooks(interceptor, worker_pre, worker_post) + .build() + ) + + action, result, state = app.step() + assert action.name == "multiply_by_two" + assert state["z"] == 20 + + # Verify interceptor ran + assert "multiply_by_two" in interceptor.intercepted_actions + + # Verify worker hooks were called + assert ("pre", "multiply_by_two") in worker_pre.called_actions + assert ("post", "multiply_by_two") in worker_post.called_actions + + +def test_no_interceptor_normal_execution(): + """Test that actions run normally without interceptors""" + app = ( + ApplicationBuilder() + .with_state(x=3) + .with_actions(add_one, multiply_by_two) + .with_transitions( + ("add_one", "multiply_by_two"), + ) + .with_entrypoint("add_one") + .build() + ) + + # Both should run normally + action, result, state = app.step() + assert action.name == "add_one" + assert state["y"] == 4 + + action, result, state = app.step() + assert action.name == "multiply_by_two" + assert state["z"] == 6 # 3 * 2 + + +def test_streaming_action_interceptor(): + """Test interceptor for streaming actions""" + streaming_interceptor = MockStreamingInterceptor() + + app = ( + ApplicationBuilder() + .with_state(prompt="test") + .with_actions(streaming_responder) + .with_entrypoint("streaming_responder") + .with_hooks(streaming_interceptor) + .build() + ) + + # Run streaming action + action, streaming_container = app.stream_result( + halt_after=["streaming_responder"], + ) + + # Consume the stream + tokens = [] + for item in streaming_container: + tokens.append(item["response"]) + + result, final_state = streaming_container.get() + + # Verify interceptor ran + assert "streaming_responder" in streaming_interceptor.intercepted_actions + + # Verify streaming worked correctly + assert tokens == ["Hello", " ", "World", "!"] + assert final_state["response"] == "Hello World!" + + +def test_multiple_interceptors_first_wins(): + """Test that when multiple interceptors match, the first one wins""" + + class FirstInterceptor(ActionExecutionInterceptorHook): + def __init__(self): + self.called = False + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + return "intercepted" in action.tags + + def intercept_run( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ) -> dict: + self.called = True + # Return a custom result with state for single-step actions + result = {"z": 999} + if hasattr(action, "single_step") and action.single_step: + result["__INTERCEPTOR_NEW_STATE__"] = state.update(z=999) + return result + + class SecondInterceptor(ActionExecutionInterceptorHook): + def __init__(self): + self.called = False + + def should_intercept(self, *, action: Action, **kwargs) -> bool: + return "intercepted" in action.tags + + def intercept_run( + self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs + ) -> dict: + self.called = True + result = {"z": 777} + if hasattr(action, "single_step") and action.single_step: + result["__INTERCEPTOR_NEW_STATE__"] = state.update(z=777) + return result + + first = FirstInterceptor() + second = SecondInterceptor() + + app = ( + ApplicationBuilder() + .with_state(x=5) + .with_actions(multiply_by_two) + .with_entrypoint("multiply_by_two") + .with_hooks(first, second) # first is registered first + .build() + ) + + action, result, state = app.step() + + # First interceptor should have been called + assert first.called + assert state["z"] == 999 + + # Second interceptor should NOT have been called + assert not second.called + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])
