This is an automated email from the ASF dual-hosted git repository.

skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/burr.git


The following commit(s) were added to refs/heads/main by this push:
     new f884adff core: add flexible_api decorator to fix mypy override errors 
(#683)
f884adff is described below

commit f884adff8d5b0afd0e813f2c46b6196c060a9649
Author: AndrĂ© Ahlert <[email protected]>
AuthorDate: Sat Mar 28 16:16:44 2026 -0300

    core: add flexible_api decorator to fix mypy override errors (#683)
    
    * core: add flexible_api decorator to fix mypy override errors on 
class-based actions
    
    Adds a `flexible_api` decorator that users can apply to `run`,
    `stream_run`, or `run_and_update` overrides that use explicit
    parameters instead of `**run_kwargs`. This prevents mypy [override]
    errors caused by narrowing the base-class signature.
    
    Closes #457
    
    * style: apply black formatting
    
    * Rename flexible_api to type_eraser and fix async support
    
    Address review feedback: rename decorator per maintainer suggestion.
    Fix critical bug where @wraps on async/generator functions broke
    is_async() detection. Add tests for all function types.
---
 burr/core/__init__.py     |   3 +-
 burr/core/action.py       |  67 ++++++++++++++++++++++
 tests/core/test_action.py | 141 ++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 210 insertions(+), 1 deletion(-)

diff --git a/burr/core/__init__.py b/burr/core/__init__.py
index c4da5a48..aa2f75a4 100644
--- a/burr/core/__init__.py
+++ b/burr/core/__init__.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from burr.core.action import Action, Condition, Result, action, default, expr, 
when
+from burr.core.action import Action, Condition, Result, action, default, expr, 
type_eraser, when
 from burr.core.application import (
     Application,
     ApplicationBuilder,
@@ -35,6 +35,7 @@ __all__ = [
     "Condition",
     "default",
     "expr",
+    "type_eraser",
     "Result",
     "State",
     "when",
diff --git a/burr/core/action.py b/burr/core/action.py
index df15b7f0..eb6fceff 100644
--- a/burr/core/action.py
+++ b/burr/core/action.py
@@ -100,8 +100,75 @@ def _validate_declared_reads(fn: Callable, declared_reads: 
list[str]) -> None:
         )
 
 
+from functools import wraps
+
 from burr.core.typing import ActionSchema
 
+
+def type_eraser(func: Callable[..., Any]) -> Callable[..., Any]:
+    """Decorator for ``run``, ``stream_run``, and ``run_and_update`` overrides
+    that declare explicit parameters instead of ``**run_kwargs``.
+
+    Applying this decorator prevents mypy ``[override]`` errors caused by
+    narrowing the base-class signature (which uses ``**run_kwargs``).
+
+    Example usage::
+
+        from burr.core import Action, State, type_eraser
+
+        class Counter(Action):
+            @property
+            def reads(self) -> list[str]:
+                return ["counter"]
+
+            @type_eraser
+            def run(self, state: State, increment_by: int) -> dict:
+                return {"counter": state["counter"] + increment_by}
+
+            @property
+            def writes(self) -> list[str]:
+                return ["counter"]
+
+            def update(self, result: dict, state: State) -> State:
+                return state.update(**result)
+
+            @property
+            def inputs(self) -> list[str]:
+                return ["increment_by"]
+    """
+
+    if inspect.iscoroutinefunction(func):
+
+        @wraps(func)
+        async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
+            return await func(*args, **kwargs)
+
+        return async_wrapper
+
+    if inspect.isasyncgenfunction(func):
+
+        @wraps(func)
+        async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any:
+            async for item in func(*args, **kwargs):
+                yield item
+
+        return async_gen_wrapper
+
+    if inspect.isgeneratorfunction(func):
+
+        @wraps(func)
+        def gen_wrapper(*args: Any, **kwargs: Any) -> Any:
+            yield from func(*args, **kwargs)
+
+        return gen_wrapper
+
+    @wraps(func)
+    def wrapper(*args: Any, **kwargs: Any) -> Any:
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
 # This is here to make accessing the pydantic actions easier
 # we just attach them to action so you can call `@action.pyddantic...`
 # The IDE will like it better and thus be able to auto-complete/type-check
diff --git a/tests/core/test_action.py b/tests/core/test_action.py
index f65a0683..c58a2989 100644
--- a/tests/core/test_action.py
+++ b/tests/core/test_action.py
@@ -38,6 +38,7 @@ from burr.core.action import (
     default,
     derive_inputs_from_fn,
     streaming_action,
+    type_eraser,
 )
 
 
@@ -1043,3 +1044,143 @@ def test_pydantic_action_not_impacted():
     from burr.core.action import create_action
 
     create_action(good_action, name="test")
+
+
+class TestTypeEraser:
+    def test_sync_run(self):
+        class MyAction(Action):
+            @property
+            def reads(self) -> list[str]:
+                return ["counter"]
+
+            @type_eraser
+            def run(self, state: State, increment_by: int) -> dict:
+                return {"counter": state["counter"] + increment_by}
+
+            @property
+            def writes(self) -> list[str]:
+                return ["counter"]
+
+            def update(self, result: dict, state: State) -> State:
+                return state.update(**result)
+
+            @property
+            def inputs(self) -> list[str]:
+                return ["increment_by"]
+
+        a = MyAction()
+        result = a.run(State({"counter": 0}), increment_by=5)
+        assert result == {"counter": 5}
+        assert a.is_async() is False
+
+    def test_async_run(self):
+        class MyAsyncAction(Action):
+            @property
+            def reads(self) -> list[str]:
+                return ["counter"]
+
+            @type_eraser
+            async def run(self, state: State, increment_by: int) -> dict:
+                return {"counter": state["counter"] + increment_by}
+
+            @property
+            def writes(self) -> list[str]:
+                return ["counter"]
+
+            def update(self, result: dict, state: State) -> State:
+                return state.update(**result)
+
+            @property
+            def inputs(self) -> list[str]:
+                return ["increment_by"]
+
+        a = MyAsyncAction()
+        assert a.is_async() is True
+        result = asyncio.run(a.run(State({"counter": 0}), increment_by=5))
+        assert result == {"counter": 5}
+
+    def test_sync_stream_run(self):
+        class MyStreamingAction(StreamingAction):
+            @property
+            def reads(self) -> list[str]:
+                return ["items"]
+
+            @type_eraser
+            def stream_run(self, state: State, prefix: str) -> Generator[dict, 
None, None]:
+                for item in state["items"]:
+                    yield {"val": f"{prefix}_{item}"}
+
+            @property
+            def writes(self) -> list[str]:
+                return ["result"]
+
+            def update(self, result: dict, state: State) -> State:
+                return state.update(**result)
+
+            @property
+            def inputs(self) -> list[str]:
+                return ["prefix"]
+
+        a = MyStreamingAction()
+        results = list(a.stream_run(State({"items": ["a", "b"]}), prefix="x"))
+        assert results == [{"val": "x_a"}, {"val": "x_b"}]
+
+    def test_async_stream_run(self):
+        class MyAsyncStreamingAction(AsyncStreamingAction):
+            @property
+            def reads(self) -> list[str]:
+                return ["items"]
+
+            @type_eraser
+            async def stream_run(self, state: State, prefix: str) -> 
AsyncGenerator[dict, None]:
+                for item in state["items"]:
+                    yield {"val": f"{prefix}_{item}"}
+
+            @property
+            def writes(self) -> list[str]:
+                return ["result"]
+
+            def update(self, result: dict, state: State) -> State:
+                return state.update(**result)
+
+            @property
+            def inputs(self) -> list[str]:
+                return ["prefix"]
+
+        a = MyAsyncStreamingAction()
+        assert a.is_async() is True
+
+        async def collect():
+            return [item async for item in a.stream_run(State({"items": ["a", 
"b"]}), prefix="x")]
+
+        results = asyncio.run(collect())
+        assert results == [{"val": "x_a"}, {"val": "x_b"}]
+
+    def test_preserves_wrapped_name(self):
+        class MyAction(Action):
+            @property
+            def reads(self) -> list[str]:
+                return []
+
+            @type_eraser
+            def run(self, state: State, custom_param: str) -> dict:
+                return {}
+
+            @property
+            def writes(self) -> list[str]:
+                return []
+
+            def update(self, result: dict, state: State) -> State:
+                return state
+
+            @property
+            def inputs(self) -> list[str]:
+                return []
+
+        assert MyAction().run.__name__ == "run"
+        assert MyAction().run.__wrapped__.__name__ == "run"
+
+    def test_exported_from_burr_core(self):
+        from burr.core import type_eraser as te
+
+        assert te is type_eraser

Reply via email to