This is an automated email from the ASF dual-hosted git repository.

arm pushed a commit to branch arm
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git


The following commit(s) were added to refs/heads/arm by this push:
     new 06076250 Some additional tidy up of typed decorators, and some 
additional tests for them
06076250 is described below

commit 060762506c7647aacb24075963d9644226f7c06c
Author: Alastair McFarlane <[email protected]>
AuthorDate: Thu Mar 19 15:29:26 2026 +0000

    Some additional tidy up of typed decorators, and some additional tests for 
them
---
 atr/blueprints/admin.py       | 100 +-----------------
 atr/blueprints/api.py         |   4 +-
 atr/blueprints/common.py      |  24 +++++
 atr/blueprints/get.py         |   5 +-
 atr/blueprints/post.py        |  16 +--
 tests/unit/test_blueprints.py | 240 ++++++++++++++++++++++++++++++++++++++++++
 6 files changed, 271 insertions(+), 118 deletions(-)

diff --git a/atr/blueprints/admin.py b/atr/blueprints/admin.py
index 5ef56003..ae49633e 100644
--- a/atr/blueprints/admin.py
+++ b/atr/blueprints/admin.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import json
 import time
 from collections.abc import Awaitable, Callable
 from types import ModuleType
@@ -27,7 +26,6 @@ import quart
 import quart_schema
 
 import atr.blueprints.common as common
-import atr.form
 import atr.log as log
 import atr.user as user
 import atr.web as web
@@ -37,93 +35,12 @@ _BLUEPRINT = quart.Blueprint(_BLUEPRINT_NAME, __name__, 
url_prefix="/admin", tem
 _routes: list[str] = []
 
 
-def empty() -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., 
Awaitable[Any]]]:
-    def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., 
Awaitable[Any]]:
-        async def wrapper(session: web.Committer, *args: Any, **kwargs: Any) 
-> Any:
-            form_data = await atr.form.quart_request()
-            try:
-                context = {
-                    "args": args,
-                    "kwargs": kwargs,
-                    "session": session,
-                }
-                atr.form.validate(atr.form.Empty, form_data, context=context)
-                return await func(session, *args, **kwargs)
-            except pydantic.ValidationError:
-                msg = "Sorry, there was an empty form validation error. Please 
try again."
-                await quart.flash(msg, "error")
-                return quart.redirect(quart.request.path)
-
-        wrapper.__annotations__ = func.__annotations__.copy()
-        wrapper.__doc__ = func.__doc__
-        wrapper.__module__ = func.__module__
-        wrapper.__name__ = func.__name__
-        return wrapper
-
-    return decorator
-
-
-def form(
-    form_cls: Any,
-) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
-    def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., 
Awaitable[Any]]:
-        async def wrapper(session: web.Committer, *args: Any, **kwargs: Any) 
-> Any:
-            form_data = await atr.form.quart_request()
-            try:
-                context = {
-                    "args": args,
-                    "kwargs": kwargs,
-                    "session": session,
-                }
-                validated_form = atr.form.validate(form_cls, form_data, 
context=context)
-                return await func(session, validated_form, *args, **kwargs)
-            except pydantic.ValidationError as e:
-                errors = e.errors()
-                if len(errors) == 0:
-                    raise RuntimeError("Validation failed, but no errors were 
reported")
-                flash_data = atr.form.flash_error_data(form_cls, errors, 
form_data)
-                summary = atr.form.flash_error_summary(errors, flash_data)
-
-                await quart.flash(summary, category="error")
-                await quart.flash(json.dumps(flash_data), 
category="form-error-data")
-                return quart.redirect(quart.request.path)
-
-        wrapper.__annotations__ = func.__annotations__.copy()
-        wrapper.__doc__ = func.__doc__
-        wrapper.__module__ = func.__module__
-        wrapper.__name__ = func.__name__
-        return wrapper
-
-    return decorator
-
-
-def get(path: str) -> Callable[[web.CommitterRouteFunction[Any]], 
web.RouteFunction[Any]]:
-    def decorator(func: web.CommitterRouteFunction[Any]) -> 
web.RouteFunction[Any]:
-        async def wrapper(*args: Any, **kwargs: Any) -> Any:
-            return await func(quart.g.session, *args, **kwargs)
-
-        endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
-        wrapper.__name__ = func.__name__
-        wrapper.__doc__ = func.__doc__
-        wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint
-
-        _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper, 
methods=["GET"])
-
-        return wrapper
-
-    return decorator
-
-
 def post(path: str) -> Callable[[web.CommitterRouteFunction[Any]], 
web.RouteFunction[Any]]:
     def decorator(func: web.CommitterRouteFunction[Any]) -> 
web.RouteFunction[Any]:
         async def wrapper(*args: Any, **kwargs: Any) -> Any:
             return await func(quart.g.session, *args, **kwargs)
 
-        endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
-        wrapper.__name__ = func.__name__
-        wrapper.__doc__ = func.__doc__
-        wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint
-
+        endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME)
         _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper, 
methods=["POST"])
 
         return wrapper
@@ -179,15 +96,7 @@ def typed(func: Callable[..., Any]) -> Callable[..., Any]:
             try:
                 kwargs[form_param_name] = await 
enhanced_session.form_validate(form_cls, context)
             except pydantic.ValidationError as e:
-                errors = e.errors()
-                if len(errors) == 0:
-                    raise RuntimeError("Validation failed, but no errors were 
reported")
-                form_data_raw = await atr.form.quart_request()
-                flash_data = atr.form.flash_error_data(form_cls, errors, 
form_data_raw)
-                summary = atr.form.flash_error_summary(errors, flash_data)
-                await quart.flash(summary, category="error")
-                await quart.flash(json.dumps(flash_data), 
category="form-error-data")
-                return quart.redirect(quart.request.path)
+                return await common.flash_form_error(form_cls, e)
             if form_safe_params:
                 await common.validate_safe_fields(kwargs[form_param_name], 
form_safe_params, kwargs)
 
@@ -202,10 +111,7 @@ def typed(func: Callable[..., Any]) -> Callable[..., Any]:
 
         return response
 
-    endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
-    wrapper.__name__ = func.__name__
-    wrapper.__doc__ = func.__doc__
-    wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint
+    endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME)
 
     # Replace the original quart request decorators
     if query_param is not None:
diff --git a/atr/blueprints/api.py b/atr/blueprints/api.py
index e335072b..6b5a3e28 100644
--- a/atr/blueprints/api.py
+++ b/atr/blueprints/api.py
@@ -84,9 +84,7 @@ def typed(func: Callable[..., Awaitable[Any]]) -> 
web.RouteFunction[Any]:
 
         return response
 
-    endpoint = original.__module__.replace(".", "_") + "_" + original.__name__
-    wrapper.__name__ = original.__name__
-    wrapper.__doc__ = original.__doc__
+    endpoint = common.setup_wrapper(wrapper, original, _BLUEPRINT_NAME)
 
     # Replace the original quart request decorators
     if query_param is not None:
diff --git a/atr/blueprints/common.py b/atr/blueprints/common.py
index 8457a9ab..86e216fe 100644
--- a/atr/blueprints/common.py
+++ b/atr/blueprints/common.py
@@ -174,6 +174,15 @@ def build_api_path(
     return path, validated_params, literal_params, unique.body, unique.form, 
unique.query, optional_params
 
 
+def setup_wrapper(wrapper: Callable[..., Any], func: Callable[..., Any], 
blueprint_name: str) -> str:
+    """Set standard metadata on a route wrapper and return the endpoint 
name."""
+    endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
+    wrapper.__name__ = func.__name__
+    wrapper.__doc__ = func.__doc__
+    wrapper.__annotations__["endpoint"] = blueprint_name + "." + endpoint
+    return endpoint
+
+
 def register_route(func: Callable[..., Any], prefix: str, routes: list[str]) 
-> None:
     module_name = func.__module__.split(".")[-1]
     routes.append(f"{prefix}.{module_name}.{func.__name__}")
@@ -232,6 +241,21 @@ async def validate_safe_fields(
             setattr(instance, name, temp[name])
 
 
+async def flash_form_error(form_cls: type, error: pydantic.ValidationError) -> 
Any:
+    """Flash form validation errors and return a redirect to the current 
page."""
+    import json
+
+    errors = error.errors()
+    if len(errors) == 0:
+        raise RuntimeError("Validation failed, but no errors were reported")
+    form_data_raw = await form.quart_request()
+    flash_data = form.flash_error_data(form_cls, errors, form_data_raw)
+    summary = form.flash_error_summary(errors, flash_data)
+    await quart.flash(summary, category="error")
+    await quart.flash(json.dumps(flash_data), category="form-error-data")
+    return quart.redirect(quart.request.path)
+
+
 async def parse_body(
     body_param: tuple[str, type[pydantic.BaseModel]],
     safe_params: list[tuple[str, type]],
diff --git a/atr/blueprints/get.py b/atr/blueprints/get.py
index 9c8033db..1b5d9873 100644
--- a/atr/blueprints/get.py
+++ b/atr/blueprints/get.py
@@ -80,10 +80,7 @@ def typed(func: Callable[..., Any]) -> 
web.RouteFunction[Any]:
 
         return response
 
-    endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
-    wrapper.__name__ = func.__name__
-    wrapper.__doc__ = func.__doc__
-    wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint
+    endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME)
 
     decorated = wrapper if public else 
auth.require(auth.Requirements.committer)(wrapper)
     _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, 
methods=["GET"])
diff --git a/atr/blueprints/post.py b/atr/blueprints/post.py
index 488e2ed3..638ef385 100644
--- a/atr/blueprints/post.py
+++ b/atr/blueprints/post.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import json
 import time
 from collections.abc import Awaitable, Callable
 from types import ModuleType
@@ -87,15 +86,7 @@ def typed(func: Callable[..., Any]) -> 
web.RouteFunction[Any]:
                         form_data = await atr.form.quart_request()
                         kwargs[form_param_name] = atr.form.validate(form_cls, 
form_data, context=context)
             except pydantic.ValidationError as e:
-                errors = e.errors()
-                if len(errors) == 0:
-                    raise RuntimeError("Validation failed, but no errors were 
reported")
-                form_data_raw = await atr.form.quart_request()
-                flash_data = atr.form.flash_error_data(form_cls, errors, 
form_data_raw)
-                summary = atr.form.flash_error_summary(errors, flash_data)
-                await quart.flash(summary, category="error")
-                await quart.flash(json.dumps(flash_data), 
category="form-error-data")
-                return quart.redirect(quart.request.path)
+                return await common.flash_form_error(form_cls, e)
             if form_safe_params:
                 await common.validate_safe_fields(kwargs[form_param_name], 
form_safe_params, kwargs)
 
@@ -109,10 +100,7 @@ def typed(func: Callable[..., Any]) -> 
web.RouteFunction[Any]:
 
         return response
 
-    endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
-    wrapper.__name__ = func.__name__
-    wrapper.__doc__ = func.__doc__
-    wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint
+    endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME)
 
     decorated = wrapper if public else 
auth.require(auth.Requirements.committer)(wrapper)
     _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, 
methods=["POST"])
diff --git a/tests/unit/test_blueprints.py b/tests/unit/test_blueprints.py
index af268bfe..59edb7c0 100644
--- a/tests/unit/test_blueprints.py
+++ b/tests/unit/test_blueprints.py
@@ -15,11 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import dataclasses
+from typing import Literal
+
 import asfquart
 import pytest
 
 import atr.blueprints as blueprints
+import atr.blueprints.common as common
+import atr.form as form
+import atr.models.safe as safe
+import atr.models.schema as schema
 import atr.util as util
+import atr.web as web
 
 _TESTED_BLUEPRINTS = frozenset({"get_blueprint", "post_blueprint", 
"admin_blueprint"})
 
@@ -60,3 +68,235 @@ async def 
test_all_routes_support_url_construction(monkeypatch):
 
     if failures:
         raise AssertionError("Routes incompatible with as_url:\n" + 
"\n".join(failures))
+
+
+def test_build_path_literal_segment():
+    async def route(_session: web.Committer, _page: Literal["dashboard"]) -> 
str:
+        return ""
+
+    path, validated, literals, form_param, public = common.build_path(route)
+    assert path == "/dashboard"
+    assert literals == {"_page": "dashboard"}
+    assert validated == []
+    assert form_param is None
+    assert public is False
+
+
+def test_build_path_safe_type():
+    async def route(_session: web.Committer, _project_key: safe.ProjectKey) -> 
str:
+        return ""
+
+    path, validated, _, _, _ = common.build_path(route)
+    assert path == "/<_project_key>"
+    assert validated == [("_project_key", safe.ProjectKey)]
+
+
+def test_build_path_combined_literal_and_safe():
+    async def route(
+        _session: web.Committer,
+        _page: Literal["project"],
+        _project_key: safe.ProjectKey,
+        _sub: Literal["version"],
+        _version_key: safe.VersionKey,
+    ) -> str:
+        return ""
+
+    path, validated, literals, _, _ = common.build_path(route)
+    assert path == "/project/<_project_key>/version/<_version_key>"
+    assert validated == [("_project_key", safe.ProjectKey), ("_version_key", 
safe.VersionKey)]
+    assert literals == {"_page": "project", "_sub": "version"}
+
+
+def test_build_path_public_session():
+    async def route(_session: web.Public, _page: Literal["home"]) -> str:
+        return ""
+
+    _, _, _, _, public = common.build_path(route)
+    assert public is True
+
+
+def test_build_path_int_converter():
+    async def route(_session: web.Committer, _page: Literal["item"], _item_id: 
int) -> str:
+        return ""
+
+    path, _, _, _, _ = common.build_path(route)
+    assert path == "/item/<int:_item_id>"
+
+
+def test_build_path_form_param():
+    class TestForm(form.Form):
+        name: str = ""
+
+    async def route(_session: web.Committer, _page: Literal["submit"], _data: 
TestForm) -> str:
+        return ""
+
+    path, _, _, form_param, _ = common.build_path(route)
+    assert path == "/submit"
+    assert form_param is not None
+    assert form_param[0] == "_data"
+    assert form_param[1] is TestForm
+
+
+def test_build_path_rejects_duplicate_form():
+    class FormA(form.Form):
+        name: str = ""
+
+    class FormB(form.Form):
+        name: str = ""
+
+    async def route(_session: web.Committer, _a: FormA, _b: FormB) -> str:
+        return ""
+
+    with pytest.raises(TypeError, match="only one Form is allowed"):
+        common.build_path(route)
+
+
+def test_build_path_rejects_session_not_first():
+    async def route(_page: Literal["home"], _session: web.Committer) -> str:
+        return ""
+
+    with pytest.raises(TypeError, match="must be first"):
+        common.build_path(route)
+
+
+def test_build_path_rejects_unannotated_param():
+    async def route(_session: web.Committer, _thing) -> str:  # type: 
ignore[reportUnknownParameterType]
+        return ""
+
+    with pytest.raises(TypeError, match="no type annotation"):
+        common.build_path(route)
+
+
+def test_build_path_rejects_bare_str():
+    async def route(_session: web.Committer, _name: str) -> str:
+        return ""
+
+    with pytest.raises(TypeError, match="unguarded str"):
+        common.build_path(route)
+
+
+def test_build_api_path_literal_and_safe():
+    async def route(
+        _session: web.Committer,
+        _page: Literal["project"],
+        _project_key: safe.ProjectKey,
+    ) -> str:
+        return ""
+
+    path, validated, literals, body, form_param, query, optional = 
common.build_api_path(route)
+    assert path == "/project/<_project_key>"
+    assert validated == [("_project_key", safe.ProjectKey)]
+    assert literals == {"_page": "project"}
+    assert body is None
+    assert form_param is None
+    assert query is None
+    assert optional == []
+
+
+def test_build_api_path_body_param():
+    class RequestBody(schema.Strict):
+        value: int
+
+    async def route(_session: web.Committer, _page: Literal["submit"], _data: 
RequestBody) -> str:
+        return ""
+
+    _, _, _, body, _, _, _ = common.build_api_path(route)
+    assert body is not None
+    assert body[0] == "_data"
+    assert body[1] is RequestBody
+
+
+def test_build_api_path_query_param():
+    @dataclasses.dataclass
+    class Filters:
+        page: int = 1
+        search: str = ""
+
+    async def route(_session: web.Committer, _page: Literal["list"], _filters: 
Filters) -> str:
+        return ""
+
+    _, _, _, _, _, query, _ = common.build_api_path(route)
+    assert query is not None
+    assert query[0] == "_filters"
+    assert query[1] is Filters
+
+
+def test_build_api_path_optional_param():
+    async def route(
+        _session: web.Committer,
+        _page: Literal["items"],
+        _category: str | None = None,
+    ) -> str:
+        return ""
+
+    path, _, _, _, _, _, optional = common.build_api_path(route)
+    assert path == "/items/<_category>"
+    assert optional == ["_category"]
+
+
+def test_build_api_path_rejects_duplicate_body():
+    class BodyA(schema.Lax):
+        x: int
+
+    class BodyB(schema.Lax):
+        y: int
+
+    async def route(_session: web.Committer, _a: BodyA, _b: BodyB) -> str:
+        return ""
+
+    with pytest.raises(TypeError, match="only one body type is allowed"):
+        common.build_api_path(route)
+
+
+def test_build_api_path_rejects_duplicate_query():
+    @dataclasses.dataclass
+    class QueryA:
+        x: int = 0
+
+    @dataclasses.dataclass
+    class QueryB:
+        y: int = 0
+
+    async def route(_session: web.Committer, _a: QueryA, _b: QueryB) -> str:
+        return ""
+
+    with pytest.raises(TypeError, match="only one query type is allowed"):
+        common.build_api_path(route)
+
+
+def test_setup_wrapper_sets_metadata():
+    async def original() -> str:
+        """Doc string."""
+        return ""
+
+    original.__module__ = "atr.get.dashboard"
+
+    async def wrapper() -> str:
+        return ""
+
+    endpoint = common.setup_wrapper(wrapper, original, "get_blueprint")
+    assert endpoint == "atr_get_dashboard_original"
+    assert wrapper.__name__ == "original"
+    assert wrapper.__doc__ == "Doc string."
+    assert wrapper.__annotations__["endpoint"] == 
"get_blueprint.atr_get_dashboard_original"
+
+
+def test_safe_params_for_type_finds_safe_fields():
+    class Body(schema.Strict):
+        project_key: safe.ProjectKey
+        version_key: safe.VersionKey
+        description: str
+
+    result = common.safe_params_for_type(Body)
+    assert ("project_key", safe.ProjectKey) in result
+    assert ("version_key", safe.VersionKey) in result
+    assert len(result) == 2
+
+
+def test_safe_params_for_type_empty_for_plain_model():
+    class Body(schema.Strict):
+        name: str
+        count: int
+
+    result = common.safe_params_for_type(Body)
+    assert result == []


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to