This is an automated email from the ASF dual-hosted git repository.

skrawcz pushed a commit to branch stefan/interceptor
in repository https://gitbox.apache.org/repos/asf/burr.git

commit a0f7d2e18165154dd102305019206c378b51b8fe
Author: Stefan Krawczyk <[email protected]>
AuthorDate: Wed Nov 19 22:04:50 2025 -0800

    WIP: run a burr action remotely
---
 burr/core/application.py                       | 148 +++++++-
 burr/lifecycle/__init__.py                     |  24 ++
 burr/lifecycle/base.py                         | 352 ++++++++++++++++++
 burr/lifecycle/internal.py                     |  76 +++-
 examples/remote-execution-ray/README.md        | 209 +++++++++++
 examples/remote-execution-ray/__init__.py      |   6 +
 examples/remote-execution-ray/application.py   | 268 ++++++++++++++
 examples/remote-execution-ray/notebook.ipynb   | 483 +++++++++++++++++++++++++
 examples/remote-execution-ray/requirements.txt |   2 +
 tests/core/test_action_interceptor.py          | 343 ++++++++++++++++++
 10 files changed, 1900 insertions(+), 11 deletions(-)

diff --git a/burr/core/application.py b/burr/core/application.py
index 55f98acf..b552df31 100644
--- a/burr/core/application.py
+++ b/burr/core/application.py
@@ -158,7 +158,13 @@ def _remap_dunder_parameters(
     return inputs
 
 
-def _run_function(function: Function, state: State, inputs: Dict[str, Any], 
name: str) -> dict:
+def _run_function(
+    function: Function,
+    state: State,
+    inputs: Dict[str, Any],
+    name: str,
+    adapter_set: Optional["LifecycleAdapterSet"] = None,
+) -> dict:
     """Runs a function, returning the result of running the function.
     Note this restricts the keys in the state to only those that the
     function reads.
@@ -166,6 +172,8 @@ def _run_function(function: Function, state: State, inputs: 
Dict[str, Any], name
     :param function: Function to run
     :param state: State at time of execution
     :param inputs: Inputs to the function
+    :param name: Name of the action (for error messages)
+    :param adapter_set: Optional lifecycle adapter set for checking 
interceptors
     :return:
     """
     if function.is_async():
@@ -174,6 +182,21 @@ def _run_function(function: Function, state: State, 
inputs: Dict[str, Any], name
             "in non-async context. Use astep()/aiterate()/arun() "
             "instead...)"
         )
+
+    # Check for execution interceptors
+    if adapter_set:
+        interceptor = adapter_set.get_first_matching_hook(
+            "intercept_action_execution", lambda hook: 
hook.should_intercept(action=function)
+        )
+        if interceptor:
+            worker_adapter_set = adapter_set.get_worker_adapter_set()
+            result = interceptor.intercept_run(
+                action=function, state=state, inputs=inputs, 
worker_adapter_set=worker_adapter_set
+            )
+            _validate_result(result, name)
+            return result
+
+    # Normal execution path
     state_to_use = state.subset(*function.reads)
     function.validate_inputs(inputs)
     if "__context" in inputs or "__tracer" in inputs:
@@ -185,10 +208,30 @@ def _run_function(function: Function, state: State, 
inputs: Dict[str, Any], name
 
 
 async def _arun_function(
-    function: Function, state: State, inputs: Dict[str, Any], name: str
+    function: Function,
+    state: State,
+    inputs: Dict[str, Any],
+    name: str,
+    adapter_set: Optional["LifecycleAdapterSet"] = None,
 ) -> dict:
     """Runs a function, returning the result of running the function.
     Async version of the above."""
+
+    # Check for execution interceptors
+    if adapter_set:
+        interceptor = adapter_set.get_first_matching_hook(
+            "intercept_action_execution",
+            lambda hook: hook.should_intercept(action=function) and 
hasattr(hook, "intercept_run"),
+        )
+        if interceptor and 
inspect.iscoroutinefunction(interceptor.intercept_run):
+            worker_adapter_set = adapter_set.get_worker_adapter_set()
+            result = await interceptor.intercept_run(
+                action=function, state=state, inputs=inputs, 
worker_adapter_set=worker_adapter_set
+            )
+            _validate_result(result, name)
+            return result
+
+    # Normal execution path
     state_to_use = state.subset(*function.reads)
     function.validate_inputs(inputs)
     result = await function.run(state_to_use, **inputs)
@@ -299,7 +342,10 @@ def _format_BASE_ERROR_MESSAGE(action: Action, 
input_state: State, inputs: dict)
 
 
 def _run_single_step_action(
-    action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]]
+    action: SingleStepAction,
+    state: State,
+    inputs: Optional[Dict[str, Any]],
+    adapter_set: Optional["LifecycleAdapterSet"] = None,
 ) -> Tuple[Dict[str, Any], State]:
     """Runs a single step action. This API is internal-facing and a bit in 
flux, but
     it corresponds to the SingleStepAction class.
@@ -307,9 +353,33 @@ def _run_single_step_action(
     :param action: Action to run
     :param state: State to run with
     :param inputs: Inputs to pass directly to the action
+    :param adapter_set: Optional lifecycle adapter set for checking 
interceptors
     :return: The result of running the action, and the new state
     """
-    # TODO -- guard all reads/writes with a subset of the state
+    # Check for execution interceptors
+    if adapter_set:
+        interceptor = adapter_set.get_first_matching_hook(
+            "intercept_action_execution", lambda hook: 
hook.should_intercept(action=action)
+        )
+        if interceptor:
+            worker_adapter_set = adapter_set.get_worker_adapter_set()
+            result = interceptor.intercept_run(
+                action=action, state=state, inputs=inputs, 
worker_adapter_set=worker_adapter_set
+            )
+            # Check if interceptor returned state via special key (for 
single-step actions)
+            if "__INTERCEPTOR_NEW_STATE__" in result:
+                new_state = result.pop("__INTERCEPTOR_NEW_STATE__")
+            else:
+                # For multi-step actions or if state wasn't provided
+                # we need to compute it
+                new_state = action.update(result, state)
+
+            _validate_result(result, action.name, action.schema)
+            out = result, _state_update(state, new_state)
+            _validate_reducer_writes(action, new_state, action.name)
+            return out
+
+    # Normal execution path
     action.validate_inputs(inputs)
     result, new_state = _adjust_single_step_output(
         action.run_and_update(state, **inputs), action.name, action.schema
@@ -334,7 +404,18 @@ def _run_single_step_streaming_action(
     action.validate_inputs(inputs)
     stream_initialize_time = system.now()
     first_stream_start_time = None
-    generator = action.stream_run_and_update(state, **inputs)
+
+    # Check for streaming action interceptors
+    interceptor = lifecycle_adapters.get_first_matching_hook(
+        "intercept_streaming_action", lambda hook: 
hook.should_intercept(action=action)
+    )
+    if interceptor:
+        worker_adapter_set = lifecycle_adapters.get_worker_adapter_set()
+        generator = interceptor.intercept_stream_run_and_update(
+            action=action, state=state, inputs=inputs, 
worker_adapter_set=worker_adapter_set
+        )
+    else:
+        generator = action.stream_run_and_update(state, **inputs)
     result = None
     state_update = None
     count = 0
@@ -387,7 +468,20 @@ async def _arun_single_step_streaming_action(
     action.validate_inputs(inputs)
     stream_initialize_time = system.now()
     first_stream_start_time = None
-    generator = action.stream_run_and_update(state, **inputs)
+
+    # Check for streaming action interceptors
+    interceptor = lifecycle_adapters.get_first_matching_hook(
+        "intercept_streaming_action",
+        lambda hook: hook.should_intercept(action=action)
+        and hasattr(hook, "intercept_stream_run_and_update"),
+    )
+    if interceptor and 
inspect.isasyncgenfunction(interceptor.intercept_stream_run_and_update):
+        worker_adapter_set = lifecycle_adapters.get_worker_adapter_set()
+        generator = interceptor.intercept_stream_run_and_update(
+            action=action, state=state, inputs=inputs, 
worker_adapter_set=worker_adapter_set
+        )
+    else:
+        generator = action.stream_run_and_update(state, **inputs)
     result = None
     state_update = None
     count = 0
@@ -523,9 +617,35 @@ async def _arun_multi_step_streaming_action(
 
 
 async def _arun_single_step_action(
-    action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]]
+    action: SingleStepAction,
+    state: State,
+    inputs: Optional[Dict[str, Any]],
+    adapter_set: Optional["LifecycleAdapterSet"] = None,
 ) -> Tuple[dict, State]:
     """Runs a single step action in async. See the synchronous version for 
more details."""
+    # Check for execution interceptors
+    if adapter_set:
+        interceptor = adapter_set.get_first_matching_hook(
+            "intercept_action_execution",
+            lambda hook: hook.should_intercept(action=action) and 
hasattr(hook, "intercept_run"),
+        )
+        if interceptor and 
inspect.iscoroutinefunction(interceptor.intercept_run):
+            worker_adapter_set = adapter_set.get_worker_adapter_set()
+            result = await interceptor.intercept_run(
+                action=action, state=state, inputs=inputs, 
worker_adapter_set=worker_adapter_set
+            )
+            # Check if interceptor returned state via special key (for 
single-step actions)
+            if "__INTERCEPTOR_NEW_STATE__" in result:
+                new_state = result.pop("__INTERCEPTOR_NEW_STATE__")
+            else:
+                # For multi-step actions or if state wasn't provided
+                new_state = action.update(result, state)
+
+            _validate_result(result, action.name, action.schema)
+            _validate_reducer_writes(action, new_state, action.name)
+            return result, _state_update(state, new_state)
+
+    # Normal execution path
     state_to_use = state
     action.validate_inputs(inputs)
     result, new_state = _adjust_single_step_output(
@@ -915,11 +1035,15 @@ class Application(Generic[ApplicationStateType]):
             try:
                 if next_action.single_step:
                     result, new_state = _run_single_step_action(
-                        next_action, self._state, action_inputs
+                        next_action, self._state, action_inputs, 
adapter_set=self._adapter_set
                     )
                 else:
                     result = _run_function(
-                        next_action, self._state, action_inputs, 
name=next_action.name
+                        next_action,
+                        self._state,
+                        action_inputs,
+                        name=next_action.name,
+                        adapter_set=self._adapter_set,
                     )
                     new_state = _run_reducer(next_action, self._state, result, 
next_action.name)
 
@@ -1065,7 +1189,10 @@ class Application(Generic[ApplicationStateType]):
                 action_inputs = self._process_inputs(inputs, next_action)
                 if next_action.single_step:
                     result, new_state = await _arun_single_step_action(
-                        next_action, self._state, inputs=action_inputs
+                        next_action,
+                        self._state,
+                        inputs=action_inputs,
+                        adapter_set=self._adapter_set,
                     )
                 else:
                     result = await _arun_function(
@@ -1073,6 +1200,7 @@ class Application(Generic[ApplicationStateType]):
                         self._state,
                         inputs=action_inputs,
                         name=next_action.name,
+                        adapter_set=self._adapter_set,
                     )
                     new_state = _run_reducer(next_action, self._state, result, 
next_action.name)
                 new_state = self._update_internal_state_value(new_state, 
next_action)
diff --git a/burr/lifecycle/__init__.py b/burr/lifecycle/__init__.py
index 4ae24073..991ee984 100644
--- a/burr/lifecycle/__init__.py
+++ b/burr/lifecycle/__init__.py
@@ -16,18 +16,30 @@
 # under the License.
 
 from burr.lifecycle.base import (
+    ActionExecutionInterceptorHook,
+    ActionExecutionInterceptorHookAsync,
     LifecycleAdapter,
     PostApplicationCreateHook,
     PostApplicationExecuteCallHook,
     PostApplicationExecuteCallHookAsync,
     PostEndSpanHook,
+    PostEndStreamHookWorker,
+    PostEndStreamHookWorkerAsync,
     PostRunStepHook,
     PostRunStepHookAsync,
+    PostRunStepHookWorker,
+    PostRunStepHookWorkerAsync,
     PreApplicationExecuteCallHook,
     PreApplicationExecuteCallHookAsync,
     PreRunStepHook,
     PreRunStepHookAsync,
+    PreRunStepHookWorker,
+    PreRunStepHookWorkerAsync,
     PreStartSpanHook,
+    PreStartStreamHookWorker,
+    PreStartStreamHookWorkerAsync,
+    StreamingActionInterceptorHook,
+    StreamingActionInterceptorHookAsync,
 )
 from burr.lifecycle.default import StateAndResultsFullLogger
 
@@ -45,4 +57,16 @@ __all__ = [
     "PostApplicationCreateHook",
     "PostEndSpanHook",
     "PreStartSpanHook",
+    "PreRunStepHookWorker",
+    "PreRunStepHookWorkerAsync",
+    "PostRunStepHookWorker",
+    "PostRunStepHookWorkerAsync",
+    "PreStartStreamHookWorker",
+    "PreStartStreamHookWorkerAsync",
+    "PostEndStreamHookWorker",
+    "PostEndStreamHookWorkerAsync",
+    "ActionExecutionInterceptorHook",
+    "ActionExecutionInterceptorHookAsync",
+    "StreamingActionInterceptorHook",
+    "StreamingActionInterceptorHookAsync",
 ]
diff --git a/burr/lifecycle/base.py b/burr/lifecycle/base.py
index 66d8bd7e..2117a7f5 100644
--- a/burr/lifecycle/base.py
+++ b/burr/lifecycle/base.py
@@ -492,6 +492,346 @@ class PostEndStreamHookAsync(abc.ABC):
         pass
 
 
[email protected]_hook("pre_run_step_worker")
+class PreRunStepHookWorker(abc.ABC):
+    """Hook that runs on the worker (e.g., Ray/Temporal) before action 
execution.
+    This hook is designed to be called by execution interceptors on remote 
workers,
+    as opposed to PreRunStepHook which always runs on the main orchestrator 
process."""
+
+    @abc.abstractmethod
+    def pre_run_step_worker(
+        self,
+        *,
+        action: "Action",
+        state: "State",
+        inputs: Dict[str, Any],
+        **future_kwargs: Any,
+    ):
+        """Run before a step is executed on the worker.
+
+        :param action: Action to be executed
+        :param state: State prior to step execution
+        :param inputs: Inputs to the action
+        :param future_kwargs: Future keyword arguments
+        """
+        pass
+
+
[email protected]_hook("pre_run_step_worker")
+class PreRunStepHookWorkerAsync(abc.ABC):
+    """Async hook that runs on the worker before action execution."""
+
+    @abc.abstractmethod
+    async def pre_run_step_worker(
+        self,
+        *,
+        action: "Action",
+        state: "State",
+        inputs: Dict[str, Any],
+        **future_kwargs: Any,
+    ):
+        """Async run before a step is executed on the worker.
+
+        :param action: Action to be executed
+        :param state: State prior to step execution
+        :param inputs: Inputs to the action
+        :param future_kwargs: Future keyword arguments
+        """
+        pass
+
+
[email protected]_hook("post_run_step_worker")
+class PostRunStepHookWorker(abc.ABC):
+    """Hook that runs on the worker after action execution.
+    This hook is designed to be called by execution interceptors on remote 
workers,
+    as opposed to PostRunStepHook which always runs on the main orchestrator 
process."""
+
+    @abc.abstractmethod
+    def post_run_step_worker(
+        self,
+        *,
+        action: "Action",
+        state: "State",
+        result: Optional[Dict[str, Any]],
+        exception: Exception,
+        **future_kwargs: Any,
+    ):
+        """Run after a step is executed on the worker.
+
+        :param action: Action that was executed
+        :param state: State after step execution
+        :param result: Result of the action
+        :param exception: Exception that was raised
+        :param future_kwargs: Future keyword arguments
+        """
+        pass
+
+
[email protected]_hook("post_run_step_worker")
+class PostRunStepHookWorkerAsync(abc.ABC):
+    """Async hook that runs on the worker after action execution."""
+
+    @abc.abstractmethod
+    async def post_run_step_worker(
+        self,
+        *,
+        action: "Action",
+        state: "State",
+        result: Optional[dict],
+        exception: Exception,
+        **future_kwargs: Any,
+    ):
+        """Async run after a step is executed on the worker.
+
+        :param action: Action that was executed
+        :param state: State after step execution
+        :param result: Result of the action
+        :param exception: Exception that was raised
+        :param future_kwargs: Future keyword arguments
+        """
+        pass
+
+
[email protected]_hook("pre_start_stream_worker")
+class PreStartStreamHookWorker(abc.ABC):
+    """Hook that runs on the worker after a stream is started."""
+
+    @abc.abstractmethod
+    def pre_start_stream_worker(
+        self,
+        *,
+        action: str,
+        state: "State",
+        inputs: Dict[str, Any],
+        **future_kwargs: Any,
+    ):
+        pass
+
+
[email protected]_hook("pre_start_stream_worker")
+class PreStartStreamHookWorkerAsync(abc.ABC):
+    """Async hook that runs on the worker after a stream is started."""
+
+    @abc.abstractmethod
+    async def pre_start_stream_worker(
+        self,
+        *,
+        action: str,
+        state: "State",
+        inputs: Dict[str, Any],
+        **future_kwargs: Any,
+    ):
+        pass
+
+
[email protected]_hook("post_end_stream_worker")
+class PostEndStreamHookWorker(abc.ABC):
+    """Hook that runs on the worker after a stream is ended."""
+
+    @abc.abstractmethod
+    def post_end_stream_worker(
+        self,
+        *,
+        action: str,
+        result: Optional[Dict[str, Any]],
+        exception: Exception,
+        **future_kwargs: Any,
+    ):
+        pass
+
+
[email protected]_hook("post_end_stream_worker")
+class PostEndStreamHookWorkerAsync(abc.ABC):
+    """Async hook that runs on the worker after a stream is ended."""
+
+    @abc.abstractmethod
+    async def post_end_stream_worker(
+        self,
+        *,
+        action: str,
+        result: Optional[Dict[str, Any]],
+        exception: Exception,
+        **future_kwargs: Any,
+    ):
+        pass
+
+
+class ActionExecutionInterceptorHook(abc.ABC):
+    """Hook that can wrap/replace action execution (e.g., for Ray/Temporal).
+    This hook allows you to intercept the execution of an action and run it
+    on a different execution backend while maintaining the same interface.
+
+    The interceptor receives a worker_adapter_set containing only worker hooks
+    (PreRunStepHookWorker, PostRunStepHookWorker, etc.) that can be called
+    on the remote execution environment.
+
+    Note: Interceptors don't use the @lifecycle.base_hook decorator because 
they
+    have multiple methods and special handling logic."""
+
+    @abc.abstractmethod
+    def should_intercept(
+        self,
+        *,
+        action: "Action",
+        **future_kwargs: Any,
+    ) -> bool:
+        """Determine if this action should be intercepted.
+
+        :param action: Action to potentially intercept
+        :param future_kwargs: Future keyword arguments
+        :return: True if this hook should intercept execution
+        """
+        pass
+
+    @abc.abstractmethod
+    def intercept_run(
+        self,
+        *,
+        action: "Action",
+        state: "State",
+        inputs: Dict[str, Any],
+        **future_kwargs: Any,
+    ) -> dict:
+        """Replace the action.run() call with custom execution.
+
+        Note: The state passed here is the FULL state, not subsetted.
+        You are responsible for subsetting it to action.reads if needed.
+
+        :param action: Action to execute
+        :param state: Current state (FULL state, not subsetted)
+        :param inputs: Inputs to the action
+        :param future_kwargs: Future keyword arguments (includes 
worker_adapter_set)
+        :return: Result dictionary from running the action
+        """
+        pass
+
+
+class ActionExecutionInterceptorHookAsync(abc.ABC):
+    """Async version of ActionExecutionInterceptorHook for intercepting async 
actions."""
+
+    @abc.abstractmethod
+    async def should_intercept(
+        self,
+        *,
+        action: "Action",
+        **future_kwargs: Any,
+    ) -> bool:
+        """Determine if this action should be intercepted.
+
+        :param action: Action to potentially intercept
+        :param future_kwargs: Future keyword arguments
+        :return: True if this hook should intercept execution
+        """
+        pass
+
+    @abc.abstractmethod
+    async def intercept_run(
+        self,
+        *,
+        action: "Action",
+        state: "State",
+        inputs: Dict[str, Any],
+        **future_kwargs: Any,
+    ) -> dict:
+        """Replace the action.run() call with custom execution.
+
+        Note: The state passed here is the FULL state, not subsetted.
+        You are responsible for subsetting it to action.reads if needed.
+
+        :param action: Action to execute
+        :param state: Current state (FULL state, not subsetted)
+        :param inputs: Inputs to the action
+        :param future_kwargs: Future keyword arguments (includes 
worker_adapter_set)
+        :return: Result dictionary from running the action
+        """
+        pass
+
+
+class StreamingActionInterceptorHook(abc.ABC):
+    """Hook to intercept streaming action execution (e.g., for Ray/Temporal).
+    This hook allows you to wrap streaming actions to execute on different 
backends.
+
+    The interceptor receives a worker_adapter_set containing only worker hooks
+    that can be called on the remote execution environment.
+
+    Note: Interceptors don't use the @lifecycle.base_hook decorator because 
they
+    have multiple methods and special handling logic."""
+
+    @abc.abstractmethod
+    def should_intercept(
+        self,
+        *,
+        action: "Action",
+        **future_kwargs: Any,
+    ) -> bool:
+        """Determine if this streaming action should be intercepted.
+
+        :param action: Streaming action to potentially intercept
+        :param future_kwargs: Future keyword arguments
+        :return: True if this hook should intercept execution
+        """
+        pass
+
+    @abc.abstractmethod
+    def intercept_stream_run_and_update(
+        self,
+        *,
+        action: "Action",
+        state: "State",
+        inputs: Dict[str, Any],
+        **future_kwargs: Any,
+    ):
+        """Replace stream_run_and_update with custom execution.
+        Must be a generator that yields (result_dict, optional_state) tuples.
+
+        :param action: Streaming action to execute
+        :param state: Current state
+        :param inputs: Inputs to the action
+        :param future_kwargs: Future keyword arguments (includes 
worker_adapter_set)
+        :return: Generator yielding (dict, Optional[State]) tuples
+        """
+        pass
+
+
+class StreamingActionInterceptorHookAsync(abc.ABC):
+    """Async version for intercepting async streaming actions."""
+
+    @abc.abstractmethod
+    async def should_intercept(
+        self,
+        *,
+        action: "Action",
+        **future_kwargs: Any,
+    ) -> bool:
+        """Determine if this streaming action should be intercepted.
+
+        :param action: Streaming action to potentially intercept
+        :param future_kwargs: Future keyword arguments
+        :return: True if this hook should intercept execution
+        """
+        pass
+
+    @abc.abstractmethod
+    def intercept_stream_run_and_update(
+        self,
+        *,
+        action: "Action",
+        state: "State",
+        inputs: Dict[str, Any],
+        **future_kwargs: Any,
+    ):
+        """Replace stream_run_and_update with custom execution.
+        Must be an async generator that yields (result_dict, optional_state) 
tuples.
+
+        :param action: Streaming action to execute
+        :param state: Current state
+        :param inputs: Inputs to the action
+        :param future_kwargs: Future keyword arguments (includes 
worker_adapter_set)
+        :return: Async generator yielding (dict, Optional[State]) tuples
+        """
+        pass
+
+
 # strictly for typing -- this conflicts a bit with the lifecycle decorator 
above, but its fine for now
 # This makes IDE completion/type-hinting easier
 LifecycleAdapter = Union[
@@ -515,4 +855,16 @@ LifecycleAdapter = Union[
     PreStartStreamHookAsync,
     PostStreamItemHookAsync,
     PostEndStreamHookAsync,
+    PreRunStepHookWorker,
+    PreRunStepHookWorkerAsync,
+    PostRunStepHookWorker,
+    PostRunStepHookWorkerAsync,
+    PreStartStreamHookWorker,
+    PreStartStreamHookWorkerAsync,
+    PostEndStreamHookWorker,
+    PostEndStreamHookWorkerAsync,
+    ActionExecutionInterceptorHook,
+    ActionExecutionInterceptorHookAsync,
+    StreamingActionInterceptorHook,
+    StreamingActionInterceptorHookAsync,
 ]
diff --git a/burr/lifecycle/internal.py b/burr/lifecycle/internal.py
index 1043bd0a..1c7eb2df 100644
--- a/burr/lifecycle/internal.py
+++ b/burr/lifecycle/internal.py
@@ -119,7 +119,15 @@ class LifecycleAdapterSet:
         :param adapters: Adapters to group together
         """
         self._adapters = list(adapters)
-        self.sync_hooks, self.async_hooks = self._get_lifecycle_hooks()
+        self._sync_hooks, self._async_hooks = self._get_lifecycle_hooks()
+
+    @property
+    def sync_hooks(self):
+        return self._sync_hooks
+
+    @property
+    def async_hooks(self):
+        return self._async_hooks
 
     def with_new_adapters(self, *adapters: "LifecycleAdapter") -> 
"LifecycleAdapterSet":
         """Adds new adapters to the set.
@@ -212,3 +220,69 @@ class LifecycleAdapterSet:
         :return: A list of adapters
         """
         return self._adapters
+
+    def get_first_matching_hook(
+        self, hook_name: str, predicate: Callable[["LifecycleAdapter"], bool]
+    ):
+        """Get first hook of given type that matches predicate.
+
+        For interceptor hooks (intercept_action_execution, 
intercept_streaming_action),
+        this searches all adapters for instances of the interceptor classes, 
since
+        interceptors don't use the standard hook registration system.
+
+        :param hook_name: Name of the hook to search for (or interceptor type)
+        :param predicate: Function that takes a hook and returns True if it 
matches
+        :return: The first matching hook, or None if no match found
+        """
+        # Special handling for interceptors - they don't use the registration 
system
+        if hook_name in ("intercept_action_execution", 
"intercept_streaming_action"):
+            # Import here to avoid circular dependency
+            from burr.lifecycle.base import (
+                ActionExecutionInterceptorHook,
+                ActionExecutionInterceptorHookAsync,
+                StreamingActionInterceptorHook,
+                StreamingActionInterceptorHookAsync,
+            )
+
+            interceptor_classes = (
+                ActionExecutionInterceptorHook,
+                ActionExecutionInterceptorHookAsync,
+                StreamingActionInterceptorHook,
+                StreamingActionInterceptorHookAsync,
+            )
+
+            for adapter in self.adapters:
+                if isinstance(adapter, interceptor_classes):
+                    if predicate(adapter):
+                        return adapter
+            return None
+
+        # Standard hook lookup for registered hooks
+        hooks = self.sync_hooks.get(hook_name, []) + 
self.async_hooks.get(hook_name, [])
+        for hook in hooks:
+            if predicate(hook):
+                return hook
+        return None
+
+    def get_worker_adapter_set(self) -> "LifecycleAdapterSet":
+        """Create a new LifecycleAdapterSet containing only worker hooks.
+        Worker hooks are those with names ending in '_worker' and are designed
+        to be called on remote execution environments (Ray/Temporal workers).
+
+        :return: A new LifecycleAdapterSet with only worker hooks
+        """
+        worker_hooks = []
+        for adapter in self.adapters:
+            # Check if this adapter is a worker hook by looking at its 
registered hooks
+            is_worker = False
+            for cls in inspect.getmro(adapter.__class__):
+                sync_hook = getattr(cls, SYNC_HOOK, None)
+                async_hook = getattr(cls, ASYNC_HOOK, None)
+                if (sync_hook and sync_hook.endswith("_worker")) or (
+                    async_hook and async_hook.endswith("_worker")
+                ):
+                    is_worker = True
+                    break
+            if is_worker:
+                worker_hooks.append(adapter)
+        return LifecycleAdapterSet(*worker_hooks)
diff --git a/examples/remote-execution-ray/README.md 
b/examples/remote-execution-ray/README.md
new file mode 100644
index 00000000..c855aee8
--- /dev/null
+++ b/examples/remote-execution-ray/README.md
@@ -0,0 +1,209 @@
+# Remote Execution with Ray
+
+This example demonstrates how to use Burr's **Action Execution Interceptors** 
to run specific actions on Ray workers while keeping orchestration on the main 
process.
+
+## Overview
+
+Burr's lifecycle hook system includes **interceptors** that can wrap action 
execution and redirect it to different execution backends like Ray, Temporal, 
or custom distributed systems.
+
+This example shows:
+- ✅ Selective interception (only actions tagged with `ray` run remotely)
+- ✅ Orchestrator hooks (run on main process)
+- ✅ Worker hooks (run on Ray workers)
+- ✅ Seamless mixing of local and remote execution
+- ✅ State management across distributed execution
+
+## Architecture
+
+```
+┌─────────────────────────────────────────────────────────────┐
+│  Main Process (Orchestrator)                                 │
+│                                                               │
+│  ┌──────────────────────────────────────────────┐            │
+│  │  Burr Application                             │            │
+│  │                                               │            │
+│  │  PreRunStepHook (Orchestrator) ────┐         │            │
+│  │                                     ↓         │            │
+│  │  RayActionInterceptor               │         │            │
+│  │    - should_intercept()             │         │            │
+│  │    - intercept_run() ───────────────┼─────────┼─────┐     │
+│  │                                     │         │      │     │
+│  │  PostRunStepHook (Orchestrator) ←──┘         │      │     │
+│  └──────────────────────────────────────────────┘      │     │
+└────────────────────────────────────────────────────────┼─────┘
+                                                          │
+                                         Ray Remote Call  │
+                                                          ↓
+┌─────────────────────────────────────────────────────────────┐
+│  Ray Worker                                                  │
+│                                                              │
+│  PreRunStepHookWorker  ────┐                                │
+│                            ↓                                │
+│  Action.run_and_update()  (actual execution)                │
+│                            │                                │
+│  PostRunStepHookWorker ←───┘                                │
+│                                                              │
+└─────────────────────────────────────────────────────────────┘
+```
+
+## Key Concepts
+
+### 1. Two-Tier Hook System
+
+**Orchestrator Hooks** (run on main process):
+- `PreRunStepHook` - runs before any action (local or remote)
+- `PostRunStepHook` - runs after any action completes
+- These hooks see all actions but don't know about execution details
+
+**Worker Hooks** (run on Ray workers):
+- `PreRunStepHookWorker` - runs on the worker before execution
+- `PostRunStepHookWorker` - runs on the worker after execution
+- Only called for intercepted actions
+- Must be serializable (picklable)
+
+### 2. Action Execution Interceptor
+
+The interceptor has two methods:
+
+```python
+def should_intercept(self, *, action: Action, **kwargs) -> bool:
+    """Decide if this action should be intercepted"""
+    return "ray" in action.tags
+
+def intercept_run(self, *, action: Action, state: State, inputs: Dict[str, 
Any], **kwargs) -> dict:
+    """Execute the action on Ray and return the result"""
+    # Get worker hooks to pass to Ray worker
+    worker_adapter_set = kwargs.get("worker_adapter_set")
+
+    # Execute on Ray with worker hooks
+    @ray.remote
+    def execute_on_ray():
+        # Call worker hooks
+        # Execute action
+        # Return result
+
+    return ray.get(execute_on_ray.remote())
+```
+
+### 3. Selective Execution
+
+Actions are tagged to control where they run:
+
+```python
+@action(reads=["count"], writes=["count"], tags=["local"])
+def local_task(state: State):
+    # Runs on main process
+    ...
+
+@action(reads=["count"], writes=["count"], tags=["ray"])
+def remote_task(state: State):
+    # Runs on Ray worker
+    ...
+```
+
+## Installation
+
+```bash
+pip install -r requirements.txt
+```
+
+## Running the Example
+
+### Python Script
+
+```bash
+python application.py
+```
+
+Expected output:
+```
+================================================================================
+Burr + Ray Remote Execution Example
+================================================================================
+
+[Main Process] Initializing Ray...
+
+================================================================================
+Step 1: Local execution (increment_local)
+================================================================================
+[Main Process] About to execute action: increment_local
+[Main Process] Finished executing action: increment_local
+Result: count=1, operation=increment_local
+
+================================================================================
+Step 2: Ray execution (heavy_computation)
+================================================================================
+[Main Process] About to execute action: heavy_computation
+[Main Process] Dispatching heavy_computation to Ray...
+[Ray Worker] Starting action: heavy_computation on Ray worker
+[Ray Worker] Running heavy computation with multiplier=3
+[Ray Worker] Completed action: heavy_computation on Ray worker
+[Main Process] Received result from Ray for heavy_computation
+[Main Process] Finished executing action: heavy_computation
+Result: count=3, operation=heavy_computation(x3)
+
+...
+```
+
+### Jupyter Notebook
+
+```bash
+jupyter notebook notebook.ipynb
+```
+
+## Use Cases
+
+This pattern is useful for:
+
+1. **Compute-Intensive Operations**: Offload heavy computations to Ray clusters
+2. **GPU Workloads**: Run ML inference/training on GPU workers
+3. **Scalability**: Distribute work across multiple machines
+4. **Resource Isolation**: Keep heavy operations away from orchestrator
+5. **Hybrid Workflows**: Mix local control flow with distributed execution
+
+## Extending to Other Backends
+
+The same pattern works for other execution backends:
+
+### Temporal
+
+```python
+class TemporalActionInterceptor(ActionExecutionInterceptorHook):
+    def should_intercept(self, *, action, **kwargs):
+        return "temporal" in action.tags
+
+    def intercept_run(self, *, action, state, inputs, **kwargs):
+        # Execute as Temporal activity
+        return await workflow.execute_activity(
+            action.run_and_update,
+            state,
+            **inputs
+        )
+```
+
+### Custom Distributed System
+
+```python
+class CustomBackendInterceptor(ActionExecutionInterceptorHook):
+    def should_intercept(self, *, action, **kwargs):
+        return "distributed" in action.tags
+
+    def intercept_run(self, *, action, state, inputs, **kwargs):
+        # Submit to your custom backend
+        job_id = backend.submit_job(action, state, inputs)
+        result = backend.wait_for_completion(job_id)
+        return result
+```
+
+## Important Notes
+
+1. **State Serialization**: State must be serializable to pass to workers
+2. **Worker Hooks**: Must be picklable (avoid closures with local variables)
+3. **Error Handling**: Exceptions on workers propagate back to orchestrator
+4. **Performance**: Ray overhead ~100ms per task; use for tasks >1s
+
+## Related Documentation
+
+- [Burr Lifecycle Hooks](https://burr.dagworks.io/concepts/hooks/)
+- [Ray Core API](https://docs.ray.io/en/latest/ray-core/walkthrough.html)
+- [Temporal Workflows](https://docs.temporal.io/)
diff --git a/examples/remote-execution-ray/__init__.py 
b/examples/remote-execution-ray/__init__.py
new file mode 100644
index 00000000..cd242bc2
--- /dev/null
+++ b/examples/remote-execution-ray/__init__.py
@@ -0,0 +1,6 @@
+"""
+Remote Execution with Ray
+
+This example demonstrates how to use Burr's Action Execution Interceptors
+to run actions on Ray workers while maintaining orchestration on the main 
process.
+"""
diff --git a/examples/remote-execution-ray/application.py 
b/examples/remote-execution-ray/application.py
new file mode 100644
index 00000000..ad44291a
--- /dev/null
+++ b/examples/remote-execution-ray/application.py
@@ -0,0 +1,268 @@
+"""
+Example demonstrating how to use Burr's action execution interceptors to run
+actions remotely on Ray workers.
+
+This example shows:
+1. How to create a RayActionInterceptor to execute actions on Ray
+2. How worker hooks run on the remote Ray worker
+3. How to mix local and remote execution based on action tags
+"""
+
+import time
+from typing import Any, Dict, Optional
+
+import ray
+
+from burr.core import Action, ApplicationBuilder, State, action
+from burr.lifecycle import (
+    ActionExecutionInterceptorHook,
+    PostRunStepHook,
+    PostRunStepHookWorker,
+    PreRunStepHook,
+    PreRunStepHookWorker,
+)
+
+
+# Define some example actions
+@action(reads=["count"], writes=["count", "last_operation"], tags=["local"])
+def increment_local(state: State) -> tuple:
+    """Increment counter locally (not on Ray)"""
+    result = {
+        "count": state["count"] + 1,
+        "last_operation": "increment_local",
+    }
+    return result, state.update(**result)
+
+
+@action(reads=["count"], writes=["count", "last_operation"], tags=["ray"])
+def heavy_computation(state: State, multiplier: int = 2) -> tuple:
+    """Simulate heavy computation that should run on Ray"""
+    print(f"[Ray Worker] Running heavy computation with 
multiplier={multiplier}")
+    time.sleep(0.5)  # Simulate work
+    result = {
+        "count": state["count"] * multiplier,
+        "last_operation": f"heavy_computation(x{multiplier})",
+    }
+    return result, state.update(**result)
+
+
+@action(reads=["count"], writes=["count", "last_operation"], tags=["ray"])
+def another_ray_task(state: State) -> tuple:
+    """Another task that runs on Ray"""
+    print("[Ray Worker] Running another Ray task")
+    time.sleep(0.3)  # Simulate work
+    result = {
+        "count": state["count"] + 10,
+        "last_operation": "another_ray_task(+10)",
+    }
+    return result, state.update(**result)
+
+
+# Orchestrator hooks (run on main process)
+class OrchestratorPreHook(PreRunStepHook):
+    """Hook that runs on the main process before action execution"""
+
+    def pre_run_step(self, *, action: Action, state: State, inputs: Dict[str, 
Any], **kwargs):
+        print(f"[Main Process] About to execute action: {action.name}")
+
+
+class OrchestratorPostHook(PostRunStepHook):
+    """Hook that runs on the main process after action execution"""
+
+    def post_run_step(
+        self,
+        *,
+        action: Action,
+        state: State,
+        result: Optional[Dict[str, Any]],
+        exception: Exception,
+        **kwargs,
+    ):
+        print(f"[Main Process] Finished executing action: {action.name}")
+
+
+# Worker hooks (run on Ray workers)
+class WorkerPreHook(PreRunStepHookWorker):
+    """Hook that runs on the Ray worker before action execution"""
+
+    def pre_run_step_worker(
+        self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs
+    ):
+        print(f"[Ray Worker] Starting action: {action.name} on Ray worker")
+
+
+class WorkerPostHook(PostRunStepHookWorker):
+    """Hook that runs on the Ray worker after action execution"""
+
+    def post_run_step_worker(
+        self,
+        *,
+        action: Action,
+        state: State,
+        result: Optional[Dict[str, Any]],
+        exception: Exception,
+        **kwargs,
+    ):
+        print(f"[Ray Worker] Completed action: {action.name} on Ray worker")
+
+
+# Ray Execution Interceptor
+class RayActionInterceptor(ActionExecutionInterceptorHook):
+    """Interceptor that executes actions tagged with 'ray' on Ray workers"""
+
+    def __init__(self):
+        self.ray_initialized = False
+
+    def _ensure_ray_initialized(self):
+        """Initialize Ray if not already initialized"""
+        if not self.ray_initialized:
+            if not ray.is_initialized():
+                print("[Main Process] Initializing Ray...")
+                ray.init(ignore_reinit_error=True)
+            self.ray_initialized = True
+
+    def should_intercept(self, *, action: Action, **kwargs) -> bool:
+        """Intercept actions tagged with 'ray'"""
+        return "ray" in action.tags
+
+    def intercept_run(
+        self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs
+    ) -> dict:
+        """Execute the action on a Ray worker"""
+        self._ensure_ray_initialized()
+
+        print(f"[Main Process] Dispatching {action.name} to Ray...")
+
+        # Extract worker hooks
+        worker_adapter_set = kwargs.get("worker_adapter_set")
+
+        # Create a Ray remote function that executes the action
+        @ray.remote
+        def execute_on_ray():
+            """Execute action on Ray worker with worker hooks"""
+            # Call pre-worker hooks
+            if worker_adapter_set:
+                worker_adapter_set.call_all_lifecycle_hooks_sync(
+                    "pre_run_step_worker",
+                    action=action,
+                    state=state,
+                    inputs=inputs,
+                )
+
+            # Execute the action
+            if hasattr(action, "single_step") and action.single_step:
+                result, new_state = action.run_and_update(state, **inputs)
+            else:
+                state_to_use = state.subset(*action.reads)
+                result = action.run(state_to_use, **inputs)
+                new_state = None
+
+            # Call post-worker hooks
+            if worker_adapter_set:
+                worker_adapter_set.call_all_lifecycle_hooks_sync(
+                    "post_run_step_worker",
+                    action=action,
+                    state=state,
+                    result=result,
+                    exception=None,
+                )
+
+            return result, new_state
+
+        # Execute remotely and wait for result
+        result_ref = execute_on_ray.remote()
+        result, new_state = ray.get(result_ref)
+
+        print(f"[Main Process] Received result from Ray for {action.name}")
+
+        # For single-step actions, include the new state
+        if new_state is not None:
+            result_with_state = result.copy()
+            result_with_state["__INTERCEPTOR_NEW_STATE__"] = new_state
+            return result_with_state
+
+        return result
+
+
+def main():
+    """Run the example application"""
+    print("=" * 80)
+    print("Burr + Ray Remote Execution Example")
+    print("=" * 80)
+    print()
+
+    # Create interceptor and hooks
+    ray_interceptor = RayActionInterceptor()
+    orchestrator_pre = OrchestratorPreHook()
+    orchestrator_post = OrchestratorPostHook()
+    worker_pre = WorkerPreHook()
+    worker_post = WorkerPostHook()
+
+    # Build the application
+    app = (
+        ApplicationBuilder()
+        .with_state(count=0)
+        .with_actions(
+            increment_local,
+            heavy_computation,
+            another_ray_task,
+        )
+        .with_transitions(
+            ("increment_local", "heavy_computation"),
+            ("heavy_computation", "another_ray_task"),
+            ("another_ray_task", "increment_local"),
+        )
+        .with_entrypoint("increment_local")
+        .with_hooks(
+            ray_interceptor,
+            orchestrator_pre,
+            orchestrator_post,
+            worker_pre,
+            worker_post,
+        )
+        .build()
+    )
+
+    # Execute steps
+    print("\n" + "=" * 80)
+    print("Step 1: Local execution (increment_local)")
+    print("=" * 80)
+    action, result, state = app.step()
+    print(f"Result: count={state['count']}, 
operation={state['last_operation']}")
+
+    print("\n" + "=" * 80)
+    print("Step 2: Ray execution (heavy_computation)")
+    print("=" * 80)
+    action, result, state = app.step(inputs={"multiplier": 3})
+    print(f"Result: count={state['count']}, 
operation={state['last_operation']}")
+
+    print("\n" + "=" * 80)
+    print("Step 3: Ray execution (another_ray_task)")
+    print("=" * 80)
+    action, result, state = app.step()
+    print(f"Result: count={state['count']}, 
operation={state['last_operation']}")
+
+    print("\n" + "=" * 80)
+    print("Step 4: Back to local execution (increment_local)")
+    print("=" * 80)
+    action, result, state = app.step()
+    print(f"Result: count={state['count']}, 
operation={state['last_operation']}")
+
+    print("\n" + "=" * 80)
+    print("Final State:")
+    print("=" * 80)
+    print(f"Count: {state['count']}")
+    print(f"Last Operation: {state['last_operation']}")
+
+    # Shutdown Ray
+    if ray.is_initialized():
+        print("\n[Main Process] Shutting down Ray...")
+        ray.shutdown()
+
+    print("\n" + "=" * 80)
+    print("Example completed successfully!")
+    print("=" * 80)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/examples/remote-execution-ray/notebook.ipynb 
b/examples/remote-execution-ray/notebook.ipynb
new file mode 100644
index 00000000..fed2a1d0
--- /dev/null
+++ b/examples/remote-execution-ray/notebook.ipynb
@@ -0,0 +1,483 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Remote Execution with Ray - Interactive Demo\n",
+    "\n",
+    "This notebook demonstrates how to use **Burr's Action Execution 
Interceptors** to run actions on Ray workers.\n",
+    "\n",
+    "## What You'll Learn\n",
+    "\n",
+    "1. How to create a Ray interceptor\n",
+    "2. How to define orchestrator vs. worker hooks\n",
+    "3. How to selectively run actions locally vs. remotely\n",
+    "4. How state flows between main process and Ray workers"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Setup"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import time\n",
+    "from typing import Dict, Any, Optional\n",
+    "\n",
+    "import ray\n",
+    "from burr.core import Action, State, ApplicationBuilder, action\n",
+    "from burr.lifecycle import (\n",
+    "    ActionExecutionInterceptorHook,\n",
+    "    PreRunStepHookWorker,\n",
+    "    PostRunStepHookWorker,\n",
+    "    PreRunStepHook,\n",
+    "    PostRunStepHook,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 1: Define Actions\n",
+    "\n",
+    "We'll create three actions:\n",
+    "- `increment_local` - runs locally (no `ray` tag)\n",
+    "- `heavy_computation` - runs on Ray (tagged with `ray`)\n",
+    "- `another_ray_task` - also runs on Ray"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "@action(reads=[\"count\"], writes=[\"count\", \"last_operation\"], 
tags=[\"local\"])\n",
+    "def increment_local(state: State) -> tuple:\n",
+    "    \"\"\"Increment counter locally (not on Ray)\"\"\"\n",
+    "    result = {\n",
+    "        \"count\": state[\"count\"] + 1,\n",
+    "        \"last_operation\": \"increment_local\",\n",
+    "    }\n",
+    "    return result, state.update(**result)\n",
+    "\n",
+    "\n",
+    "@action(reads=[\"count\"], writes=[\"count\", \"last_operation\"], 
tags=[\"ray\"])\n",
+    "def heavy_computation(state: State, multiplier: int = 2) -> tuple:\n",
+    "    \"\"\"Simulate heavy computation that should run on Ray\"\"\"\n",
+    "    print(f\"🔧 [Ray Worker] Running heavy computation with 
multiplier={multiplier}\")\n",
+    "    time.sleep(0.5)  # Simulate work\n",
+    "    result = {\n",
+    "        \"count\": state[\"count\"] * multiplier,\n",
+    "        \"last_operation\": f\"heavy_computation(x{multiplier})\",\n",
+    "    }\n",
+    "    return result, state.update(**result)\n",
+    "\n",
+    "\n",
+    "@action(reads=[\"count\"], writes=[\"count\", \"last_operation\"], 
tags=[\"ray\"])\n",
+    "def another_ray_task(state: State) -> tuple:\n",
+    "    \"\"\"Another task that runs on Ray\"\"\"\n",
+    "    print(\"🔧 [Ray Worker] Running another Ray task\")\n",
+    "    time.sleep(0.3)  # Simulate work\n",
+    "    result = {\n",
+    "        \"count\": state[\"count\"] + 10,\n",
+    "        \"last_operation\": \"another_ray_task(+10)\",\n",
+    "    }\n",
+    "    return result, state.update(**result)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 2: Define Hooks\n",
+    "\n",
+    "We define two types of hooks:\n",
+    "1. **Orchestrator hooks** - run on the main process\n",
+    "2. **Worker hooks** - run on Ray workers"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Orchestrator hooks (run on main process)\n",
+    "class OrchestratorPreHook(PreRunStepHook):\n",
+    "    def pre_run_step(self, *, action: Action, state: State, inputs: 
Dict[str, Any], **kwargs):\n",
+    "        print(f\"📋 [Main Process] About to execute: {action.name}\")\n",
+    "\n",
+    "\n",
+    "class OrchestratorPostHook(PostRunStepHook):\n",
+    "    def post_run_step(\n",
+    "        self, *, action: Action, state: State, result: Optional[Dict[str, 
Any]], exception, **kwargs\n",
+    "    ):\n",
+    "        print(f\"✅ [Main Process] Finished: {action.name}\")\n",
+    "\n",
+    "\n",
+    "# Worker hooks (run on Ray workers)\n",
+    "class WorkerPreHook(PreRunStepHookWorker):\n",
+    "    def pre_run_step_worker(self, *, action: Action, state: State, 
inputs: Dict[str, Any], **kwargs):\n",
+    "        print(f\"⚙️  [Ray Worker] Starting: {action.name}\")\n",
+    "\n",
+    "\n",
+    "class WorkerPostHook(PostRunStepHookWorker):\n",
+    "    def post_run_step_worker(\n",
+    "        self, *, action: Action, state: State, result: Optional[Dict[str, 
Any]], exception, **kwargs\n",
+    "    ):\n",
+    "        print(f\"✨ [Ray Worker] Completed: {action.name}\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 3: Create the Ray Interceptor\n",
+    "\n",
+    "The interceptor decides which actions to run on Ray and handles the 
remote execution."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class RayActionInterceptor(ActionExecutionInterceptorHook):\n",
+    "    \"\"\"Interceptor that executes actions tagged with 'ray' on Ray 
workers\"\"\"\n",
+    "\n",
+    "    def __init__(self):\n",
+    "        self.ray_initialized = False\n",
+    "\n",
+    "    def _ensure_ray_initialized(self):\n",
+    "        if not self.ray_initialized:\n",
+    "            if not ray.is_initialized():\n",
+    "                print(\"🚀 [Main Process] Initializing Ray...\")\n",
+    "                ray.init(ignore_reinit_error=True)\n",
+    "            self.ray_initialized = True\n",
+    "\n",
+    "    def should_intercept(self, *, action: Action, **kwargs) -> bool:\n",
+    "        \"\"\"Intercept actions tagged with 'ray'\"\"\"\n",
+    "        return \"ray\" in action.tags\n",
+    "\n",
+    "    def intercept_run(\n",
+    "        self, *, action: Action, state: State, inputs: Dict[str, Any], 
**kwargs\n",
+    "    ) -> dict:\n",
+    "        \"\"\"Execute the action on a Ray worker\"\"\"\n",
+    "        self._ensure_ray_initialized()\n",
+    "\n",
+    "        print(f\"📤 [Main Process] Dispatching {action.name} to 
Ray...\")\n",
+    "\n",
+    "        # Extract worker hooks\n",
+    "        worker_adapter_set = kwargs.get(\"worker_adapter_set\")\n",
+    "\n",
+    "        # Create a Ray remote function\n",
+    "        @ray.remote\n",
+    "        def execute_on_ray():\n",
+    "            # Call pre-worker hooks\n",
+    "            if worker_adapter_set:\n",
+    "                worker_adapter_set.call_all_lifecycle_hooks_sync(\n",
+    "                    \"pre_run_step_worker\",\n",
+    "                    action=action,\n",
+    "                    state=state,\n",
+    "                    inputs=inputs,\n",
+    "                )\n",
+    "\n",
+    "            # Execute the action\n",
+    "            if hasattr(action, \"single_step\") and 
action.single_step:\n",
+    "                result, new_state = action.run_and_update(state, 
**inputs)\n",
+    "            else:\n",
+    "                state_to_use = state.subset(*action.reads)\n",
+    "                result = action.run(state_to_use, **inputs)\n",
+    "                new_state = None\n",
+    "\n",
+    "            # Call post-worker hooks\n",
+    "            if worker_adapter_set:\n",
+    "                worker_adapter_set.call_all_lifecycle_hooks_sync(\n",
+    "                    \"post_run_step_worker\",\n",
+    "                    action=action,\n",
+    "                    state=state,\n",
+    "                    result=result,\n",
+    "                    exception=None,\n",
+    "                )\n",
+    "\n",
+    "            return result, new_state\n",
+    "\n",
+    "        # Execute remotely and wait for result\n",
+    "        result_ref = execute_on_ray.remote()\n",
+    "        result, new_state = ray.get(result_ref)\n",
+    "\n",
+    "        print(f\"📥 [Main Process] Received result from Ray for 
{action.name}\")\n",
+    "\n",
+    "        # For single-step actions, include the new state\n",
+    "        if new_state is not None:\n",
+    "            result_with_state = result.copy()\n",
+    "            result_with_state[\"__INTERCEPTOR_NEW_STATE__\"] = 
new_state\n",
+    "            return result_with_state\n",
+    "\n",
+    "        return result"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 4: Build the Application\n",
+    "\n",
+    "Now we put it all together!"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Create interceptor and hooks\n",
+    "ray_interceptor = RayActionInterceptor()\n",
+    "orchestrator_pre = OrchestratorPreHook()\n",
+    "orchestrator_post = OrchestratorPostHook()\n",
+    "worker_pre = WorkerPreHook()\n",
+    "worker_post = WorkerPostHook()\n",
+    "\n",
+    "# Build the application\n",
+    "app = (\n",
+    "    ApplicationBuilder()\n",
+    "    .with_state(count=0)\n",
+    "    .with_actions(\n",
+    "        increment_local,\n",
+    "        heavy_computation,\n",
+    "        another_ray_task,\n",
+    "    )\n",
+    "    .with_transitions(\n",
+    "        (\"increment_local\", \"heavy_computation\"),\n",
+    "        (\"heavy_computation\", \"another_ray_task\"),\n",
+    "        (\"another_ray_task\", \"increment_local\"),\n",
+    "    )\n",
+    "    .with_entrypoint(\"increment_local\")\n",
+    "    .with_hooks(\n",
+    "        ray_interceptor,\n",
+    "        orchestrator_pre,\n",
+    "        orchestrator_post,\n",
+    "        worker_pre,\n",
+    "        worker_post,\n",
+    "    )\n",
+    "    .build()\n",
+    ")\n",
+    "\n",
+    "print(\"✨ Application built successfully!\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 5: Execute Actions\n",
+    "\n",
+    "Let's run through the workflow step by step."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Step 5a: Local Execution"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print(\"\\n\" + \"=\"*60)\n",
+    "print(\"STEP 1: Local Execution (increment_local)\")\n",
+    "print(\"=\"*60)\n",
+    "\n",
+    "action, result, state = app.step()\n",
+    "\n",
+    "print(f\"\\n📊 Result: count={state['count']}, 
operation={state['last_operation']}\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Notice:\n",
+    "- ✅ Orchestrator hooks run\n",
+    "- ❌ Worker hooks DON'T run (action not intercepted)\n",
+    "- ❌ No Ray dispatch messages"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Step 5b: Ray Execution #1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print(\"\\n\" + \"=\"*60)\n",
+    "print(\"STEP 2: Ray Execution (heavy_computation)\")\n",
+    "print(\"=\"*60)\n",
+    "\n",
+    "action, result, state = app.step(inputs={\"multiplier\": 3})\n",
+    "\n",
+    "print(f\"\\n📊 Result: count={state['count']}, 
operation={state['last_operation']}\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Notice:\n",
+    "- ✅ Orchestrator hooks run (on main process)\n",
+    "- ✅ Worker hooks run (on Ray worker!)\n",
+    "- ✅ Ray dispatch and receive messages\n",
+    "- ✅ Actual computation happens on Ray worker"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Step 5c: Ray Execution #2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print(\"\\n\" + \"=\"*60)\n",
+    "print(\"STEP 3: Ray Execution (another_ray_task)\")\n",
+    "print(\"=\"*60)\n",
+    "\n",
+    "action, result, state = app.step()\n",
+    "\n",
+    "print(f\"\\n📊 Result: count={state['count']}, 
operation={state['last_operation']}\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Step 5d: Back to Local"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print(\"\\n\" + \"=\"*60)\n",
+    "print(\"STEP 4: Back to Local Execution (increment_local)\")\n",
+    "print(\"=\"*60)\n",
+    "\n",
+    "action, result, state = app.step()\n",
+    "\n",
+    "print(f\"\\n📊 Result: count={state['count']}, 
operation={state['last_operation']}\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 6: View Final State"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print(\"\\n\" + \"=\"*60)\n",
+    "print(\"FINAL STATE\")\n",
+    "print(\"=\"*60)\n",
+    "print(f\"Count: {state['count']}\")\n",
+    "print(f\"Last Operation: {state['last_operation']}\")\n",
+    "print(\"\\nWorkflow: 0 → +1 (local) → x3 (ray) → +10 (ray) → +1 (local) = 
4\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Cleanup"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if ray.is_initialized():\n",
+    "    print(\"🛑 Shutting down Ray...\")\n",
+    "    ray.shutdown()\n",
+    "    print(\"✅ Ray shutdown complete\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Key Takeaways\n",
+    "\n",
+    "1. **Selective Execution**: Actions can run locally or remotely based on 
tags\n",
+    "2. **Two-Tier Hooks**: Orchestrator hooks always run; worker hooks only 
run for intercepted actions\n",
+    "3. **Seamless Integration**: State flows naturally between main process 
and workers\n",
+    "4. **Transparent to Actions**: Actions don't know they're running on 
Ray\n",
+    "5. **Flexible**: Easy to add more actions or change execution backend\n",
+    "\n",
+    "## Next Steps\n",
+    "\n",
+    "Try modifying this notebook:\n",
+    "- Add your own actions with different tags\n",
+    "- Create a more complex workflow\n",
+    "- Add custom logging in the hooks\n",
+    "- Experiment with async actions\n",
+    "- Try with actual compute-intensive operations"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.0"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/examples/remote-execution-ray/requirements.txt 
b/examples/remote-execution-ray/requirements.txt
new file mode 100644
index 00000000..b0dbf87d
--- /dev/null
+++ b/examples/remote-execution-ray/requirements.txt
@@ -0,0 +1,2 @@
+burr
+ray>=2.0.0
diff --git a/tests/core/test_action_interceptor.py 
b/tests/core/test_action_interceptor.py
new file mode 100644
index 00000000..97249b31
--- /dev/null
+++ b/tests/core/test_action_interceptor.py
@@ -0,0 +1,343 @@
+# Tests for action execution interceptor hooks
+from typing import Any, Dict, Generator, Optional, Tuple
+
+import pytest
+
+from burr.core import Action, ApplicationBuilder, State, action
+from burr.core.action import streaming_action
+from burr.lifecycle import (
+    ActionExecutionInterceptorHook,
+    PostRunStepHookWorker,
+    PreRunStepHookWorker,
+    StreamingActionInterceptorHook,
+)
+
+
+# Test actions
+@action(reads=["x"], writes=["y"])
+def add_one(state: State) -> Tuple[dict, State]:
+    result = {"y": state["x"] + 1}
+    return result, state.update(**result)
+
+
+@action(reads=["x"], writes=["z"], tags=["intercepted"])
+def multiply_by_two(state: State) -> Tuple[dict, State]:
+    result = {"z": state["x"] * 2}
+    return result, state.update(**result)
+
+
+@streaming_action(reads=["prompt"], writes=["response"], 
tags=["streaming_intercepted"])
+def streaming_responder(state: State) -> Generator[Tuple[dict, 
Optional[State]], None, None]:
+    """Simple streaming action for testing"""
+    tokens = ["Hello", " ", "World", "!"]
+    buffer = []
+    for token in tokens:
+        buffer.append(token)
+        yield {"response": token}, None
+    full_response = "".join(buffer)
+    yield {"response": full_response}, state.update(response=full_response)
+
+
+# Mock interceptor that captures execution
+class MockActionInterceptor(ActionExecutionInterceptorHook):
+    """Test interceptor that tracks which actions were intercepted"""
+
+    def __init__(self):
+        self.intercepted_actions = []
+        self.worker_hooks_called = []
+
+    def should_intercept(self, *, action: Action, **kwargs) -> bool:
+        # Intercept actions with the "intercepted" tag
+        return "intercepted" in action.tags
+
+    def intercept_run(
+        self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs
+    ) -> dict:
+        self.intercepted_actions.append(action.name)
+
+        # Extract worker_adapter_set if provided
+        worker_adapter_set = kwargs.get("worker_adapter_set")
+
+        # Call worker pre-hooks if they exist
+        if worker_adapter_set:
+            worker_adapter_set.call_all_lifecycle_hooks_sync(
+                "pre_run_step_worker",
+                action=action,
+                state=state,
+                inputs=inputs,
+            )
+
+        # Simulate "remote" execution - check if it's a single-step action
+        # For single-step actions, we need to call run_and_update and handle 
both result and state
+        if hasattr(action, "single_step") and action.single_step:
+            # Store the new state in a special key that 
_run_single_step_action will extract
+            result, new_state = action.run_and_update(state, **inputs)
+            # Store state in result for extraction
+            result_with_state = result.copy()
+            result_with_state["__INTERCEPTOR_NEW_STATE__"] = new_state
+            result = result_with_state
+        else:
+            # For multi-step actions, call run
+            state_to_use = state.subset(*action.reads)
+            action.validate_inputs(inputs)
+            result = action.run(state_to_use, **inputs)
+
+        # Call worker post-hooks if they exist
+        if worker_adapter_set:
+            worker_adapter_set.call_all_lifecycle_hooks_sync(
+                "post_run_step_worker",
+                action=action,
+                state=state,
+                result=result,
+                exception=None,
+            )
+
+        return result
+
+
+class MockStreamingInterceptor(StreamingActionInterceptorHook):
+    """Test interceptor for streaming actions"""
+
+    def __init__(self):
+        self.intercepted_actions = []
+
+    def should_intercept(self, *, action: Action, **kwargs) -> bool:
+        return "streaming_intercepted" in action.tags
+
+    def intercept_stream_run_and_update(
+        self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs
+    ):
+        self.intercepted_actions.append(action.name)
+
+        # Extract worker_adapter_set if provided
+        worker_adapter_set = kwargs.get("worker_adapter_set")
+
+        # Call worker pre-stream-hooks if they exist
+        if worker_adapter_set:
+            worker_adapter_set.call_all_lifecycle_hooks_sync(
+                "pre_start_stream_worker",
+                action=action.name,
+                state=state,
+                inputs=inputs,
+            )
+
+        # Run the streaming action normally (simulating remote execution)
+        generator = action.stream_run_and_update(state, **inputs)
+        result = None
+        for item in generator:
+            result = item
+            yield item
+
+        # Call worker post-stream-hooks if they exist
+        if worker_adapter_set and result:
+            worker_adapter_set.call_all_lifecycle_hooks_sync(
+                "post_end_stream_worker",
+                action=action.name,
+                result=result[0] if result else None,
+                exception=None,
+            )
+
+
+class WorkerPreHook(PreRunStepHookWorker):
+    """Test worker hook that runs before action execution"""
+
+    def __init__(self):
+        self.called_actions = []
+
+    def pre_run_step_worker(
+        self, *, action: Action, state: State, inputs: Dict[str, Any], **kwargs
+    ):
+        self.called_actions.append(("pre", action.name))
+
+
+class WorkerPostHook(PostRunStepHookWorker):
+    """Test worker hook that runs after action execution"""
+
+    def __init__(self):
+        self.called_actions = []
+
+    def post_run_step_worker(
+        self,
+        *,
+        action: Action,
+        state: State,
+        result: Optional[Dict[str, Any]],
+        exception: Exception,
+        **kwargs,
+    ):
+        self.called_actions.append(("post", action.name))
+
+
+def test_interceptor_intercepts_tagged_action():
+    """Test that interceptor only intercepts actions with specific tags"""
+    interceptor = MockActionInterceptor()
+
+    app = (
+        ApplicationBuilder()
+        .with_state(x=5)
+        .with_actions(add_one, multiply_by_two)
+        .with_transitions(
+            ("add_one", "multiply_by_two"),
+            ("multiply_by_two", "add_one"),
+        )
+        .with_entrypoint("add_one")
+        .with_hooks(interceptor)
+        .build()
+    )
+
+    # Run add_one (not intercepted)
+    action, result, state = app.step()
+    assert action.name == "add_one"
+    assert state["y"] == 6
+    assert "add_one" not in interceptor.intercepted_actions
+
+    # Run multiply_by_two (intercepted)
+    action, result, state = app.step()
+    assert action.name == "multiply_by_two"
+    assert state["z"] == 10  # 5 * 2, using original x value
+    assert "multiply_by_two" in interceptor.intercepted_actions
+
+
+def test_interceptor_calls_worker_hooks():
+    """Test that interceptor properly calls worker hooks"""
+    interceptor = MockActionInterceptor()
+    worker_pre = WorkerPreHook()
+    worker_post = WorkerPostHook()
+
+    app = (
+        ApplicationBuilder()
+        .with_state(x=10)
+        .with_actions(multiply_by_two)
+        .with_entrypoint("multiply_by_two")
+        .with_hooks(interceptor, worker_pre, worker_post)
+        .build()
+    )
+
+    action, result, state = app.step()
+    assert action.name == "multiply_by_two"
+    assert state["z"] == 20
+
+    # Verify interceptor ran
+    assert "multiply_by_two" in interceptor.intercepted_actions
+
+    # Verify worker hooks were called
+    assert ("pre", "multiply_by_two") in worker_pre.called_actions
+    assert ("post", "multiply_by_two") in worker_post.called_actions
+
+
+def test_no_interceptor_normal_execution():
+    """Test that actions run normally without interceptors"""
+    app = (
+        ApplicationBuilder()
+        .with_state(x=3)
+        .with_actions(add_one, multiply_by_two)
+        .with_transitions(
+            ("add_one", "multiply_by_two"),
+        )
+        .with_entrypoint("add_one")
+        .build()
+    )
+
+    # Both should run normally
+    action, result, state = app.step()
+    assert action.name == "add_one"
+    assert state["y"] == 4
+
+    action, result, state = app.step()
+    assert action.name == "multiply_by_two"
+    assert state["z"] == 6  # 3 * 2
+
+
+def test_streaming_action_interceptor():
+    """Test interceptor for streaming actions"""
+    streaming_interceptor = MockStreamingInterceptor()
+
+    app = (
+        ApplicationBuilder()
+        .with_state(prompt="test")
+        .with_actions(streaming_responder)
+        .with_entrypoint("streaming_responder")
+        .with_hooks(streaming_interceptor)
+        .build()
+    )
+
+    # Run streaming action
+    action, streaming_container = app.stream_result(
+        halt_after=["streaming_responder"],
+    )
+
+    # Consume the stream
+    tokens = []
+    for item in streaming_container:
+        tokens.append(item["response"])
+
+    result, final_state = streaming_container.get()
+
+    # Verify interceptor ran
+    assert "streaming_responder" in streaming_interceptor.intercepted_actions
+
+    # Verify streaming worked correctly
+    assert tokens == ["Hello", " ", "World", "!"]
+    assert final_state["response"] == "Hello World!"
+
+
+def test_multiple_interceptors_first_wins():
+    """Test that when multiple interceptors match, the first one wins"""
+
+    class FirstInterceptor(ActionExecutionInterceptorHook):
+        def __init__(self):
+            self.called = False
+
+        def should_intercept(self, *, action: Action, **kwargs) -> bool:
+            return "intercepted" in action.tags
+
+        def intercept_run(
+            self, *, action: Action, state: State, inputs: Dict[str, Any], 
**kwargs
+        ) -> dict:
+            self.called = True
+            # Return a custom result with state for single-step actions
+            result = {"z": 999}
+            if hasattr(action, "single_step") and action.single_step:
+                result["__INTERCEPTOR_NEW_STATE__"] = state.update(z=999)
+            return result
+
+    class SecondInterceptor(ActionExecutionInterceptorHook):
+        def __init__(self):
+            self.called = False
+
+        def should_intercept(self, *, action: Action, **kwargs) -> bool:
+            return "intercepted" in action.tags
+
+        def intercept_run(
+            self, *, action: Action, state: State, inputs: Dict[str, Any], 
**kwargs
+        ) -> dict:
+            self.called = True
+            result = {"z": 777}
+            if hasattr(action, "single_step") and action.single_step:
+                result["__INTERCEPTOR_NEW_STATE__"] = state.update(z=777)
+            return result
+
+    first = FirstInterceptor()
+    second = SecondInterceptor()
+
+    app = (
+        ApplicationBuilder()
+        .with_state(x=5)
+        .with_actions(multiply_by_two)
+        .with_entrypoint("multiply_by_two")
+        .with_hooks(first, second)  # first is registered first
+        .build()
+    )
+
+    action, result, state = app.step()
+
+    # First interceptor should have been called
+    assert first.called
+    assert state["z"] == 999
+
+    # Second interceptor should NOT have been called
+    assert not second.called
+
+
+if __name__ == "__main__":
+    pytest.main([__file__, "-v"])

Reply via email to