This is an automated email from the ASF dual-hosted git repository.
skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/burr.git
The following commit(s) were added to refs/heads/main by this push:
new 8a975b02 core: add AST-based linter for undeclared state reads in
function-bas… (#656)
8a975b02 is described below
commit 8a975b02f8fd76097917de25ae5217ddee4272f5
Author: Smita D Ambiger <[email protected]>
AuthorDate: Sun Mar 1 13:05:12 2026 +0530
core: add AST-based linter for undeclared state reads in function-bas…
(#656)
* core: add AST-based linter for undeclared state reads in function-based
actions
* core: cleanup duplicate originating_fn assignment
* fix: address review feedback and add regression tests
---------
Co-authored-by: Smita Ambiger <[email protected]>
---
burr/core/action.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++-
tests/core/test_action.py | 55 ++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 112 insertions(+), 1 deletion(-)
diff --git a/burr/core/action.py b/burr/core/action.py
index 69a7c75b..b2e7c16d 100644
--- a/burr/core/action.py
+++ b/burr/core/action.py
@@ -21,6 +21,7 @@ import builtins
import copy
import inspect
import sys
+import textwrap
import types
import typing
from collections.abc import AsyncIterator
@@ -49,6 +50,56 @@ else:
from typing import Self
from burr.core.state import State
+
+
+def _validate_declared_reads(fn: Callable, declared_reads: list[str]) -> None:
+ if not declared_reads:
+ return
+
+ try:
+ source = inspect.getsource(fn)
+ except OSError:
+ return # skip if source unavailable
+
+ # detect actual state parameter name
+ sig = inspect.signature(fn)
+ state_param_name = None
+
+ for name, param in sig.parameters.items():
+ if param.annotation is State:
+ state_param_name = name
+ break
+
+ if state_param_name is None:
+ return
+
+ tree = ast.parse(textwrap.dedent(source))
+
+ declared = set(declared_reads)
+ violations = []
+
+ class Visitor(ast.NodeVisitor):
+ def visit_Subscript(self, node):
+ if (
+ isinstance(node.value, ast.Name)
+ and node.value.id == state_param_name
+ and isinstance(node.slice, ast.Constant)
+ and isinstance(node.slice.value, str)
+ ):
+ key = node.slice.value
+ if key not in declared:
+ violations.append(key)
+ self.generic_visit(node)
+
+ Visitor().visit(tree)
+
+ if violations:
+ raise ValueError(
+ f"Action reads undeclared state keys: {violations}. "
+ f"Declared reads: {declared_reads}"
+ )
+
+
from burr.core.typing import ActionSchema
# This is here to make accessing the pydantic actions easier
@@ -628,6 +679,8 @@ class FunctionBasedAction(SingleStepAction):
self._fn = fn
self._reads = reads
self._writes = writes
+ _validate_declared_reads(self._originating_fn, self._reads)
+
self._bound_params = bound_params if bound_params is not None else {}
self._inputs = (
derive_inputs_from_fn(self._bound_params, self._fn)
@@ -1106,9 +1159,12 @@ class
FunctionBasedStreamingAction(SingleStepStreamingAction):
:param writes:
"""
super(FunctionBasedStreamingAction, self).__init__()
+ self._originating_fn = originating_fn if originating_fn is not None
else fn
self._fn = fn
self._reads = reads
self._writes = writes
+ _validate_declared_reads(self._originating_fn, self._reads)
+
self._bound_params = bound_params if bound_params is not None else {}
self._inputs = (
derive_inputs_from_fn(self._bound_params, self._fn)
@@ -1118,7 +1174,7 @@ class
FunctionBasedStreamingAction(SingleStepStreamingAction):
[item for item in input_spec[1] if item not in
self._bound_params],
)
)
- self._originating_fn = originating_fn if originating_fn is not None
else fn
+
self._schema = schema
self._tags = tags if tags is not None else []
diff --git a/tests/core/test_action.py b/tests/core/test_action.py
index fd0ed36b..83fecf3b 100644
--- a/tests/core/test_action.py
+++ b/tests/core/test_action.py
@@ -823,3 +823,58 @@ def test_non_existent_bound_parameters():
required, optional = derive_inputs_from_fn(bound_params, fn)
assert required == []
assert optional == []
+
+
+def test_undeclared_state_read_raises_error():
+ with pytest.raises(ValueError):
+
+ @action(reads=["foo"], writes=[])
+ def bad_action(state: State):
+ _ = state["bar"]
+ return {}, state
+
+
+def test_declared_state_read_passes():
+ @action(reads=["foo"], writes=[])
+ def good_action(state: State):
+ _ = state["foo"]
+ return {}, state
+
+
+def test_multiple_undeclared_reads_interleaved():
+ with pytest.raises(ValueError) as exc:
+
+ @action(reads=["foo"], writes=[])
+ def bad_action(state: State):
+ _ = state["foo"]
+ _ = state["bar"]
+ _ = state["baz"]
+ return {}, state
+
+ message = str(exc.value)
+ assert "bar" in message
+ assert "baz" in message
+
+
+def test_pydantic_action_not_impacted():
+ try:
+ from pydantic import BaseModel
+ except ImportError:
+ pytest.skip("pydantic not installed")
+
+ class MyState(BaseModel):
+ foo: str
+
+ @action.pydantic(
+ reads=["foo"],
+ writes=["foo"],
+ state_input_type=MyState,
+ state_output_type=MyState,
+ )
+ def good_action(state: MyState):
+ return {"foo": state.foo}
+
+ # ensure decoration didn't raise and action is creatable
+ from burr.core.action import create_action
+
+ create_action(good_action, name="test")