This is an automated email from the ASF dual-hosted git repository.
arm pushed a commit to branch arm
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
The following commit(s) were added to refs/heads/arm by this push:
new 06076250 Some additional tidy up of typed decorators, and some
additional tests for them
06076250 is described below
commit 060762506c7647aacb24075963d9644226f7c06c
Author: Alastair McFarlane <[email protected]>
AuthorDate: Thu Mar 19 15:29:26 2026 +0000
Some additional tidy up of typed decorators, and some additional tests for
them
---
atr/blueprints/admin.py | 100 +-----------------
atr/blueprints/api.py | 4 +-
atr/blueprints/common.py | 24 +++++
atr/blueprints/get.py | 5 +-
atr/blueprints/post.py | 16 +--
tests/unit/test_blueprints.py | 240 ++++++++++++++++++++++++++++++++++++++++++
6 files changed, 271 insertions(+), 118 deletions(-)
diff --git a/atr/blueprints/admin.py b/atr/blueprints/admin.py
index 5ef56003..ae49633e 100644
--- a/atr/blueprints/admin.py
+++ b/atr/blueprints/admin.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import json
import time
from collections.abc import Awaitable, Callable
from types import ModuleType
@@ -27,7 +26,6 @@ import quart
import quart_schema
import atr.blueprints.common as common
-import atr.form
import atr.log as log
import atr.user as user
import atr.web as web
@@ -37,93 +35,12 @@ _BLUEPRINT = quart.Blueprint(_BLUEPRINT_NAME, __name__,
url_prefix="/admin", tem
_routes: list[str] = []
-def empty() -> Callable[[Callable[..., Awaitable[Any]]], Callable[...,
Awaitable[Any]]]:
- def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[...,
Awaitable[Any]]:
- async def wrapper(session: web.Committer, *args: Any, **kwargs: Any)
-> Any:
- form_data = await atr.form.quart_request()
- try:
- context = {
- "args": args,
- "kwargs": kwargs,
- "session": session,
- }
- atr.form.validate(atr.form.Empty, form_data, context=context)
- return await func(session, *args, **kwargs)
- except pydantic.ValidationError:
- msg = "Sorry, there was an empty form validation error. Please
try again."
- await quart.flash(msg, "error")
- return quart.redirect(quart.request.path)
-
- wrapper.__annotations__ = func.__annotations__.copy()
- wrapper.__doc__ = func.__doc__
- wrapper.__module__ = func.__module__
- wrapper.__name__ = func.__name__
- return wrapper
-
- return decorator
-
-
-def form(
- form_cls: Any,
-) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
- def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[...,
Awaitable[Any]]:
- async def wrapper(session: web.Committer, *args: Any, **kwargs: Any)
-> Any:
- form_data = await atr.form.quart_request()
- try:
- context = {
- "args": args,
- "kwargs": kwargs,
- "session": session,
- }
- validated_form = atr.form.validate(form_cls, form_data,
context=context)
- return await func(session, validated_form, *args, **kwargs)
- except pydantic.ValidationError as e:
- errors = e.errors()
- if len(errors) == 0:
- raise RuntimeError("Validation failed, but no errors were
reported")
- flash_data = atr.form.flash_error_data(form_cls, errors,
form_data)
- summary = atr.form.flash_error_summary(errors, flash_data)
-
- await quart.flash(summary, category="error")
- await quart.flash(json.dumps(flash_data),
category="form-error-data")
- return quart.redirect(quart.request.path)
-
- wrapper.__annotations__ = func.__annotations__.copy()
- wrapper.__doc__ = func.__doc__
- wrapper.__module__ = func.__module__
- wrapper.__name__ = func.__name__
- return wrapper
-
- return decorator
-
-
-def get(path: str) -> Callable[[web.CommitterRouteFunction[Any]],
web.RouteFunction[Any]]:
- def decorator(func: web.CommitterRouteFunction[Any]) ->
web.RouteFunction[Any]:
- async def wrapper(*args: Any, **kwargs: Any) -> Any:
- return await func(quart.g.session, *args, **kwargs)
-
- endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
- wrapper.__name__ = func.__name__
- wrapper.__doc__ = func.__doc__
- wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint
-
- _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper,
methods=["GET"])
-
- return wrapper
-
- return decorator
-
-
def post(path: str) -> Callable[[web.CommitterRouteFunction[Any]],
web.RouteFunction[Any]]:
def decorator(func: web.CommitterRouteFunction[Any]) ->
web.RouteFunction[Any]:
async def wrapper(*args: Any, **kwargs: Any) -> Any:
return await func(quart.g.session, *args, **kwargs)
- endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
- wrapper.__name__ = func.__name__
- wrapper.__doc__ = func.__doc__
- wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint
-
+ endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME)
_BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper,
methods=["POST"])
return wrapper
@@ -179,15 +96,7 @@ def typed(func: Callable[..., Any]) -> Callable[..., Any]:
try:
kwargs[form_param_name] = await
enhanced_session.form_validate(form_cls, context)
except pydantic.ValidationError as e:
- errors = e.errors()
- if len(errors) == 0:
- raise RuntimeError("Validation failed, but no errors were
reported")
- form_data_raw = await atr.form.quart_request()
- flash_data = atr.form.flash_error_data(form_cls, errors,
form_data_raw)
- summary = atr.form.flash_error_summary(errors, flash_data)
- await quart.flash(summary, category="error")
- await quart.flash(json.dumps(flash_data),
category="form-error-data")
- return quart.redirect(quart.request.path)
+ return await common.flash_form_error(form_cls, e)
if form_safe_params:
await common.validate_safe_fields(kwargs[form_param_name],
form_safe_params, kwargs)
@@ -202,10 +111,7 @@ def typed(func: Callable[..., Any]) -> Callable[..., Any]:
return response
- endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
- wrapper.__name__ = func.__name__
- wrapper.__doc__ = func.__doc__
- wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint
+ endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME)
# Replace the original quart request decorators
if query_param is not None:
diff --git a/atr/blueprints/api.py b/atr/blueprints/api.py
index e335072b..6b5a3e28 100644
--- a/atr/blueprints/api.py
+++ b/atr/blueprints/api.py
@@ -84,9 +84,7 @@ def typed(func: Callable[..., Awaitable[Any]]) ->
web.RouteFunction[Any]:
return response
- endpoint = original.__module__.replace(".", "_") + "_" + original.__name__
- wrapper.__name__ = original.__name__
- wrapper.__doc__ = original.__doc__
+ endpoint = common.setup_wrapper(wrapper, original, _BLUEPRINT_NAME)
# Replace the original quart request decorators
if query_param is not None:
diff --git a/atr/blueprints/common.py b/atr/blueprints/common.py
index 8457a9ab..86e216fe 100644
--- a/atr/blueprints/common.py
+++ b/atr/blueprints/common.py
@@ -174,6 +174,15 @@ def build_api_path(
return path, validated_params, literal_params, unique.body, unique.form,
unique.query, optional_params
+def setup_wrapper(wrapper: Callable[..., Any], func: Callable[..., Any],
blueprint_name: str) -> str:
+ """Set standard metadata on a route wrapper and return the endpoint
name."""
+ endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
+ wrapper.__name__ = func.__name__
+ wrapper.__doc__ = func.__doc__
+ wrapper.__annotations__["endpoint"] = blueprint_name + "." + endpoint
+ return endpoint
+
+
def register_route(func: Callable[..., Any], prefix: str, routes: list[str])
-> None:
module_name = func.__module__.split(".")[-1]
routes.append(f"{prefix}.{module_name}.{func.__name__}")
@@ -232,6 +241,21 @@ async def validate_safe_fields(
setattr(instance, name, temp[name])
+async def flash_form_error(form_cls: type, error: pydantic.ValidationError) ->
Any:
+ """Flash form validation errors and return a redirect to the current
page."""
+ import json
+
+ errors = error.errors()
+ if len(errors) == 0:
+ raise RuntimeError("Validation failed, but no errors were reported")
+ form_data_raw = await form.quart_request()
+ flash_data = form.flash_error_data(form_cls, errors, form_data_raw)
+ summary = form.flash_error_summary(errors, flash_data)
+ await quart.flash(summary, category="error")
+ await quart.flash(json.dumps(flash_data), category="form-error-data")
+ return quart.redirect(quart.request.path)
+
+
async def parse_body(
body_param: tuple[str, type[pydantic.BaseModel]],
safe_params: list[tuple[str, type]],
diff --git a/atr/blueprints/get.py b/atr/blueprints/get.py
index 9c8033db..1b5d9873 100644
--- a/atr/blueprints/get.py
+++ b/atr/blueprints/get.py
@@ -80,10 +80,7 @@ def typed(func: Callable[..., Any]) ->
web.RouteFunction[Any]:
return response
- endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
- wrapper.__name__ = func.__name__
- wrapper.__doc__ = func.__doc__
- wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint
+ endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME)
decorated = wrapper if public else
auth.require(auth.Requirements.committer)(wrapper)
_BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated,
methods=["GET"])
diff --git a/atr/blueprints/post.py b/atr/blueprints/post.py
index 488e2ed3..638ef385 100644
--- a/atr/blueprints/post.py
+++ b/atr/blueprints/post.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-import json
import time
from collections.abc import Awaitable, Callable
from types import ModuleType
@@ -87,15 +86,7 @@ def typed(func: Callable[..., Any]) ->
web.RouteFunction[Any]:
form_data = await atr.form.quart_request()
kwargs[form_param_name] = atr.form.validate(form_cls,
form_data, context=context)
except pydantic.ValidationError as e:
- errors = e.errors()
- if len(errors) == 0:
- raise RuntimeError("Validation failed, but no errors were
reported")
- form_data_raw = await atr.form.quart_request()
- flash_data = atr.form.flash_error_data(form_cls, errors,
form_data_raw)
- summary = atr.form.flash_error_summary(errors, flash_data)
- await quart.flash(summary, category="error")
- await quart.flash(json.dumps(flash_data),
category="form-error-data")
- return quart.redirect(quart.request.path)
+ return await common.flash_form_error(form_cls, e)
if form_safe_params:
await common.validate_safe_fields(kwargs[form_param_name],
form_safe_params, kwargs)
@@ -109,10 +100,7 @@ def typed(func: Callable[..., Any]) ->
web.RouteFunction[Any]:
return response
- endpoint = func.__module__.replace(".", "_") + "_" + func.__name__
- wrapper.__name__ = func.__name__
- wrapper.__doc__ = func.__doc__
- wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint
+ endpoint = common.setup_wrapper(wrapper, func, _BLUEPRINT_NAME)
decorated = wrapper if public else
auth.require(auth.Requirements.committer)(wrapper)
_BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated,
methods=["POST"])
diff --git a/tests/unit/test_blueprints.py b/tests/unit/test_blueprints.py
index af268bfe..59edb7c0 100644
--- a/tests/unit/test_blueprints.py
+++ b/tests/unit/test_blueprints.py
@@ -15,11 +15,19 @@
# specific language governing permissions and limitations
# under the License.
+import dataclasses
+from typing import Literal
+
import asfquart
import pytest
import atr.blueprints as blueprints
+import atr.blueprints.common as common
+import atr.form as form
+import atr.models.safe as safe
+import atr.models.schema as schema
import atr.util as util
+import atr.web as web
_TESTED_BLUEPRINTS = frozenset({"get_blueprint", "post_blueprint",
"admin_blueprint"})
@@ -60,3 +68,235 @@ async def
test_all_routes_support_url_construction(monkeypatch):
if failures:
raise AssertionError("Routes incompatible with as_url:\n" +
"\n".join(failures))
+
+
+def test_build_path_literal_segment():
+ async def route(_session: web.Committer, _page: Literal["dashboard"]) ->
str:
+ return ""
+
+ path, validated, literals, form_param, public = common.build_path(route)
+ assert path == "/dashboard"
+ assert literals == {"_page": "dashboard"}
+ assert validated == []
+ assert form_param is None
+ assert public is False
+
+
+def test_build_path_safe_type():
+ async def route(_session: web.Committer, _project_key: safe.ProjectKey) ->
str:
+ return ""
+
+ path, validated, _, _, _ = common.build_path(route)
+ assert path == "/<_project_key>"
+ assert validated == [("_project_key", safe.ProjectKey)]
+
+
+def test_build_path_combined_literal_and_safe():
+ async def route(
+ _session: web.Committer,
+ _page: Literal["project"],
+ _project_key: safe.ProjectKey,
+ _sub: Literal["version"],
+ _version_key: safe.VersionKey,
+ ) -> str:
+ return ""
+
+ path, validated, literals, _, _ = common.build_path(route)
+ assert path == "/project/<_project_key>/version/<_version_key>"
+ assert validated == [("_project_key", safe.ProjectKey), ("_version_key",
safe.VersionKey)]
+ assert literals == {"_page": "project", "_sub": "version"}
+
+
+def test_build_path_public_session():
+ async def route(_session: web.Public, _page: Literal["home"]) -> str:
+ return ""
+
+ _, _, _, _, public = common.build_path(route)
+ assert public is True
+
+
+def test_build_path_int_converter():
+ async def route(_session: web.Committer, _page: Literal["item"], _item_id:
int) -> str:
+ return ""
+
+ path, _, _, _, _ = common.build_path(route)
+ assert path == "/item/<int:_item_id>"
+
+
+def test_build_path_form_param():
+ class TestForm(form.Form):
+ name: str = ""
+
+ async def route(_session: web.Committer, _page: Literal["submit"], _data:
TestForm) -> str:
+ return ""
+
+ path, _, _, form_param, _ = common.build_path(route)
+ assert path == "/submit"
+ assert form_param is not None
+ assert form_param[0] == "_data"
+ assert form_param[1] is TestForm
+
+
+def test_build_path_rejects_duplicate_form():
+ class FormA(form.Form):
+ name: str = ""
+
+ class FormB(form.Form):
+ name: str = ""
+
+ async def route(_session: web.Committer, _a: FormA, _b: FormB) -> str:
+ return ""
+
+ with pytest.raises(TypeError, match="only one Form is allowed"):
+ common.build_path(route)
+
+
+def test_build_path_rejects_session_not_first():
+ async def route(_page: Literal["home"], _session: web.Committer) -> str:
+ return ""
+
+ with pytest.raises(TypeError, match="must be first"):
+ common.build_path(route)
+
+
+def test_build_path_rejects_unannotated_param():
+ async def route(_session: web.Committer, _thing) -> str: # type:
ignore[reportUnknownParameterType]
+ return ""
+
+ with pytest.raises(TypeError, match="no type annotation"):
+ common.build_path(route)
+
+
+def test_build_path_rejects_bare_str():
+ async def route(_session: web.Committer, _name: str) -> str:
+ return ""
+
+ with pytest.raises(TypeError, match="unguarded str"):
+ common.build_path(route)
+
+
+def test_build_api_path_literal_and_safe():
+ async def route(
+ _session: web.Committer,
+ _page: Literal["project"],
+ _project_key: safe.ProjectKey,
+ ) -> str:
+ return ""
+
+ path, validated, literals, body, form_param, query, optional =
common.build_api_path(route)
+ assert path == "/project/<_project_key>"
+ assert validated == [("_project_key", safe.ProjectKey)]
+ assert literals == {"_page": "project"}
+ assert body is None
+ assert form_param is None
+ assert query is None
+ assert optional == []
+
+
+def test_build_api_path_body_param():
+ class RequestBody(schema.Strict):
+ value: int
+
+ async def route(_session: web.Committer, _page: Literal["submit"], _data:
RequestBody) -> str:
+ return ""
+
+ _, _, _, body, _, _, _ = common.build_api_path(route)
+ assert body is not None
+ assert body[0] == "_data"
+ assert body[1] is RequestBody
+
+
+def test_build_api_path_query_param():
+ @dataclasses.dataclass
+ class Filters:
+ page: int = 1
+ search: str = ""
+
+ async def route(_session: web.Committer, _page: Literal["list"], _filters:
Filters) -> str:
+ return ""
+
+ _, _, _, _, _, query, _ = common.build_api_path(route)
+ assert query is not None
+ assert query[0] == "_filters"
+ assert query[1] is Filters
+
+
+def test_build_api_path_optional_param():
+ async def route(
+ _session: web.Committer,
+ _page: Literal["items"],
+ _category: str | None = None,
+ ) -> str:
+ return ""
+
+ path, _, _, _, _, _, optional = common.build_api_path(route)
+ assert path == "/items/<_category>"
+ assert optional == ["_category"]
+
+
+def test_build_api_path_rejects_duplicate_body():
+ class BodyA(schema.Lax):
+ x: int
+
+ class BodyB(schema.Lax):
+ y: int
+
+ async def route(_session: web.Committer, _a: BodyA, _b: BodyB) -> str:
+ return ""
+
+ with pytest.raises(TypeError, match="only one body type is allowed"):
+ common.build_api_path(route)
+
+
+def test_build_api_path_rejects_duplicate_query():
+ @dataclasses.dataclass
+ class QueryA:
+ x: int = 0
+
+ @dataclasses.dataclass
+ class QueryB:
+ y: int = 0
+
+ async def route(_session: web.Committer, _a: QueryA, _b: QueryB) -> str:
+ return ""
+
+ with pytest.raises(TypeError, match="only one query type is allowed"):
+ common.build_api_path(route)
+
+
+def test_setup_wrapper_sets_metadata():
+ async def original() -> str:
+ """Doc string."""
+ return ""
+
+ original.__module__ = "atr.get.dashboard"
+
+ async def wrapper() -> str:
+ return ""
+
+ endpoint = common.setup_wrapper(wrapper, original, "get_blueprint")
+ assert endpoint == "atr_get_dashboard_original"
+ assert wrapper.__name__ == "original"
+ assert wrapper.__doc__ == "Doc string."
+ assert wrapper.__annotations__["endpoint"] ==
"get_blueprint.atr_get_dashboard_original"
+
+
+def test_safe_params_for_type_finds_safe_fields():
+ class Body(schema.Strict):
+ project_key: safe.ProjectKey
+ version_key: safe.VersionKey
+ description: str
+
+ result = common.safe_params_for_type(Body)
+ assert ("project_key", safe.ProjectKey) in result
+ assert ("version_key", safe.VersionKey) in result
+ assert len(result) == 2
+
+
+def test_safe_params_for_type_empty_for_plain_model():
+ class Body(schema.Strict):
+ name: str
+ count: int
+
+ result = common.safe_params_for_type(Body)
+ assert result == []
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]