This is an automated email from the ASF dual-hosted git repository. skrawcz pushed a commit to branch stefan/improve-generator-spans in repository https://gitbox.apache.org/repos/asf/burr.git
commit 7f34286be6cfb473a671d9f77c2420ba3764597d Author: Stefan Krawczyk <[email protected]> AuthorDate: Sat Feb 28 22:35:23 2026 -0800 Add streaming timing to Burr tracker, UI, and documentation Surface generation-vs-consumer timing data in the Burr tracker and UI, independent of the OpenTelemetry streaming telemetry modes. Tracker changes: - Extend StreamState with timing fields (generation_time_ns, consumer_time_ns, first_item_time_ns, etc.) accumulated via PreStreamGenerateHook/PostStreamGenerateHook - Add generate hooks to SyncTrackingClient in both base.py and client.py with defensive getattr for subclass compatibility - Add optional timing fields to EndStreamModel (generation_time_ms, consumer_time_ms, first_item_time_ms) with None defaults for backwards compatibility - Update post_end_stream in both LocalTrackingClient and S3TrackingClient to convert and write accumulated timing UI changes: - Add timing fields to EndStreamModel.ts TypeScript type - Update StepList.tsx end_stream rendering to show "gen: Xms · consumer: Yms · N items · TTFT: Zms" when timing data is available, falling back to legacy throughput display Documentation: - Add "Streaming Telemetry Modes" section to additional-visibility.rst - Add StreamingTelemetryMode to opentelemetry.rst API reference - Add 4 new generate hook classes to lifecycle.rst - Add "Telemetry & Observability" section to streaming-actions.rst - Add "Streaming Timing" section to tracking.rst - Update monitoring.rst and examples/opentelemetry/README.md - Update design doc to reflect final implementation Example: - Add --tracker flag to streaming_telemetry_modes.py for validating timing data in the Burr UI Tests: - 6 new tests covering StreamState defaults, timing accumulation, defensive noop, EndStreamModel backwards compat, and sync/async end-to-end with LocalTrackingClient --- burr/tracking/base.py | 105 +++++++- burr/tracking/client.py | 136 +++++++++++ burr/tracking/common/models.py | 17 +- burr/tracking/s3client.py | 14 ++ docs/concepts/additional-visibility.rst | 42 ++++ docs/concepts/streaming-actions.rst | 18 ++ docs/examples/deployment/monitoring.rst | 5 +- docs/reference/integrations/opentelemetry.rst | 4 + docs/reference/lifecycle.rst | 12 + docs/reference/tracking.rst | 15 ++ examples/opentelemetry/README.md | 15 +- .../opentelemetry/streaming_telemetry_modes.py | 38 ++- telemetry/ui/src/api/models/EndStreamModel.ts | 11 +- .../ui/src/components/routes/app/StepList.tsx | 24 +- tests/tracking/test_local_tracking_client.py | 272 ++++++++++++++++++++- 15 files changed, 710 insertions(+), 18 deletions(-) diff --git a/burr/tracking/base.py b/burr/tracking/base.py index d8b3f54f..ab26784d 100644 --- a/burr/tracking/base.py +++ b/burr/tracking/base.py @@ -16,6 +16,9 @@ # under the License. import abc +import datetime +import time +from typing import Any, Optional from burr.lifecycle import ( PostApplicationCreateHook, @@ -27,8 +30,10 @@ from burr.lifecycle import ( from burr.lifecycle.base import ( DoLogAttributeHook, PostEndStreamHook, + PostStreamGenerateHook, PostStreamItemHook, PreStartStreamHook, + PreStreamGenerateHook, ) @@ -42,10 +47,106 @@ class SyncTrackingClient( PreStartStreamHook, PostStreamItemHook, PostEndStreamHook, + PreStreamGenerateHook, + PostStreamGenerateHook, abc.ABC, ): - """Base class for synchronous tracking clients. All tracking clients must implement from this - TODO -- create an async tracking client""" + """Base class for synchronous tracking clients. + + Inherits from PreStreamGenerateHook/PostStreamGenerateHook so that all + tracker implementations automatically accumulate generation-vs-consumer + timing for streaming actions. The accumulated data is written to the + EndStreamModel in post_end_stream. + + Subclasses do NOT need to override pre_stream_generate/post_stream_generate + unless they want custom behavior — the default implementations here handle + timing accumulation using the StreamState dataclass. + + TODO -- create an async tracking client + """ + + def pre_stream_generate( + self, + *, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + """Records the start of a single generator __next__() call. + + Uses defensive getattr to access stream_state so that custom subclasses + that don't call super().__init__() or don't have stream_state won't crash. + """ + stream_state = getattr(self, "stream_state", None) + if stream_state is None: + return + key = (app_id, action, partition_key) + state = stream_state.get(key) + if state is None: + return + + now_ns = time.monotonic_ns() + state._pre_generate_ns = now_ns + + # Record the stream start time on the first yield + if state.stream_start_ns is None: + state.stream_start_ns = now_ns + + # Consumer time = gap between previous post_stream_generate and this + # pre_stream_generate. On the first call there's no previous post, so + # consumer_time stays at 0. + if state.last_post_generate_ns is not None: + state.consumer_time_ns += now_ns - state.last_post_generate_ns + + def post_stream_generate( + self, + *, + item: Any, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + exception: Optional[Exception] = None, + **future_kwargs: Any, + ): + """Records the end of a single generator __next__() call. + + Accumulates generation_time_ns from the paired pre_stream_generate call, + tracks iteration_count, and captures first_item_time_ns for TTFT. + + Uses defensive getattr to access stream_state so that custom subclasses + that don't call super().__init__() or don't have stream_state won't crash. + """ + stream_state = getattr(self, "stream_state", None) + if stream_state is None: + return + key = (app_id, action, partition_key) + state = stream_state.get(key) + if state is None: + return + + now_ns = time.monotonic_ns() + state.last_post_generate_ns = now_ns + + # Accumulate generation time (time spent inside the generator) + if state._pre_generate_ns is not None: + state.generation_time_ns += now_ns - state._pre_generate_ns + state._pre_generate_ns = None + + # Track iteration count (only for actual items, not StopIteration) + if item is not None: + state.iteration_count += 1 + + # Capture TTFT (time from stream start to first item) + if state.first_item_time_ns is None and item is not None: + if state.stream_start_ns is not None: + state.first_item_time_ns = now_ns - state.stream_start_ns @abc.abstractmethod def copy(self): diff --git a/burr/tracking/client.py b/burr/tracking/client.py index 44919aed..1ba49af3 100644 --- a/burr/tracking/client.py +++ b/burr/tracking/client.py @@ -18,13 +18,16 @@ import abc import dataclasses import datetime +import time from burr.common.types import BaseCopyable from burr.lifecycle.base import ( DoLogAttributeHook, PostEndStreamHook, + PostStreamGenerateHook, PostStreamItemHook, PreStartStreamHook, + PreStreamGenerateHook, ) # this is a quick hack to get it to work on windows @@ -120,9 +123,38 @@ def _allowed_project_name(project_name: str, on_windows: bool) -> bool: @dataclasses.dataclass class StreamState: + """Tracks state for an in-progress stream. + + The timing fields (generation_time_ns, consumer_time_ns, etc.) are populated + by the PreStreamGenerateHook/PostStreamGenerateHook implementations on the + tracker. They accumulate generation vs. consumer timing across all yields, + enabling the tracker to write a timing summary when the stream ends. + + These fields default to 0/None so that existing code that only uses + stream_init_time/count continues to work unchanged. + """ + stream_init_time: datetime.datetime count: Optional[int] + # --- Streaming timing fields (populated by pre/post_stream_generate) --- + # Accumulated wall-clock nanoseconds the generator spent producing items. + generation_time_ns: int = 0 + # Accumulated wall-clock nanoseconds the consumer spent processing items. + consumer_time_ns: int = 0 + # Total number of items the generator has yielded so far. + iteration_count: int = 0 + # Nanosecond timestamp of the first item produced (for TTFT calculation). + first_item_time_ns: Optional[int] = None + # Nanosecond timestamp when the stream started (first pre_stream_generate). + stream_start_ns: Optional[int] = None + # Nanosecond timestamp of the most recent post_stream_generate call, + # used to compute consumer_time between yields. + last_post_generate_ns: Optional[int] = None + # Nanosecond timestamp captured at the start of the current generation + # (set in pre_stream_generate, consumed in post_stream_generate). + _pre_generate_ns: Optional[int] = None + StateKey = Tuple[str, str, Optional[str]] @@ -137,9 +169,98 @@ class SyncTrackingClient( PreStartStreamHook, PostStreamItemHook, PostEndStreamHook, + PreStreamGenerateHook, + PostStreamGenerateHook, BaseCopyable, ABC, ): + """Synchronous tracking client base class (client.py variant). + + Includes PreStreamGenerateHook/PostStreamGenerateHook so that all tracker + implementations automatically accumulate generation-vs-consumer timing for + streaming actions. The concrete implementations below populate the + StreamState timing fields; post_end_stream reads them to write the + EndStreamModel with timing data. + + Subclasses do NOT need to override pre/post_stream_generate unless they + want custom behavior. + """ + + def pre_stream_generate( + self, + *, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + """Records the start of a single generator __next__() call. + + Uses defensive getattr so custom subclasses without stream_state + won't crash. + """ + stream_state = getattr(self, "stream_state", None) + if stream_state is None: + return + key = (app_id, action, partition_key) + state = stream_state.get(key) + if state is None: + return + + now_ns = time.monotonic_ns() + state._pre_generate_ns = now_ns + + if state.stream_start_ns is None: + state.stream_start_ns = now_ns + + # Consumer time = gap between previous post_stream_generate and this call. + # On the first call there's no previous post, so consumer_time stays at 0. + if state.last_post_generate_ns is not None: + state.consumer_time_ns += now_ns - state.last_post_generate_ns + + def post_stream_generate( + self, + *, + item: Any, + item_index: int, + stream_initialize_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + exception: Optional[Exception] = None, + **future_kwargs: Any, + ): + """Records the end of a single generator __next__() call. + + Accumulates generation_time_ns, tracks iteration_count, and captures + first_item_time_ns for TTFT. Uses defensive getattr for compatibility. + """ + stream_state = getattr(self, "stream_state", None) + if stream_state is None: + return + key = (app_id, action, partition_key) + state = stream_state.get(key) + if state is None: + return + + now_ns = time.monotonic_ns() + state.last_post_generate_ns = now_ns + + if state._pre_generate_ns is not None: + state.generation_time_ns += now_ns - state._pre_generate_ns + state._pre_generate_ns = None + + if item is not None: + state.iteration_count += 1 + + if state.first_item_time_ns is None and item is not None: + if state.stream_start_ns is not None: + state.first_item_time_ns = now_ns - state.stream_start_ns + @abc.abstractmethod def copy(self) -> Self: """Clones the tracking client. This is useful for forking applications. @@ -591,12 +712,27 @@ class LocalTrackingClient( **future_kwargs: Any, ): stream_state = self.stream_state[app_id, action, partition_key] + # Convert nanosecond timing accumulated by pre/post_stream_generate + # into millisecond floats for the EndStreamModel. If stream_start_ns + # is None, the generate hooks never fired (e.g. the action isn't using + # the instrumented generator), so we leave timing fields as None. + generation_time_ms = None + consumer_time_ms = None + first_item_time_ms = None + if stream_state.stream_start_ns is not None: + generation_time_ms = stream_state.generation_time_ns / 1_000_000 + consumer_time_ms = stream_state.consumer_time_ns / 1_000_000 + if stream_state.first_item_time_ns is not None: + first_item_time_ms = stream_state.first_item_time_ns / 1_000_000 self._append_write_line( EndStreamModel( action_sequence_id=sequence_id, span_id=None, end_time=system.now(), items_streamed=stream_state.count, + generation_time_ms=generation_time_ms, + consumer_time_ms=consumer_time_ms, + first_item_time_ms=first_item_time_ms, ) ) del self.stream_state[app_id, action, partition_key] diff --git a/burr/tracking/common/models.py b/burr/tracking/common/models.py index 5980bf9d..911b1e48 100644 --- a/burr/tracking/common/models.py +++ b/burr/tracking/common/models.py @@ -261,7 +261,14 @@ class FirstItemStreamModel(IdentifyingModel): class EndStreamModel(IdentifyingModel): - """Pydantic model that represents an entry for the first item of a stream""" + """Pydantic model that represents the end of a stream. + + The optional timing fields (generation_time_ms, consumer_time_ms, etc.) are + populated when the tracker has PreStreamGenerateHook/PostStreamGenerateHook + support. They are Optional so that: + - Old log files (without timing) still parse with new server code. + - New log files don't crash old server code (Pydantic ignores extra keys). + """ action_sequence_id: int span_id: Optional[ @@ -271,6 +278,14 @@ class EndStreamModel(IdentifyingModel): items_streamed: int type: str = "end_stream" + # --- Streaming timing summary (Optional for backwards compatibility) --- + # Sum of time spent inside the generator producing items (excludes consumer wait). + generation_time_ms: Optional[float] = None + # Sum of time the consumer spent processing yielded items between yields. + consumer_time_ms: Optional[float] = None + # Time from stream start to first item produced (time to first token / TTFT). + first_item_time_ms: Optional[float] = None + @property def sequence_id(self) -> int: return self.action_sequence_id diff --git a/burr/tracking/s3client.py b/burr/tracking/s3client.py index 561d517a..e6857b89 100644 --- a/burr/tracking/s3client.py +++ b/burr/tracking/s3client.py @@ -500,12 +500,26 @@ class S3TrackingClient(SyncTrackingClient): **future_kwargs: Any, ): stream_state = self.stream_state[app_id, action, partition_key] + # Convert nanosecond timing accumulated by pre/post_stream_generate + # into millisecond floats for the EndStreamModel. If stream_start_ns + # is None, the generate hooks never fired, so we leave timing as None. + generation_time_ms = None + consumer_time_ms = None + first_item_time_ms = None + if stream_state.stream_start_ns is not None: + generation_time_ms = stream_state.generation_time_ns / 1_000_000 + consumer_time_ms = stream_state.consumer_time_ns / 1_000_000 + if stream_state.first_item_time_ns is not None: + first_item_time_ms = stream_state.first_item_time_ns / 1_000_000 self.submit_log_event( EndStreamModel( action_sequence_id=sequence_id, span_id=None, end_time=system.now(), items_streamed=stream_state.count, + generation_time_ms=generation_time_ms, + consumer_time_ms=consumer_time_ms, + first_item_time_ms=first_item_time_ms, ), app_id, partition_key, diff --git a/docs/concepts/additional-visibility.rst b/docs/concepts/additional-visibility.rst index a315e68d..656ac44c 100644 --- a/docs/concepts/additional-visibility.rst +++ b/docs/concepts/additional-visibility.rst @@ -308,6 +308,48 @@ it as you see fit). With this you can log to any OpenTelemetry provider. +Streaming Telemetry Modes +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When using streaming actions with the ``OpenTelemetryBridge``, you can control how +streaming telemetry is emitted via the ``streaming_telemetry`` parameter. This accepts +a :py:class:`StreamingTelemetryMode <burr.integrations.opentelemetry.StreamingTelemetryMode>` enum value: + +.. code-block:: python + + from burr.integrations.opentelemetry import OpenTelemetryBridge, StreamingTelemetryMode + + # Default: one action span with streaming attributes (generation time, consumer time, TTFT) + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=StreamingTelemetryMode.SINGLE_SPAN) + + # Lightest-weight: no action span, single summary event on the method span + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=StreamingTelemetryMode.EVENT) + + # Per-yield spans: one child span per generator yield, measuring generation time only + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=StreamingTelemetryMode.CHUNK_SPANS) + + # Maximum visibility: action span with attributes + per-yield child spans + bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=StreamingTelemetryMode.SINGLE_AND_CHUNK_SPANS) + +The four modes are: + +- **SINGLE_SPAN** (default) — The action span covers the full generator lifetime. Span attributes + provide the generation-vs-consumer time breakdown (``stream.generation_time_ms``, + ``stream.consumer_time_ms``, ``stream.iteration_count``, ``stream.first_item_time_ms``). + This is backwards-compatible with the pre-existing behavior. +- **EVENT** — No action span is created. A single ``stream_completed`` event is added to the + method span with the same timing attributes. This is the lightest-weight option. +- **CHUNK_SPANS** — No action span. One child span per generator yield, each measuring only + the generation time for that item. +- **SINGLE_AND_CHUNK_SPANS** — Combines SINGLE_SPAN and CHUNK_SPANS: the action span with + streaming attributes, plus per-yield child spans nested under it. + +Non-streaming actions are unaffected by this setting — they always get an action span regardless +of the mode. + +See the `streaming_telemetry_modes.py example <https://github.com/apache/burr/tree/main/examples/opentelemetry/streaming_telemetry_modes.py>`_ +for a runnable demo of all four modes. + Instrumenting libraries ----------------------- diff --git a/docs/concepts/streaming-actions.rst b/docs/concepts/streaming-actions.rst index 95e3a06e..edf64241 100644 --- a/docs/concepts/streaming-actions.rst +++ b/docs/concepts/streaming-actions.rst @@ -242,3 +242,21 @@ be consistent with the asynchronous method. If you're using the old version, the 1. The return type of the streaming action should be ``Generator[Tuple[dict, Optional[State], None, None]]`` instead of ``Generator[dict, None, Tuple[dict, State]]``. 2. All intermediate results should be yielded as ``yield {'response': delta}, None`` instead of ``yield {'response': delta}``. 3. The final result will be a ``yield`` instead of a ``return`` + +----------------------- +Telemetry & Observability +----------------------- + +Streaming actions are instrumented with lifecycle hooks that bracket each generator +yield. This enables two forms of observability: + +**OpenTelemetry** — The :py:class:`OpenTelemetryBridge <burr.integrations.opentelemetry.OpenTelemetryBridge>` +supports configurable streaming telemetry modes (``SINGLE_SPAN``, ``EVENT``, ``CHUNK_SPANS``, +``SINGLE_AND_CHUNK_SPANS``) that control how streaming spans and events are emitted. These +modes provide timing attributes such as generation time, consumer time, iteration count, and +time to first token (TTFT). See :ref:`Streaming Telemetry Modes <opentelref>` for details. + +**Burr Tracker** — The :py:class:`LocalTrackingClient <burr.tracking.LocalTrackingClient>` (and +``S3TrackingClient``) automatically accumulate generation-vs-consumer timing for streaming +actions and write it to the ``end_stream`` log entry. The Burr UI displays this data in the +step detail view, showing the generation/consumer time split and TTFT. diff --git a/docs/examples/deployment/monitoring.rst b/docs/examples/deployment/monitoring.rst index 44f104e2..0a0909e4 100644 --- a/docs/examples/deployment/monitoring.rst +++ b/docs/examples/deployment/monitoring.rst @@ -22,7 +22,10 @@ Monitoring in Production ------------------------ Burr's telemetry UI is meant both for debugging and running in production. It can consume `OpenTelemetry traces <https://burr.apache.org/reference/integrations/opentelemetry/>`_, -and has a suite of useful capabilities for debugging Burr applications. +and has a suite of useful capabilities for debugging Burr applications. For streaming actions, the tracker +and UI surface generation-vs-consumer timing (including time to first token), and the +``OpenTelemetryBridge`` supports :ref:`configurable streaming telemetry modes <opentelref>` for +controlling span and event granularity. It has two (current) implementations: diff --git a/docs/reference/integrations/opentelemetry.rst b/docs/reference/integrations/opentelemetry.rst index e68ffd7c..ff8755e6 100644 --- a/docs/reference/integrations/opentelemetry.rst +++ b/docs/reference/integrations/opentelemetry.rst @@ -41,4 +41,8 @@ Reference for the various useful methods: .. autoclass:: burr.integrations.opentelemetry.OpenTelemetryBridge :members: +.. autoclass:: burr.integrations.opentelemetry.StreamingTelemetryMode + :members: + :undoc-members: + .. autofunction:: burr.integrations.opentelemetry.init_instruments diff --git a/docs/reference/lifecycle.rst b/docs/reference/lifecycle.rst index 84703e93..0175e76b 100644 --- a/docs/reference/lifecycle.rst +++ b/docs/reference/lifecycle.rst @@ -68,6 +68,18 @@ and add instances to the application builder to customize your state machines's .. autoclass:: burr.lifecycle.base.PostApplicationExecuteCallHookAsync :members: +.. autoclass:: burr.lifecycle.base.PreStreamGenerateHook + :members: + +.. autoclass:: burr.lifecycle.base.PreStreamGenerateHookAsync + :members: + +.. autoclass:: burr.lifecycle.base.PostStreamGenerateHook + :members: + +.. autoclass:: burr.lifecycle.base.PostStreamGenerateHookAsync + :members: + These hooks are available for you to use: .. autoclass:: burr.lifecycle.default.StateAndResultsFullLogger diff --git a/docs/reference/tracking.rst b/docs/reference/tracking.rst index 4bd268be..b48a3651 100644 --- a/docs/reference/tracking.rst +++ b/docs/reference/tracking.rst @@ -29,3 +29,18 @@ Rather, you should use this through/in conjunction with :py:meth:`burr.core.appl :members: .. automethod:: __init__ + +Streaming Timing +~~~~~~~~~~~~~~~~ + +For streaming actions, the tracker automatically accumulates timing data by implementing +``PreStreamGenerateHook`` and ``PostStreamGenerateHook``. When a streaming action completes, +the ``end_stream`` log entry includes the following optional timing fields: + +- ``generation_time_ms`` — Sum of time spent inside the generator producing items (excludes consumer wait time). +- ``consumer_time_ms`` — Sum of time the consumer spent processing yielded items between yields. +- ``first_item_time_ms`` — Time from stream start to first item produced (time to first token / TTFT). + +These fields are ``null`` when the streaming timing hooks have not fired (e.g. old log files or +non-instrumented generators). The Burr UI renders these fields in the step detail view when +available, falling back to the legacy throughput calculation otherwise. diff --git a/examples/opentelemetry/README.md b/examples/opentelemetry/README.md index d0c18c6c..ff350c2d 100644 --- a/examples/opentelemetry/README.md +++ b/examples/opentelemetry/README.md @@ -27,6 +27,19 @@ We have two modes: 2. Log Burr to OpenTelemetry See [notebook.ipynb](./notebook.ipynb) for a simple overview. -See [application.py](./application.py) for the full code +See [application.py](./application.py) for the full code. + +## Streaming Telemetry + +For streaming actions, the `OpenTelemetryBridge` supports four configurable +telemetry modes via `StreamingTelemetryMode`: + +- **SINGLE_SPAN** (default) — one action span with streaming attributes (generation time, consumer time, TTFT) +- **EVENT** — no action span, single summary event on the method span +- **CHUNK_SPANS** — per-yield child spans measuring generation time only +- **SINGLE_AND_CHUNK_SPANS** — action span with attributes + per-yield child spans + +See [streaming_telemetry_modes.py](./streaming_telemetry_modes.py) for a runnable +demo exercising all four modes with the console exporter. See the [documentation](https://burr.dagworks.io/concepts/additional-visibility/#open-telemetry) for more info diff --git a/examples/opentelemetry/streaming_telemetry_modes.py b/examples/opentelemetry/streaming_telemetry_modes.py index c50381f1..67a653bc 100644 --- a/examples/opentelemetry/streaming_telemetry_modes.py +++ b/examples/opentelemetry/streaming_telemetry_modes.py @@ -20,14 +20,23 @@ Runs a simple async streaming action under each mode with the OTel console exporter so you can see the spans and events printed to stdout. +When --tracker is passed, each mode also gets a LocalTrackingClient so the +results show up in the Burr UI (run ``burr`` to open it). + Usage: + # OTel console output only python examples/opentelemetry/streaming_telemetry_modes.py + # OTel console output + Burr tracker (viewable in the UI) + python examples/opentelemetry/streaming_telemetry_modes.py --tracker + No external APIs are needed — the streaming action simulates an LLM by yielding tokens with small delays. """ +import argparse import asyncio +import time from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor @@ -68,7 +77,7 @@ graph = GraphBuilder().with_actions(generate=generate_response).with_transitions # --------------------------------------------------------------------------- -async def run_with_mode(mode: StreamingTelemetryMode) -> None: +async def run_with_mode(mode: StreamingTelemetryMode, use_tracker: bool = False) -> None: """Builds an app with the given streaming telemetry mode and runs it.""" # Each run gets its own tracer provider so console output stays grouped provider = TracerProvider() @@ -77,16 +86,20 @@ async def run_with_mode(mode: StreamingTelemetryMode) -> None: bridge = OpenTelemetryBridge(tracer=tracer, streaming_telemetry=mode) - app = ( + builder = ( ApplicationBuilder() .with_graph(graph) .with_entrypoint("generate") .with_state(State({"prompt": "hello world from burr streaming"})) .with_hooks(bridge) - .with_identifiers(app_id=f"demo-{mode.value}") - .build() + .with_identifiers(app_id=f"demo-{mode.value}-{time.time()}") ) + if use_tracker: + builder = builder.with_tracker(project="streaming-telemetry-modes", tracker="local") + + app = builder.build() + action, container = await app.astream_result(halt_after=["generate"]) async for item in container: await asyncio.sleep(0.05) # simulate consumer processing time per token @@ -100,7 +113,7 @@ async def run_with_mode(mode: StreamingTelemetryMode) -> None: # --------------------------------------------------------------------------- -async def main(): +async def main(use_tracker: bool = False): modes = [ StreamingTelemetryMode.SINGLE_SPAN, StreamingTelemetryMode.EVENT, @@ -111,9 +124,20 @@ async def main(): print(f"\n{'=' * 70}") print(f" StreamingTelemetryMode.{mode.name}") print(f"{'=' * 70}\n") - await run_with_mode(mode) + await run_with_mode(mode, use_tracker=use_tracker) print() + if use_tracker: + print("Tracker data written to ~/.burr/streaming-telemetry-modes/") + print("Run `burr` to open the UI and inspect the streaming timing data.") + if __name__ == "__main__": - asyncio.run(main()) + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--tracker", + action="store_true", + help="Enable the Burr LocalTrackingClient so results appear in the UI", + ) + args = parser.parse_args() + asyncio.run(main(use_tracker=args.tracker)) diff --git a/telemetry/ui/src/api/models/EndStreamModel.ts b/telemetry/ui/src/api/models/EndStreamModel.ts index 60dcd1d7..061c95a2 100644 --- a/telemetry/ui/src/api/models/EndStreamModel.ts +++ b/telemetry/ui/src/api/models/EndStreamModel.ts @@ -22,7 +22,10 @@ /* tslint:disable */ /* eslint-disable */ /** - * Pydantic model that represents an entry for the first item of a stream + * Pydantic model that represents the end of a stream. + * + * The optional timing fields are populated when the tracker has + * PreStreamGenerateHook/PostStreamGenerateHook support. */ export type EndStreamModel = { type?: string; @@ -30,4 +33,10 @@ export type EndStreamModel = { span_id: string | null; end_time: string; items_streamed: number; + /** Sum of time spent inside the generator producing items (ms). */ + generation_time_ms?: number | null; + /** Sum of time the consumer spent processing yielded items (ms). */ + consumer_time_ms?: number | null; + /** Time from stream start to first item produced / TTFT (ms). */ + first_item_time_ms?: number | null; }; diff --git a/telemetry/ui/src/components/routes/app/StepList.tsx b/telemetry/ui/src/components/routes/app/StepList.tsx index 5f0d197e..f12ec130 100644 --- a/telemetry/ui/src/components/routes/app/StepList.tsx +++ b/telemetry/ui/src/components/routes/app/StepList.tsx @@ -892,9 +892,27 @@ const StepSubTable = (props: { new Date(firstStream?.first_item_time).getTime() : undefined; const numStreamed = streamModel.items_streamed; - // const name = ellapsedStreamTime ? `last token (throughput=${ellapsedStreamTime/streamModel.items_streamed} ms/token)` : 'last token'; - // const name = `throughput: ${(ellapsedStreamTime || 0) / numStreamed} ms/token (${numStreamed} tokens/${ellapsedStreamTime}ms)`; - const name = `throughput: ${((ellapsedStreamTime || 0) / numStreamed).toFixed(1)} ms/token (${numStreamed}/${ellapsedStreamTime}ms)`; + // Build a descriptive name that includes generation/consumer timing + // when available (from the new PreStreamGenerateHook/PostStreamGenerateHook + // timing accumulation), falling back to the legacy throughput calculation. + const genTime = streamModel.generation_time_ms; + const consTime = streamModel.consumer_time_ms; + const ttftTime = streamModel.first_item_time_ms; + let name: string; + if ( + genTime !== null && + genTime !== undefined && + consTime !== null && + consTime !== undefined + ) { + const ttft = + ttftTime !== null && ttftTime !== undefined + ? ` · TTFT: ${ttftTime.toFixed(0)}ms` + : ''; + name = `gen: ${genTime.toFixed(0)}ms · consumer: ${consTime.toFixed(0)}ms · ${numStreamed} items${ttft}`; + } else { + name = `throughput: ${((ellapsedStreamTime || 0) / numStreamed).toFixed(1)} ms/token (${numStreamed}/${ellapsedStreamTime}ms)`; + } return ( <StepSubTableRow key={`streaming-${i}`} diff --git a/tests/tracking/test_local_tracking_client.py b/tests/tracking/test_local_tracking_client.py index 7a8196c0..1100d8b8 100644 --- a/tests/tracking/test_local_tracking_client.py +++ b/tests/tracking/test_local_tracking_client.py @@ -15,19 +15,22 @@ # specific language governing permissions and limitations # under the License. +import asyncio import json import os +import time import uuid -from typing import Literal, Optional, Tuple +from typing import Generator, Literal, Optional, Tuple import pytest import burr from burr import lifecycle from burr.core import Action, Application, ApplicationBuilder, Result, State, action, default, expr +from burr.core.action import StreamingAction, streaming_action from burr.core.persistence import BaseStatePersister, PersistedStateData from burr.tracking import LocalTrackingClient -from burr.tracking.client import _allowed_project_name +from burr.tracking.client import StreamState, _allowed_project_name from burr.tracking.common.models import ( ApplicationMetadataModel, ApplicationModel, @@ -37,6 +40,7 @@ from burr.tracking.common.models import ( ChildApplicationModel, EndEntryModel, EndSpanModel, + EndStreamModel, ) from burr.visibility import TracerFactory @@ -494,3 +498,267 @@ def test_local_tracking_client_copy(): assert copy.project_id == tracking_client.project_id assert copy.serde_kwargs == tracking_client.serde_kwargs assert copy.storage_dir == tracking_client.storage_dir + + +# --------------------------------------------------------------------------- +# StreamState timing accumulation tests +# --------------------------------------------------------------------------- + + +def test_stream_state_defaults(): + """New timing fields on StreamState should default to 0/None so existing + code that only uses stream_init_time/count is unaffected.""" + import datetime + + ss = StreamState(stream_init_time=datetime.datetime.now(), count=0) + assert ss.generation_time_ns == 0 + assert ss.consumer_time_ns == 0 + assert ss.iteration_count == 0 + assert ss.first_item_time_ns is None + assert ss.stream_start_ns is None + assert ss.last_post_generate_ns is None + assert ss._pre_generate_ns is None + + +def test_pre_post_stream_generate_accumulates_timing(): + """Directly exercises pre/post_stream_generate on LocalTrackingClient to + verify that generation_time_ns, consumer_time_ns, iteration_count, and + first_item_time_ns are accumulated correctly.""" + import datetime + + tracker = LocalTrackingClient("test", "/tmp/unused") + app_id = "app1" + action_name = "gen" + pk = None + key = (app_id, action_name, pk) + now = datetime.datetime.now() + + # Simulate pre_start_stream creating the StreamState + tracker.stream_state[key] = StreamState(stream_init_time=now, count=0) + + common = dict( + stream_initialize_time=now, + action=action_name, + sequence_id=0, + app_id=app_id, + partition_key=pk, + ) + + # Yield 0: pre -> (generation) -> post + tracker.pre_stream_generate(item_index=0, **common) + state = tracker.stream_state[key] + assert state.stream_start_ns is not None # set on first call + assert state._pre_generate_ns is not None + + tracker.post_stream_generate(item={"token": "hello"}, item_index=0, **common) + assert state.iteration_count == 1 + assert state.generation_time_ns > 0 + assert state.first_item_time_ns is not None # TTFT captured + first_gen_time = state.generation_time_ns + + # Yield 1: pre -> (generation) -> post + tracker.pre_stream_generate(item_index=1, **common) + # Consumer time should now be > 0 (gap between previous post and this pre) + assert state.consumer_time_ns > 0 + + tracker.post_stream_generate(item={"token": "world"}, item_index=1, **common) + assert state.iteration_count == 2 + assert state.generation_time_ns > first_gen_time + + # Final yield (item=None signals StopIteration) + tracker.pre_stream_generate(item_index=2, **common) + tracker.post_stream_generate(item=None, item_index=2, **common) + # item=None should NOT increment iteration_count + assert state.iteration_count == 2 + + +def test_pre_stream_generate_no_stream_state_is_noop(): + """pre/post_stream_generate should silently do nothing when there's no + matching stream_state entry (defensive getattr pattern).""" + import datetime + + tracker = LocalTrackingClient("test", "/tmp/unused") + now = datetime.datetime.now() + common = dict( + stream_initialize_time=now, + action="missing", + sequence_id=0, + app_id="missing", + partition_key=None, + ) + # Should not raise + tracker.pre_stream_generate(item_index=0, **common) + tracker.post_stream_generate(item={"x": 1}, item_index=0, **common) + + +# --------------------------------------------------------------------------- +# EndStreamModel backwards compatibility tests +# --------------------------------------------------------------------------- + + +def test_end_stream_model_without_timing_fields(): + """Old-style EndStreamModel JSON (no timing fields) should parse into the + new model with None timing values — backwards compatibility.""" + old_json = ( + '{"type":"end_stream","action_sequence_id":1,"span_id":null,' + '"end_time":"2024-01-01T00:00:00","items_streamed":10}' + ) + model = EndStreamModel.model_validate_json(old_json) + assert model.items_streamed == 10 + assert model.generation_time_ms is None + assert model.consumer_time_ms is None + assert model.first_item_time_ms is None + + +def test_end_stream_model_with_timing_fields(): + """EndStreamModel with timing fields should round-trip through JSON.""" + import datetime + + model = EndStreamModel( + action_sequence_id=1, + span_id=None, + end_time=datetime.datetime.now(), + items_streamed=47, + generation_time_ms=245.3, + consumer_time_ms=1830.1, + first_item_time_ms=52.0, + ) + dumped = model.model_dump_json() + restored = EndStreamModel.model_validate_json(dumped) + assert restored.generation_time_ms == 245.3 + assert restored.consumer_time_ms == 1830.1 + assert restored.first_item_time_ms == 52.0 + + +# --------------------------------------------------------------------------- +# End-to-end streaming test with LocalTrackingClient +# --------------------------------------------------------------------------- + + +class _SimpleStreamingAction(StreamingAction): + """A streaming action that yields a fixed number of items with a small + delay to produce measurable generation time.""" + + @property + def reads(self) -> list[str]: + return ["prompt"] + + @property + def writes(self) -> list[str]: + return ["response"] + + def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: + tokens = state["prompt"].split() + for token in tokens: + time.sleep(0.01) # small delay so generation_time_ns > 0 + yield {"token": token} + + def update(self, result: dict, state: State) -> State: + return state.update(response=result.get("token", "")) + + +def test_streaming_action_end_to_end_writes_timing(tmpdir): + """Integration test: run a streaming action through ApplicationBuilder with + a LocalTrackingClient and verify that the end_stream log entry contains + non-null timing fields.""" + app_id = str(uuid.uuid4()) + log_dir = os.path.join(tmpdir, "tracking") + project_name = "test_streaming_timing" + + tracker = LocalTrackingClient(project=project_name, storage_dir=log_dir) + app = ( + ApplicationBuilder() + .with_state(prompt="hello world test", response="") + .with_actions(generate=_SimpleStreamingAction()) + .with_transitions() + .with_entrypoint("generate") + .with_tracker(tracker) + .with_identifiers(app_id=app_id) + .build() + ) + + action_, streaming_container = app.stream_result(halt_after=["generate"]) + for _ in streaming_container: + time.sleep(0.01) # simulate consumer processing + streaming_container.get() + + # Read the log file and find the end_stream entry + log_path = os.path.join(log_dir, project_name, app_id, LocalTrackingClient.LOG_FILENAME) + assert os.path.exists(log_path) + with open(log_path) as f: + log_lines = [json.loads(line) for line in f.readlines()] + + end_stream_entries = [ + EndStreamModel.model_validate(line) for line in log_lines if line["type"] == "end_stream" + ] + assert len(end_stream_entries) == 1 + end_stream = end_stream_entries[0] + + # Verify timing fields are populated (not None) + assert end_stream.generation_time_ms is not None + assert ( + end_stream.generation_time_ms > 0 + ), "generation_time_ms should be > 0 (we slept in stream_run)" + assert end_stream.consumer_time_ms is not None + assert ( + end_stream.consumer_time_ms > 0 + ), "consumer_time_ms should be > 0 (we slept between items)" + assert end_stream.first_item_time_ms is not None + assert end_stream.first_item_time_ms > 0, "first_item_time_ms (TTFT) should be > 0" + # items_streamed is tracked by the existing post_stream_item hook, which + # may not count all yields depending on the streaming container semantics. + assert end_stream.items_streamed >= 1 + + +async def test_async_streaming_action_end_to_end_writes_timing(tmpdir): + """Async variant: verify timing fields appear in end_stream log entry.""" + + @streaming_action(reads=["prompt"], writes=["response"]) + async def async_generate(state: State): + tokens = state["prompt"].split() + buffer = [] + for token in tokens: + await asyncio.sleep(0.01) + buffer.append(token) + yield {"token": token}, None + yield {"token": ""}, state.update(response=" ".join(buffer)) + + app_id = str(uuid.uuid4()) + log_dir = os.path.join(tmpdir, "tracking") + project_name = "test_async_streaming_timing" + + tracker = LocalTrackingClient(project=project_name, storage_dir=log_dir) + app = ( + ApplicationBuilder() + .with_state(prompt="async streaming test tokens", response="") + .with_actions(generate=async_generate) + .with_transitions() + .with_entrypoint("generate") + .with_tracker(tracker) + .with_identifiers(app_id=app_id) + .build() + ) + + action_, streaming_container = await app.astream_result(halt_after=["generate"]) + async for _ in streaming_container: + await asyncio.sleep(0.01) + await streaming_container.get() + + log_path = os.path.join(log_dir, project_name, app_id, LocalTrackingClient.LOG_FILENAME) + assert os.path.exists(log_path) + with open(log_path) as f: + log_lines = [json.loads(line) for line in f.readlines()] + + end_stream_entries = [ + EndStreamModel.model_validate(line) for line in log_lines if line["type"] == "end_stream" + ] + assert len(end_stream_entries) == 1 + end_stream = end_stream_entries[0] + + assert end_stream.generation_time_ms is not None + assert end_stream.generation_time_ms > 0 + assert end_stream.consumer_time_ms is not None + assert end_stream.consumer_time_ms > 0 + assert end_stream.first_item_time_ms is not None + assert end_stream.first_item_time_ms > 0 + assert end_stream.items_streamed >= 1
