This is an automated email from the ASF dual-hosted git repository.

skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/burr.git


The following commit(s) were added to refs/heads/main by this push:
     new 8a975b02 core: add AST-based linter for undeclared state reads in 
function-bas… (#656)
8a975b02 is described below

commit 8a975b02f8fd76097917de25ae5217ddee4272f5
Author: Smita D Ambiger <[email protected]>
AuthorDate: Sun Mar 1 13:05:12 2026 +0530

    core: add AST-based linter for undeclared state reads in function-bas… 
(#656)
    
    * core: add AST-based linter for undeclared state reads in function-based 
actions
    
    * core: cleanup duplicate originating_fn assignment
    
    * fix: address review feedback and add regression tests
    
    ---------
    
    Co-authored-by: Smita Ambiger <[email protected]>
---
 burr/core/action.py       | 58 ++++++++++++++++++++++++++++++++++++++++++++++-
 tests/core/test_action.py | 55 ++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 112 insertions(+), 1 deletion(-)

diff --git a/burr/core/action.py b/burr/core/action.py
index 69a7c75b..b2e7c16d 100644
--- a/burr/core/action.py
+++ b/burr/core/action.py
@@ -21,6 +21,7 @@ import builtins
 import copy
 import inspect
 import sys
+import textwrap
 import types
 import typing
 from collections.abc import AsyncIterator
@@ -49,6 +50,56 @@ else:
     from typing import Self
 
 from burr.core.state import State
+
+
+def _validate_declared_reads(fn: Callable, declared_reads: list[str]) -> None:
+    if not declared_reads:
+        return
+
+    try:
+        source = inspect.getsource(fn)
+    except OSError:
+        return  # skip if source unavailable
+
+    # detect actual state parameter name
+    sig = inspect.signature(fn)
+    state_param_name = None
+
+    for name, param in sig.parameters.items():
+        if param.annotation is State:
+            state_param_name = name
+            break
+
+    if state_param_name is None:
+        return
+
+    tree = ast.parse(textwrap.dedent(source))
+
+    declared = set(declared_reads)
+    violations = []
+
+    class Visitor(ast.NodeVisitor):
+        def visit_Subscript(self, node):
+            if (
+                isinstance(node.value, ast.Name)
+                and node.value.id == state_param_name
+                and isinstance(node.slice, ast.Constant)
+                and isinstance(node.slice.value, str)
+            ):
+                key = node.slice.value
+                if key not in declared:
+                    violations.append(key)
+            self.generic_visit(node)
+
+    Visitor().visit(tree)
+
+    if violations:
+        raise ValueError(
+            f"Action reads undeclared state keys: {violations}. "
+            f"Declared reads: {declared_reads}"
+        )
+
+
 from burr.core.typing import ActionSchema
 
 # This is here to make accessing the pydantic actions easier
@@ -628,6 +679,8 @@ class FunctionBasedAction(SingleStepAction):
         self._fn = fn
         self._reads = reads
         self._writes = writes
+        _validate_declared_reads(self._originating_fn, self._reads)
+
         self._bound_params = bound_params if bound_params is not None else {}
         self._inputs = (
             derive_inputs_from_fn(self._bound_params, self._fn)
@@ -1106,9 +1159,12 @@ class 
FunctionBasedStreamingAction(SingleStepStreamingAction):
         :param writes:
         """
         super(FunctionBasedStreamingAction, self).__init__()
+        self._originating_fn = originating_fn if originating_fn is not None 
else fn
         self._fn = fn
         self._reads = reads
         self._writes = writes
+        _validate_declared_reads(self._originating_fn, self._reads)
+
         self._bound_params = bound_params if bound_params is not None else {}
         self._inputs = (
             derive_inputs_from_fn(self._bound_params, self._fn)
@@ -1118,7 +1174,7 @@ class 
FunctionBasedStreamingAction(SingleStepStreamingAction):
                 [item for item in input_spec[1] if item not in 
self._bound_params],
             )
         )
-        self._originating_fn = originating_fn if originating_fn is not None 
else fn
+
         self._schema = schema
         self._tags = tags if tags is not None else []
 
diff --git a/tests/core/test_action.py b/tests/core/test_action.py
index fd0ed36b..83fecf3b 100644
--- a/tests/core/test_action.py
+++ b/tests/core/test_action.py
@@ -823,3 +823,58 @@ def test_non_existent_bound_parameters():
     required, optional = derive_inputs_from_fn(bound_params, fn)
     assert required == []
     assert optional == []
+
+
+def test_undeclared_state_read_raises_error():
+    with pytest.raises(ValueError):
+
+        @action(reads=["foo"], writes=[])
+        def bad_action(state: State):
+            _ = state["bar"]
+            return {}, state
+
+
+def test_declared_state_read_passes():
+    @action(reads=["foo"], writes=[])
+    def good_action(state: State):
+        _ = state["foo"]
+        return {}, state
+
+
+def test_multiple_undeclared_reads_interleaved():
+    with pytest.raises(ValueError) as exc:
+
+        @action(reads=["foo"], writes=[])
+        def bad_action(state: State):
+            _ = state["foo"]
+            _ = state["bar"]
+            _ = state["baz"]
+            return {}, state
+
+    message = str(exc.value)
+    assert "bar" in message
+    assert "baz" in message
+
+
+def test_pydantic_action_not_impacted():
+    try:
+        from pydantic import BaseModel
+    except ImportError:
+        pytest.skip("pydantic not installed")
+
+    class MyState(BaseModel):
+        foo: str
+
+    @action.pydantic(
+        reads=["foo"],
+        writes=["foo"],
+        state_input_type=MyState,
+        state_output_type=MyState,
+    )
+    def good_action(state: MyState):
+        return {"foo": state.foo}
+
+    # ensure decoration didn't raise and action is creatable
+    from burr.core.action import create_action
+
+    create_action(good_action, name="test")

Reply via email to