This is an automated email from the ASF dual-hosted git repository.
skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/burr.git
The following commit(s) were added to refs/heads/main by this push:
new 2c5abdb0 fix: graceful stream shutdown on exceptions in streaming
actions (#680)
2c5abdb0 is described below
commit 2c5abdb0e3829228bb1397c8063eff7d9b85c6ae
Author: André Ahlert <[email protected]>
AuthorDate: Sat Mar 28 16:24:33 2026 -0300
fix: graceful stream shutdown on exceptions in streaming actions (#680)
* fix: allow streaming actions to gracefully handle raised exceptions
When a streaming action catches an exception and yields a final state
in a try/except/finally block, the stream now completes gracefully
instead of propagating the exception and killing the connection.
If the generator yields a valid state_update before the exception
propagates, the exception is suppressed and the stream terminates
normally. If no state was yielded, the exception propagates as before.
Closes #581
* fix: add logging for suppressed exceptions and tests for streaming
graceful shutdown
- Add logger.warning with exc_info in all 4 streaming except blocks
- Remove dead caught_exc variable in multi-step functions
- Fix count increment asymmetry between sync and async single-step
- Add 8 tests covering graceful exception handling and propagation
* style: apply black formatting
* style: apply pre-commit formatting (black 23.11.0, isort 5.12.0)
---
burr/core/application.py | 249 +++++++++++++++---------
tests/core/test_application.py | 430 +++++++++++++++++++++++++++++++++++++++--
2 files changed, 565 insertions(+), 114 deletions(-)
diff --git a/burr/core/application.py b/burr/core/application.py
index 84e3f7cb..7cbe9614 100644
--- a/burr/core/application.py
+++ b/burr/core/application.py
@@ -338,31 +338,44 @@ def _run_single_step_streaming_action(
result = None
state_update = None
count = 0
- for item in generator:
- if not isinstance(item, tuple):
- # TODO -- consider adding support for just returning a result.
- raise ValueError(
- f"Action {action.name} must yield a tuple of (result,
state_update). "
- f"For all non-final results (intermediate),"
- f"the state update must be None"
- )
- result, state_update = item
- count += 1
+ try:
+ for item in generator:
+ if not isinstance(item, tuple):
+ # TODO -- consider adding support for just returning a result.
+ raise ValueError(
+ f"Action {action.name} must yield a tuple of (result,
state_update). "
+ f"For all non-final results (intermediate),"
+ f"the state update must be None"
+ )
+ result, state_update = item
+ if state_update is None:
+ count += 1
+ if first_stream_start_time is None:
+ first_stream_start_time = system.now()
+ lifecycle_adapters.call_all_lifecycle_hooks_sync(
+ "post_stream_item",
+ item=result,
+ item_index=count,
+ stream_initialize_time=stream_initialize_time,
+ first_stream_item_start_time=first_stream_start_time,
+ action=action.name,
+ app_id=app_id,
+ partition_key=partition_key,
+ sequence_id=sequence_id,
+ )
+ yield result, None
+ except Exception as e:
if state_update is None:
- if first_stream_start_time is None:
- first_stream_start_time = system.now()
- lifecycle_adapters.call_all_lifecycle_hooks_sync(
- "post_stream_item",
- item=result,
- item_index=count,
- stream_initialize_time=stream_initialize_time,
- first_stream_item_start_time=first_stream_start_time,
- action=action.name,
- app_id=app_id,
- partition_key=partition_key,
- sequence_id=sequence_id,
- )
- yield result, None
+ raise
+ logger.warning(
+ "Streaming action '%s' raised %s after yielding %d items. "
+ "Proceeding with final state from generator cleanup. Original
error: %s",
+ action.name,
+ type(e).__name__,
+ count,
+ e,
+ exc_info=True,
+ )
if state_update is None:
raise ValueError(
@@ -391,31 +404,45 @@ async def _arun_single_step_streaming_action(
result = None
state_update = None
count = 0
- async for item in generator:
- if not isinstance(item, tuple):
- # TODO -- consider adding support for just returning a result.
- raise ValueError(
- f"Action {action.name} must yield a tuple of (result,
state_update). "
- f"For all non-final results (intermediate),"
- f"the state update must be None"
- )
- result, state_update = item
+ try:
+ async for item in generator:
+ if not isinstance(item, tuple):
+ # TODO -- consider adding support for just returning a result.
+ raise ValueError(
+ f"Action {action.name} must yield a tuple of (result,
state_update). "
+ f"For all non-final results (intermediate),"
+ f"the state update must be None"
+ )
+ result, state_update = item
+ if state_update is None:
+ count += 1
+ if first_stream_start_time is None:
+ first_stream_start_time = system.now()
+ await
lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
+ "post_stream_item",
+ item=result,
+ item_index=count,
+ stream_initialize_time=stream_initialize_time,
+ first_stream_item_start_time=first_stream_start_time,
+ action=action.name,
+ app_id=app_id,
+ partition_key=partition_key,
+ sequence_id=sequence_id,
+ )
+ yield result, None
+ except Exception as e:
if state_update is None:
- if first_stream_start_time is None:
- first_stream_start_time = system.now()
- await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
- "post_stream_item",
- item=result,
- item_index=count,
- stream_initialize_time=stream_initialize_time,
- first_stream_item_start_time=first_stream_start_time,
- action=action.name,
- app_id=app_id,
- partition_key=partition_key,
- sequence_id=sequence_id,
- )
- count += 1
- yield result, None
+ raise
+ logger.warning(
+ "Streaming action '%s' raised %s after yielding %d items. "
+ "Proceeding with final state from generator cleanup. Original
error: %s",
+ action.name,
+ type(e).__name__,
+ count,
+ e,
+ exc_info=True,
+ )
+
if state_update is None:
raise ValueError(
f"Action {action.name} did not return a state update. For async
actions, the last yield "
@@ -450,28 +477,42 @@ def _run_multi_step_streaming_action(
result = None
first_stream_start_time = None
count = 0
- for item in generator:
- # We want to peek ahead so we can return the last one
- # This is slightly eager, but only in the case in which we
- # are using a multi-step streaming action
- next_result = result
- result = item
- if next_result is not None:
- if first_stream_start_time is None:
- first_stream_start_time = system.now()
- lifecycle_adapters.call_all_lifecycle_hooks_sync(
- "post_stream_item",
- item=next_result,
- item_index=count,
- stream_initialize_time=stream_initialize_time,
- first_stream_item_start_time=first_stream_start_time,
- action=action.name,
- app_id=app_id,
- partition_key=partition_key,
- sequence_id=sequence_id,
- )
- count += 1
- yield next_result, None
+ try:
+ for item in generator:
+ # We want to peek ahead so we can return the last one
+ # This is slightly eager, but only in the case in which we
+ # are using a multi-step streaming action
+ next_result = result
+ result = item
+ if next_result is not None:
+ if first_stream_start_time is None:
+ first_stream_start_time = system.now()
+ lifecycle_adapters.call_all_lifecycle_hooks_sync(
+ "post_stream_item",
+ item=next_result,
+ item_index=count,
+ stream_initialize_time=stream_initialize_time,
+ first_stream_item_start_time=first_stream_start_time,
+ action=action.name,
+ app_id=app_id,
+ partition_key=partition_key,
+ sequence_id=sequence_id,
+ )
+ count += 1
+ yield next_result, None
+ except Exception as e:
+ if result is None:
+ raise
+ logger.warning(
+ "Streaming action '%s' raised %s after yielding %d items. "
+ "Proceeding with last yielded result for reducer. "
+ "Note: the reducer will run on potentially partial data. Original
error: %s",
+ action.name,
+ type(e).__name__,
+ count,
+ e,
+ exc_info=True,
+ )
state_update = _run_reducer(action, state, result, action.name)
_validate_result(result, action.name, action.schema)
_validate_reducer_writes(action, state_update, action.name)
@@ -494,28 +535,42 @@ async def _arun_multi_step_streaming_action(
result = None
first_stream_start_time = None
count = 0
- async for item in generator:
- # We want to peek ahead so we can return the last one
- # This is slightly eager, but only in the case in which we
- # are using a multi-step streaming action
- next_result = result
- result = item
- if next_result is not None:
- if first_stream_start_time is None:
- first_stream_start_time = system.now()
- await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
- "post_stream_item",
- item=next_result,
- stream_initialize_time=stream_initialize_time,
- item_index=count,
- first_stream_item_start_time=first_stream_start_time,
- action=action.name,
- app_id=app_id,
- partition_key=partition_key,
- sequence_id=sequence_id,
- )
- count += 1
- yield next_result, None
+ try:
+ async for item in generator:
+ # We want to peek ahead so we can return the last one
+ # This is slightly eager, but only in the case in which we
+ # are using a multi-step streaming action
+ next_result = result
+ result = item
+ if next_result is not None:
+ if first_stream_start_time is None:
+ first_stream_start_time = system.now()
+ await
lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async(
+ "post_stream_item",
+ item=next_result,
+ stream_initialize_time=stream_initialize_time,
+ item_index=count,
+ first_stream_item_start_time=first_stream_start_time,
+ action=action.name,
+ app_id=app_id,
+ partition_key=partition_key,
+ sequence_id=sequence_id,
+ )
+ count += 1
+ yield next_result, None
+ except Exception as e:
+ if result is None:
+ raise
+ logger.warning(
+ "Streaming action '%s' raised %s after yielding %d items. "
+ "Proceeding with last yielded result for reducer. "
+ "Note: the reducer will run on potentially partial data. Original
error: %s",
+ action.name,
+ type(e).__name__,
+ count,
+ e,
+ exc_info=True,
+ )
state_update = _run_reducer(action, state, result, action.name)
_validate_result(result, action.name, action.schema)
_validate_reducer_writes(action, state_update, action.name)
@@ -1862,7 +1917,9 @@ class Application(Generic[ApplicationStateType]):
halt_before: Optional[Union[str, List[str]]] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> Generator[
- Tuple[Action, StreamingResultContainer[ApplicationStateType,
Union[dict, Any]]], None, None
+ Tuple[Action, StreamingResultContainer[ApplicationStateType,
Union[dict, Any]]],
+ None,
+ None,
]:
"""Produces an iterator that iterates through intermediate streams.
You may want
to use this in something like deep research mode in which:
@@ -1905,7 +1962,11 @@ class Application(Generic[ApplicationStateType]):
halt_before: Optional[Union[str, List[str]]] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> AsyncGenerator[
- Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType,
Union[dict, Any]]], None
+ Tuple[
+ Action,
+ AsyncStreamingResultContainer[ApplicationStateType, Union[dict,
Any]],
+ ],
+ None,
]:
"""Async version of stream_iterate. Produces an async generator that
iterates
through intermediate streams. See stream_iterate for more details.
diff --git a/tests/core/test_application.py b/tests/core/test_application.py
index c90c4067..7383583f 100644
--- a/tests/core/test_application.py
+++ b/tests/core/test_application.py
@@ -630,7 +630,12 @@ def
test_run_single_step_streaming_action_errors_missing_write():
state = State()
with pytest.raises(ValueError, match="missing_value"):
gen = _run_single_step_streaming_action(
- action, state, inputs={}, sequence_id=0,
partition_key="partition_key", app_id="app_id"
+ action,
+ state,
+ inputs={},
+ sequence_id=0,
+ partition_key="partition_key",
+ app_id="app_id",
)
collections.deque(gen, maxlen=0) # exhaust the generator
@@ -687,7 +692,12 @@ def
test_run_multi_step_streaming_action_errors_missing_write():
state = State()
with pytest.raises(ValueError, match="missing_value"):
gen = _run_multi_step_streaming_action(
- action, state, inputs={}, sequence_id=0,
partition_key="partition_key", app_id="app_id"
+ action,
+ state,
+ inputs={},
+ sequence_id=0,
+ partition_key="partition_key",
+ app_id="app_id",
)
collections.deque(gen, maxlen=0) # exhaust the generator
@@ -1005,7 +1015,12 @@ def test__run_multistep_streaming_action():
action = base_streaming_counter.with_name("counter")
state = State({"count": 0, "tracker": []})
generator = _run_multi_step_streaming_action(
- action, state, inputs={}, sequence_id=0,
partition_key="partition_key", app_id="app_id"
+ action,
+ state,
+ inputs={},
+ sequence_id=0,
+ partition_key="partition_key",
+ app_id="app_id",
)
last_result = -1
result = None
@@ -1111,7 +1126,12 @@ def test__run_streaming_action_incorrect_result_type():
state = State()
with pytest.raises(ValueError, match="returned a non-dict"):
gen = _run_multi_step_streaming_action(
- action, state, inputs={}, sequence_id=0,
partition_key="partition_key", app_id="app_id"
+ action,
+ state,
+ inputs={},
+ sequence_id=0,
+ partition_key="partition_key",
+ app_id="app_id",
)
collections.deque(gen, maxlen=0) # exhaust the generator
@@ -1168,7 +1188,12 @@ def test__run_single_step_streaming_action():
action = base_streaming_single_step_counter.with_name("counter")
state = State({"count": 0, "tracker": []})
generator = _run_single_step_streaming_action(
- action, state, inputs={}, sequence_id=0,
partition_key="partition_key", app_id="app_id"
+ action,
+ state,
+ inputs={},
+ sequence_id=0,
+ partition_key="partition_key",
+ app_id="app_id",
)
last_result = -1
result, state = None, None
@@ -1275,6 +1300,328 @@ async def
test__run_single_step_streaming_action_async_callbacks():
assert len(hook.items) == 10 # one for each streaming callback
+class SingleStepStreamingCounterWithException(SingleStepStreamingAction):
+ """Yields intermediate items, raises, then yields final state in finally
block."""
+
+ def stream_run_and_update(
+ self, state: State, **run_kwargs
+ ) -> Generator[Tuple[dict, Optional[State]], None, None]:
+ count = state["count"]
+ try:
+ for i in range(3):
+ yield {"count": count + ((i + 1) / 10)}, None
+ raise RuntimeError("simulated failure")
+ finally:
+ yield {"count": count + 1}, state.update(count=count +
1).append(tracker=count + 1)
+
+ @property
+ def reads(self) -> list[str]:
+ return ["count"]
+
+ @property
+ def writes(self) -> list[str]:
+ return ["count", "tracker"]
+
+
+class
SingleStepStreamingCounterWithExceptionNoState(SingleStepStreamingAction):
+ """Raises without ever yielding a final state update."""
+
+ def stream_run_and_update(
+ self, state: State, **run_kwargs
+ ) -> Generator[Tuple[dict, Optional[State]], None, None]:
+ count = state["count"]
+ for i in range(3):
+ yield {"count": count + ((i + 1) / 10)}, None
+ raise RuntimeError("simulated failure with no state")
+
+ @property
+ def reads(self) -> list[str]:
+ return ["count"]
+
+ @property
+ def writes(self) -> list[str]:
+ return ["count", "tracker"]
+
+
+class SingleStepStreamingCounterWithExceptionAsync(SingleStepStreamingAction):
+ """Async variant: yields intermediate items, raises, then yields final
state in finally."""
+
+ async def stream_run_and_update(
+ self, state: State, **run_kwargs
+ ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
+ count = state["count"]
+ try:
+ for i in range(3):
+ yield {"count": count + ((i + 1) / 10)}, None
+ raise RuntimeError("simulated failure")
+ finally:
+ yield {"count": count + 1}, state.update(count=count +
1).append(tracker=count + 1)
+
+ @property
+ def reads(self) -> list[str]:
+ return ["count"]
+
+ @property
+ def writes(self) -> list[str]:
+ return ["count", "tracker"]
+
+
+class
SingleStepStreamingCounterWithExceptionNoStateAsync(SingleStepStreamingAction):
+ """Async variant: raises without ever yielding a final state update."""
+
+ async def stream_run_and_update(
+ self, state: State, **run_kwargs
+ ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
+ count = state["count"]
+ for i in range(3):
+ yield {"count": count + ((i + 1) / 10)}, None
+ raise RuntimeError("simulated failure with no state")
+
+ @property
+ def reads(self) -> list[str]:
+ return ["count"]
+
+ @property
+ def writes(self) -> list[str]:
+ return ["count", "tracker"]
+
+
+class MultiStepStreamingCounterWithException(StreamingAction):
+ """Yields intermediate items, raises, then yields final result in finally
block."""
+
+ def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None,
None]:
+ count = state["count"]
+ try:
+ for i in range(3):
+ yield {"count": count + ((i + 1) / 10)}
+ raise RuntimeError("simulated failure")
+ finally:
+ yield {"count": count + 1}
+
+ @property
+ def reads(self) -> list[str]:
+ return ["count"]
+
+ @property
+ def writes(self) -> list[str]:
+ return ["count", "tracker"]
+
+ def update(self, result: dict, state: State) -> State:
+ return state.update(**result).append(tracker=result["count"])
+
+
+class MultiStepStreamingCounterWithExceptionNoResult(StreamingAction):
+ """Raises without ever yielding any item."""
+
+ def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None,
None]:
+ raise RuntimeError("simulated failure with no result")
+ yield # make this a generator function
+
+ @property
+ def reads(self) -> list[str]:
+ return ["count"]
+
+ @property
+ def writes(self) -> list[str]:
+ return ["count", "tracker"]
+
+ def update(self, result: dict, state: State) -> State:
+ return state.update(**result).append(tracker=result["count"])
+
+
+class MultiStepStreamingCounterWithExceptionAsync(AsyncStreamingAction):
+ """Async variant: yields intermediate items, raises, then yields final
result in finally."""
+
+ async def stream_run(self, state: State, **run_kwargs) ->
AsyncGenerator[dict, None]:
+ count = state["count"]
+ try:
+ for i in range(3):
+ yield {"count": count + ((i + 1) / 10)}
+ raise RuntimeError("simulated failure")
+ finally:
+ yield {"count": count + 1}
+
+ @property
+ def reads(self) -> list[str]:
+ return ["count"]
+
+ @property
+ def writes(self) -> list[str]:
+ return ["count", "tracker"]
+
+ def update(self, result: dict, state: State) -> State:
+ return state.update(**result).append(tracker=result["count"])
+
+
+class
MultiStepStreamingCounterWithExceptionNoResultAsync(AsyncStreamingAction):
+ """Async variant: raises without ever yielding any item."""
+
+ async def stream_run(self, state: State, **run_kwargs) ->
AsyncGenerator[dict, None]:
+ raise RuntimeError("simulated failure with no result")
+ yield # make this an async generator
+
+ @property
+ def reads(self) -> list[str]:
+ return ["count"]
+
+ @property
+ def writes(self) -> list[str]:
+ return ["count", "tracker"]
+
+ def update(self, result: dict, state: State) -> State:
+ return state.update(**result).append(tracker=result["count"])
+
+
+def test__run_single_step_streaming_action_graceful_exception():
+ """When the generator raises but yields a final state in finally, stream
completes gracefully."""
+ action = SingleStepStreamingCounterWithException().with_name("counter")
+ state = State({"count": 0, "tracker": []})
+ generator = _run_single_step_streaming_action(
+ action, state, inputs={}, sequence_id=0, partition_key="pk",
app_id="app"
+ )
+ results = list(generator)
+ intermediate = [(r, s) for r, s in results if s is None]
+ final = [(r, s) for r, s in results if s is not None]
+ assert len(intermediate) == 3
+ assert len(final) == 1
+ assert final[0][0] == {"count": 1}
+ assert final[0][1].subset("count", "tracker").get_all() == {
+ "count": 1,
+ "tracker": [1],
+ }
+
+
+def test__run_single_step_streaming_action_exception_propagates():
+ """When the generator raises without yielding a final state, exception
propagates."""
+ action =
SingleStepStreamingCounterWithExceptionNoState().with_name("counter")
+ state = State({"count": 0, "tracker": []})
+ generator = _run_single_step_streaming_action(
+ action, state, inputs={}, sequence_id=0, partition_key="pk",
app_id="app"
+ )
+ with pytest.raises(RuntimeError, match="simulated failure with no state"):
+ list(generator)
+
+
+async def test__run_single_step_streaming_action_graceful_exception_async():
+ """Async: when the generator raises but yields a final state in finally,
stream completes."""
+ action =
SingleStepStreamingCounterWithExceptionAsync().with_name("counter")
+ state = State({"count": 0, "tracker": []})
+ generator = _arun_single_step_streaming_action(
+ action=action,
+ state=state,
+ inputs={},
+ sequence_id=0,
+ app_id="app",
+ partition_key="pk",
+ lifecycle_adapters=LifecycleAdapterSet(),
+ )
+ results = []
+ async for item in generator:
+ results.append(item)
+ intermediate = [(r, s) for r, s in results if s is None]
+ final = [(r, s) for r, s in results if s is not None]
+ assert len(intermediate) == 3
+ assert len(final) == 1
+ assert final[0][0] == {"count": 1}
+ assert final[0][1].subset("count", "tracker").get_all() == {
+ "count": 1,
+ "tracker": [1],
+ }
+
+
+async def test__run_single_step_streaming_action_exception_propagates_async():
+ """Async: when the generator raises without yielding a final state,
exception propagates."""
+ action =
SingleStepStreamingCounterWithExceptionNoStateAsync().with_name("counter")
+ state = State({"count": 0, "tracker": []})
+ generator = _arun_single_step_streaming_action(
+ action=action,
+ state=state,
+ inputs={},
+ sequence_id=0,
+ app_id="app",
+ partition_key="pk",
+ lifecycle_adapters=LifecycleAdapterSet(),
+ )
+ with pytest.raises(RuntimeError, match="simulated failure with no state"):
+ async for _ in generator:
+ pass
+
+
+def test__run_multi_step_streaming_action_graceful_exception():
+ """When the generator raises but yields a final result in finally, stream
completes."""
+ action = MultiStepStreamingCounterWithException().with_name("counter")
+ state = State({"count": 0, "tracker": []})
+ generator = _run_multi_step_streaming_action(
+ action, state, inputs={}, sequence_id=0, partition_key="pk",
app_id="app"
+ )
+ results = list(generator)
+ intermediate = [(r, s) for r, s in results if s is None]
+ final = [(r, s) for r, s in results if s is not None]
+ assert len(intermediate) == 3
+ assert len(final) == 1
+ assert final[0][0] == {"count": 1}
+ assert final[0][1].subset("count", "tracker").get_all() == {
+ "count": 1,
+ "tracker": [1],
+ }
+
+
+def test__run_multi_step_streaming_action_exception_propagates():
+ """When the generator raises without yielding any result, exception
propagates."""
+ action =
MultiStepStreamingCounterWithExceptionNoResult().with_name("counter")
+ state = State({"count": 0, "tracker": []})
+ generator = _run_multi_step_streaming_action(
+ action, state, inputs={}, sequence_id=0, partition_key="pk",
app_id="app"
+ )
+ with pytest.raises(RuntimeError, match="simulated failure with no result"):
+ list(generator)
+
+
+async def test__run_multi_step_streaming_action_graceful_exception_async():
+ """Async: when the generator raises but yields a final result in finally,
stream completes."""
+ action = MultiStepStreamingCounterWithExceptionAsync().with_name("counter")
+ state = State({"count": 0, "tracker": []})
+ generator = _arun_multi_step_streaming_action(
+ action=action,
+ state=state,
+ inputs={},
+ sequence_id=0,
+ app_id="app",
+ partition_key="pk",
+ lifecycle_adapters=LifecycleAdapterSet(),
+ )
+ results = []
+ async for item in generator:
+ results.append(item)
+ intermediate = [(r, s) for r, s in results if s is None]
+ final = [(r, s) for r, s in results if s is not None]
+ assert len(intermediate) == 3
+ assert len(final) == 1
+ assert final[0][0] == {"count": 1}
+ assert final[0][1].subset("count", "tracker").get_all() == {
+ "count": 1,
+ "tracker": [1],
+ }
+
+
+async def test__run_multi_step_streaming_action_exception_propagates_async():
+ """Async: when the generator raises without yielding any result, exception
propagates."""
+ action =
MultiStepStreamingCounterWithExceptionNoResultAsync().with_name("counter")
+ state = State({"count": 0, "tracker": []})
+ generator = _arun_multi_step_streaming_action(
+ action=action,
+ state=state,
+ inputs={},
+ sequence_id=0,
+ app_id="app",
+ partition_key="pk",
+ lifecycle_adapters=LifecycleAdapterSet(),
+ )
+ with pytest.raises(RuntimeError, match="simulated failure with no result"):
+ async for _ in generator:
+ pass
+
+
class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion):
async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict,
State]:
return {}, state.wipe(delete=["to_delete"])
@@ -1935,7 +2282,11 @@ async def test_app_a_run_async_and_sync():
graph=Graph(
actions=[counter_action_sync, counter_action_async, result_action],
transitions=[
- Transition(counter_action_sync, counter_action_async,
Condition.expr("count < 20")),
+ Transition(
+ counter_action_sync,
+ counter_action_async,
+ Condition.expr("count < 20"),
+ ),
Transition(counter_action_async, counter_action_sync, default),
Transition(counter_action_sync, result_action, default),
],
@@ -2060,7 +2411,8 @@ async def
test_astream_result_halt_after_unique_ordered_sequence_id():
def test_stream_result_halt_after_run_through_streaming():
"""Tests that we can pass through streaming results,
- fully realize them, then get to the streaming results at the end and
return the stream"""
+ fully realize them, then get to the streaming results at the end and
return the stream
+ """
action_tracker = CallCaptureTracker()
stream_event_tracker = StreamEventCaptureTracker()
counter_action = base_streaming_single_step_counter.with_name("counter")
@@ -2665,7 +3017,10 @@ def test__adjust_single_step_output_result_and_state():
def test__adjust_single_step_output_just_state():
state = State({"count": 1})
- assert _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) ==
({}, state)
+ assert _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) ==
(
+ {},
+ state,
+ )
def test__adjust_single_step_output_errors_incorrect_type():
@@ -2912,7 +3267,11 @@ class BrokenPersister(BaseStatePersister):
"""Broken persistor."""
def load(
- self, partition_key: str, app_id: Optional[str], sequence_id:
Optional[int] = None, **kwargs
+ self,
+ partition_key: str,
+ app_id: Optional[str],
+ sequence_id: Optional[int] = None,
+ **kwargs,
) -> Optional[PersistedStateData]:
return dict(
partition_key="key",
@@ -2968,7 +3327,8 @@ def
test_load_from_sync_cannot_have_async_persistor_error():
default_entrypoint="foo",
)
with pytest.raises(
- ValueError, match="are building the sync application, but have used an
async initializer."
+ ValueError,
+ match="are building the sync application, but have used an async
initializer.",
):
# we have not initialized
builder._load_from_sync_persister()
@@ -2984,7 +3344,8 @@ async def
test_load_from_async_cannot_have_sync_persistor_error():
default_entrypoint="foo",
)
with pytest.raises(
- ValueError, match="are building the async application, but have used
an sync initializer."
+ ValueError,
+ match="are building the async application, but have used an sync
initializer.",
):
# we have not initialized
await builder._load_from_async_persister()
@@ -3055,7 +3416,11 @@ class DummyPersister(BaseStatePersister):
"""Dummy persistor."""
def load(
- self, partition_key: str, app_id: Optional[str], sequence_id:
Optional[int] = None, **kwargs
+ self,
+ partition_key: str,
+ app_id: Optional[str],
+ sequence_id: Optional[int] = None,
+ **kwargs,
) -> Optional[PersistedStateData]:
return PersistedStateData(
partition_key="user123",
@@ -3416,7 +3781,10 @@ def test_application_recursive_action_lifecycle_hooks():
hook = TestingHook()
foo = []
- @action(reads=["recursion_count", "total_count"],
writes=["recursion_count", "total_count"])
+ @action(
+ reads=["recursion_count", "total_count"],
+ writes=["recursion_count", "total_count"],
+ )
def recursive_action(state: State) -> State:
foo.append(1)
recursion_count = state["recursion_count"]
@@ -3498,7 +3866,8 @@ def
test_set_sync_state_persister_cannot_have_async_error():
persister = AsyncDevNullPersister()
builder.with_state_persister(persister)
with pytest.raises(
- ValueError, match="are building the sync application, but have used an
async persister."
+ ValueError,
+ match="are building the sync application, but have used an async
persister.",
):
# we have not initialized
builder._set_sync_state_persister()
@@ -3519,7 +3888,8 @@ async def
test_set_async_state_persister_cannot_have_sync_error():
persister = DevNullPersister()
builder.with_state_persister(persister)
with pytest.raises(
- ValueError, match="are building the async application, but have used
an sync persister."
+ ValueError,
+ match="are building the async application, but have used an sync
persister.",
):
# we have not initialized
await builder._set_async_state_persister()
@@ -3620,15 +3990,27 @@ class ActionWithContextTracer(ActionWithoutContext):
def test_remap_context_variable_with_mangled_context_kwargs():
_action = ActionWithKwargs()
- inputs = {"__context": "context_value", "other_key": "other_value", "foo":
"foo_value"}
- expected = {"__context": "context_value", "other_key": "other_value",
"foo": "foo_value"}
+ inputs = {
+ "__context": "context_value",
+ "other_key": "other_value",
+ "foo": "foo_value",
+ }
+ expected = {
+ "__context": "context_value",
+ "other_key": "other_value",
+ "foo": "foo_value",
+ }
assert _remap_dunder_parameters(_action.run, inputs, ["__context",
"__tracer"]) == expected
def test_remap_context_variable_with_mangled_context():
_action = ActionWithContext()
- inputs = {"__context": "context_value", "other_key": "other_value", "foo":
"foo_value"}
+ inputs = {
+ "__context": "context_value",
+ "other_key": "other_value",
+ "foo": "foo_value",
+ }
expected = {
f"_{ActionWithContext.__name__}__context": "context_value",
"other_key": "other_value",
@@ -3657,8 +4039,16 @@ def
test_remap_context_variable_with_mangled_contexttracer():
def test_remap_context_variable_without_mangled_context():
_action = ActionWithoutContext()
- inputs = {"__context": "context_value", "other_key": "other_value", "foo":
"foo_value"}
- expected = {"__context": "context_value", "other_key": "other_value",
"foo": "foo_value"}
+ inputs = {
+ "__context": "context_value",
+ "other_key": "other_value",
+ "foo": "foo_value",
+ }
+ expected = {
+ "__context": "context_value",
+ "other_key": "other_value",
+ "foo": "foo_value",
+ }
assert _remap_dunder_parameters(_action.run, inputs, ["__context",
"__tracer"]) == expected