This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch arm in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit 53a8813bed93e4fbf4a0346573816e84f6b359d1 Author: Alastair McFarlane <[email protected]> AuthorDate: Thu Mar 19 15:29:26 2026 +0000 Some additional tidy up of typed decorators, and some additional tests for them --- atr/blueprints/admin.py | 100 +----------------- atr/blueprints/api.py | 4 +- atr/blueprints/common.py | 24 +++++ atr/blueprints/get.py | 5 +- atr/blueprints/post.py | 16 +-- tests/unit/test_blueprints.py | 240 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 271 insertions(+), 118 deletions(-) diff --git a/atr/blueprints/admin.py b/atr/blueprints/admin.py index 5ef56003..ae49633e 100644 --- a/atr/blueprints/admin.py +++ b/atr/blueprints/admin.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import json import time from collections.abc import Awaitable, Callable from types import ModuleType @@ -27,7 +26,6 @@ import quart import quart_schema import atr.blueprints.common as common -import atr.form import atr.log as log import atr.user as user import atr.web as web @@ -37,93 +35,12 @@ _BLUEPRINT = quart.Blueprint(_BLUEPRINT_NAME, __name__, url_prefix="/admin", tem _routes: list[str] = [] -def empty() -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]: - def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: - async def wrapper(session: web.Committer, *args: Any, **kwargs: Any) -> Any: - form_data = await atr.form.quart_request() - try: - context = { - "args": args, - "kwargs": kwargs, - "session": session, - } - atr.form.validate(atr.form.Empty, form_data, context=context) - return await func(session, *args, **kwargs) - except pydantic.ValidationError: - msg = "Sorry, there was an empty form validation error. Please try again." - await quart.flash(msg, "error") - return quart.redirect(quart.request.path) - - wrapper.__annotations__ = func.__annotations__.copy() - wrapper.__doc__ = func.__doc__ - wrapper.__module__ = func.__module__ - wrapper.__name__ = func.__name__ - return wrapper - - return decorator - - -def form( - form_cls: Any, -) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]: - def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: - async def wrapper(session: web.Committer, *args: Any, **kwargs: Any) -> Any: - form_data = await atr.form.quart_request() - try: - context = { - "args": args, - "kwargs": kwargs, - "session": session, - } - validated_form = atr.form.validate(form_cls, form_data, context=context) - return await func(session, validated_form, *args, **kwargs) - except pydantic.ValidationError as e: - errors = e.errors() - if len(errors) == 0: - raise RuntimeError("Validation failed, but no errors were reported") - flash_data = atr.form.flash_error_data(form_cls, errors, form_data) - summary = atr.form.flash_error_summary(errors, flash_data) - - await quart.flash(summary, category="error") - await quart.flash(json.dumps(flash_data), category="form-error-data") - return quart.redirect(quart.request.path) - - wrapper.__annotations__ = func.__annotations__.copy() - wrapper.__doc__ = func.__doc__ - wrapper.__module__ = func.__module__ - wrapper.__name__ = func.__name__ - return wrapper - - return decorator - - -def get(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.RouteFunction[Any]]: - def decorator(func: web.CommitterRouteFunction[Any]) -> web.RouteFunction[Any]: - async def wrapper(*args: Any, **kwargs: Any) -> Any: - return await func(quart.g.session, *args, **kwargs) - - endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ - wrapper.__name__ = func.__name__ - wrapper.__doc__ = func.__doc__ - wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint - - _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper, methods=["GET"]) - - return wrapper - - return decorator - - def post(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.RouteFunction[Any]]: def decorator(func: web.CommitterRouteFunction[Any]) -> web.RouteFunction[Any]: async def wrapper(*args: Any, **kwargs: Any) -> Any: return await func(quart.g.session, *args, **kwargs) - endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ - wrapper.__name__ = func.__name__ - wrapper.__doc__ = func.__doc__ - wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint - + endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME) _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper, methods=["POST"]) return wrapper @@ -179,15 +96,7 @@ def typed(func: Callable[..., Any]) -> Callable[..., Any]: try: kwargs[form_param_name] = await enhanced_session.form_validate(form_cls, context) except pydantic.ValidationError as e: - errors = e.errors() - if len(errors) == 0: - raise RuntimeError("Validation failed, but no errors were reported") - form_data_raw = await atr.form.quart_request() - flash_data = atr.form.flash_error_data(form_cls, errors, form_data_raw) - summary = atr.form.flash_error_summary(errors, flash_data) - await quart.flash(summary, category="error") - await quart.flash(json.dumps(flash_data), category="form-error-data") - return quart.redirect(quart.request.path) + return await common.flash_form_error(form_cls, e) if form_safe_params: await common.validate_safe_fields(kwargs[form_param_name], form_safe_params, kwargs) @@ -202,10 +111,7 @@ def typed(func: Callable[..., Any]) -> Callable[..., Any]: return response - endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ - wrapper.__name__ = func.__name__ - wrapper.__doc__ = func.__doc__ - wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint + endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME) # Replace the original quart request decorators if query_param is not None: diff --git a/atr/blueprints/api.py b/atr/blueprints/api.py index e335072b..6b5a3e28 100644 --- a/atr/blueprints/api.py +++ b/atr/blueprints/api.py @@ -84,9 +84,7 @@ def typed(func: Callable[..., Awaitable[Any]]) -> web.RouteFunction[Any]: return response - endpoint = original.__module__.replace(".", "_") + "_" + original.__name__ - wrapper.__name__ = original.__name__ - wrapper.__doc__ = original.__doc__ + endpoint = common.setup_wrapper(wrapper, original, _BLUEPRINT_NAME) # Replace the original quart request decorators if query_param is not None: diff --git a/atr/blueprints/common.py b/atr/blueprints/common.py index 8457a9ab..86e216fe 100644 --- a/atr/blueprints/common.py +++ b/atr/blueprints/common.py @@ -174,6 +174,15 @@ def build_api_path( return path, validated_params, literal_params, unique.body, unique.form, unique.query, optional_params +def setup_wrapper(wrapper: Callable[..., Any], func: Callable[..., Any], blueprint_name: str) -> str: + """Set standard metadata on a route wrapper and return the endpoint name.""" + endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ + wrapper.__name__ = func.__name__ + wrapper.__doc__ = func.__doc__ + wrapper.__annotations__["endpoint"] = blueprint_name + "." + endpoint + return endpoint + + def register_route(func: Callable[..., Any], prefix: str, routes: list[str]) -> None: module_name = func.__module__.split(".")[-1] routes.append(f"{prefix}.{module_name}.{func.__name__}") @@ -232,6 +241,21 @@ async def validate_safe_fields( setattr(instance, name, temp[name]) +async def flash_form_error(form_cls: type, error: pydantic.ValidationError) -> Any: + """Flash form validation errors and return a redirect to the current page.""" + import json + + errors = error.errors() + if len(errors) == 0: + raise RuntimeError("Validation failed, but no errors were reported") + form_data_raw = await form.quart_request() + flash_data = form.flash_error_data(form_cls, errors, form_data_raw) + summary = form.flash_error_summary(errors, flash_data) + await quart.flash(summary, category="error") + await quart.flash(json.dumps(flash_data), category="form-error-data") + return quart.redirect(quart.request.path) + + async def parse_body( body_param: tuple[str, type[pydantic.BaseModel]], safe_params: list[tuple[str, type]], diff --git a/atr/blueprints/get.py b/atr/blueprints/get.py index 9c8033db..1b5d9873 100644 --- a/atr/blueprints/get.py +++ b/atr/blueprints/get.py @@ -80,10 +80,7 @@ def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]: return response - endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ - wrapper.__name__ = func.__name__ - wrapper.__doc__ = func.__doc__ - wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint + endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME) decorated = wrapper if public else auth.require(auth.Requirements.committer)(wrapper) _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["GET"]) diff --git a/atr/blueprints/post.py b/atr/blueprints/post.py index 488e2ed3..638ef385 100644 --- a/atr/blueprints/post.py +++ b/atr/blueprints/post.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import json import time from collections.abc import Awaitable, Callable from types import ModuleType @@ -87,15 +86,7 @@ def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]: form_data = await atr.form.quart_request() kwargs[form_param_name] = atr.form.validate(form_cls, form_data, context=context) except pydantic.ValidationError as e: - errors = e.errors() - if len(errors) == 0: - raise RuntimeError("Validation failed, but no errors were reported") - form_data_raw = await atr.form.quart_request() - flash_data = atr.form.flash_error_data(form_cls, errors, form_data_raw) - summary = atr.form.flash_error_summary(errors, flash_data) - await quart.flash(summary, category="error") - await quart.flash(json.dumps(flash_data), category="form-error-data") - return quart.redirect(quart.request.path) + return await common.flash_form_error(form_cls, e) if form_safe_params: await common.validate_safe_fields(kwargs[form_param_name], form_safe_params, kwargs) @@ -109,10 +100,7 @@ def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]: return response - endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ - wrapper.__name__ = func.__name__ - wrapper.__doc__ = func.__doc__ - wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint + endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME) decorated = wrapper if public else auth.require(auth.Requirements.committer)(wrapper) _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["POST"]) diff --git a/tests/unit/test_blueprints.py b/tests/unit/test_blueprints.py index af268bfe..28e79279 100644 --- a/tests/unit/test_blueprints.py +++ b/tests/unit/test_blueprints.py @@ -15,11 +15,19 @@ # specific language governing permissions and limitations # under the License. +import dataclasses +from typing import Literal + import asfquart import pytest import atr.blueprints as blueprints +import atr.blueprints.common as common +import atr.form as form +import atr.models.safe as safe +import atr.models.schema as schema import atr.util as util +import atr.web as web _TESTED_BLUEPRINTS = frozenset({"get_blueprint", "post_blueprint", "admin_blueprint"}) @@ -60,3 +68,235 @@ async def test_all_routes_support_url_construction(monkeypatch): if failures: raise AssertionError("Routes incompatible with as_url:\n" + "\n".join(failures)) + + +def test_build_path_literal_segment(): + async def route(_session: web.Committer, _page: Literal["dashboard"]) -> str: + return "" + + path, validated, literals, form_param, public = common.build_path(route) + assert path == "/dashboard" + assert literals == {"_page": "dashboard"} + assert validated == [] + assert form_param is None + assert public is False + + +def test_build_path_safe_type(): + async def route(_session: web.Committer, _project_key: safe.ProjectKey) -> str: + return "" + + path, validated, _, _, _ = common.build_path(route) + assert path == "/<_project_key>" + assert validated == [("_project_key", safe.ProjectKey)] + + +def test_build_path_combined_literal_and_safe(): + async def route( + _session: web.Committer, + _page: Literal["project"], + _project_key: safe.ProjectKey, + _sub: Literal["version"], + _version_key: safe.VersionKey, + ) -> str: + return "" + + path, validated, literals, _, _ = common.build_path(route) + assert path == "/project/<_project_key>/version/<_version_key>" + assert validated == [("_project_key", safe.ProjectKey), ("_version_key", safe.VersionKey)] + assert literals == {"_page": "project", "_sub": "version"} + + +def test_build_path_public_session(): + async def route(_session: web.Public, _page: Literal["home"]) -> str: + return "" + + _, _, _, _, public = common.build_path(route) + assert public is True + + +def test_build_path_int_converter(): + async def route(_session: web.Committer, _page: Literal["item"], _item_id: int) -> str: + return "" + + path, _, _, _, _ = common.build_path(route) + assert path == "/item/<int:_item_id>" + + +def test_build_path_form_param(): + class TestForm(form.Form): + name: str = "" + + async def route(_session: web.Committer, _page: Literal["submit"], _data: TestForm) -> str: + return "" + + path, _, _, form_param, _ = common.build_path(route) + assert path == "/submit" + assert form_param is not None + assert form_param[0] == "_data" + assert form_param[1] is TestForm + + +def test_build_path_rejects_duplicate_form(): + class FormA(form.Form): + name: str = "" + + class FormB(form.Form): + name: str = "" + + async def route(_session: web.Committer, _a: FormA, _b: FormB) -> str: + return "" + + with pytest.raises(TypeError, match="only one Form is allowed"): + common.build_path(route) + + +def test_build_path_rejects_session_not_first(): + async def route(_page: Literal["home"], _session: web.Committer) -> str: + return "" + + with pytest.raises(TypeError, match="must be first"): + common.build_path(route) + + +def test_build_path_rejects_unannotated_param(): + async def route(_session: web.Committer, _thing) -> str: # type: ignore[reportUnknownParameterType] + return "" + + with pytest.raises(TypeError, match="no type annotation"): + common.build_path(route) + + +def test_build_path_rejects_bare_str(): + async def route(_session: web.Committer, _name: str) -> str: + return "" + + with pytest.raises(TypeError, match="unguarded str"): + common.build_path(route) + + +def test_build_api_path_literal_and_safe(): + async def route( + _session: web.Committer, + _page: Literal["project"], + _project_key: safe.ProjectKey, + ) -> str: + return "" + + path, validated, literals, body, form_param, query, optional = common.build_api_path(route) + assert path == "/project/<_project_key>" + assert validated == [("_project_key", safe.ProjectKey)] + assert literals == {"_page": "project"} + assert body is None + assert form_param is None + assert query is None + assert optional == [] + + +def test_build_api_path_body_param(): + class RequestBody(schema.Strict): + value: int + + async def route(_session: web.Committer, _page: Literal["submit"], _data: RequestBody) -> str: + return "" + + _, _, _, body, _, _, _ = common.build_api_path(route) + assert body is not None + assert body[0] == "_data" + assert body[1] is RequestBody + + +def test_build_api_path_query_param(): + @dataclasses.dataclass + class Filters: + page: int = 1 + search: str = "" + + async def route(_session: web.Committer, _page: Literal["list"], _filters: Filters) -> str: + return "" + + _, _, _, _, _, query, _ = common.build_api_path(route) + assert query is not None + assert query[0] == "_filters" + assert query[1] is Filters + + +def test_build_api_path_optional_param(): + async def route( + _session: web.Committer, + _page: Literal["items"], + _category: str | None = None, + ) -> str: + return "" + + path, _, _, _, _, _, optional = common.build_api_path(route) + assert path == "/items/<_category>" + assert optional == ["_category"] + + +def test_build_api_path_rejects_duplicate_body(): + class BodyA(schema.Lax): + x: int + + class BodyB(schema.Lax): + y: int + + async def route(_session: web.Committer, _a: BodyA, _b: BodyB) -> str: + return "" + + with pytest.raises(TypeError, match="only one body type is allowed"): + common.build_api_path(route) + + +def test_build_api_path_rejects_duplicate_query(): + @dataclasses.dataclass + class QueryA: + x: int = 0 + + @dataclasses.dataclass + class QueryB: + y: int = 0 + + async def route(_session: web.Committer, _a: QueryA, _b: QueryB) -> str: + return "" + + with pytest.raises(TypeError, match="only one query type is allowed"): + common.build_api_path(route) + + +def test_setup_wrapper_sets_metadata(): + async def index() -> str: + """Doc string.""" + return "" + + index.__module__ = "atr.get.dashboard" + + async def wrapper() -> str: + return "" + + endpoint = common.setup_wrapper(wrapper, index, "get_blueprint") + assert endpoint == "atr_get_dashboard_index" + assert wrapper.__name__ == "index" + assert wrapper.__doc__ == "Doc string." + assert wrapper.__annotations__["endpoint"] == "get_blueprint.atr_get_dashboard_index" + + +def test_safe_params_for_type_finds_safe_fields(): + class Body(schema.Strict): + project_key: safe.ProjectKey + version_key: safe.VersionKey + description: str + + result = common.safe_params_for_type(Body) + assert ("project_key", safe.ProjectKey) in result + assert ("version_key", safe.VersionKey) in result + assert len(result) == 2 + + +def test_safe_params_for_type_empty_for_plain_model(): + class Body(schema.Strict): + name: str + count: int + + result = common.safe_params_for_type(Body) + assert result == [] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
