This is an automated email from the ASF dual-hosted git repository.
skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/burr.git
The following commit(s) were added to refs/heads/main by this push:
new f884adff core: add flexible_api decorator to fix mypy override errors
(#683)
f884adff is described below
commit f884adff8d5b0afd0e813f2c46b6196c060a9649
Author: André Ahlert <[email protected]>
AuthorDate: Sat Mar 28 16:16:44 2026 -0300
core: add flexible_api decorator to fix mypy override errors (#683)
* core: add flexible_api decorator to fix mypy override errors on
class-based actions
Adds a `flexible_api` decorator that users can apply to `run`,
`stream_run`, or `run_and_update` overrides that use explicit
parameters instead of `**run_kwargs`. This prevents mypy [override]
errors caused by narrowing the base-class signature.
Closes #457
* style: apply black formatting
* Rename flexible_api to type_eraser and fix async support
Address review feedback: rename decorator per maintainer suggestion.
Fix critical bug where @wraps on async/generator functions broke
is_async() detection. Add tests for all function types.
---
burr/core/__init__.py | 3 +-
burr/core/action.py | 67 ++++++++++++++++++++++
tests/core/test_action.py | 141 ++++++++++++++++++++++++++++++++++++++++++++++
3 files changed, 210 insertions(+), 1 deletion(-)
diff --git a/burr/core/__init__.py b/burr/core/__init__.py
index c4da5a48..aa2f75a4 100644
--- a/burr/core/__init__.py
+++ b/burr/core/__init__.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from burr.core.action import Action, Condition, Result, action, default, expr,
when
+from burr.core.action import Action, Condition, Result, action, default, expr,
type_eraser, when
from burr.core.application import (
Application,
ApplicationBuilder,
@@ -35,6 +35,7 @@ __all__ = [
"Condition",
"default",
"expr",
+ "type_eraser",
"Result",
"State",
"when",
diff --git a/burr/core/action.py b/burr/core/action.py
index df15b7f0..eb6fceff 100644
--- a/burr/core/action.py
+++ b/burr/core/action.py
@@ -100,8 +100,75 @@ def _validate_declared_reads(fn: Callable, declared_reads:
list[str]) -> None:
)
+from functools import wraps
+
from burr.core.typing import ActionSchema
+
+def type_eraser(func: Callable[..., Any]) -> Callable[..., Any]:
+ """Decorator for ``run``, ``stream_run``, and ``run_and_update`` overrides
+ that declare explicit parameters instead of ``**run_kwargs``.
+
+ Applying this decorator prevents mypy ``[override]`` errors caused by
+ narrowing the base-class signature (which uses ``**run_kwargs``).
+
+ Example usage::
+
+ from burr.core import Action, State, type_eraser
+
+ class Counter(Action):
+ @property
+ def reads(self) -> list[str]:
+ return ["counter"]
+
+ @type_eraser
+ def run(self, state: State, increment_by: int) -> dict:
+ return {"counter": state["counter"] + increment_by}
+
+ @property
+ def writes(self) -> list[str]:
+ return ["counter"]
+
+ def update(self, result: dict, state: State) -> State:
+ return state.update(**result)
+
+ @property
+ def inputs(self) -> list[str]:
+ return ["increment_by"]
+ """
+
+ if inspect.iscoroutinefunction(func):
+
+ @wraps(func)
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
+ return await func(*args, **kwargs)
+
+ return async_wrapper
+
+ if inspect.isasyncgenfunction(func):
+
+ @wraps(func)
+ async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any:
+ async for item in func(*args, **kwargs):
+ yield item
+
+ return async_gen_wrapper
+
+ if inspect.isgeneratorfunction(func):
+
+ @wraps(func)
+ def gen_wrapper(*args: Any, **kwargs: Any) -> Any:
+ yield from func(*args, **kwargs)
+
+ return gen_wrapper
+
+ @wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
# This is here to make accessing the pydantic actions easier
# we just attach them to action so you can call `@action.pyddantic...`
# The IDE will like it better and thus be able to auto-complete/type-check
diff --git a/tests/core/test_action.py b/tests/core/test_action.py
index f65a0683..c58a2989 100644
--- a/tests/core/test_action.py
+++ b/tests/core/test_action.py
@@ -38,6 +38,7 @@ from burr.core.action import (
default,
derive_inputs_from_fn,
streaming_action,
+ type_eraser,
)
@@ -1043,3 +1044,143 @@ def test_pydantic_action_not_impacted():
from burr.core.action import create_action
create_action(good_action, name="test")
+
+
+class TestTypeEraser:
+ def test_sync_run(self):
+ class MyAction(Action):
+ @property
+ def reads(self) -> list[str]:
+ return ["counter"]
+
+ @type_eraser
+ def run(self, state: State, increment_by: int) -> dict:
+ return {"counter": state["counter"] + increment_by}
+
+ @property
+ def writes(self) -> list[str]:
+ return ["counter"]
+
+ def update(self, result: dict, state: State) -> State:
+ return state.update(**result)
+
+ @property
+ def inputs(self) -> list[str]:
+ return ["increment_by"]
+
+ a = MyAction()
+ result = a.run(State({"counter": 0}), increment_by=5)
+ assert result == {"counter": 5}
+ assert a.is_async() is False
+
+ def test_async_run(self):
+ class MyAsyncAction(Action):
+ @property
+ def reads(self) -> list[str]:
+ return ["counter"]
+
+ @type_eraser
+ async def run(self, state: State, increment_by: int) -> dict:
+ return {"counter": state["counter"] + increment_by}
+
+ @property
+ def writes(self) -> list[str]:
+ return ["counter"]
+
+ def update(self, result: dict, state: State) -> State:
+ return state.update(**result)
+
+ @property
+ def inputs(self) -> list[str]:
+ return ["increment_by"]
+
+ a = MyAsyncAction()
+ assert a.is_async() is True
+ result = asyncio.run(a.run(State({"counter": 0}), increment_by=5))
+ assert result == {"counter": 5}
+
+ def test_sync_stream_run(self):
+ class MyStreamingAction(StreamingAction):
+ @property
+ def reads(self) -> list[str]:
+ return ["items"]
+
+ @type_eraser
+ def stream_run(self, state: State, prefix: str) -> Generator[dict,
None, None]:
+ for item in state["items"]:
+ yield {"val": f"{prefix}_{item}"}
+
+ @property
+ def writes(self) -> list[str]:
+ return ["result"]
+
+ def update(self, result: dict, state: State) -> State:
+ return state.update(**result)
+
+ @property
+ def inputs(self) -> list[str]:
+ return ["prefix"]
+
+ a = MyStreamingAction()
+ results = list(a.stream_run(State({"items": ["a", "b"]}), prefix="x"))
+ assert results == [{"val": "x_a"}, {"val": "x_b"}]
+
+ def test_async_stream_run(self):
+ class MyAsyncStreamingAction(AsyncStreamingAction):
+ @property
+ def reads(self) -> list[str]:
+ return ["items"]
+
+ @type_eraser
+ async def stream_run(self, state: State, prefix: str) ->
AsyncGenerator[dict, None]:
+ for item in state["items"]:
+ yield {"val": f"{prefix}_{item}"}
+
+ @property
+ def writes(self) -> list[str]:
+ return ["result"]
+
+ def update(self, result: dict, state: State) -> State:
+ return state.update(**result)
+
+ @property
+ def inputs(self) -> list[str]:
+ return ["prefix"]
+
+ a = MyAsyncStreamingAction()
+ assert a.is_async() is True
+
+ async def collect():
+ return [item async for item in a.stream_run(State({"items": ["a",
"b"]}), prefix="x")]
+
+ results = asyncio.run(collect())
+ assert results == [{"val": "x_a"}, {"val": "x_b"}]
+
+ def test_preserves_wrapped_name(self):
+ class MyAction(Action):
+ @property
+ def reads(self) -> list[str]:
+ return []
+
+ @type_eraser
+ def run(self, state: State, custom_param: str) -> dict:
+ return {}
+
+ @property
+ def writes(self) -> list[str]:
+ return []
+
+ def update(self, result: dict, state: State) -> State:
+ return state
+
+ @property
+ def inputs(self) -> list[str]:
+ return []
+
+ assert MyAction().run.__name__ == "run"
+ assert MyAction().run.__wrapped__.__name__ == "run"
+
+ def test_exported_from_burr_core(self):
+ from burr.core import type_eraser as te
+
+ assert te is type_eraser