This is an automated email from the ASF dual-hosted git repository.

skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/burr.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c5abdb0 fix: graceful stream shutdown on exceptions in streaming 
actions (#680)
2c5abdb0 is described below

commit 2c5abdb0e3829228bb1397c8063eff7d9b85c6ae
Author: AndrĂ© Ahlert <[email protected]>
AuthorDate: Sat Mar 28 16:24:33 2026 -0300

    fix: graceful stream shutdown on exceptions in streaming actions (#680)
    
    * fix: allow streaming actions to gracefully handle raised exceptions
    
    When a streaming action catches an exception and yields a final state
    in a try/except/finally block, the stream now completes gracefully
    instead of propagating the exception and killing the connection.
    
    If the generator yields a valid state_update before the exception
    propagates, the exception is suppressed and the stream terminates
    normally. If no state was yielded, the exception propagates as before.
    
    Closes #581
    
    * fix: add logging for suppressed exceptions and tests for streaming 
graceful shutdown
    
    - Add logger.warning with exc_info in all 4 streaming except blocks
    - Remove dead caught_exc variable in multi-step functions
    - Fix count increment asymmetry between sync and async single-step
    - Add 8 tests covering graceful exception handling and propagation
    
    * style: apply black formatting
    
    * style: apply pre-commit formatting (black 23.11.0, isort 5.12.0)
---
 burr/core/application.py       | 249 +++++++++++++++---------
 tests/core/test_application.py | 430 +++++++++++++++++++++++++++++++++++++++--
 2 files changed, 565 insertions(+), 114 deletions(-)

diff --git a/burr/core/application.py b/burr/core/application.py
index 84e3f7cb..7cbe9614 100644
--- a/burr/core/application.py
+++ b/burr/core/application.py
@@ -338,31 +338,44 @@ def _run_single_step_streaming_action(
     result = None
     state_update = None
     count = 0
-    for item in generator:
-        if not isinstance(item, tuple):
-            # TODO -- consider adding support for just returning a result.
-            raise ValueError(
-                f"Action {action.name} must yield a tuple of (result, 
state_update). "
-                f"For all non-final results (intermediate),"
-                f"the state update must be None"
-            )
-        result, state_update = item
-        count += 1
+    try:
+        for item in generator:
+            if not isinstance(item, tuple):
+                # TODO -- consider adding support for just returning a result.
+                raise ValueError(
+                    f"Action {action.name} must yield a tuple of (result, 
state_update). "
+                    f"For all non-final results (intermediate),"
+                    f"the state update must be None"
+                )
+            result, state_update = item
+            if state_update is None:
+                count += 1
+                if first_stream_start_time is None:
+                    first_stream_start_time = system.now()
+                lifecycle_adapters.call_all_lifecycle_hooks_sync(
+                    "post_stream_item",
+                    item=result,
+                    item_index=count,
+                    stream_initialize_time=stream_initialize_time,
+                    first_stream_item_start_time=first_stream_start_time,
+                    action=action.name,
+                    app_id=app_id,
+                    partition_key=partition_key,
+                    sequence_id=sequence_id,
+                )
+                yield result, None
+    except Exception as e:
         if state_update is None:
-            if first_stream_start_time is None:
-                first_stream_start_time = system.now()
-            lifecycle_adapters.call_all_lifecycle_hooks_sync(
-                "post_stream_item",
-                item=result,
-                item_index=count,
-                stream_initialize_time=stream_initialize_time,
-                first_stream_item_start_time=first_stream_start_time,
-                action=action.name,
-                app_id=app_id,
-                partition_key=partition_key,
-                sequence_id=sequence_id,
-            )
-            yield result, None
+            raise
+        logger.warning(
+            "Streaming action '%s' raised %s after yielding %d items. "
+            "Proceeding with final state from generator cleanup. Original 
error: %s",
+            action.name,
+            type(e).__name__,
+            count,
+            e,
+            exc_info=True,
+        )
 
     if state_update is None:
         raise ValueError(
@@ -391,31 +404,45 @@ async def _arun_single_step_streaming_action(
     result = None
     state_update = None
     count = 0
-    async for item in generator:
-        if not isinstance(item, tuple):
-            # TODO -- consider adding support for just returning a result.
-            raise ValueError(
-                f"Action {action.name} must yield a tuple of (result, 
state_update). "
-                f"For all non-final results (intermediate),"
-                f"the state update must be None"
-            )
-        result, state_update = item
+    try:
+        async for item in generator:
+            if not isinstance(item, tuple):
+                # TODO -- consider adding support for just returning a result.
+                raise ValueError(
+                    f"Action {action.name} must yield a tuple of (result, 
state_update). "
+                    f"For all non-final results (intermediate),"
+                    f"the state update must be None"
+                )
+            result, state_update = item
+            if state_update is None:
+                count += 1
+                if first_stream_start_time is None:
+                    first_stream_start_time = system.now()
+                await 
lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
+                    "post_stream_item",
+                    item=result,
+                    item_index=count,
+                    stream_initialize_time=stream_initialize_time,
+                    first_stream_item_start_time=first_stream_start_time,
+                    action=action.name,
+                    app_id=app_id,
+                    partition_key=partition_key,
+                    sequence_id=sequence_id,
+                )
+                yield result, None
+    except Exception as e:
         if state_update is None:
-            if first_stream_start_time is None:
-                first_stream_start_time = system.now()
-            await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
-                "post_stream_item",
-                item=result,
-                item_index=count,
-                stream_initialize_time=stream_initialize_time,
-                first_stream_item_start_time=first_stream_start_time,
-                action=action.name,
-                app_id=app_id,
-                partition_key=partition_key,
-                sequence_id=sequence_id,
-            )
-            count += 1
-            yield result, None
+            raise
+        logger.warning(
+            "Streaming action '%s' raised %s after yielding %d items. "
+            "Proceeding with final state from generator cleanup. Original 
error: %s",
+            action.name,
+            type(e).__name__,
+            count,
+            e,
+            exc_info=True,
+        )
+
     if state_update is None:
         raise ValueError(
             f"Action {action.name} did not return a state update. For async 
actions, the last yield "
@@ -450,28 +477,42 @@ def _run_multi_step_streaming_action(
     result = None
     first_stream_start_time = None
     count = 0
-    for item in generator:
-        # We want to peek ahead so we can return the last one
-        # This is slightly eager, but only in the case in which we
-        # are using a multi-step streaming action
-        next_result = result
-        result = item
-        if next_result is not None:
-            if first_stream_start_time is None:
-                first_stream_start_time = system.now()
-            lifecycle_adapters.call_all_lifecycle_hooks_sync(
-                "post_stream_item",
-                item=next_result,
-                item_index=count,
-                stream_initialize_time=stream_initialize_time,
-                first_stream_item_start_time=first_stream_start_time,
-                action=action.name,
-                app_id=app_id,
-                partition_key=partition_key,
-                sequence_id=sequence_id,
-            )
-            count += 1
-            yield next_result, None
+    try:
+        for item in generator:
+            # We want to peek ahead so we can return the last one
+            # This is slightly eager, but only in the case in which we
+            # are using a multi-step streaming action
+            next_result = result
+            result = item
+            if next_result is not None:
+                if first_stream_start_time is None:
+                    first_stream_start_time = system.now()
+                lifecycle_adapters.call_all_lifecycle_hooks_sync(
+                    "post_stream_item",
+                    item=next_result,
+                    item_index=count,
+                    stream_initialize_time=stream_initialize_time,
+                    first_stream_item_start_time=first_stream_start_time,
+                    action=action.name,
+                    app_id=app_id,
+                    partition_key=partition_key,
+                    sequence_id=sequence_id,
+                )
+                count += 1
+                yield next_result, None
+    except Exception as e:
+        if result is None:
+            raise
+        logger.warning(
+            "Streaming action '%s' raised %s after yielding %d items. "
+            "Proceeding with last yielded result for reducer. "
+            "Note: the reducer will run on potentially partial data. Original 
error: %s",
+            action.name,
+            type(e).__name__,
+            count,
+            e,
+            exc_info=True,
+        )
     state_update = _run_reducer(action, state, result, action.name)
     _validate_result(result, action.name, action.schema)
     _validate_reducer_writes(action, state_update, action.name)
@@ -494,28 +535,42 @@ async def _arun_multi_step_streaming_action(
     result = None
     first_stream_start_time = None
     count = 0
-    async for item in generator:
-        # We want to peek ahead so we can return the last one
-        # This is slightly eager, but only in the case in which we
-        # are using a multi-step streaming action
-        next_result = result
-        result = item
-        if next_result is not None:
-            if first_stream_start_time is None:
-                first_stream_start_time = system.now()
-            await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
-                "post_stream_item",
-                item=next_result,
-                stream_initialize_time=stream_initialize_time,
-                item_index=count,
-                first_stream_item_start_time=first_stream_start_time,
-                action=action.name,
-                app_id=app_id,
-                partition_key=partition_key,
-                sequence_id=sequence_id,
-            )
-            count += 1
-            yield next_result, None
+    try:
+        async for item in generator:
+            # We want to peek ahead so we can return the last one
+            # This is slightly eager, but only in the case in which we
+            # are using a multi-step streaming action
+            next_result = result
+            result = item
+            if next_result is not None:
+                if first_stream_start_time is None:
+                    first_stream_start_time = system.now()
+                await 
lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
+                    "post_stream_item",
+                    item=next_result,
+                    stream_initialize_time=stream_initialize_time,
+                    item_index=count,
+                    first_stream_item_start_time=first_stream_start_time,
+                    action=action.name,
+                    app_id=app_id,
+                    partition_key=partition_key,
+                    sequence_id=sequence_id,
+                )
+                count += 1
+                yield next_result, None
+    except Exception as e:
+        if result is None:
+            raise
+        logger.warning(
+            "Streaming action '%s' raised %s after yielding %d items. "
+            "Proceeding with last yielded result for reducer. "
+            "Note: the reducer will run on potentially partial data. Original 
error: %s",
+            action.name,
+            type(e).__name__,
+            count,
+            e,
+            exc_info=True,
+        )
     state_update = _run_reducer(action, state, result, action.name)
     _validate_result(result, action.name, action.schema)
     _validate_reducer_writes(action, state_update, action.name)
@@ -1862,7 +1917,9 @@ class Application(Generic[ApplicationStateType]):
         halt_before: Optional[Union[str, List[str]]] = None,
         inputs: Optional[Dict[str, Any]] = None,
     ) -> Generator[
-        Tuple[Action, StreamingResultContainer[ApplicationStateType, 
Union[dict, Any]]], None, None
+        Tuple[Action, StreamingResultContainer[ApplicationStateType, 
Union[dict, Any]]],
+        None,
+        None,
     ]:
         """Produces an iterator that iterates through intermediate streams. 
You may want
         to use this in something like deep research mode in which:
@@ -1905,7 +1962,11 @@ class Application(Generic[ApplicationStateType]):
         halt_before: Optional[Union[str, List[str]]] = None,
         inputs: Optional[Dict[str, Any]] = None,
     ) -> AsyncGenerator[
-        Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType, 
Union[dict, Any]]], None
+        Tuple[
+            Action,
+            AsyncStreamingResultContainer[ApplicationStateType, Union[dict, 
Any]],
+        ],
+        None,
     ]:
         """Async version of stream_iterate. Produces an async generator that 
iterates
         through intermediate streams. See stream_iterate for more details.
diff --git a/tests/core/test_application.py b/tests/core/test_application.py
index c90c4067..7383583f 100644
--- a/tests/core/test_application.py
+++ b/tests/core/test_application.py
@@ -630,7 +630,12 @@ def 
test_run_single_step_streaming_action_errors_missing_write():
     state = State()
     with pytest.raises(ValueError, match="missing_value"):
         gen = _run_single_step_streaming_action(
-            action, state, inputs={}, sequence_id=0, 
partition_key="partition_key", app_id="app_id"
+            action,
+            state,
+            inputs={},
+            sequence_id=0,
+            partition_key="partition_key",
+            app_id="app_id",
         )
         collections.deque(gen, maxlen=0)  # exhaust the generator
 
@@ -687,7 +692,12 @@ def 
test_run_multi_step_streaming_action_errors_missing_write():
     state = State()
     with pytest.raises(ValueError, match="missing_value"):
         gen = _run_multi_step_streaming_action(
-            action, state, inputs={}, sequence_id=0, 
partition_key="partition_key", app_id="app_id"
+            action,
+            state,
+            inputs={},
+            sequence_id=0,
+            partition_key="partition_key",
+            app_id="app_id",
         )
         collections.deque(gen, maxlen=0)  # exhaust the generator
 
@@ -1005,7 +1015,12 @@ def test__run_multistep_streaming_action():
     action = base_streaming_counter.with_name("counter")
     state = State({"count": 0, "tracker": []})
     generator = _run_multi_step_streaming_action(
-        action, state, inputs={}, sequence_id=0, 
partition_key="partition_key", app_id="app_id"
+        action,
+        state,
+        inputs={},
+        sequence_id=0,
+        partition_key="partition_key",
+        app_id="app_id",
     )
     last_result = -1
     result = None
@@ -1111,7 +1126,12 @@ def test__run_streaming_action_incorrect_result_type():
     state = State()
     with pytest.raises(ValueError, match="returned a non-dict"):
         gen = _run_multi_step_streaming_action(
-            action, state, inputs={}, sequence_id=0, 
partition_key="partition_key", app_id="app_id"
+            action,
+            state,
+            inputs={},
+            sequence_id=0,
+            partition_key="partition_key",
+            app_id="app_id",
         )
         collections.deque(gen, maxlen=0)  # exhaust the generator
 
@@ -1168,7 +1188,12 @@ def test__run_single_step_streaming_action():
     action = base_streaming_single_step_counter.with_name("counter")
     state = State({"count": 0, "tracker": []})
     generator = _run_single_step_streaming_action(
-        action, state, inputs={}, sequence_id=0, 
partition_key="partition_key", app_id="app_id"
+        action,
+        state,
+        inputs={},
+        sequence_id=0,
+        partition_key="partition_key",
+        app_id="app_id",
     )
     last_result = -1
     result, state = None, None
@@ -1275,6 +1300,328 @@ async def 
test__run_single_step_streaming_action_async_callbacks():
     assert len(hook.items) == 10  # one for each streaming callback
 
 
+class SingleStepStreamingCounterWithException(SingleStepStreamingAction):
+    """Yields intermediate items, raises, then yields final state in finally 
block."""
+
+    def stream_run_and_update(
+        self, state: State, **run_kwargs
+    ) -> Generator[Tuple[dict, Optional[State]], None, None]:
+        count = state["count"]
+        try:
+            for i in range(3):
+                yield {"count": count + ((i + 1) / 10)}, None
+            raise RuntimeError("simulated failure")
+        finally:
+            yield {"count": count + 1}, state.update(count=count + 
1).append(tracker=count + 1)
+
+    @property
+    def reads(self) -> list[str]:
+        return ["count"]
+
+    @property
+    def writes(self) -> list[str]:
+        return ["count", "tracker"]
+
+
+class 
SingleStepStreamingCounterWithExceptionNoState(SingleStepStreamingAction):
+    """Raises without ever yielding a final state update."""
+
+    def stream_run_and_update(
+        self, state: State, **run_kwargs
+    ) -> Generator[Tuple[dict, Optional[State]], None, None]:
+        count = state["count"]
+        for i in range(3):
+            yield {"count": count + ((i + 1) / 10)}, None
+        raise RuntimeError("simulated failure with no state")
+
+    @property
+    def reads(self) -> list[str]:
+        return ["count"]
+
+    @property
+    def writes(self) -> list[str]:
+        return ["count", "tracker"]
+
+
+class SingleStepStreamingCounterWithExceptionAsync(SingleStepStreamingAction):
+    """Async variant: yields intermediate items, raises, then yields final 
state in finally."""
+
+    async def stream_run_and_update(
+        self, state: State, **run_kwargs
+    ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
+        count = state["count"]
+        try:
+            for i in range(3):
+                yield {"count": count + ((i + 1) / 10)}, None
+            raise RuntimeError("simulated failure")
+        finally:
+            yield {"count": count + 1}, state.update(count=count + 
1).append(tracker=count + 1)
+
+    @property
+    def reads(self) -> list[str]:
+        return ["count"]
+
+    @property
+    def writes(self) -> list[str]:
+        return ["count", "tracker"]
+
+
+class 
SingleStepStreamingCounterWithExceptionNoStateAsync(SingleStepStreamingAction):
+    """Async variant: raises without ever yielding a final state update."""
+
+    async def stream_run_and_update(
+        self, state: State, **run_kwargs
+    ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
+        count = state["count"]
+        for i in range(3):
+            yield {"count": count + ((i + 1) / 10)}, None
+        raise RuntimeError("simulated failure with no state")
+
+    @property
+    def reads(self) -> list[str]:
+        return ["count"]
+
+    @property
+    def writes(self) -> list[str]:
+        return ["count", "tracker"]
+
+
+class MultiStepStreamingCounterWithException(StreamingAction):
+    """Yields intermediate items, raises, then yields final result in finally 
block."""
+
+    def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, 
None]:
+        count = state["count"]
+        try:
+            for i in range(3):
+                yield {"count": count + ((i + 1) / 10)}
+            raise RuntimeError("simulated failure")
+        finally:
+            yield {"count": count + 1}
+
+    @property
+    def reads(self) -> list[str]:
+        return ["count"]
+
+    @property
+    def writes(self) -> list[str]:
+        return ["count", "tracker"]
+
+    def update(self, result: dict, state: State) -> State:
+        return state.update(**result).append(tracker=result["count"])
+
+
+class MultiStepStreamingCounterWithExceptionNoResult(StreamingAction):
+    """Raises without ever yielding any item."""
+
+    def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, 
None]:
+        raise RuntimeError("simulated failure with no result")
+        yield  # make this a generator function
+
+    @property
+    def reads(self) -> list[str]:
+        return ["count"]
+
+    @property
+    def writes(self) -> list[str]:
+        return ["count", "tracker"]
+
+    def update(self, result: dict, state: State) -> State:
+        return state.update(**result).append(tracker=result["count"])
+
+
+class MultiStepStreamingCounterWithExceptionAsync(AsyncStreamingAction):
+    """Async variant: yields intermediate items, raises, then yields final 
result in finally."""
+
+    async def stream_run(self, state: State, **run_kwargs) -> 
AsyncGenerator[dict, None]:
+        count = state["count"]
+        try:
+            for i in range(3):
+                yield {"count": count + ((i + 1) / 10)}
+            raise RuntimeError("simulated failure")
+        finally:
+            yield {"count": count + 1}
+
+    @property
+    def reads(self) -> list[str]:
+        return ["count"]
+
+    @property
+    def writes(self) -> list[str]:
+        return ["count", "tracker"]
+
+    def update(self, result: dict, state: State) -> State:
+        return state.update(**result).append(tracker=result["count"])
+
+
+class 
MultiStepStreamingCounterWithExceptionNoResultAsync(AsyncStreamingAction):
+    """Async variant: raises without ever yielding any item."""
+
+    async def stream_run(self, state: State, **run_kwargs) -> 
AsyncGenerator[dict, None]:
+        raise RuntimeError("simulated failure with no result")
+        yield  # make this an async generator
+
+    @property
+    def reads(self) -> list[str]:
+        return ["count"]
+
+    @property
+    def writes(self) -> list[str]:
+        return ["count", "tracker"]
+
+    def update(self, result: dict, state: State) -> State:
+        return state.update(**result).append(tracker=result["count"])
+
+
+def test__run_single_step_streaming_action_graceful_exception():
+    """When the generator raises but yields a final state in finally, stream 
completes gracefully."""
+    action = SingleStepStreamingCounterWithException().with_name("counter")
+    state = State({"count": 0, "tracker": []})
+    generator = _run_single_step_streaming_action(
+        action, state, inputs={}, sequence_id=0, partition_key="pk", 
app_id="app"
+    )
+    results = list(generator)
+    intermediate = [(r, s) for r, s in results if s is None]
+    final = [(r, s) for r, s in results if s is not None]
+    assert len(intermediate) == 3
+    assert len(final) == 1
+    assert final[0][0] == {"count": 1}
+    assert final[0][1].subset("count", "tracker").get_all() == {
+        "count": 1,
+        "tracker": [1],
+    }
+
+
+def test__run_single_step_streaming_action_exception_propagates():
+    """When the generator raises without yielding a final state, exception 
propagates."""
+    action = 
SingleStepStreamingCounterWithExceptionNoState().with_name("counter")
+    state = State({"count": 0, "tracker": []})
+    generator = _run_single_step_streaming_action(
+        action, state, inputs={}, sequence_id=0, partition_key="pk", 
app_id="app"
+    )
+    with pytest.raises(RuntimeError, match="simulated failure with no state"):
+        list(generator)
+
+
+async def test__run_single_step_streaming_action_graceful_exception_async():
+    """Async: when the generator raises but yields a final state in finally, 
stream completes."""
+    action = 
SingleStepStreamingCounterWithExceptionAsync().with_name("counter")
+    state = State({"count": 0, "tracker": []})
+    generator = _arun_single_step_streaming_action(
+        action=action,
+        state=state,
+        inputs={},
+        sequence_id=0,
+        app_id="app",
+        partition_key="pk",
+        lifecycle_adapters=LifecycleAdapterSet(),
+    )
+    results = []
+    async for item in generator:
+        results.append(item)
+    intermediate = [(r, s) for r, s in results if s is None]
+    final = [(r, s) for r, s in results if s is not None]
+    assert len(intermediate) == 3
+    assert len(final) == 1
+    assert final[0][0] == {"count": 1}
+    assert final[0][1].subset("count", "tracker").get_all() == {
+        "count": 1,
+        "tracker": [1],
+    }
+
+
+async def test__run_single_step_streaming_action_exception_propagates_async():
+    """Async: when the generator raises without yielding a final state, 
exception propagates."""
+    action = 
SingleStepStreamingCounterWithExceptionNoStateAsync().with_name("counter")
+    state = State({"count": 0, "tracker": []})
+    generator = _arun_single_step_streaming_action(
+        action=action,
+        state=state,
+        inputs={},
+        sequence_id=0,
+        app_id="app",
+        partition_key="pk",
+        lifecycle_adapters=LifecycleAdapterSet(),
+    )
+    with pytest.raises(RuntimeError, match="simulated failure with no state"):
+        async for _ in generator:
+            pass
+
+
+def test__run_multi_step_streaming_action_graceful_exception():
+    """When the generator raises but yields a final result in finally, stream 
completes."""
+    action = MultiStepStreamingCounterWithException().with_name("counter")
+    state = State({"count": 0, "tracker": []})
+    generator = _run_multi_step_streaming_action(
+        action, state, inputs={}, sequence_id=0, partition_key="pk", 
app_id="app"
+    )
+    results = list(generator)
+    intermediate = [(r, s) for r, s in results if s is None]
+    final = [(r, s) for r, s in results if s is not None]
+    assert len(intermediate) == 3
+    assert len(final) == 1
+    assert final[0][0] == {"count": 1}
+    assert final[0][1].subset("count", "tracker").get_all() == {
+        "count": 1,
+        "tracker": [1],
+    }
+
+
+def test__run_multi_step_streaming_action_exception_propagates():
+    """When the generator raises without yielding any result, exception 
propagates."""
+    action = 
MultiStepStreamingCounterWithExceptionNoResult().with_name("counter")
+    state = State({"count": 0, "tracker": []})
+    generator = _run_multi_step_streaming_action(
+        action, state, inputs={}, sequence_id=0, partition_key="pk", 
app_id="app"
+    )
+    with pytest.raises(RuntimeError, match="simulated failure with no result"):
+        list(generator)
+
+
+async def test__run_multi_step_streaming_action_graceful_exception_async():
+    """Async: when the generator raises but yields a final result in finally, 
stream completes."""
+    action = MultiStepStreamingCounterWithExceptionAsync().with_name("counter")
+    state = State({"count": 0, "tracker": []})
+    generator = _arun_multi_step_streaming_action(
+        action=action,
+        state=state,
+        inputs={},
+        sequence_id=0,
+        app_id="app",
+        partition_key="pk",
+        lifecycle_adapters=LifecycleAdapterSet(),
+    )
+    results = []
+    async for item in generator:
+        results.append(item)
+    intermediate = [(r, s) for r, s in results if s is None]
+    final = [(r, s) for r, s in results if s is not None]
+    assert len(intermediate) == 3
+    assert len(final) == 1
+    assert final[0][0] == {"count": 1}
+    assert final[0][1].subset("count", "tracker").get_all() == {
+        "count": 1,
+        "tracker": [1],
+    }
+
+
+async def test__run_multi_step_streaming_action_exception_propagates_async():
+    """Async: when the generator raises without yielding any result, exception 
propagates."""
+    action = 
MultiStepStreamingCounterWithExceptionNoResultAsync().with_name("counter")
+    state = State({"count": 0, "tracker": []})
+    generator = _arun_multi_step_streaming_action(
+        action=action,
+        state=state,
+        inputs={},
+        sequence_id=0,
+        app_id="app",
+        partition_key="pk",
+        lifecycle_adapters=LifecycleAdapterSet(),
+    )
+    with pytest.raises(RuntimeError, match="simulated failure with no result"):
+        async for _ in generator:
+            pass
+
+
 class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion):
     async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, 
State]:
         return {}, state.wipe(delete=["to_delete"])
@@ -1935,7 +2282,11 @@ async def test_app_a_run_async_and_sync():
         graph=Graph(
             actions=[counter_action_sync, counter_action_async, result_action],
             transitions=[
-                Transition(counter_action_sync, counter_action_async, 
Condition.expr("count < 20")),
+                Transition(
+                    counter_action_sync,
+                    counter_action_async,
+                    Condition.expr("count < 20"),
+                ),
                 Transition(counter_action_async, counter_action_sync, default),
                 Transition(counter_action_sync, result_action, default),
             ],
@@ -2060,7 +2411,8 @@ async def 
test_astream_result_halt_after_unique_ordered_sequence_id():
 
 def test_stream_result_halt_after_run_through_streaming():
     """Tests that we can pass through streaming results,
-    fully realize them, then get to the streaming results at the end and 
return the stream"""
+    fully realize them, then get to the streaming results at the end and 
return the stream
+    """
     action_tracker = CallCaptureTracker()
     stream_event_tracker = StreamEventCaptureTracker()
     counter_action = base_streaming_single_step_counter.with_name("counter")
@@ -2665,7 +3017,10 @@ def test__adjust_single_step_output_result_and_state():
 
 def test__adjust_single_step_output_just_state():
     state = State({"count": 1})
-    assert _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) == 
({}, state)
+    assert _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) == 
(
+        {},
+        state,
+    )
 
 
 def test__adjust_single_step_output_errors_incorrect_type():
@@ -2912,7 +3267,11 @@ class BrokenPersister(BaseStatePersister):
     """Broken persistor."""
 
     def load(
-        self, partition_key: str, app_id: Optional[str], sequence_id: 
Optional[int] = None, **kwargs
+        self,
+        partition_key: str,
+        app_id: Optional[str],
+        sequence_id: Optional[int] = None,
+        **kwargs,
     ) -> Optional[PersistedStateData]:
         return dict(
             partition_key="key",
@@ -2968,7 +3327,8 @@ def 
test_load_from_sync_cannot_have_async_persistor_error():
         default_entrypoint="foo",
     )
     with pytest.raises(
-        ValueError, match="are building the sync application, but have used an 
async initializer."
+        ValueError,
+        match="are building the sync application, but have used an async 
initializer.",
     ):
         # we have not initialized
         builder._load_from_sync_persister()
@@ -2984,7 +3344,8 @@ async def 
test_load_from_async_cannot_have_sync_persistor_error():
         default_entrypoint="foo",
     )
     with pytest.raises(
-        ValueError, match="are building the async application, but have used 
an sync initializer."
+        ValueError,
+        match="are building the async application, but have used an sync 
initializer.",
     ):
         # we have not initialized
         await builder._load_from_async_persister()
@@ -3055,7 +3416,11 @@ class DummyPersister(BaseStatePersister):
     """Dummy persistor."""
 
     def load(
-        self, partition_key: str, app_id: Optional[str], sequence_id: 
Optional[int] = None, **kwargs
+        self,
+        partition_key: str,
+        app_id: Optional[str],
+        sequence_id: Optional[int] = None,
+        **kwargs,
     ) -> Optional[PersistedStateData]:
         return PersistedStateData(
             partition_key="user123",
@@ -3416,7 +3781,10 @@ def test_application_recursive_action_lifecycle_hooks():
     hook = TestingHook()
     foo = []
 
-    @action(reads=["recursion_count", "total_count"], 
writes=["recursion_count", "total_count"])
+    @action(
+        reads=["recursion_count", "total_count"],
+        writes=["recursion_count", "total_count"],
+    )
     def recursive_action(state: State) -> State:
         foo.append(1)
         recursion_count = state["recursion_count"]
@@ -3498,7 +3866,8 @@ def 
test_set_sync_state_persister_cannot_have_async_error():
     persister = AsyncDevNullPersister()
     builder.with_state_persister(persister)
     with pytest.raises(
-        ValueError, match="are building the sync application, but have used an 
async persister."
+        ValueError,
+        match="are building the sync application, but have used an async 
persister.",
     ):
         # we have not initialized
         builder._set_sync_state_persister()
@@ -3519,7 +3888,8 @@ async def 
test_set_async_state_persister_cannot_have_sync_error():
     persister = DevNullPersister()
     builder.with_state_persister(persister)
     with pytest.raises(
-        ValueError, match="are building the async application, but have used 
an sync persister."
+        ValueError,
+        match="are building the async application, but have used an sync 
persister.",
     ):
         # we have not initialized
         await builder._set_async_state_persister()
@@ -3620,15 +3990,27 @@ class ActionWithContextTracer(ActionWithoutContext):
 def test_remap_context_variable_with_mangled_context_kwargs():
     _action = ActionWithKwargs()
 
-    inputs = {"__context": "context_value", "other_key": "other_value", "foo": 
"foo_value"}
-    expected = {"__context": "context_value", "other_key": "other_value", 
"foo": "foo_value"}
+    inputs = {
+        "__context": "context_value",
+        "other_key": "other_value",
+        "foo": "foo_value",
+    }
+    expected = {
+        "__context": "context_value",
+        "other_key": "other_value",
+        "foo": "foo_value",
+    }
     assert _remap_dunder_parameters(_action.run, inputs, ["__context", 
"__tracer"]) == expected
 
 
 def test_remap_context_variable_with_mangled_context():
     _action = ActionWithContext()
 
-    inputs = {"__context": "context_value", "other_key": "other_value", "foo": 
"foo_value"}
+    inputs = {
+        "__context": "context_value",
+        "other_key": "other_value",
+        "foo": "foo_value",
+    }
     expected = {
         f"_{ActionWithContext.__name__}__context": "context_value",
         "other_key": "other_value",
@@ -3657,8 +4039,16 @@ def 
test_remap_context_variable_with_mangled_contexttracer():
 
 def test_remap_context_variable_without_mangled_context():
     _action = ActionWithoutContext()
-    inputs = {"__context": "context_value", "other_key": "other_value", "foo": 
"foo_value"}
-    expected = {"__context": "context_value", "other_key": "other_value", 
"foo": "foo_value"}
+    inputs = {
+        "__context": "context_value",
+        "other_key": "other_value",
+        "foo": "foo_value",
+    }
+    expected = {
+        "__context": "context_value",
+        "other_key": "other_value",
+        "foo": "foo_value",
+    }
     assert _remap_dunder_parameters(_action.run, inputs, ["__context", 
"__tracer"]) == expected
 
 


Reply via email to