This is an automated email from the ASF dual-hosted git repository. skrawcz pushed a commit to branch stefan/improve-generator-spans in repository https://gitbox.apache.org/repos/asf/burr.git
commit 564fcde84b0674141f650dc1b5e1b5c7e44403dc Author: Stefan Krawczyk <[email protected]> AuthorDate: Sat Feb 28 15:24:35 2026 -0800 feat: add configurable streaming telemetry to OpenTelemetryBridge Add pre_stream_generate/post_stream_generate lifecycle hooks that bracket each generator __next__/__anext__ call, enabling per-yield instrumentation for streaming actions. Also improves attributes added for streaming to distinguish internal computation from consumer time. OpenTelemetryBridge accepts a new streaming_telemetry parameter (StreamingTelemetryMode enum) controlling how streaming actions are instrumented: - SINGLE_SPAN (default): single action span, backwards compatible - EVENT: action span + stream_completed summary event with generation/consumer timing, iteration count, and TTFT - CHUNK_SPANS: per-yield child spans, no action span - BOTH: action span with summary event + per-yield child spans --- burr/core/application.py | 170 ++- burr/integrations/opentelemetry.py | 301 ++++- burr/lifecycle/__init__.py | 8 + burr/lifecycle/base.py | 96 ++ .../opentelemetry/streaming_telemetry_modes.py | 119 ++ tests/core/test_application.py | 365 ++++++ tests/integrations/test_opentelemetry.py | 1234 ++++++++++++++++++++ 7 files changed, 2276 insertions(+), 17 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index dc8067c4..a92cdcb1 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -320,6 +320,132 @@ def _run_single_step_action( return out +def _instrumented_sync_generator( + generator: Generator, + lifecycle_adapters: LifecycleAdapterSet, + stream_initialize_time, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], +) -> Generator: + """Wraps a synchronous generator to instrument stream generation with lifecycle hooks. + + This function wraps a generator and fires pre_stream_generate and post_stream_generate + hooks around each __next__() call. This brackets the actual generation time for each + item, excluding consumer processing time. The hooks receive metadata including the + item index, action name, sequence_id, app_id, and partition_key. + + Exceptions raised during generation are propagated after firing the post_stream_generate + hook with the exception. StopIteration is handled gracefully to signal generator completion. + + Args: + generator: The synchronous generator to wrap and instrument. + lifecycle_adapters: Set of lifecycle adapters to call hooks on. + stream_initialize_time: Timestamp when the stream was initialized. + action: Name of the action generating the stream. + sequence_id: Sequence identifier for this execution. + app_id: Application identifier. + partition_key: Optional partition key for distributed execution. + + Yields: + Items from the wrapped generator, one at a time. + """ + gen_iter = iter(generator) + count = 0 + while True: + hook_kwargs = dict( + item_index=count, + stream_initialize_time=stream_initialize_time, + action=action, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) + lifecycle_adapters.call_all_lifecycle_hooks_sync("pre_stream_generate", **hook_kwargs) + try: + item = next(gen_iter) + except StopIteration: + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_generate", item=None, exception=None, **hook_kwargs + ) + return + except Exception as e: + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_generate", item=None, exception=e, **hook_kwargs + ) + raise + lifecycle_adapters.call_all_lifecycle_hooks_sync( + "post_stream_generate", item=item, exception=None, **hook_kwargs + ) + yield item + count += 1 + + +async def _instrumented_async_generator( + generator: AsyncGenerator, + lifecycle_adapters: LifecycleAdapterSet, + stream_initialize_time, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], +) -> AsyncGenerator: + """Wraps an asynchronous generator to instrument stream generation with lifecycle hooks. + + This function wraps an async generator and fires pre_stream_generate and post_stream_generate + hooks around each __anext__() call. This brackets the actual generation time for each + item, excluding consumer processing time. The hooks receive metadata including the + item index, action name, sequence_id, app_id, and partition_key. + + Exceptions raised during generation are propagated after firing the post_stream_generate + hook with the exception. StopAsyncIteration is handled gracefully to signal generator completion. + + Args: + generator: The asynchronous generator to wrap and instrument. + lifecycle_adapters: Set of lifecycle adapters to call hooks on. + stream_initialize_time: Timestamp when the stream was initialized. + action: Name of the action generating the stream. + sequence_id: Sequence identifier for this execution. + app_id: Application identifier. + partition_key: Optional partition key for distributed execution. + + Yields: + Items from the wrapped generator, one at a time. + """ + aiter = generator.__aiter__() + count = 0 + while True: + hook_kwargs = dict( + item_index=count, + stream_initialize_time=stream_initialize_time, + action=action, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "pre_stream_generate", **hook_kwargs + ) + try: + item = await aiter.__anext__() + except StopAsyncIteration: + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_generate", item=None, exception=None, **hook_kwargs + ) + return + except Exception as e: + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_generate", item=None, exception=e, **hook_kwargs + ) + raise + await lifecycle_adapters.call_all_lifecycle_hooks_sync_and_async( + "post_stream_generate", item=item, exception=None, **hook_kwargs + ) + yield item + count += 1 + + def _run_single_step_streaming_action( action: SingleStepStreamingAction, state: State, @@ -334,7 +460,16 @@ def _run_single_step_streaming_action( action.validate_inputs(inputs) stream_initialize_time = system.now() first_stream_start_time = None - generator = action.stream_run_and_update(state, **inputs) + raw_generator = action.stream_run_and_update(state, **inputs) + generator = _instrumented_sync_generator( + raw_generator, + lifecycle_adapters, + stream_initialize_time=stream_initialize_time, + action=action.name, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) result = None state_update = None count = 0 @@ -387,7 +522,16 @@ async def _arun_single_step_streaming_action( action.validate_inputs(inputs) stream_initialize_time = system.now() first_stream_start_time = None - generator = action.stream_run_and_update(state, **inputs) + raw_generator = action.stream_run_and_update(state, **inputs) + generator = _instrumented_async_generator( + raw_generator, + lifecycle_adapters, + stream_initialize_time=stream_initialize_time, + action=action.name, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) result = None state_update = None count = 0 @@ -446,7 +590,16 @@ def _run_multi_step_streaming_action( """ action.validate_inputs(inputs) stream_initialize_time = system.now() - generator = action.stream_run(state, **inputs) + raw_generator = action.stream_run(state, **inputs) + generator = _instrumented_sync_generator( + raw_generator, + lifecycle_adapters, + stream_initialize_time=stream_initialize_time, + action=action.name, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) result = None first_stream_start_time = None count = 0 @@ -490,7 +643,16 @@ async def _arun_multi_step_streaming_action( """Runs a multi-step streaming action in async. See the synchronous version for more details.""" action.validate_inputs(inputs) stream_initialize_time = system.now() - generator = action.stream_run(state, **inputs) + raw_generator = action.stream_run(state, **inputs) + generator = _instrumented_async_generator( + raw_generator, + lifecycle_adapters, + stream_initialize_time=stream_initialize_time, + action=action.name, + sequence_id=sequence_id, + app_id=app_id, + partition_key=partition_key, + ) result = None first_stream_start_time = None count = 0 diff --git a/burr/integrations/opentelemetry.py b/burr/integrations/opentelemetry.py index 32dc4dd7..7dc09c6d 100644 --- a/burr/integrations/opentelemetry.py +++ b/burr/integrations/opentelemetry.py @@ -17,12 +17,14 @@ import dataclasses import datetime +import enum import importlib import importlib.metadata import json import logging import random import sys +import time from contextvars import ContextVar from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple @@ -47,8 +49,10 @@ from burr.core import Action, ApplicationGraph, State, serde from burr.lifecycle import ( PostApplicationExecuteCallHook, PostRunStepHook, + PostStreamGenerateHook, PreApplicationExecuteCallHook, PreRunStepHook, + PreStreamGenerateHook, ) from burr.lifecycle.base import DoLogAttributeHook, ExecuteMethod, PostEndSpanHook, PreStartSpanHook from burr.tracking import LocalTrackingClient @@ -87,6 +91,43 @@ def get_cached_span(span_id: int) -> Optional[FullSpanContext]: tracker_context = ContextVar[Optional[SyncTrackingClient]]("tracker_context", default=None) +# Tracks whether the action-level span was skipped for streaming actions +_skipped_action_span = ContextVar[bool]("_skipped_action_span", default=False) + + +# Valid streaming telemetry modes +class StreamingTelemetryMode(enum.Enum): + """Controls how streaming actions are instrumented by the OpenTelemetryBridge. + + - ``SINGLE_SPAN``: A single action span covers the full generator lifetime (default). + - ``EVENT``: A single action span plus a ``stream_completed`` summary span event. + - ``CHUNK_SPANS``: No action span. Per-yield child spans under the method span. + - ``SINGLE_AND_CHUNK_SPANS``: Action span with summary event plus per-yield child spans. + """ + + SINGLE_SPAN = "single_span" + EVENT = "event" + CHUNK_SPANS = "chunk_spans" + SINGLE_AND_CHUNK_SPANS = "single_and_chunk_spans" + + [email protected] +class _StreamingAccumulator: + """Accumulates timing data across stream yields for the span event summary.""" + + generation_time_ns: int = 0 + consumer_time_ns: int = 0 + iteration_count: int = 0 + first_item_time_ns: Optional[int] = None + stream_start_ns: Optional[int] = None + last_post_generate_ns: Optional[int] = None + _pre_generate_ns: Optional[int] = None + + +_streaming_accumulator = ContextVar[Optional[_StreamingAccumulator]]( + "_streaming_accumulator", default=None +) + def _is_homogeneous_sequence(value: Sequence): if len(value) == 0: @@ -149,20 +190,34 @@ class OpenTelemetryBridge( PreStartSpanHook, PostEndSpanHook, DoLogAttributeHook, + PreStreamGenerateHook, + PostStreamGenerateHook, ): - """Adapter to log Burr events to OpenTelemetry. At a high level, this works as follows: + """Lifecycle adapter that maps Burr execution events to OpenTelemetry spans and events. + + **How it works** + + The bridge implements Burr lifecycle hooks to create a span hierarchy that mirrors the + execution structure: - 1. On any of the start/pre hooks (pre_run_execute_call, pre_run_step, pre_start_span), we start a new span - 2. On any of the post ones we exit the span, accounting for the error (setting it if needed) - 3. On do_log_attributes, we log the attributes to the current span -- these are serialized using the serde module + 1. ``pre_run_execute_call`` / ``post_run_execute_call`` — creates a top-level **method span** + for the application method being called (e.g. ``step``, ``astream_result``). + 2. ``pre_run_step`` / ``post_run_step`` — creates an **action span** as a child of the + method span. For streaming actions, behavior depends on the ``streaming_telemetry`` mode. + 3. ``pre_start_span`` / ``post_end_span`` — creates **sub-action spans** for user-defined + visibility spans (via ``TracerFactory`` / ``__tracer``). + 4. ``do_log_attributes`` — sets OTel attributes on the current span. + 5. ``pre_stream_generate`` / ``post_stream_generate`` — for streaming actions, optionally + creates per-yield **chunk spans** and/or accumulates timing data for a summary event. - This works by logging to OpenTelemetry, and setting the span processor to be the right one (that knows about the tracker). + All spans are managed via a ContextVar-based token stack (``token_stack``) to correctly + handle nesting across sync and async execution. - You can use this as follows: + **Usage** .. code-block:: python - # replace with instructions from your prefered vendor + # replace with instructions from your preferred vendor my_vendor_library_or_tracer_provider.init() app = ( @@ -174,15 +229,44 @@ class OpenTelemetryBridge( .build() ) - app.run() # will log to OpenTelemetry + app.run() # will log to OpenTelemetry + + **Streaming telemetry modes** + + The ``streaming_telemetry`` parameter controls how streaming actions are instrumented. + Non-streaming actions are unaffected — they always produce a single action span. + + - ``StreamingTelemetryMode.SINGLE_SPAN`` (default): A single action span covers the full + generator lifetime (including consumer wait time). Streaming **attributes** are set on + the span with the generation/consumer timing breakdown: + + - ``stream.generation_time_ms`` — time spent inside the generator producing items + - ``stream.consumer_time_ms`` — time the consumer spent processing yielded items + - ``stream.iteration_count`` — number of items yielded + - ``stream.first_item_time_ms`` — time to first item (TTFT) + + - ``StreamingTelemetryMode.EVENT``: No action span. A ``stream_completed`` (or + ``stream_error``) span event is added to the **method span** with the timing summary + (including ``stream.total_time_ms`` since there is no action span to carry duration). + - ``StreamingTelemetryMode.CHUNK_SPANS``: No action span. A child span + (``{action}::chunk_{N}``) is created for each generator yield under the method span. + Each chunk span measures only generation time (excludes consumer processing time). + - ``StreamingTelemetryMode.SINGLE_AND_CHUNK_SPANS``: Combines ``SINGLE_SPAN`` and ``CHUNK_SPANS`` — the + action span (with streaming attributes) plus per-yield chunk spans as its children. """ - def __init__(self, tracer_name: str = None, tracer: trace.Tracer = None): + def __init__( + self, + tracer_name: str = None, + tracer: trace.Tracer = None, + streaming_telemetry: StreamingTelemetryMode = StreamingTelemetryMode.SINGLE_SPAN, + ): """Initializes an OpenTel adapter. Passes in a tracer_name or a tracer object, should only pass one. :param tracer_name: Name of the tracer if you want it to initialize for you -- not including it will use a default :param tracer: Tracer object if you want to pass it in yourself + :param streaming_telemetry: How to instrument streaming actions. See :class:`StreamingTelemetryMode`. """ if tracer_name and tracer: raise ValueError( @@ -192,6 +276,54 @@ class OpenTelemetryBridge( self.tracer = tracer else: self.tracer = trace.get_tracer(__name__ if tracer_name is None else tracer_name) + self.streaming_telemetry = streaming_telemetry + + @property + def _emit_chunk_spans(self) -> bool: + """Whether to create per-yield chunk spans (CHUNK_SPANS or BOTH).""" + return self.streaming_telemetry in ( + StreamingTelemetryMode.CHUNK_SPANS, + StreamingTelemetryMode.SINGLE_AND_CHUNK_SPANS, + ) + + @property + def _emit_event(self) -> bool: + """Whether to emit a summary span event on the method span (EVENT only). + + EVENT mode skips the action span entirely and attaches a ``stream_completed`` + event to the method span instead. + """ + return self.streaming_telemetry == StreamingTelemetryMode.EVENT + + @property + def _emit_attributes(self) -> bool: + """Whether to set streaming attributes on the action span (SINGLE_SPAN or BOTH). + + These modes create an action span and set generation time, consumer time, + iteration count, and TTFT as span attributes. + """ + return self.streaming_telemetry in ( + StreamingTelemetryMode.SINGLE_SPAN, + StreamingTelemetryMode.SINGLE_AND_CHUNK_SPANS, + ) + + @property + def _use_accumulator(self) -> bool: + """Whether timing accumulation is needed (all modes except CHUNK_SPANS).""" + return self.streaming_telemetry != StreamingTelemetryMode.CHUNK_SPANS + + @property + def _skip_single_action_span_for_streaming(self) -> bool: + """Whether to skip the action-level span for streaming actions. + + True for EVENT and CHUNK_SPANS modes. EVENT attaches data to the method span + instead. CHUNK_SPANS replaces the action span with per-yield child spans. + In SINGLE_SPAN and BOTH modes, the action span is created normally. + """ + return self.streaming_telemetry in ( + StreamingTelemetryMode.EVENT, + StreamingTelemetryMode.CHUNK_SPANS, + ) def pre_run_execute_call( self, @@ -199,6 +331,11 @@ class OpenTelemetryBridge( method: ExecuteMethod, **future_kwargs: Any, ): + """Opens the top-level **method span** (e.g. ``step``, ``astream_result``). + + This is the outermost span in the Burr trace hierarchy. Action spans and chunk + spans are nested under it. + """ # TODO -- handle links -- we need to wire this through _enter_span(method.value, self.tracer) @@ -208,6 +345,11 @@ class OpenTelemetryBridge( attributes: Dict[str, Any], **future_kwargs: Any, ): + """Sets key-value attributes on the current OTel span. + + Values are serialized via :func:`convert_to_otel_attribute` to ensure they are + OTel-compatible types (str, bool, int, float, or homogeneous sequences thereof). + """ otel_span = get_current_span() if otel_span is None: logger.warning( @@ -224,7 +366,22 @@ class OpenTelemetryBridge( action: "Action", **future_kwargs: Any, ): - _enter_span(action.name, self.tracer) + """Opens an **action span** for the step about to execute. + + For streaming actions in ``EVENT`` or ``CHUNK_SPANS`` mode, the action span is + skipped. In ``SINGLE_SPAN`` and ``SINGLE_AND_CHUNK_SPANS`` modes, the action span is created normally. + + For all modes except ``CHUNK_SPANS``, a :class:`_StreamingAccumulator` is initialized + to collect timing data across generator yields. + """ + if getattr(action, "streaming", False) and self._skip_single_action_span_for_streaming: + _skipped_action_span.set(True) + else: + _skipped_action_span.set(False) + _enter_span(action.name, self.tracer) + # Initialize accumulator for modes that need timing data + if getattr(action, "streaming", False) and self._use_accumulator: + _streaming_accumulator.set(_StreamingAccumulator()) def pre_start_span( self, @@ -232,6 +389,11 @@ class OpenTelemetryBridge( span: "ActionSpan", **future_kwargs: Any, ): + """Opens a **sub-action span** for a user-defined visibility span. + + These are created by the ``TracerFactory`` (``__tracer``) context manager inside + actions, and are nested under the current action span. + """ _enter_span(span.name, self.tracer) def post_end_span( @@ -240,6 +402,7 @@ class OpenTelemetryBridge( span: "ActionSpan", **future_kwargs: Any, ): + """Closes a sub-action span opened by :meth:`pre_start_span`.""" # TODO -- wire through exceptions _exit_span() @@ -249,7 +412,120 @@ class OpenTelemetryBridge( exception: Exception, **future_kwargs: Any, ): - _exit_span(exception) + """Closes the action span and, for streaming actions, emits summary telemetry. + + Behavior depends on mode: + + - ``SINGLE_SPAN`` / ``SINGLE_AND_CHUNK_SPANS``: Sets streaming attributes on the action span, then + closes it. + - ``EVENT``: Emits a ``stream_completed`` (or ``stream_error``) span event on the + method span (the action span was skipped). Resets the skipped flag. + - ``CHUNK_SPANS``: The action span was skipped; just resets the flag. + """ + acc = _streaming_accumulator.get() + if acc is not None: + first_item_ms = 0.0 + if acc.first_item_time_ns is not None and acc.stream_start_ns is not None: + first_item_ms = (acc.first_item_time_ns - acc.stream_start_ns) / 1e6 + + if self._emit_attributes: + # SINGLE_SPAN / BOTH: set attributes on the action span + otel_span = get_current_span() + if otel_span is not None: + otel_span.set_attributes( + { + "stream.generation_time_ms": acc.generation_time_ns / 1e6, + "stream.consumer_time_ms": acc.consumer_time_ns / 1e6, + "stream.iteration_count": acc.iteration_count, + "stream.first_item_time_ms": first_item_ms, + } + ) + + elif self._emit_event: + # EVENT: emit span event on the method span (action span was skipped) + otel_span = get_current_span() + if otel_span is not None: + total_time_ns = 0 + if acc.stream_start_ns is not None and acc.last_post_generate_ns is not None: + total_time_ns = acc.last_post_generate_ns - acc.stream_start_ns + event_name = "stream_error" if exception else "stream_completed" + attrs: Dict[str, Any] = { + "stream.generation_time_ms": acc.generation_time_ns / 1e6, + "stream.consumer_time_ms": acc.consumer_time_ns / 1e6, + "stream.total_time_ms": total_time_ns / 1e6, + "stream.iteration_count": acc.iteration_count, + "stream.first_item_time_ms": first_item_ms, + } + if exception: + attrs["stream.error"] = str(exception) + otel_span.add_event(event_name, attributes=attrs) + + _streaming_accumulator.set(None) + + if _skipped_action_span.get(): + _skipped_action_span.set(False) + else: + _exit_span(exception) + + def pre_stream_generate( + self, + *, + action: str, + item_index: int, + **future_kwargs: Any, + ): + """Called just before each ``__next__()`` / ``__anext__()`` on the generator. + + For modes with accumulation (``SINGLE_SPAN``, ``EVENT``, ``SINGLE_AND_CHUNK_SPANS``), records the + start of generation time and accumulates consumer time (the gap between the previous + ``post_stream_generate`` and now). + + In ``CHUNK_SPANS`` or ``SINGLE_AND_CHUNK_SPANS`` mode, opens a child span named + ``{action}::chunk_{item_index}``. + """ + now_ns = time.time_ns() + acc = _streaming_accumulator.get() + if acc is not None: + if acc.stream_start_ns is None: + acc.stream_start_ns = now_ns + if acc.last_post_generate_ns is not None: + acc.consumer_time_ns += now_ns - acc.last_post_generate_ns + acc._pre_generate_ns = now_ns # stash for post + + if self._emit_chunk_spans: + _enter_span(f"{action}::chunk_{item_index}", self.tracer) + + def post_stream_generate( + self, + *, + item: Any, + item_index: int, + exception: Optional[Exception], + **future_kwargs: Any, + ): + """Called just after each ``__next__()`` / ``__anext__()`` returns (or raises). + + For modes with accumulation (``SINGLE_SPAN``, ``EVENT``, ``SINGLE_AND_CHUNK_SPANS``), accumulates + generation time and updates the iteration count. When ``item`` is not ``None``, + the item is counted; a ``None`` item signals generator exhaustion (``StopIteration``). + + In ``CHUNK_SPANS`` or ``SINGLE_AND_CHUNK_SPANS`` mode, closes the chunk span opened by + :meth:`pre_stream_generate`, setting an error status if ``exception`` is provided. + """ + now_ns = time.time_ns() + acc = _streaming_accumulator.get() + if acc is not None: + pre_ns = acc._pre_generate_ns + if pre_ns is not None: + acc.generation_time_ns += now_ns - pre_ns + if item is not None: + acc.iteration_count += 1 + if acc.first_item_time_ns is None: + acc.first_item_time_ns = now_ns + acc.last_post_generate_ns = now_ns + + if self._emit_chunk_spans: + _exit_span(exception) def post_run_execute_call( self, @@ -257,6 +533,7 @@ class OpenTelemetryBridge( exception: Optional[Exception], **future_kwargs, ): + """Closes the top-level method span opened by :meth:`pre_run_execute_call`.""" _exit_span(exception) @@ -705,8 +982,6 @@ if __name__ == "__main__": tracker = LocalTrackingClient("otel_test") opentel_adapter = OpenTelemetryTracker(burr_tracker=tracker) - import time - from burr.core import ApplicationBuilder, Result, action, default, expr from burr.visibility import TracerFactory diff --git a/burr/lifecycle/__init__.py b/burr/lifecycle/__init__.py index 4ae24073..9cf10eff 100644 --- a/burr/lifecycle/__init__.py +++ b/burr/lifecycle/__init__.py @@ -23,11 +23,15 @@ from burr.lifecycle.base import ( PostEndSpanHook, PostRunStepHook, PostRunStepHookAsync, + PostStreamGenerateHook, + PostStreamGenerateHookAsync, PreApplicationExecuteCallHook, PreApplicationExecuteCallHookAsync, PreRunStepHook, PreRunStepHookAsync, PreStartSpanHook, + PreStreamGenerateHook, + PreStreamGenerateHookAsync, ) from burr.lifecycle.default import StateAndResultsFullLogger @@ -45,4 +49,8 @@ __all__ = [ "PostApplicationCreateHook", "PostEndSpanHook", "PreStartSpanHook", + "PreStreamGenerateHook", + "PreStreamGenerateHookAsync", + "PostStreamGenerateHook", + "PostStreamGenerateHookAsync", ] diff --git a/burr/lifecycle/base.py b/burr/lifecycle/base.py index 66d8bd7e..a2c3e3ff 100644 --- a/burr/lifecycle/base.py +++ b/burr/lifecycle/base.py @@ -492,6 +492,98 @@ class PostEndStreamHookAsync(abc.ABC): pass [email protected]_hook("pre_stream_generate") +class PreStreamGenerateHook(abc.ABC): + """Hook that runs before the generator produces its next item. + Paired with PostStreamGenerateHook to bracket the actual generation time + for each stream item, excluding consumer processing time. + """ + + @abc.abstractmethod + def pre_stream_generate( + self, + *, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + pass + + [email protected]_hook("pre_stream_generate") +class PreStreamGenerateHookAsync(abc.ABC): + """Hook that runs before the generator produces its next item (async variant). + Paired with PostStreamGenerateHookAsync to bracket the actual generation time + for each stream item, excluding consumer processing time. + """ + + @abc.abstractmethod + async def pre_stream_generate( + self, + *, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + pass + + [email protected]_hook("post_stream_generate") +class PostStreamGenerateHook(abc.ABC): + """Hook that runs after the generator has produced an item (or exhausted/errored). + Paired with PreStreamGenerateHook to bracket the actual generation time + for each stream item, excluding consumer processing time. + """ + + @abc.abstractmethod + def post_stream_generate( + self, + *, + item: Any, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + exception: Optional[Exception], + **future_kwargs: Any, + ): + pass + + [email protected]_hook("post_stream_generate") +class PostStreamGenerateHookAsync(abc.ABC): + """Hook that runs after the generator has produced an item (or exhausted/errored). + Async variant. Paired with PreStreamGenerateHookAsync to bracket the actual + generation time for each stream item, excluding consumer processing time. + """ + + @abc.abstractmethod + async def post_stream_generate( + self, + *, + item: Any, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + exception: Optional[Exception], + **future_kwargs: Any, + ): + pass + + # strictly for typing -- this conflicts a bit with the lifecycle decorator above, but its fine for now # This makes IDE completion/type-hinting easier LifecycleAdapter = Union[ @@ -515,4 +607,8 @@ LifecycleAdapter = Union[ PreStartStreamHookAsync, PostStreamItemHookAsync, PostEndStreamHookAsync, + PreStreamGenerateHook, + PreStreamGenerateHookAsync, + PostStreamGenerateHook, + PostStreamGenerateHookAsync, ] diff --git a/examples/opentelemetry/streaming_telemetry_modes.py b/examples/opentelemetry/streaming_telemetry_modes.py new file mode 100644 index 00000000..c50381f1 --- /dev/null +++ b/examples/opentelemetry/streaming_telemetry_modes.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Demonstrates the four StreamingTelemetryMode options for OpenTelemetryBridge. + +Runs a simple async streaming action under each mode with the OTel console +exporter so you can see the spans and events printed to stdout. + +Usage: + python examples/opentelemetry/streaming_telemetry_modes.py + +No external APIs are needed — the streaming action simulates an LLM by yielding +tokens with small delays. +""" + +import asyncio + +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor + +from burr.core import ApplicationBuilder, State +from burr.core.action import streaming_action +from burr.core.graph import GraphBuilder +from burr.integrations.opentelemetry import OpenTelemetryBridge, StreamingTelemetryMode + +# --------------------------------------------------------------------------- +# A simple streaming action that simulates token-by-token LLM output +# --------------------------------------------------------------------------- + + +@streaming_action(reads=["prompt"], writes=["response"]) +async def generate_response(state: State) -> None: + """Simulates a streaming LLM response, yielding one token at a time.""" + tokens = state["prompt"].split() + buffer = [] + for token in tokens: + await asyncio.sleep(0.02) # simulate generation latency per token + buffer.append(token) + yield {"token": token}, None + + response = " ".join(buffer) + yield {"token": "", "response": response}, state.update(response=response) + + +# --------------------------------------------------------------------------- +# Build the graph (shared across all modes) +# --------------------------------------------------------------------------- + +graph = GraphBuilder().with_actions(generate=generate_response).with_transitions().build() + + +# --------------------------------------------------------------------------- +# Run one mode +# --------------------------------------------------------------------------- + + +async def run_with_mode(mode: StreamingTelemetryMode) -> None: + """Builds an app with the given streaming telemetry mode and runs it.""" + # Each run gets its own tracer provider so console output stays grouped + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter())) + tracer = provider.get_tracer("streaming-telemetry-demo") + + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=mode) + + app = ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("generate") + .with_state(State({"prompt": "hello world from burr streaming"})) + .with_hooks(bridge) + .with_identifiers(app_id=f"demo-{mode.value}") + .build() + ) + + action, container = await app.astream_result(halt_after=["generate"]) + async for item in container: + await asyncio.sleep(0.05) # simulate consumer processing time per token + await container.get() + + provider.shutdown() + + +# --------------------------------------------------------------------------- +# Main — run all four modes +# --------------------------------------------------------------------------- + + +async def main(): + modes = [ + StreamingTelemetryMode.SINGLE_SPAN, + StreamingTelemetryMode.EVENT, + StreamingTelemetryMode.CHUNK_SPANS, + StreamingTelemetryMode.SINGLE_AND_CHUNK_SPANS, + ] + for mode in modes: + print(f"\n{'=' * 70}") + print(f" StreamingTelemetryMode.{mode.name}") + print(f"{'=' * 70}\n") + await run_with_mode(mode) + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/core/test_application.py b/tests/core/test_application.py index c90c4067..c9facd9b 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -83,12 +83,16 @@ from burr.lifecycle.base import ( PostApplicationExecuteCallHook, PostApplicationExecuteCallHookAsync, PostEndStreamHook, + PostStreamGenerateHook, + PostStreamGenerateHookAsync, PostStreamItemHook, PostStreamItemHookAsync, PreApplicationExecuteCallHook, PreApplicationExecuteCallHookAsync, PreStartStreamHook, PreStartStreamHookAsync, + PreStreamGenerateHook, + PreStreamGenerateHookAsync, ) from burr.lifecycle.internal import LifecycleAdapterSet from burr.tracking.base import SyncTrackingClient @@ -3761,3 +3765,364 @@ def test_initialize_from_applies_override_state_values(): app = builder.build() assert app.state["x"] == 100 + + +# ============================================================================ +# Tests for pre_stream_generate / post_stream_generate lifecycle hooks +# ============================================================================ + + +class GenerateEventCaptureTracker( + PreStartStreamHook, + PostEndStreamHook, +): + """Captures pre/post_stream_generate calls via the new hooks, plus + existing pre_start_stream/post_end_stream for ordering verification.""" + + def __init__(self): + self.calls: list[tuple[str, dict]] = [] + + def pre_start_stream( + self, *, action: str, sequence_id: int, app_id: str, partition_key, **future_kwargs + ): + self.calls.append(("pre_start_stream", {"action": action})) + + def post_end_stream( + self, *, action: str, sequence_id: int, app_id: str, partition_key, **future_kwargs + ): + self.calls.append(("post_end_stream", {"action": action})) + + +class StreamGenerateTracker(PreStreamGenerateHook, PostStreamGenerateHook): + """Sync tracker that captures pre/post_stream_generate hook calls.""" + + def __init__(self): + self.calls: list[tuple[str, int]] = [] # (hook_name, item_index) + + def pre_stream_generate(self, *, item_index: int, action: str, **future_kwargs): + self.calls.append(("pre_stream_generate", item_index)) + + def post_stream_generate( + self, *, item, item_index: int, action: str, exception, **future_kwargs + ): + self.calls.append(("post_stream_generate", item_index)) + + +class StreamGenerateTrackerAsync(PreStreamGenerateHookAsync, PostStreamGenerateHookAsync): + """Async tracker that captures pre/post_stream_generate hook calls.""" + + def __init__(self): + self.calls: list[tuple[str, int]] = [] + + async def pre_stream_generate(self, *, item_index: int, action: str, **future_kwargs): + self.calls.append(("pre_stream_generate", item_index)) + + async def post_stream_generate( + self, *, item, item_index: int, action: str, exception, **future_kwargs + ): + self.calls.append(("post_stream_generate", item_index)) + + +# --- Test #1: sync single-step calls pre/post_stream_generate --- + + +def test__run_single_step_streaming_action_calls_stream_generate_hooks(): + tracker = StreamGenerateTracker() + action = base_streaming_single_step_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + collections.deque(generator, maxlen=0) # exhaust + # 10 intermediate yields + 1 final yield = 11 items from the action generator + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + # pre fires before each __next__ including the StopIteration attempt (11 items + 1 stop = 12) + assert len(pre_calls) == 12 + assert len(post_calls) == 12 # matched: 11 items + 1 StopIteration + + +# --- Test #2: sync multi-step calls pre/post_stream_generate --- + + +def test__run_multi_step_streaming_action_calls_stream_generate_hooks(): + tracker = StreamGenerateTracker() + action = base_streaming_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_multi_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + collections.deque(generator, maxlen=0) + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + # Multi-step: 11 items from generator + 1 StopIteration attempt = 12 pre, 12 post + assert len(pre_calls) == 12 + assert len(post_calls) == 12 + + +# --- Test #3: async single-step calls pre/post_stream_generate --- + + +async def test__arun_single_step_streaming_action_calls_stream_generate_hooks(): + tracker = StreamGenerateTrackerAsync() + action = base_streaming_single_step_counter_async.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_single_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + async for _ in generator: + pass + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + assert len(pre_calls) == 12 + assert len(post_calls) == 12 + + +# --- Test #4: async multi-step calls pre/post_stream_generate --- + + +async def test__arun_multi_step_streaming_action_calls_stream_generate_hooks(): + tracker = StreamGenerateTrackerAsync() + action = base_streaming_counter_async.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_multi_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + async for _ in generator: + pass + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + assert len(pre_calls) == 12 + assert len(post_calls) == 12 + + +# --- Test #5: hook ordering (pre always before corresponding post) --- + + +def test__run_single_step_streaming_action_stream_generate_hook_ordering(): + tracker = StreamGenerateTracker() + action = base_streaming_single_step_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + collections.deque(generator, maxlen=0) + # Verify strict interleaving: pre(0), post(0), pre(1), post(1), ... + for i in range(0, len(tracker.calls), 2): + assert tracker.calls[i] == ("pre_stream_generate", i // 2) + assert tracker.calls[i + 1] == ("post_stream_generate", i // 2) + + +async def test__arun_multi_step_streaming_action_stream_generate_hook_ordering(): + tracker = StreamGenerateTrackerAsync() + action = base_streaming_counter_async.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_multi_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + async for _ in generator: + pass + for i in range(0, len(tracker.calls), 2): + assert tracker.calls[i] == ("pre_stream_generate", i // 2) + assert tracker.calls[i + 1] == ("post_stream_generate", i // 2) + + +# --- Test #7: error mid-stream --- + + +class ErrorAfterNSingleStep(SingleStepStreamingAction): + """Action that raises after n intermediate yields.""" + + def __init__(self, n: int): + super().__init__() + self.n = n + + def stream_run_and_update(self, state, **run_kwargs): + for i in range(self.n): + yield {"i": i}, None + raise RuntimeError("boom") + + @property + def reads(self): + return [] + + @property + def writes(self): + return [] + + +def test__run_single_step_streaming_action_stream_generate_on_error(): + tracker = StreamGenerateTracker() + action = ErrorAfterNSingleStep(3).with_name("errorer") + state = State({}) + generator = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + with pytest.raises(RuntimeError, match="boom"): + collections.deque(generator, maxlen=0) + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + # 3 successful yields + 1 that raises = 4 pre, 4 post (error post has exception) + assert len(pre_calls) == 4 + assert len(post_calls) == 4 + + +# --- Test #8: existing post_stream_item callbacks unchanged --- + + +def test__run_single_step_streaming_action_existing_callbacks_unchanged_with_generate_hooks(): + class TrackingCallback(PostStreamItemHook): + def __init__(self): + self.items = [] + + def post_stream_item(self, item, **future_kwargs): + self.items.append(item) + + hook = TrackingCallback() + gen_tracker = StreamGenerateTracker() + action = base_streaming_single_step_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(hook, gen_tracker), + ) + collections.deque(generator, maxlen=0) + # post_stream_item still fires exactly 10 times (only for intermediate items) + assert len(hook.items) == 10 + # But pre/post_stream_generate fire for all items + StopIteration + pre_calls = [c for c in gen_tracker.calls if c[0] == "pre_stream_generate"] + assert len(pre_calls) == 12 + + +# --- Test #18: single yield --- + + +class SingleYieldAction(SingleStepStreamingAction): + def stream_run_and_update(self, state, **run_kwargs): + yield {"val": 1}, None + yield {"val": 2}, state + + @property + def reads(self): + return [] + + @property + def writes(self): + return [] + + +def test__stream_generate_hooks_single_intermediate_yield(): + tracker = StreamGenerateTracker() + action = SingleYieldAction().with_name("single") + state = State({}) + gen = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + collections.deque(gen, maxlen=0) + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + # 2 items + 1 StopIteration = 3 pre/post pairs + assert len(pre_calls) == 3 + assert len(post_calls) == 3 + + +# --- Test #19: zero intermediate yields --- + + +class NoIntermediateYieldAction(SingleStepStreamingAction): + def stream_run_and_update(self, state, **run_kwargs): + yield {"val": 1}, state + + @property + def reads(self): + return [] + + @property + def writes(self): + return [] + + +def test__stream_generate_hooks_zero_intermediate_yields(): + tracker = StreamGenerateTracker() + action = NoIntermediateYieldAction().with_name("noint") + state = State({}) + gen = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="pk", + app_id="app", + lifecycle_adapters=LifecycleAdapterSet(tracker), + ) + collections.deque(gen, maxlen=0) + pre_calls = [c for c in tracker.calls if c[0] == "pre_stream_generate"] + post_calls = [c for c in tracker.calls if c[0] == "post_stream_generate"] + # 1 item (final) + 1 StopIteration = 2 pre/post pairs + assert len(pre_calls) == 2 + assert len(post_calls) == 2 + + +# --- Test #20: non-streaming action doesn't fire stream generate hooks --- + + +def test__non_streaming_action_does_not_fire_stream_generate_hooks(): + tracker = StreamGenerateTracker() + action = base_single_step_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + _run_single_step_action(action, state, inputs={}) + # No stream generate hooks should have been called + assert len(tracker.calls) == 0 diff --git a/tests/integrations/test_opentelemetry.py b/tests/integrations/test_opentelemetry.py index b8606276..f9644f47 100644 --- a/tests/integrations/test_opentelemetry.py +++ b/tests/integrations/test_opentelemetry.py @@ -15,10 +15,1244 @@ # specific language governing permissions and limitations # under the License. +import asyncio +import datetime +import threading +import time import typing +from typing import Sequence +from unittest.mock import MagicMock + +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult import burr.integrations.opentelemetry as burr_otel +from burr.core.action import SingleStepAction, SingleStepStreamingAction +from burr.core.application import Application, _arun_single_step_streaming_action +from burr.core.graph import Graph +from burr.core.state import State +from burr.integrations.opentelemetry import OpenTelemetryBridge +from burr.integrations.opentelemetry import StreamingTelemetryMode as STM +from burr.integrations.opentelemetry import ( + _exit_span, + _skipped_action_span, + _streaming_accumulator, + token_stack, +) +from burr.lifecycle.internal import LifecycleAdapterSet + +# ============================================================================ +# Simple in-memory exporter (not available in all otel SDK versions) +# ============================================================================ + + +class _InMemorySpanExporter(SpanExporter): + """Collects finished spans in memory for test assertions.""" + + def __init__(self): + self._spans = [] + self._lock = threading.Lock() + + def export(self, spans: Sequence) -> SpanExportResult: + with self._lock: + self._spans.extend(spans) + return SpanExportResult.SUCCESS + + def shutdown(self): + pass + + def get_finished_spans(self): + with self._lock: + return list(self._spans) + + def clear(self): + with self._lock: + self._spans.clear() def test_instrument_specs_match_instruments_literal(): assert set(typing.get_args(burr_otel.INSTRUMENTS)) == set(burr_otel.INSTRUMENTS_SPECS.keys()) + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _make_bridge_and_exporter(streaming_telemetry: STM = STM.SINGLE_SPAN): + """Creates an OpenTelemetryBridge with an in-memory exporter for test assertions.""" + exporter = _InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = provider.get_tracer("test") + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=streaming_telemetry) + return bridge, exporter + + +def _make_mock_action(name: str, streaming: bool = False): + """Creates a mock Action object with the given name and streaming flag.""" + action = MagicMock() + action.name = name + action.streaming = streaming + return action + + +def _reset_token_stack(): + """Reset the token_stack and streaming ContextVars to clean state.""" + token_stack.set(None) + _skipped_action_span.set(False) + _streaming_accumulator.set(None) + + +# ============================================================================ +# Test #9: pre_stream_generate enters a span +# ============================================================================ + + +def test_bridge_pre_stream_generate_enters_span(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + bridge.pre_stream_generate( + action="my_action", + item_index=0, + stream_initialize_time=datetime.datetime.now(), + sequence_id=0, + app_id="app", + partition_key="pk", + ) + + stack = token_stack.get() + assert stack is not None + assert len(stack) == 1 + _, span = stack[0] + assert span.name == "my_action::chunk_0" + + # Clean up + _exit_span() + _reset_token_stack() + + +# ============================================================================ +# Test #10: post_stream_generate exits a span +# ============================================================================ + + +def test_bridge_post_stream_generate_exits_span(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + bridge.pre_stream_generate( + action="my_action", + item_index=0, + stream_initialize_time=datetime.datetime.now(), + sequence_id=0, + app_id="app", + partition_key="pk", + ) + assert len(token_stack.get()) == 1 + + bridge.post_stream_generate( + item={"chunk": "data"}, + item_index=0, + stream_initialize_time=datetime.datetime.now(), + action="my_action", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + stack = token_stack.get() + assert len(stack) == 0 + + spans = exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "my_action::chunk_0" + assert spans[0].status.status_code == trace.StatusCode.OK + + _reset_token_stack() + + +# ============================================================================ +# Test #12: span naming for multiple chunks +# ============================================================================ + + +def test_bridge_stream_span_naming(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + now = datetime.datetime.now() + for i in range(3): + bridge.pre_stream_generate( + action="my_action", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_action", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + spans = exporter.get_finished_spans() + assert [s.name for s in spans] == [ + "my_action::chunk_0", + "my_action::chunk_1", + "my_action::chunk_2", + ] + _reset_token_stack() + + +# ============================================================================ +# Test #14: span closed on generator error +# ============================================================================ + + +def test_bridge_stream_span_closed_on_generator_error(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + now = datetime.datetime.now() + exc = RuntimeError("generator failed") + + bridge.pre_stream_generate( + action="my_action", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item=None, + item_index=0, + stream_initialize_time=now, + action="my_action", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=exc, + ) + + spans = exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == trace.StatusCode.ERROR + assert "generator failed" in spans[0].status.description + + stack = token_stack.get() + assert len(stack) == 0 + _reset_token_stack() + + +# ============================================================================ +# Test #21: pre_run_step skips span for streaming action +# ============================================================================ + + +def test_bridge_pre_run_step_skips_span_for_streaming_action(): + _reset_token_stack() + bridge, _ = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + action = _make_mock_action("stream_action", streaming=True) + + bridge.pre_run_step(action=action) + + assert _skipped_action_span.get() is True + stack = token_stack.get() + assert stack is None or len(stack) == 0 + + _reset_token_stack() + + +# ============================================================================ +# Test #22: pre_run_step creates span for non-streaming action +# ============================================================================ + + +def test_bridge_pre_run_step_creates_span_for_non_streaming_action(): + _reset_token_stack() + bridge, _ = _make_bridge_and_exporter() + action = _make_mock_action("normal_action", streaming=False) + + bridge.pre_run_step(action=action) + + assert _skipped_action_span.get() is False + stack = token_stack.get() + assert stack is not None + assert len(stack) == 1 + _, span = stack[0] + assert span.name == "normal_action" + + # Clean up + _exit_span() + _reset_token_stack() + + +# ============================================================================ +# Test #23: post_run_step skips exit when action span was skipped +# ============================================================================ + + +def test_bridge_post_run_step_skips_exit_when_action_span_was_skipped(): + _reset_token_stack() + bridge, _ = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + # Simulate streaming action: pre_run_step skipped the span + _skipped_action_span.set(True) + + # post_run_step should not pop anything (nothing was pushed) + bridge.post_run_step(exception=None) + + assert _skipped_action_span.get() is False + stack = token_stack.get() + assert stack is None or len(stack) == 0 + + _reset_token_stack() + + +# ============================================================================ +# Test #24: post_run_step exits span for non-streaming action +# ============================================================================ + + +def test_bridge_post_run_step_exits_span_for_non_streaming_action(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter() + action = _make_mock_action("normal_action", streaming=False) + + bridge.pre_run_step(action=action) + bridge.post_run_step(exception=None) + + stack = token_stack.get() + assert len(stack) == 0 + + spans = exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "normal_action" + assert spans[0].status.status_code == trace.StatusCode.OK + + _reset_token_stack() + + +# ============================================================================ +# Test #25: streaming hierarchy has no action span — chunks are children of method span +# ============================================================================ + + +def test_bridge_streaming_span_hierarchy_no_action_span(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + # Simulate full streaming hook sequence + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + for i in range(3): + bridge.pre_stream_generate( + action="my_stream", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have 4 spans: 3 chunks + 1 method. No "my_stream" action span. + assert "my_stream" not in span_names + assert "my_stream::chunk_0" in span_names + assert "my_stream::chunk_1" in span_names + assert "my_stream::chunk_2" in span_names + assert "stream_result" in span_names + assert len(spans) == 4 + + # Chunk spans should be children of the stream_result method span + method_span = next(s for s in spans if s.name == "stream_result") + for s in spans: + if s.name.startswith("my_stream::chunk_"): + assert s.parent is not None + assert s.parent.span_id == method_span.context.span_id + + _reset_token_stack() + + +# ============================================================================ +# Test #26: non-streaming then streaming — no state leak +# ============================================================================ + + +def test_bridge_non_streaming_then_streaming_no_state_leak(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + now = datetime.datetime.now() + + # First: non-streaming action + normal_action = _make_mock_action("normal", streaming=False) + bridge.pre_run_step(action=normal_action) + bridge.post_run_step(exception=None) + assert _skipped_action_span.get() is False + + # Second: streaming action + stream_action = _make_mock_action("streamer", streaming=True) + bridge.pre_run_step(action=stream_action) + assert _skipped_action_span.get() is True + + bridge.pre_stream_generate( + action="streamer", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"x": 1}, + item_index=0, + stream_initialize_time=now, + action="streamer", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + bridge.post_run_step(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + # Should have: "normal" action span + "streamer::chunk_0" chunk span + assert "normal" in span_names + assert "streamer::chunk_0" in span_names + assert "streamer" not in span_names # no action-level span for streaming + + _reset_token_stack() + + +# ============================================================================ +# Test #27: streaming then non-streaming — no state leak +# ============================================================================ + + +def test_bridge_streaming_then_non_streaming_no_state_leak(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + now = datetime.datetime.now() + + # First: streaming action + stream_action = _make_mock_action("streamer", streaming=True) + bridge.pre_run_step(action=stream_action) + bridge.pre_stream_generate( + action="streamer", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"x": 1}, + item_index=0, + stream_initialize_time=now, + action="streamer", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + bridge.post_run_step(exception=None) + + # Second: non-streaming action + normal_action = _make_mock_action("normal", streaming=False) + bridge.pre_run_step(action=normal_action) + assert _skipped_action_span.get() is False + stack = token_stack.get() + assert len(stack) == 1 + _, span = stack[0] + assert span.name == "normal" + + bridge.post_run_step(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + assert "streamer::chunk_0" in span_names + assert "normal" in span_names + assert "streamer" not in span_names + + _reset_token_stack() + + +# ============================================================================ +# Test #11 (updated): child spans under action span for non-streaming +# ============================================================================ + + +def test_bridge_non_streaming_creates_child_spans_under_action_span(): + """For non-streaming actions, pre/post_start_span creates children of the action span.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter() + action = _make_mock_action("my_action", streaming=False) + + bridge.pre_run_step(action=action) + # Simulate a nested span (e.g., from TracerFactory) + mock_span = MagicMock() + mock_span.name = "inner_op" + bridge.pre_start_span(span=mock_span) + bridge.post_end_span(span=mock_span) + bridge.post_run_step(exception=None) + + spans = exporter.get_finished_spans() + assert len(spans) == 2 + action_span = next(s for s in spans if s.name == "my_action") + inner_span = next(s for s in spans if s.name == "inner_op") + assert inner_span.parent is not None + assert inner_span.parent.span_id == action_span.context.span_id + + _reset_token_stack() + + +# ============================================================================ +# Test #13: span timing excludes consumer time +# ============================================================================ + + +async def test_bridge_stream_span_timing_excludes_consumer_time(): + """Verify that chunk spans measure generation time, not consumer processing time.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + class SlowGeneratorAction(SingleStepStreamingAction): + """Each yield takes ~50ms of 'generation time'.""" + + async def stream_run_and_update(self, state, **run_kwargs): + for i in range(3): + await asyncio.sleep(0.05) # simulate generation time + yield {"i": i}, None + await asyncio.sleep(0.05) + yield {"i": 3}, state + + @property + def reads(self): + return [] + + @property + def writes(self): + return [] + + action = SlowGeneratorAction().with_name("slow_gen") + state = State({}) + + generator = _arun_single_step_streaming_action( + action=action, + state=state, + inputs={}, + sequence_id=0, + app_id="app", + partition_key="pk", + lifecycle_adapters=LifecycleAdapterSet(bridge), + ) + + # Consumer adds significant delay + async for item, state_update in generator: + await asyncio.sleep(0.3) # simulate slow consumer + + spans = exporter.get_finished_spans() + chunk_spans = [s for s in spans if "chunk_" in s.name] + assert len(chunk_spans) >= 3 # at least 3 intermediate + final + stop + + for span in chunk_spans: + duration_ns = span.end_time - span.start_time + duration_ms = duration_ns / 1e6 + # Each chunk should take roughly 50ms of generation time, + # NOT 350ms (50ms generation + 300ms consumer). + # Use generous tolerance to avoid flakiness. + assert duration_ms < 200, ( + f"Span {span.name} took {duration_ms:.0f}ms, expected <200ms. " + f"Consumer time is leaking into the span." + ) + + _reset_token_stack() + + +# ============================================================================ +# Test #15: full integration — astream_result produces per-yield spans +# ============================================================================ + + +async def test_astream_result_with_otel_bridge_produces_per_yield_spans(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + + class SimpleStreamer(SingleStepStreamingAction): + async def stream_run_and_update(self, state, **run_kwargs): + for i in range(5): + yield {"i": i}, None + yield {"i": 5}, state.update(done=True) + + @property + def reads(self): + return [] + + @property + def writes(self): + return ["done"] + + streamer = SimpleStreamer().with_name("streamer") + app = Application( + state=State({"done": False}), + entrypoint="streamer", + adapter_set=LifecycleAdapterSet(bridge), + partition_key="test", + uid="test-app", + graph=Graph( + actions=[streamer], + transitions=[], + ), + ) + + action, container = await app.astream_result(halt_after=["streamer"]) + _ = [item async for item in container] + result, state = await container.get() + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have: stream_result method span + chunk spans (no action span) + assert "stream_result" in span_names + assert "streamer" not in span_names # no action-level span + chunk_names = [n for n in span_names if n.startswith("streamer::chunk_")] + # 5 intermediate + 1 final + 1 StopIteration = 7 chunk spans + assert len(chunk_names) >= 5 + + # All chunk spans are children of stream_result + method_span = next(s for s in spans if s.name == "stream_result") + for s in spans: + if s.name.startswith("streamer::chunk_"): + assert s.parent is not None + assert s.parent.span_id == method_span.context.span_id + + _reset_token_stack() + + +# ============================================================================ +# Test #17: non-streaming action via astream_result still gets action span +# ============================================================================ + + +async def test_astream_result_with_otel_bridge_non_streaming_action(): + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter() + + class SimpleAction(SingleStepAction): + def run_and_update(self, state, **run_kwargs): + return {"val": 1}, state.update(val=1) + + @property + def reads(self): + return [] + + @property + def writes(self): + return ["val"] + + action_obj = SimpleAction().with_name("simple") + app = Application( + state=State({"val": 0}), + entrypoint="simple", + adapter_set=LifecycleAdapterSet(bridge), + partition_key="test", + uid="test-app", + graph=Graph( + actions=[action_obj], + transitions=[], + ), + ) + + action, container = await app.astream_result(halt_after=["simple"]) + result, state = await container.get() + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Non-streaming should have an action span + assert "simple" in span_names + # No chunk spans + chunk_names = [n for n in span_names if "chunk_" in n] + assert len(chunk_names) == 0 + + _reset_token_stack() + + +# ============================================================================ +# Mode: "single_span" — backwards compatible (action span, no chunks, no events) +# ============================================================================ + + +def test_bridge_single_span_mode_creates_action_span_for_streaming(): + """In 'single_span' mode, streaming actions get a normal action span, no chunk spans.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.SINGLE_SPAN) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # The action span should have been created + assert _skipped_action_span.get() is False + stack = token_stack.get() + assert len(stack) == 2 # method span + action span + + for i in range(3): + bridge.pre_stream_generate( + action="my_stream", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have action span + method span, no chunk spans + assert "my_stream" in span_names + assert "stream_result" in span_names + chunk_names = [n for n in span_names if "chunk_" in n] + assert len(chunk_names) == 0 + + # No span events on the action span + action_span = next(s for s in spans if s.name == "my_stream") + assert len(action_span.events) == 0 + + _reset_token_stack() + + +def test_bridge_single_span_mode_sets_attributes_on_action_span(): + """In 'single_span' mode, streaming attributes are set on the action span.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.SINGLE_SPAN) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # Accumulator should be initialized + assert _streaming_accumulator.get() is not None + + for i in range(3): + bridge.pre_stream_generate( + action="my_stream", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + action_span = next(s for s in spans if s.name == "my_stream") + + # Should have attributes, not events + assert len(action_span.events) == 0 + attrs = dict(action_span.attributes) + assert attrs["stream.iteration_count"] == 3 + assert "stream.generation_time_ms" in attrs + assert "stream.consumer_time_ms" in attrs + assert "stream.first_item_time_ms" in attrs + + # No chunk spans + chunk_names = [s.name for s in spans if "chunk_" in s.name] + assert len(chunk_names) == 0 + + _reset_token_stack() + + +# ============================================================================ +# Mode: "event" — action span + summary event, no chunk spans +# ============================================================================ + + +def test_bridge_event_mode_emits_stream_completed_event(): + """In 'event' mode, no action span. A stream_completed event is added to the method span.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.EVENT) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # Action span should be skipped + assert _skipped_action_span.get() is True + # Accumulator should be initialized + assert _streaming_accumulator.get() is not None + + for i in range(3): + bridge.pre_stream_generate( + action="my_stream", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + # Signal end of stream (StopIteration case) + bridge.pre_stream_generate( + action="my_stream", + item_index=3, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item=None, + item_index=3, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have only method span, no action span, no chunk spans + assert "stream_result" in span_names + assert "my_stream" not in span_names + chunk_names = [n for n in span_names if "chunk_" in n] + assert len(chunk_names) == 0 + + # Method span should have a stream_completed event + method_span = next(s for s in spans if s.name == "stream_result") + assert len(method_span.events) == 1 + event = method_span.events[0] + assert event.name == "stream_completed" + attrs = dict(event.attributes) + assert "stream.generation_time_ms" in attrs + assert "stream.consumer_time_ms" in attrs + assert "stream.total_time_ms" in attrs + assert attrs["stream.iteration_count"] == 3 + assert "stream.first_item_time_ms" in attrs + + # Accumulator should be cleaned up + assert _streaming_accumulator.get() is None + + _reset_token_stack() + + +def test_bridge_event_mode_emits_stream_error_event(): + """In 'event' mode with an exception, a stream_error event is emitted on method span.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.EVENT) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # One successful yield + bridge.pre_stream_generate( + action="my_stream", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": 0}, + item_index=0, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + # Error on second yield + exc = RuntimeError("stream failed") + bridge.pre_stream_generate( + action="my_stream", + item_index=1, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item=None, + item_index=1, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=exc, + ) + + bridge.post_run_step(exception=exc) + bridge.post_run_execute_call(exception=exc) + + spans = exporter.get_finished_spans() + # No action span — event is on the method span + method_span = next(s for s in spans if s.name == "stream_result") + + assert len(method_span.events) == 1 + event = method_span.events[0] + assert event.name == "stream_error" + attrs = dict(event.attributes) + assert attrs["stream.iteration_count"] == 1 # only 1 successful yield + assert "stream.error" in attrs + assert "stream failed" in attrs["stream.error"] + + _reset_token_stack() + + +def test_bridge_event_mode_accumulator_timing_values(): + """Verify that the accumulator separates generation time from consumer time.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.EVENT) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # Simulate 2 yields with measurable time gaps + bridge.pre_stream_generate( + action="my_stream", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + time.sleep(0.05) # ~50ms generation time + bridge.post_stream_generate( + item={"i": 0}, + item_index=0, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + time.sleep(0.1) # ~100ms consumer time + + bridge.pre_stream_generate( + action="my_stream", + item_index=1, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + time.sleep(0.05) # ~50ms generation time + bridge.post_stream_generate( + item={"i": 1}, + item_index=1, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + # Event is on the method span (no action span in EVENT mode) + method_span = next(s for s in spans if s.name == "stream_result") + event = method_span.events[0] + attrs = dict(event.attributes) + + gen_ms = attrs["stream.generation_time_ms"] + consumer_ms = attrs["stream.consumer_time_ms"] + total_ms = attrs["stream.total_time_ms"] + first_item_ms = attrs["stream.first_item_time_ms"] + + # Generation: ~100ms total (2 × 50ms) + assert gen_ms >= 50, f"generation_time_ms={gen_ms}, expected >= 50" + assert gen_ms < 300, f"generation_time_ms={gen_ms}, expected < 300" + + # Consumer: ~100ms (gap between first post and second pre) + assert consumer_ms >= 50, f"consumer_time_ms={consumer_ms}, expected >= 50" + assert consumer_ms < 300, f"consumer_time_ms={consumer_ms}, expected < 300" + + # Total should be >= generation + consumer + assert total_ms >= gen_ms, f"total_time_ms={total_ms} < generation_time_ms={gen_ms}" + + # First item time should be close to first generation time (~50ms) + assert first_item_ms >= 20, f"first_item_time_ms={first_item_ms}, expected >= 20" + assert first_item_ms < 200, f"first_item_time_ms={first_item_ms}, expected < 200" + + assert attrs["stream.iteration_count"] == 2 + + _reset_token_stack() + + +# ============================================================================ +# Mode: SINGLE_AND_CHUNK_SPANS — action span with attributes + per-yield child spans +# ============================================================================ + + +def test_bridge_single_and_chunk_spans_mode(): + """In SINGLE_AND_CHUNK_SPANS mode, action span has streaming attributes AND per-yield child spans exist.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.SINGLE_AND_CHUNK_SPANS) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # Action span should be created (not skipped) + assert _skipped_action_span.get() is False + + for i in range(3): + bridge.pre_stream_generate( + action="my_stream", + item_index=i, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": i}, + item_index=i, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have: method span + action span + 3 chunk spans + assert "stream_result" in span_names + assert "my_stream" in span_names + assert "my_stream::chunk_0" in span_names + assert "my_stream::chunk_1" in span_names + assert "my_stream::chunk_2" in span_names + assert len(spans) == 5 + + # Action span should have streaming attributes (not events) + action_span = next(s for s in spans if s.name == "my_stream") + assert len(action_span.events) == 0 + attrs = dict(action_span.attributes) + assert attrs["stream.iteration_count"] == 3 + assert "stream.generation_time_ms" in attrs + assert "stream.consumer_time_ms" in attrs + assert "stream.first_item_time_ms" in attrs + + # Chunk spans should be children of the action span + for s in spans: + if s.name.startswith("my_stream::chunk_"): + assert s.parent is not None + assert s.parent.span_id == action_span.context.span_id + + _reset_token_stack() + + +# ============================================================================ +# Mode: "chunk_spans" — per-yield spans, no action span (already covered above, +# this test verifies no event is emitted) +# ============================================================================ + + +def test_bridge_chunk_spans_mode_no_event_emitted(): + """In 'chunk_spans' mode, no span event is emitted (no accumulator).""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.CHUNK_SPANS) + action = _make_mock_action("my_stream", streaming=True) + now = datetime.datetime.now() + + bridge.pre_run_execute_call(method=burr_otel.ExecuteMethod.stream_result) + bridge.pre_run_step(action=action) + + # No accumulator in spans-only mode + assert _streaming_accumulator.get() is None + + bridge.pre_stream_generate( + action="my_stream", + item_index=0, + stream_initialize_time=now, + sequence_id=0, + app_id="app", + partition_key="pk", + ) + bridge.post_stream_generate( + item={"i": 0}, + item_index=0, + stream_initialize_time=now, + action="my_stream", + sequence_id=0, + app_id="app", + partition_key="pk", + exception=None, + ) + + bridge.post_run_step(exception=None) + bridge.post_run_execute_call(exception=None) + + spans = exporter.get_finished_spans() + + # Only chunk span + method span, no action span + span_names = [s.name for s in spans] + assert "my_stream::chunk_0" in span_names + assert "stream_result" in span_names + assert "my_stream" not in span_names + + # No events on any span + for s in spans: + assert len(s.events) == 0 + + _reset_token_stack() + + +# ============================================================================ +# Integration: "event" mode with astream_result +# ============================================================================ + + +async def test_astream_result_event_mode_produces_summary_event(): + """Full integration test: event mode with astream_result produces summary event.""" + _reset_token_stack() + bridge, exporter = _make_bridge_and_exporter(streaming_telemetry=STM.EVENT) + + class SimpleStreamer(SingleStepStreamingAction): + async def stream_run_and_update(self, state, **run_kwargs): + for i in range(5): + yield {"i": i}, None + yield {"i": 5}, state.update(done=True) + + @property + def reads(self): + return [] + + @property + def writes(self): + return ["done"] + + streamer = SimpleStreamer().with_name("streamer") + app = Application( + state=State({"done": False}), + entrypoint="streamer", + adapter_set=LifecycleAdapterSet(bridge), + partition_key="test", + uid="test-app", + graph=Graph( + actions=[streamer], + transitions=[], + ), + ) + + action, container = await app.astream_result(halt_after=["streamer"]) + _ = [item async for item in container] + result, state = await container.get() + + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + + # Should have only method span, no action span, no chunk spans + assert "stream_result" in span_names + assert "streamer" not in span_names + chunk_names = [n for n in span_names if "chunk_" in n] + assert len(chunk_names) == 0 + + # Method span should have stream_completed event + method_span = next(s for s in spans if s.name == "stream_result") + assert len(method_span.events) == 1 + event = method_span.events[0] + assert event.name == "stream_completed" + attrs = dict(event.attributes) + # 5 intermediate + 1 final = 6 yielded items + assert attrs["stream.iteration_count"] >= 5 + assert attrs["stream.generation_time_ms"] >= 0 + assert attrs["stream.consumer_time_ms"] >= 0 + assert attrs["stream.total_time_ms"] >= 0 + assert attrs["stream.first_item_time_ms"] >= 0 + + _reset_token_stack()
