This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch arm in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit a106a69d8e0908493be389acca726daec23b42fc Author: Alastair McFarlane <[email protected]> AuthorDate: Thu Mar 19 15:29:26 2026 +0000 Some additional tidy up of typed decorators, and some additional tests for them --- atr/blueprints/admin.py | 115 +++----------------- atr/blueprints/api.py | 15 ++- atr/blueprints/common.py | 58 ++++++---- atr/blueprints/get.py | 19 +--- atr/blueprints/post.py | 26 ++--- tests/unit/test_blueprints.py | 240 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 305 insertions(+), 168 deletions(-) diff --git a/atr/blueprints/admin.py b/atr/blueprints/admin.py index 5ef56003..6d2bf7fe 100644 --- a/atr/blueprints/admin.py +++ b/atr/blueprints/admin.py @@ -14,11 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import json import time from collections.abc import Awaitable, Callable from types import ModuleType -from typing import Any, Concatenate, overload +from typing import Any, Concatenate, Final, overload import asfquart.base as base import asfquart.session @@ -27,103 +26,23 @@ import quart import quart_schema import atr.blueprints.common as common -import atr.form import atr.log as log import atr.user as user import atr.web as web -_BLUEPRINT_NAME = "admin_blueprint" -_BLUEPRINT = quart.Blueprint(_BLUEPRINT_NAME, __name__, url_prefix="/admin", template_folder="../admin/templates") +_BLUEPRINT_NAME: Final = "admin_blueprint" +_BLUEPRINT: Final = quart.Blueprint( + _BLUEPRINT_NAME, __name__, url_prefix="/admin", template_folder="../admin/templates" +) _routes: list[str] = [] -def empty() -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]: - def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: - async def wrapper(session: web.Committer, *args: Any, **kwargs: Any) -> Any: - form_data = await atr.form.quart_request() - try: - context = { - "args": args, - "kwargs": kwargs, - "session": session, - } - atr.form.validate(atr.form.Empty, form_data, context=context) - return await func(session, *args, **kwargs) - except pydantic.ValidationError: - msg = "Sorry, there was an empty form validation error. Please try again." - await quart.flash(msg, "error") - return quart.redirect(quart.request.path) - - wrapper.__annotations__ = func.__annotations__.copy() - wrapper.__doc__ = func.__doc__ - wrapper.__module__ = func.__module__ - wrapper.__name__ = func.__name__ - return wrapper - - return decorator - - -def form( - form_cls: Any, -) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]: - def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: - async def wrapper(session: web.Committer, *args: Any, **kwargs: Any) -> Any: - form_data = await atr.form.quart_request() - try: - context = { - "args": args, - "kwargs": kwargs, - "session": session, - } - validated_form = atr.form.validate(form_cls, form_data, context=context) - return await func(session, validated_form, *args, **kwargs) - except pydantic.ValidationError as e: - errors = e.errors() - if len(errors) == 0: - raise RuntimeError("Validation failed, but no errors were reported") - flash_data = atr.form.flash_error_data(form_cls, errors, form_data) - summary = atr.form.flash_error_summary(errors, flash_data) - - await quart.flash(summary, category="error") - await quart.flash(json.dumps(flash_data), category="form-error-data") - return quart.redirect(quart.request.path) - - wrapper.__annotations__ = func.__annotations__.copy() - wrapper.__doc__ = func.__doc__ - wrapper.__module__ = func.__module__ - wrapper.__name__ = func.__name__ - return wrapper - - return decorator - - -def get(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.RouteFunction[Any]]: - def decorator(func: web.CommitterRouteFunction[Any]) -> web.RouteFunction[Any]: - async def wrapper(*args: Any, **kwargs: Any) -> Any: - return await func(quart.g.session, *args, **kwargs) - - endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ - wrapper.__name__ = func.__name__ - wrapper.__doc__ = func.__doc__ - wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint - - _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper, methods=["GET"]) - - return wrapper - - return decorator - - def post(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.RouteFunction[Any]]: def decorator(func: web.CommitterRouteFunction[Any]) -> web.RouteFunction[Any]: async def wrapper(*args: Any, **kwargs: Any) -> Any: return await func(quart.g.session, *args, **kwargs) - endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ - wrapper.__name__ = func.__name__ - wrapper.__doc__ = func.__doc__ - wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint - + endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME) _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper, methods=["POST"]) return wrapper @@ -150,12 +69,13 @@ def typed(func: Callable[..., Any]) -> Callable[..., Any]: """Decorator that derives the URL path from the function's type annotations. - Literal["..."] parameters become literal path segments - - safe.ProjectName / safe.VersionName parameters are validated via cache/DB + - safe.SafeType subclass parameters are validated from the URL path - pydantic.BaseModel subclass parameters are parsed from the JSON request body + - form.Form subclass parameters are validated from the request body - dataclass parameters are parsed from the query string - str | None parameters create optional URL segments (two routes registered) - - int, float, str use Quart's built-in type converters - - HTTP method is POST if a body param is present, GET otherwise + - int, float use Quart's built-in type converters + - HTTP method is POST if a body or form param is present, GET otherwise """ path, validated_params, literal_params, body_param, form_param, query_param, optional_params = ( common.build_api_path(func) @@ -179,15 +99,7 @@ def typed(func: Callable[..., Any]) -> Callable[..., Any]: try: kwargs[form_param_name] = await enhanced_session.form_validate(form_cls, context) except pydantic.ValidationError as e: - errors = e.errors() - if len(errors) == 0: - raise RuntimeError("Validation failed, but no errors were reported") - form_data_raw = await atr.form.quart_request() - flash_data = atr.form.flash_error_data(form_cls, errors, form_data_raw) - summary = atr.form.flash_error_summary(errors, flash_data) - await quart.flash(summary, category="error") - await quart.flash(json.dumps(flash_data), category="form-error-data") - return quart.redirect(quart.request.path) + return await common.flash_form_error(form_cls, e) if form_safe_params: await common.validate_safe_fields(kwargs[form_param_name], form_safe_params, kwargs) @@ -202,10 +114,7 @@ def typed(func: Callable[..., Any]) -> Callable[..., Any]: return response - endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ - wrapper.__name__ = func.__name__ - wrapper.__doc__ = func.__doc__ - wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint + endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME) # Replace the original quart request decorators if query_param is not None: diff --git a/atr/blueprints/api.py b/atr/blueprints/api.py index e335072b..172e75ec 100644 --- a/atr/blueprints/api.py +++ b/atr/blueprints/api.py @@ -20,7 +20,7 @@ import sys import time from collections.abc import Awaitable, Callable from types import ModuleType -from typing import Any +from typing import Any, Final import asfquart.base as base import pydantic @@ -34,8 +34,8 @@ import atr.blueprints.common as common import atr.log as log import atr.web as web -_BLUEPRINT_NAME = "api_blueprint" -_BLUEPRINT = quart.Blueprint(_BLUEPRINT_NAME, __name__, url_prefix="/api") +_BLUEPRINT_NAME: Final = "api_blueprint" +_BLUEPRINT: Final = quart.Blueprint(_BLUEPRINT_NAME, __name__, url_prefix="/api") _routes: list[str] = [] @@ -49,13 +49,12 @@ def register(app: base.QuartApp) -> tuple[ModuleType, list[str]]: def typed(func: Callable[..., Awaitable[Any]]) -> web.RouteFunction[Any]: """Decorator that derives the URL path from the function's type annotations. - - Arguments after session are joined with / to make the web path - Literal["..."] parameters become literal path segments - - safe.ProjectName / safe.VersionName parameters are validated via cache/DB + - safe.SafeType subclass parameters are validated from the URL path - pydantic.BaseModel subclass parameters are parsed from the JSON request body - dataclass parameters are parsed from the query string - str | None parameters create optional URL segments (two routes registered) - - int, float, str use Quart's built-in type converters + - int, float use Quart's built-in type converters - HTTP method is POST if a body param is present, GET otherwise """ original = inspect.unwrap(func) @@ -84,9 +83,7 @@ def typed(func: Callable[..., Awaitable[Any]]) -> web.RouteFunction[Any]: return response - endpoint = original.__module__.replace(".", "_") + "_" + original.__name__ - wrapper.__name__ = original.__name__ - wrapper.__doc__ = original.__doc__ + endpoint = common.setup_wrapper(wrapper, original, _BLUEPRINT_NAME) # Replace the original quart request decorators if query_param is not None: diff --git a/atr/blueprints/common.py b/atr/blueprints/common.py index 8457a9ab..c5e26a32 100644 --- a/atr/blueprints/common.py +++ b/atr/blueprints/common.py @@ -20,7 +20,7 @@ import inspect import types import typing from collections.abc import Callable -from typing import Annotated, Any, Literal, TypeAliasType, get_args, get_origin, get_type_hints +from typing import Annotated, Any, Final, Literal, TypeAliasType, get_args, get_origin, get_type_hints import asfquart.base as base import asfquart.session @@ -35,13 +35,13 @@ import atr.models.safe as safe import atr.models.unsafe as unsafe import atr.web as web -QUART_CONVERTERS: dict[Any, str] = { +QUART_CONVERTERS: Final[dict[Any, str]] = { int: "int", float: "float", unsafe.Path: "path", } -VALIDATED_TYPES: set[Any] = { +VALIDATED_TYPES: Final[set[Any]] = { safe.Alphanumeric, safe.CommitteeKey, safe.ProjectKey, @@ -81,7 +81,7 @@ def build_path( Validates that the session param (web.Committer or web.Public) is first, and that only one Form param is allowed Returns (path, validated_params, literal_params, form_param, public) where: - - validated_params: (name, type) pairs for URL params to be validated with cache/DB + - validated_params: (name, type) pairs for safe.SafeType subclass URL params to be validated - literal_params: param name → literal string value for Literal["..."] params - form_param: (name, type) for the single form.Form subclass param, or None - public: True if the session type is web.Public @@ -130,14 +130,14 @@ def build_api_path( Accepts URL path params for data, Literal strings for plain URL text, dataclasses for GET query params and Pydantic model params for POST bodies - Returns (path, validated_params, literal_params, body_param, query_param, + Returns (path, validated_params, literal_params, body_param, form_param, query_param, optional_params) where: - - validated_params: (name, type) pairs for URL params to be validated with cache/DB + - validated_params: (name, type) pairs for safe.SafeType subclass URL params to be validated - literal_params: param name -> literal string value for Literal["..."] params - body_param: (name, type) for the single BaseModel param, or None + - form_param: (name, type) for the single form.Form subclass param, or None - query_param: (name, type) for the single dataclass param, or None - optional_params: param names whose type is T | None with a default of None - - return_type: the return type of the function """ hints = get_type_hints(func, include_extras=True) sig = inspect.signature(func) @@ -174,6 +174,15 @@ def build_api_path( return path, validated_params, literal_params, unique.body, unique.form, unique.query, optional_params +def setup_wrapper(wrapper: Callable[..., Any], func: Callable[..., Any], blueprint_name: str) -> str: + """Set standard metadata on a route wrapper and return the endpoint name.""" + endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ + wrapper.__name__ = func.__name__ + wrapper.__doc__ = func.__doc__ + wrapper.__annotations__["endpoint"] = blueprint_name + "." + endpoint + return endpoint + + def register_route(func: Callable[..., Any], prefix: str, routes: list[str]) -> None: module_name = func.__module__.split(".")[-1] routes.append(f"{prefix}.{module_name}.{func.__name__}") @@ -192,23 +201,13 @@ async def validate_params(kwargs: dict[str, Any], known_params: list[tuple[str, """Validate URL parameters in order, using the type-specific validators.""" for param_name, param_type in known_params: raw = kwargs[param_name] - if param_type is safe.ProjectKey: - try: - kwargs[param_name] = safe.ProjectKey(raw) - except ValueError: - raise base.ASFQuartException(f"Project name {param_name!r} is invalid. ") - elif param_type is safe.VersionKey: - try: - kwargs[param_name] = safe.VersionKey(raw) - except ValueError: - raise base.ASFQuartException(f"Version name {param_name!r} is invalid. ") - elif param_type is safe.RevisionNumber: + if param_type is unsafe.UnsafeStr: + kwargs[param_name] = unsafe.UnsafeStr(raw) + elif issubclass(param_type, safe.SafeType): try: - kwargs[param_name] = safe.RevisionNumber(raw) + kwargs[param_name] = param_type(raw) except ValueError: - raise base.ASFQuartException(f"Revision number {param_name!r} is invalid. ") - elif param_type is unsafe.UnsafeStr: - kwargs[param_name] = unsafe.UnsafeStr(raw) + raise base.ASFQuartException(f"Parameter {param_name!r} is invalid. ") async def validate_safe_fields( @@ -232,6 +231,21 @@ async def validate_safe_fields( setattr(instance, name, temp[name]) +async def flash_form_error(form_cls: type, error: pydantic.ValidationError) -> Any: + """Flash form validation errors and return a redirect to the current page.""" + import json + + errors = error.errors() + if len(errors) == 0: + raise RuntimeError("Validation failed, but no errors were reported") + form_data_raw = await form.quart_request() + flash_data = form.flash_error_data(form_cls, errors, form_data_raw) + summary = form.flash_error_summary(errors, flash_data) + await quart.flash(summary, category="error") + await quart.flash(json.dumps(flash_data), category="form-error-data") + return quart.redirect(quart.request.path) + + async def parse_body( body_param: tuple[str, type[pydantic.BaseModel]], safe_params: list[tuple[str, type]], diff --git a/atr/blueprints/get.py b/atr/blueprints/get.py index 9c8033db..b7c60c51 100644 --- a/atr/blueprints/get.py +++ b/atr/blueprints/get.py @@ -18,7 +18,7 @@ import time from collections.abc import Awaitable, Callable from types import ModuleType -from typing import Any, Concatenate, ParamSpec, TypeVar, overload +from typing import Any, Concatenate, Final, overload import asfquart.auth as auth import asfquart.base as base @@ -28,13 +28,10 @@ import atr.blueprints.common as common import atr.log as log import atr.web as web -_BLUEPRINT_NAME = "get_blueprint" -_BLUEPRINT = quart.Blueprint(_BLUEPRINT_NAME, __name__) +_BLUEPRINT_NAME: Final = "get_blueprint" +_BLUEPRINT: Final = quart.Blueprint(_BLUEPRINT_NAME, __name__) _routes: list[str] = [] -_P = ParamSpec("_P") -_R = TypeVar("_R") - def register(app: base.QuartApp) -> tuple[ModuleType, list[str]]: import atr.get as get @@ -54,12 +51,9 @@ def typed[**P, R](func: Callable[Concatenate[web.Public, P], Awaitable[R]]) -> w def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]: """Decorator that derives the URL path from the function's type annotations. - - Arguments after session are joined with / to make the web path - Literal["..."] parameters become literal path segments - - safe.ProjectName / safe.VersionName parameters are validated via cache/DB + - safe.SafeType subclass parameters are validated from the URL path - int, float use Quart's built-in type converters - - str parameters pass through as-is - - check_access is called automatically for committer routes with project_key """ path, validated_params, literal_params, _, public = common.build_path(func) @@ -80,10 +74,7 @@ def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]: return response - endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ - wrapper.__name__ = func.__name__ - wrapper.__doc__ = func.__doc__ - wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint + endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME) decorated = wrapper if public else auth.require(auth.Requirements.committer)(wrapper) _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["GET"]) diff --git a/atr/blueprints/post.py b/atr/blueprints/post.py index 488e2ed3..ca9daeaf 100644 --- a/atr/blueprints/post.py +++ b/atr/blueprints/post.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. -import json import time from collections.abc import Awaitable, Callable from types import ModuleType -from typing import Any, Concatenate, overload +from typing import Any, Concatenate, Final, overload import asfquart.auth as auth import asfquart.base as base @@ -32,8 +31,8 @@ import atr.log as log import atr.models.safe as safe import atr.web as web -_BLUEPRINT_NAME = "post_blueprint" -_BLUEPRINT = quart.Blueprint(_BLUEPRINT_NAME, __name__) +_BLUEPRINT_NAME: Final = "post_blueprint" +_BLUEPRINT: Final = quart.Blueprint(_BLUEPRINT_NAME, __name__) _routes: list[str] = [] @@ -55,11 +54,9 @@ def typed[**P, R](func: Callable[Concatenate[web.Public, P], Awaitable[R]]) -> w def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]: """Decorator that derives the URL path from the function's type annotations. - - Arguments after session are joined with / to make the web path - Literal["..."] parameters become literal path segments - - safe.ProjectName / safe.VersionName parameters are validated via cache/DB + - safe.SafeType subclass parameters are validated from the URL path - int, float use Quart's built-in type converters - - str parameters pass through as-is - A single form.Form subclass parameter is validated from the request body and injected - check_access is called automatically for committer routes with project_key """ @@ -87,15 +84,7 @@ def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]: form_data = await atr.form.quart_request() kwargs[form_param_name] = atr.form.validate(form_cls, form_data, context=context) except pydantic.ValidationError as e: - errors = e.errors() - if len(errors) == 0: - raise RuntimeError("Validation failed, but no errors were reported") - form_data_raw = await atr.form.quart_request() - flash_data = atr.form.flash_error_data(form_cls, errors, form_data_raw) - summary = atr.form.flash_error_summary(errors, flash_data) - await quart.flash(summary, category="error") - await quart.flash(json.dumps(flash_data), category="form-error-data") - return quart.redirect(quart.request.path) + return await common.flash_form_error(form_cls, e) if form_safe_params: await common.validate_safe_fields(kwargs[form_param_name], form_safe_params, kwargs) @@ -109,10 +98,7 @@ def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]: return response - endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ - wrapper.__name__ = func.__name__ - wrapper.__doc__ = func.__doc__ - wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint + endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME) decorated = wrapper if public else auth.require(auth.Requirements.committer)(wrapper) _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["POST"]) diff --git a/tests/unit/test_blueprints.py b/tests/unit/test_blueprints.py index af268bfe..28e79279 100644 --- a/tests/unit/test_blueprints.py +++ b/tests/unit/test_blueprints.py @@ -15,11 +15,19 @@ # specific language governing permissions and limitations # under the License. +import dataclasses +from typing import Literal + import asfquart import pytest import atr.blueprints as blueprints +import atr.blueprints.common as common +import atr.form as form +import atr.models.safe as safe +import atr.models.schema as schema import atr.util as util +import atr.web as web _TESTED_BLUEPRINTS = frozenset({"get_blueprint", "post_blueprint", "admin_blueprint"}) @@ -60,3 +68,235 @@ async def test_all_routes_support_url_construction(monkeypatch): if failures: raise AssertionError("Routes incompatible with as_url:\n" + "\n".join(failures)) + + +def test_build_path_literal_segment(): + async def route(_session: web.Committer, _page: Literal["dashboard"]) -> str: + return "" + + path, validated, literals, form_param, public = common.build_path(route) + assert path == "/dashboard" + assert literals == {"_page": "dashboard"} + assert validated == [] + assert form_param is None + assert public is False + + +def test_build_path_safe_type(): + async def route(_session: web.Committer, _project_key: safe.ProjectKey) -> str: + return "" + + path, validated, _, _, _ = common.build_path(route) + assert path == "/<_project_key>" + assert validated == [("_project_key", safe.ProjectKey)] + + +def test_build_path_combined_literal_and_safe(): + async def route( + _session: web.Committer, + _page: Literal["project"], + _project_key: safe.ProjectKey, + _sub: Literal["version"], + _version_key: safe.VersionKey, + ) -> str: + return "" + + path, validated, literals, _, _ = common.build_path(route) + assert path == "/project/<_project_key>/version/<_version_key>" + assert validated == [("_project_key", safe.ProjectKey), ("_version_key", safe.VersionKey)] + assert literals == {"_page": "project", "_sub": "version"} + + +def test_build_path_public_session(): + async def route(_session: web.Public, _page: Literal["home"]) -> str: + return "" + + _, _, _, _, public = common.build_path(route) + assert public is True + + +def test_build_path_int_converter(): + async def route(_session: web.Committer, _page: Literal["item"], _item_id: int) -> str: + return "" + + path, _, _, _, _ = common.build_path(route) + assert path == "/item/<int:_item_id>" + + +def test_build_path_form_param(): + class TestForm(form.Form): + name: str = "" + + async def route(_session: web.Committer, _page: Literal["submit"], _data: TestForm) -> str: + return "" + + path, _, _, form_param, _ = common.build_path(route) + assert path == "/submit" + assert form_param is not None + assert form_param[0] == "_data" + assert form_param[1] is TestForm + + +def test_build_path_rejects_duplicate_form(): + class FormA(form.Form): + name: str = "" + + class FormB(form.Form): + name: str = "" + + async def route(_session: web.Committer, _a: FormA, _b: FormB) -> str: + return "" + + with pytest.raises(TypeError, match="only one Form is allowed"): + common.build_path(route) + + +def test_build_path_rejects_session_not_first(): + async def route(_page: Literal["home"], _session: web.Committer) -> str: + return "" + + with pytest.raises(TypeError, match="must be first"): + common.build_path(route) + + +def test_build_path_rejects_unannotated_param(): + async def route(_session: web.Committer, _thing) -> str: # type: ignore[reportUnknownParameterType] + return "" + + with pytest.raises(TypeError, match="no type annotation"): + common.build_path(route) + + +def test_build_path_rejects_bare_str(): + async def route(_session: web.Committer, _name: str) -> str: + return "" + + with pytest.raises(TypeError, match="unguarded str"): + common.build_path(route) + + +def test_build_api_path_literal_and_safe(): + async def route( + _session: web.Committer, + _page: Literal["project"], + _project_key: safe.ProjectKey, + ) -> str: + return "" + + path, validated, literals, body, form_param, query, optional = common.build_api_path(route) + assert path == "/project/<_project_key>" + assert validated == [("_project_key", safe.ProjectKey)] + assert literals == {"_page": "project"} + assert body is None + assert form_param is None + assert query is None + assert optional == [] + + +def test_build_api_path_body_param(): + class RequestBody(schema.Strict): + value: int + + async def route(_session: web.Committer, _page: Literal["submit"], _data: RequestBody) -> str: + return "" + + _, _, _, body, _, _, _ = common.build_api_path(route) + assert body is not None + assert body[0] == "_data" + assert body[1] is RequestBody + + +def test_build_api_path_query_param(): + @dataclasses.dataclass + class Filters: + page: int = 1 + search: str = "" + + async def route(_session: web.Committer, _page: Literal["list"], _filters: Filters) -> str: + return "" + + _, _, _, _, _, query, _ = common.build_api_path(route) + assert query is not None + assert query[0] == "_filters" + assert query[1] is Filters + + +def test_build_api_path_optional_param(): + async def route( + _session: web.Committer, + _page: Literal["items"], + _category: str | None = None, + ) -> str: + return "" + + path, _, _, _, _, _, optional = common.build_api_path(route) + assert path == "/items/<_category>" + assert optional == ["_category"] + + +def test_build_api_path_rejects_duplicate_body(): + class BodyA(schema.Lax): + x: int + + class BodyB(schema.Lax): + y: int + + async def route(_session: web.Committer, _a: BodyA, _b: BodyB) -> str: + return "" + + with pytest.raises(TypeError, match="only one body type is allowed"): + common.build_api_path(route) + + +def test_build_api_path_rejects_duplicate_query(): + @dataclasses.dataclass + class QueryA: + x: int = 0 + + @dataclasses.dataclass + class QueryB: + y: int = 0 + + async def route(_session: web.Committer, _a: QueryA, _b: QueryB) -> str: + return "" + + with pytest.raises(TypeError, match="only one query type is allowed"): + common.build_api_path(route) + + +def test_setup_wrapper_sets_metadata(): + async def index() -> str: + """Doc string.""" + return "" + + index.__module__ = "atr.get.dashboard" + + async def wrapper() -> str: + return "" + + endpoint = common.setup_wrapper(wrapper, index, "get_blueprint") + assert endpoint == "atr_get_dashboard_index" + assert wrapper.__name__ == "index" + assert wrapper.__doc__ == "Doc string." + assert wrapper.__annotations__["endpoint"] == "get_blueprint.atr_get_dashboard_index" + + +def test_safe_params_for_type_finds_safe_fields(): + class Body(schema.Strict): + project_key: safe.ProjectKey + version_key: safe.VersionKey + description: str + + result = common.safe_params_for_type(Body) + assert ("project_key", safe.ProjectKey) in result + assert ("version_key", safe.VersionKey) in result + assert len(result) == 2 + + +def test_safe_params_for_type_empty_for_plain_model(): + class Body(schema.Strict): + name: str + count: int + + result = common.safe_params_for_type(Body) + assert result == [] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
