This is an automated email from the ASF dual-hosted git repository.

skrawcz pushed a commit to branch stefan/fix_nested_otel
in repository https://gitbox.apache.org/repos/asf/burr.git

commit 8d50ef1044a719700bc7507614357aae1d7a0134
Author: Stefan Krawczyk <[email protected]>
AuthorDate: Thu Oct 23 14:39:03 2025 -0700

    Fixes some otel and nested burr application bugs
    
    This was caught trying to run nested burr with otel.
---
 burr/core/application.py                      |   2 +-
 burr/integrations/opentelemetry.py            |  19 ++--
 tests/core/test_application.py                |  46 +++++++++
 tests/integrations/test_burr_opentelemetry.py | 141 +++++++++++++++++++++++++-
 4 files changed, 197 insertions(+), 11 deletions(-)

diff --git a/burr/core/application.py b/burr/core/application.py
index 9831e1af..1bcea94e 100644
--- a/burr/core/application.py
+++ b/burr/core/application.py
@@ -232,7 +232,7 @@ def _state_update(state_to_modify: State, modified_state: 
State) -> State:
 
 def _validate_reducer_writes(reducer: Reducer, state: State, name: str) -> 
None:
     required_writes = reducer.writes
-    missing_writes = set(reducer.writes) - state.keys()
+    missing_writes = set(reducer.writes) - set(state.keys())
     if len(missing_writes) > 0:
         raise ValueError(
             f"State is missing write keys after running: {name}. Missing keys 
are: {missing_writes}. "
diff --git a/burr/integrations/opentelemetry.py 
b/burr/integrations/opentelemetry.py
index 78944569..32dc4dd7 100644
--- a/burr/integrations/opentelemetry.py
+++ b/burr/integrations/opentelemetry.py
@@ -499,19 +499,20 @@ class BurrTrackingSpanProcessor(SpanProcessor):
                         app_id=parent_span.app_id,
                     ),
                 )
-                self.tracker.pre_start_span(
-                    action=context.action_span.action,
-                    action_sequence_id=context.action_span.action_sequence_id,
-                    span=context.action_span,
-                    span_dependencies=[],  # TODO -- log
-                    app_id=context.app_id,
-                    partition_key=context.partition_key,
-                )
+                if self.tracker is not None:
+                    self.tracker.pre_start_span(
+                        action=context.action_span.action,
+                        
action_sequence_id=context.action_span.action_sequence_id,
+                        span=context.action_span,
+                        span_dependencies=[],  # TODO -- log
+                        app_id=context.app_id,
+                        partition_key=context.partition_key,
+                    )
 
     def on_end(self, span: "Span") -> None:
         cached_span = get_cached_span(span.get_span_context().span_id)
         # If this is none it means we're outside of the burr context
-        if cached_span is not None:
+        if cached_span is not None and self.tracker is not None:
             # TODO -- get tracker context to work
             self.tracker.post_end_span(
                 action=cached_span.action_span.action,
diff --git a/tests/core/test_application.py b/tests/core/test_application.py
index 5ebd3341..8bab570a 100644
--- a/tests/core/test_application.py
+++ b/tests/core/test_application.py
@@ -486,6 +486,52 @@ def test__run_reducer_deletes_state():
     assert "count" not in state
 
 
+def test__validate_reducer_writes_with_state_keys_returning_list():
+    """Tests that _validate_reducer_writes works when state.keys() returns a 
list.
+
+    This is a regression test for a bug where state.keys() could return a list
+    instead of a set, causing a TypeError when trying to do set subtraction.
+    """
+    from burr.core.application import _validate_reducer_writes
+
+    # Create a reducer with some expected writes
+    reducer = PassedInAction(
+        reads=["input"],
+        writes=["output", "result"],
+        fn=...,
+        update_fn=lambda result, state: state.update(**result),
+        inputs=[],
+    )
+
+    # Create a state that has all the required writes
+    state = State({"input": 1, "output": 2, "result": 3})
+
+    # This should not raise a TypeError even if state.keys() returns a list
+    # (which was the original bug)
+    _validate_reducer_writes(reducer, state, "test_action")
+
+
+def test__validate_reducer_writes_raises_on_missing_keys():
+    """Tests that _validate_reducer_writes raises ValueError when required 
keys are missing."""
+    from burr.core.application import _validate_reducer_writes
+
+    # Create a reducer with some expected writes
+    reducer = PassedInAction(
+        reads=["input"],
+        writes=["output", "result", "missing_key"],
+        fn=...,
+        update_fn=lambda result, state: state.update(**result),
+        inputs=[],
+    )
+
+    # Create a state that is missing some required writes
+    state = State({"input": 1, "output": 2, "result": 3})
+
+    # This should raise a ValueError for missing "missing_key"
+    with pytest.raises(ValueError, match="missing_key"):
+        _validate_reducer_writes(reducer, state, "test_action")
+
+
 async def test__arun_function():
     """Tests that we can run an async function"""
     action = base_counter_action_async
diff --git a/tests/integrations/test_burr_opentelemetry.py 
b/tests/integrations/test_burr_opentelemetry.py
index e5568202..6fc7ee3a 100644
--- a/tests/integrations/test_burr_opentelemetry.py
+++ b/tests/integrations/test_burr_opentelemetry.py
@@ -16,12 +16,17 @@
 # under the License.
 
 import json
+from unittest.mock import Mock, patch
 
 import pydantic
 import pytest
 
 from burr.core import serde
-from burr.integrations.opentelemetry import convert_to_otel_attribute
+from burr.integrations.opentelemetry import (
+    BurrTrackingSpanProcessor,
+    convert_to_otel_attribute,
+    tracker_context,
+)
 
 
 class SampleModel(pydantic.BaseModel):
@@ -43,3 +48,137 @@ class SampleModel(pydantic.BaseModel):
 )
 def test_convert_to_otel_attribute(value, expected):
     assert convert_to_otel_attribute(value) == expected
+
+
+def test_burr_tracking_span_processor_on_start_with_none_tracker():
+    """Test that on_start handles None tracker gracefully without raising an 
error."""
+    processor = BurrTrackingSpanProcessor()
+
+    # Mock a span with a parent
+    mock_span = Mock()
+    mock_span.parent = Mock()
+    mock_span.parent.span_id = 12345
+    mock_span.name = "test_span"
+
+    # Mock the get_cached_span to return a parent span context
+    with patch("burr.integrations.opentelemetry.get_cached_span") as 
mock_get_cached:
+        mock_parent_context = Mock()
+        mock_parent_context.action_span = Mock()
+        mock_parent_context.action_span.spawn = Mock(return_value=Mock())
+        mock_parent_context.partition_key = "test_partition"
+        mock_parent_context.app_id = "test_app"
+        mock_get_cached.return_value = mock_parent_context
+
+        # Mock cache_span
+        with patch("burr.integrations.opentelemetry.cache_span"):
+            # Set tracker_context to None (simulating no tracker in context)
+            token = tracker_context.set(None)
+            try:
+                # This should not raise an error even though tracker is None
+                processor.on_start(mock_span, parent_context=None)
+            finally:
+                tracker_context.reset(token)
+
+
+def test_burr_tracking_span_processor_on_end_with_none_tracker():
+    """Test that on_end handles None tracker gracefully without raising an 
error."""
+    processor = BurrTrackingSpanProcessor()
+
+    # Mock a span
+    mock_span = Mock()
+    mock_span_context = Mock()
+    mock_span_context.span_id = 67890
+    mock_span.get_span_context = Mock(return_value=mock_span_context)
+    mock_span.attributes = {}
+
+    # Mock the get_cached_span to return a cached span
+    with patch("burr.integrations.opentelemetry.get_cached_span") as 
mock_get_cached:
+        mock_cached_span = Mock()
+        mock_cached_span.action_span = Mock()
+        mock_cached_span.action_span.action = "test_action"
+        mock_cached_span.action_span.action_sequence_id = 1
+        mock_cached_span.app_id = "test_app"
+        mock_cached_span.partition_key = "test_partition"
+        mock_get_cached.return_value = mock_cached_span
+
+        # Mock uncache_span
+        with patch("burr.integrations.opentelemetry.uncache_span"):
+            # Set tracker_context to None (simulating no tracker in context)
+            token = tracker_context.set(None)
+            try:
+                # This should not raise an error even though tracker is None
+                processor.on_end(mock_span)
+            finally:
+                tracker_context.reset(token)
+
+
+def test_burr_tracking_span_processor_on_start_with_valid_tracker():
+    """Test that on_start calls tracker methods when tracker is available."""
+    processor = BurrTrackingSpanProcessor()
+
+    # Mock a span with a parent
+    mock_span = Mock()
+    mock_span.parent = Mock()
+    mock_span.parent.span_id = 12345
+    mock_span.name = "test_span"
+
+    # Mock tracker
+    mock_tracker = Mock()
+
+    # Mock the get_cached_span to return a parent span context
+    with patch("burr.integrations.opentelemetry.get_cached_span") as 
mock_get_cached:
+        mock_parent_context = Mock()
+        mock_parent_context.action_span = Mock()
+        mock_parent_context.action_span.spawn = 
Mock(return_value=Mock(action="test_action"))
+        mock_parent_context.partition_key = "test_partition"
+        mock_parent_context.app_id = "test_app"
+        mock_get_cached.return_value = mock_parent_context
+
+        # Mock cache_span
+        with patch("burr.integrations.opentelemetry.cache_span"):
+            # Set tracker_context to a valid tracker
+            token = tracker_context.set(mock_tracker)
+            try:
+                processor.on_start(mock_span, parent_context=None)
+
+                # Verify that pre_start_span was called on the tracker
+                assert mock_tracker.pre_start_span.called
+            finally:
+                tracker_context.reset(token)
+
+
+def test_burr_tracking_span_processor_on_end_with_valid_tracker():
+    """Test that on_end calls tracker methods when tracker is available."""
+    processor = BurrTrackingSpanProcessor()
+
+    # Mock a span
+    mock_span = Mock()
+    mock_span_context = Mock()
+    mock_span_context.span_id = 67890
+    mock_span.get_span_context = Mock(return_value=mock_span_context)
+    mock_span.attributes = {}
+
+    # Mock tracker
+    mock_tracker = Mock()
+
+    # Mock the get_cached_span to return a cached span
+    with patch("burr.integrations.opentelemetry.get_cached_span") as 
mock_get_cached:
+        mock_cached_span = Mock()
+        mock_cached_span.action_span = Mock()
+        mock_cached_span.action_span.action = "test_action"
+        mock_cached_span.action_span.action_sequence_id = 1
+        mock_cached_span.app_id = "test_app"
+        mock_cached_span.partition_key = "test_partition"
+        mock_get_cached.return_value = mock_cached_span
+
+        # Mock uncache_span
+        with patch("burr.integrations.opentelemetry.uncache_span"):
+            # Set tracker_context to a valid tracker
+            token = tracker_context.set(mock_tracker)
+            try:
+                processor.on_end(mock_span)
+
+                # Verify that post_end_span was called on the tracker
+                assert mock_tracker.post_end_span.called
+            finally:
+                tracker_context.reset(token)

Reply via email to