This is an automated email from the ASF dual-hosted git repository.
skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/burr.git
The following commit(s) were added to refs/heads/main by this push:
new fc2f1447 Fixes some otel and nested burr application bugs (#590)
fc2f1447 is described below
commit fc2f14477521050ef1cba481d7c8118adfbd5925
Author: Stefan Krawczyk <[email protected]>
AuthorDate: Sun Nov 9 20:15:56 2025 -0800
Fixes some otel and nested burr application bugs (#590)
* Fixes some otel and nested burr application bugs
This was caught trying to run nested burr with otel.
* Updates test structure
Important that mocks use `spec=` where possible.
---
burr/core/application.py | 2 +-
burr/integrations/opentelemetry.py | 19 ++--
tests/core/test_application.py | 43 ++++++++
tests/integrations/test_burr_opentelemetry.py | 151 +++++++++++++++++++++++++-
4 files changed, 204 insertions(+), 11 deletions(-)
diff --git a/burr/core/application.py b/burr/core/application.py
index 0450d983..55f98acf 100644
--- a/burr/core/application.py
+++ b/burr/core/application.py
@@ -232,7 +232,7 @@ def _state_update(state_to_modify: State, modified_state:
State) -> State:
def _validate_reducer_writes(reducer: Reducer, state: State, name: str) ->
None:
required_writes = reducer.writes
- missing_writes = set(reducer.writes) - state.keys()
+ missing_writes = set(reducer.writes) - set(state.keys())
if len(missing_writes) > 0:
raise ValueError(
f"State is missing write keys after running: {name}. Missing keys
are: {missing_writes}. "
diff --git a/burr/integrations/opentelemetry.py
b/burr/integrations/opentelemetry.py
index 78944569..32dc4dd7 100644
--- a/burr/integrations/opentelemetry.py
+++ b/burr/integrations/opentelemetry.py
@@ -499,19 +499,20 @@ class BurrTrackingSpanProcessor(SpanProcessor):
app_id=parent_span.app_id,
),
)
- self.tracker.pre_start_span(
- action=context.action_span.action,
- action_sequence_id=context.action_span.action_sequence_id,
- span=context.action_span,
- span_dependencies=[], # TODO -- log
- app_id=context.app_id,
- partition_key=context.partition_key,
- )
+ if self.tracker is not None:
+ self.tracker.pre_start_span(
+ action=context.action_span.action,
+
action_sequence_id=context.action_span.action_sequence_id,
+ span=context.action_span,
+ span_dependencies=[], # TODO -- log
+ app_id=context.app_id,
+ partition_key=context.partition_key,
+ )
def on_end(self, span: "Span") -> None:
cached_span = get_cached_span(span.get_span_context().span_id)
# If this is none it means we're outside of the burr context
- if cached_span is not None:
+ if cached_span is not None and self.tracker is not None:
# TODO -- get tracker context to work
self.tracker.post_end_span(
action=cached_span.action_span.action,
diff --git a/tests/core/test_application.py b/tests/core/test_application.py
index 5ebd3341..7e0073dc 100644
--- a/tests/core/test_application.py
+++ b/tests/core/test_application.py
@@ -57,6 +57,7 @@ from burr.core.application import (
_run_reducer,
_run_single_step_action,
_run_single_step_streaming_action,
+ _validate_reducer_writes,
_validate_start,
)
from burr.core.graph import Graph, GraphBuilder, Transition
@@ -486,6 +487,48 @@ def test__run_reducer_deletes_state():
assert "count" not in state
+def test__validate_reducer_writes_with_state_keys_returning_list():
+ """Tests that _validate_reducer_writes works when state.keys() returns a
list.
+
+ This is a regression test for a bug where state.keys() could return a list
+ instead of a set, causing a TypeError when trying to do set subtraction.
+ """
+ # Create a reducer with some expected writes
+ reducer = PassedInAction(
+ reads=["input"],
+ writes=["output", "result"],
+ fn=...,
+ update_fn=lambda result, state: state.update(**result),
+ inputs=[],
+ )
+
+ # Create a state that has all the required writes
+ state = State({"input": 1, "output": 2, "result": 3})
+
+ # This should not raise a TypeError even if state.keys() returns a list
+ # (which was the original bug)
+ _validate_reducer_writes(reducer, state, "test_action")
+
+
+def test__validate_reducer_writes_raises_on_missing_keys():
+ """Tests that _validate_reducer_writes raises ValueError when required
keys are missing."""
+ # Create a reducer with some expected writes
+ reducer = PassedInAction(
+ reads=["input"],
+ writes=["output", "result", "missing_key"],
+ fn=...,
+ update_fn=lambda result, state: state.update(**result),
+ inputs=[],
+ )
+
+ # Create a state that is missing some required writes
+ state = State({"input": 1, "output": 2, "result": 3})
+
+ # This should raise a ValueError for missing "missing_key"
+ with pytest.raises(ValueError, match="missing_key"):
+ _validate_reducer_writes(reducer, state, "test_action")
+
+
async def test__arun_function():
"""Tests that we can run an async function"""
action = base_counter_action_async
diff --git a/tests/integrations/test_burr_opentelemetry.py
b/tests/integrations/test_burr_opentelemetry.py
index e5568202..9c5f640c 100644
--- a/tests/integrations/test_burr_opentelemetry.py
+++ b/tests/integrations/test_burr_opentelemetry.py
@@ -16,12 +16,22 @@
# under the License.
import json
+from unittest.mock import Mock, patch
import pydantic
import pytest
+from opentelemetry.sdk.trace import Span
+from opentelemetry.trace import SpanContext
from burr.core import serde
-from burr.integrations.opentelemetry import convert_to_otel_attribute
+from burr.integrations.opentelemetry import (
+ BurrTrackingSpanProcessor,
+ FullSpanContext,
+ convert_to_otel_attribute,
+ tracker_context,
+)
+from burr.tracking.base import SyncTrackingClient
+from burr.visibility import ActionSpan
class SampleModel(pydantic.BaseModel):
@@ -43,3 +53,142 @@ class SampleModel(pydantic.BaseModel):
)
def test_convert_to_otel_attribute(value, expected):
assert convert_to_otel_attribute(value) == expected
+
+
+def test_burr_tracking_span_processor_on_start_with_none_tracker():
+ """Test that on_start handles None tracker gracefully without raising an
error."""
+ processor = BurrTrackingSpanProcessor()
+
+ # Mock a span with a parent
+ mock_span = Mock(spec=Span)
+ mock_span.parent = Mock()
+ mock_span.parent.span_id = 12345
+ mock_span.name = "test_span"
+
+ # Mock the get_cached_span to return a parent span context
+ with patch("burr.integrations.opentelemetry.get_cached_span") as
mock_get_cached:
+ mock_parent_context = Mock(spec=FullSpanContext)
+ mock_parent_context.action_span = Mock(spec=ActionSpan)
+ mock_spawned_span = Mock(spec=ActionSpan)
+ mock_parent_context.action_span.spawn =
Mock(return_value=mock_spawned_span)
+ mock_parent_context.partition_key = "test_partition"
+ mock_parent_context.app_id = "test_app"
+ mock_get_cached.return_value = mock_parent_context
+
+ # Mock cache_span
+ with patch("burr.integrations.opentelemetry.cache_span"):
+ # Set tracker_context to None (simulating no tracker in context)
+ token = tracker_context.set(None)
+ try:
+ # This should not raise an error even though tracker is None
+ processor.on_start(mock_span, parent_context=None)
+ finally:
+ tracker_context.reset(token)
+
+
+def test_burr_tracking_span_processor_on_end_with_none_tracker():
+ """Test that on_end handles None tracker gracefully without raising an
error."""
+ processor = BurrTrackingSpanProcessor()
+
+ # Mock a span
+ mock_span = Mock(spec=Span)
+ mock_span_context = Mock(spec=SpanContext)
+ mock_span_context.span_id = 67890
+ mock_span.get_span_context = Mock(return_value=mock_span_context)
+ mock_span.attributes = {}
+
+ # Mock the get_cached_span to return a cached span
+ with patch("burr.integrations.opentelemetry.get_cached_span") as
mock_get_cached:
+ mock_cached_span = Mock(spec=FullSpanContext)
+ mock_cached_span.action_span = Mock(spec=ActionSpan)
+ mock_cached_span.action_span.action = "test_action"
+ mock_cached_span.action_span.action_sequence_id = 1
+ mock_cached_span.app_id = "test_app"
+ mock_cached_span.partition_key = "test_partition"
+ mock_get_cached.return_value = mock_cached_span
+
+ # Mock uncache_span
+ with patch("burr.integrations.opentelemetry.uncache_span"):
+ # Set tracker_context to None (simulating no tracker in context)
+ token = tracker_context.set(None)
+ try:
+ # This should not raise an error even though tracker is None
+ processor.on_end(mock_span)
+ finally:
+ tracker_context.reset(token)
+
+
+def test_burr_tracking_span_processor_on_start_with_valid_tracker():
+ """Test that on_start calls tracker methods when tracker is available."""
+ processor = BurrTrackingSpanProcessor()
+
+ # Mock a span with a parent
+ mock_span = Mock(spec=Span)
+ mock_span.parent = Mock()
+ mock_span.parent.span_id = 12345
+ mock_span.name = "test_span"
+
+ # Mock tracker
+ mock_tracker = Mock(spec=SyncTrackingClient)
+
+ # Mock the get_cached_span to return a parent span context
+ with patch("burr.integrations.opentelemetry.get_cached_span") as
mock_get_cached:
+ mock_parent_context = Mock(spec=FullSpanContext)
+ mock_parent_action_span = Mock(spec=ActionSpan)
+ mock_spawned_span = Mock(spec=ActionSpan)
+ mock_spawned_span.action = "test_action"
+ mock_spawned_span.action_sequence_id = 1
+ mock_parent_action_span.spawn = Mock(return_value=mock_spawned_span)
+ mock_parent_context.action_span = mock_parent_action_span
+ mock_parent_context.partition_key = "test_partition"
+ mock_parent_context.app_id = "test_app"
+ mock_get_cached.return_value = mock_parent_context
+
+ # Mock cache_span
+ with patch("burr.integrations.opentelemetry.cache_span"):
+ # Set tracker_context to a valid tracker
+ token = tracker_context.set(mock_tracker)
+ try:
+ processor.on_start(mock_span, parent_context=None)
+
+ # Verify that pre_start_span was called on the tracker
+ assert mock_tracker.pre_start_span.called
+ finally:
+ tracker_context.reset(token)
+
+
+def test_burr_tracking_span_processor_on_end_with_valid_tracker():
+ """Test that on_end calls tracker methods when tracker is available."""
+ processor = BurrTrackingSpanProcessor()
+
+ # Mock a span
+ mock_span = Mock(spec=Span)
+ mock_span_context = Mock(spec=SpanContext)
+ mock_span_context.span_id = 67890
+ mock_span.get_span_context = Mock(return_value=mock_span_context)
+ mock_span.attributes = {}
+
+ # Mock tracker
+ mock_tracker = Mock(spec=SyncTrackingClient)
+
+ # Mock the get_cached_span to return a cached span
+ with patch("burr.integrations.opentelemetry.get_cached_span") as
mock_get_cached:
+ mock_cached_span = Mock(spec=FullSpanContext)
+ mock_cached_span.action_span = Mock(spec=ActionSpan)
+ mock_cached_span.action_span.action = "test_action"
+ mock_cached_span.action_span.action_sequence_id = 1
+ mock_cached_span.app_id = "test_app"
+ mock_cached_span.partition_key = "test_partition"
+ mock_get_cached.return_value = mock_cached_span
+
+ # Mock uncache_span
+ with patch("burr.integrations.opentelemetry.uncache_span"):
+ # Set tracker_context to a valid tracker
+ token = tracker_context.set(mock_tracker)
+ try:
+ processor.on_end(mock_span)
+
+ # Verify that post_end_span was called on the tracker
+ assert mock_tracker.post_end_span.called
+ finally:
+ tracker_context.reset(token)