This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch taint_tracking_types in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit 5b438db7fa0ec234010eab57a3cb09fa94f27ae9 Author: Alastair McFarlane <[email protected]> AuthorDate: Wed Feb 25 11:56:34 2026 +0000 First cut of taint tracking types for project and version --- atr/blueprints/get.py | 129 +++++++++++++++++++++++++++++++++++++++++++----- atr/blueprints/post.py | 129 +++++++++++++++++++++++++++++++++++++++++++----- atr/cache.py | 61 +++++++++++++++++++++++ atr/get/announce.py | 11 ++++- atr/get/checks.py | 6 ++- atr/get/compose.py | 11 ++++- atr/get/distribution.py | 10 ++-- atr/get/download.py | 4 +- atr/get/draft.py | 2 +- atr/get/file.py | 4 +- atr/get/finish.py | 2 +- atr/get/ignores.py | 2 +- atr/get/manual.py | 4 +- atr/get/release.py | 2 +- atr/get/report.py | 2 +- atr/get/result.py | 2 +- atr/get/revisions.py | 2 +- atr/get/sbom.py | 2 +- atr/get/test.py | 2 +- atr/get/upload.py | 2 +- atr/get/voting.py | 2 +- atr/server.py | 4 ++ atr/taint.py | 33 +++++++++++++ atr/validated.py | 25 ++++++++++ atr/web.py | 29 +++++++++++ 25 files changed, 428 insertions(+), 54 deletions(-) diff --git a/atr/blueprints/get.py b/atr/blueprints/get.py index f715defd..c1b1511b 100644 --- a/atr/blueprints/get.py +++ b/atr/blueprints/get.py @@ -25,8 +25,13 @@ import asfquart.base as base import asfquart.session import quart +import atr.cache as cache +import atr.db as db import atr.ldap as ldap import atr.log as log +import atr.models.sql as sql +import atr.taint as taint +import atr.validated as validated import atr.web as web _BLUEPRINT_NAME = "get_blueprint" @@ -34,17 +39,50 @@ _BLUEPRINT = quart.Blueprint(_BLUEPRINT_NAME, __name__) _routes: list[str] = [] +async def _authenticate() -> web.Committer: + web_session = await asfquart.session.read() + if web_session is None: + raise base.ASFQuartException("Not authenticated", errorcode=401) + if (web_session.uid is None) or (not await ldap.is_active(web_session.uid)): + asfquart.session.clear() + raise base.ASFQuartException("Account is disabled", errorcode=401) + return web.Committer(web_session) + + +def _register(func: Callable[..., Any]) -> None: + module_name = func.__module__.split(".")[-1] + _routes.append(f"get.{module_name}.{func.__name__}") + + +async def _validate_project(raw: str) -> validated.ProjectName: + if cache.project_version_has_project(raw): + return validated.ProjectName(raw) + async with db.session() as data: + project = await data.project(name=raw, status=sql.ProjectStatus.ACTIVE, _committee=False).get() + if project is None: + raise base.ASFQuartException(f"Project {raw!r} not found", errorcode=404) + return validated.ProjectName(project.name) + + +async def _validate_version(project_name: str, raw: str) -> validated.VersionName: + if cache.project_version_has_version(project_name, raw): + return validated.VersionName(raw) + async with db.session() as data: + release = await data.release( + project_name=project_name, + version=raw, + _project=False, + _committee=False, + ).get() + if release is None: + raise base.ASFQuartException(f"Version {raw!r} not found for project {project_name!r}", errorcode=404) + return validated.VersionName(release.version) + + def committer(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.RouteFunction[Any]]: def decorator(func: web.CommitterRouteFunction[Any]) -> web.RouteFunction[Any]: async def wrapper(*args: Any, **kwargs: Any) -> Any: - web_session = await asfquart.session.read() - if web_session is None: - raise base.ASFQuartException("Not authenticated", errorcode=401) - if (web_session.uid is None) or (not await ldap.is_active(web_session.uid)): - asfquart.session.clear() - raise base.ASFQuartException("Account is disabled", errorcode=401) - - enhanced_session = web.Committer(web_session) + enhanced_session = await _authenticate() start_time_ns = time.perf_counter_ns() response = await func(enhanced_session, *args, **kwargs) end_time_ns = time.perf_counter_ns() @@ -65,9 +103,76 @@ def committer(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.Rout decorated = auth.require(auth.Requirements.committer)(wrapper) _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["GET"]) + _register(func) + + return decorated + + return decorator + + +def committer_project(path: str) -> Callable[[web.CommitterProjectHandler[Any]], web.RouteFunction[Any]]: + def decorator(func: web.CommitterProjectHandler[Any]) -> web.RouteFunction[Any]: + async def wrapper(*_args: Any, **kwargs: Any) -> Any: + enhanced_session = await _authenticate() + project_name = await _validate_project(kwargs.pop("project_name")) + unsafe_kwargs: dict[str, taint.UnsafeStr] = {k: taint.UnsafeStr(v) for k, v in kwargs.items()} - module_name = func.__module__.split(".")[-1] - _routes.append(f"get.{module_name}.{func.__name__}") + start_time_ns = time.perf_counter_ns() + response = await func(enhanced_session, project_name=project_name, **unsafe_kwargs) + end_time_ns = time.perf_counter_ns() + total_ns = end_time_ns - start_time_ns + total_ms = total_ns // 1_000_000 + + log.performance( + f"GET {path} {func.__name__} = 0 0 {total_ms}", + ) + + return response + + endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ + wrapper.__name__ = func.__name__ + wrapper.__doc__ = func.__doc__ + wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint + + decorated = auth.require(auth.Requirements.committer)(wrapper) + _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["GET"]) + _register(func) + + return decorated + + return decorator + + +def committer_project_version(path: str) -> Callable[[web.CommitterProjectVersionHandler[Any]], web.RouteFunction[Any]]: + def decorator(func: web.CommitterProjectVersionHandler[Any]) -> web.RouteFunction[Any]: + async def wrapper(*_args: Any, **kwargs: Any) -> Any: + enhanced_session = await _authenticate() + project_name = await _validate_project(kwargs.pop("project_name")) + version_name = await _validate_version(project_name, kwargs.pop("version_name")) + unsafe_kwargs: dict[str, taint.UnsafeStr] = {k: taint.UnsafeStr(v) for k, v in kwargs.items()} + + start_time_ns = time.perf_counter_ns() + response = await func( + enhanced_session, project_name=project_name, version_name=version_name, **unsafe_kwargs + ) + end_time_ns = time.perf_counter_ns() + total_ns = end_time_ns - start_time_ns + total_ms = total_ns // 1_000_000 + + log.performance( + f"GET {path} {func.__name__} = 0 0 {total_ms}", + ) + + return response + + endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ + wrapper.__name__ = func.__name__ + wrapper.__doc__ = func.__doc__ + wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint + + decorated = auth.require(auth.Requirements.committer)(wrapper) + _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["GET"]) + _register(func) return decorated @@ -87,9 +192,7 @@ def public(path: str) -> Callable[[Callable[..., Awaitable[Any]]], web.RouteFunc wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper, methods=["GET"]) - - module_name = func.__module__.split(".")[-1] - _routes.append(f"get.{module_name}.{func.__name__}") + _register(func) return wrapper diff --git a/atr/blueprints/post.py b/atr/blueprints/post.py index 8c9133a0..052136a9 100644 --- a/atr/blueprints/post.py +++ b/atr/blueprints/post.py @@ -27,9 +27,14 @@ import asfquart.session import pydantic import quart +import atr.cache as cache +import atr.db as db import atr.form import atr.ldap as ldap import atr.log as log +import atr.models.sql as sql +import atr.taint as taint +import atr.validated as validated import atr.web as web _BLUEPRINT_NAME = "post_blueprint" @@ -37,17 +42,50 @@ _BLUEPRINT = quart.Blueprint(_BLUEPRINT_NAME, __name__) _routes: list[str] = [] +async def _authenticate() -> web.Committer: + web_session = await asfquart.session.read() + if web_session is None: + raise base.ASFQuartException("Not authenticated", errorcode=401) + if (web_session.uid is None) or (not await ldap.is_active(web_session.uid)): + asfquart.session.clear() + raise base.ASFQuartException("Account is disabled", errorcode=401) + return web.Committer(web_session) + + +def _register(func: Callable[..., Any]) -> None: + module_name = func.__module__.split(".")[-1] + _routes.append(f"post.{module_name}.{func.__name__}") + + +async def _validate_project(raw: str) -> validated.ProjectName: + if cache.project_version_has_project(raw): + return validated.ProjectName(raw) + async with db.session() as data: + project = await data.project(name=raw, status=sql.ProjectStatus.ACTIVE, _committee=False).get() + if project is None: + raise base.ASFQuartException(f"Project {raw!r} not found", errorcode=404) + return validated.ProjectName(project.name) + + +async def _validate_version(project_name: str, raw: str) -> validated.VersionName: + if cache.project_version_has_version(project_name, raw): + return validated.VersionName(raw) + async with db.session() as data: + release = await data.release( + project_name=project_name, + version=raw, + _project=False, + _committee=False, + ).get() + if release is None: + raise base.ASFQuartException(f"Version {raw!r} not found for project {project_name!r}", errorcode=404) + return validated.VersionName(release.version) + + def committer(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.RouteFunction[Any]]: def decorator(func: web.CommitterRouteFunction[Any]) -> web.RouteFunction[Any]: async def wrapper(*args: Any, **kwargs: Any) -> Any: - web_session = await asfquart.session.read() - if web_session is None: - raise base.ASFQuartException("Not authenticated", errorcode=401) - if (web_session.uid is None) or (not await ldap.is_active(web_session.uid)): - asfquart.session.clear() - raise base.ASFQuartException("Account is disabled", errorcode=401) - - enhanced_session = web.Committer(web_session) + enhanced_session = await _authenticate() start_time_ns = time.perf_counter_ns() response = await func(enhanced_session, *args, **kwargs) end_time_ns = time.perf_counter_ns() @@ -68,9 +106,76 @@ def committer(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.Rout decorated = auth.require(auth.Requirements.committer)(wrapper) _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["POST"]) + _register(func) + + return decorated + + return decorator + + +def committer_project(path: str) -> Callable[[web.CommitterProjectHandler[Any]], web.RouteFunction[Any]]: + def decorator(func: web.CommitterProjectHandler[Any]) -> web.RouteFunction[Any]: + async def wrapper(*_args: Any, **kwargs: Any) -> Any: + enhanced_session = await _authenticate() + project_name = await _validate_project(kwargs.pop("project_name")) + unsafe_kwargs: dict[str, taint.UnsafeStr] = {k: taint.UnsafeStr(v) for k, v in kwargs.items()} - module_name = func.__module__.split(".")[-1] - _routes.append(f"post.{module_name}.{func.__name__}") + start_time_ns = time.perf_counter_ns() + response = await func(enhanced_session, project_name=project_name, **unsafe_kwargs) + end_time_ns = time.perf_counter_ns() + total_ns = end_time_ns - start_time_ns + total_ms = total_ns // 1_000_000 + + log.performance( + f"POST {path} {func.__name__} = 0 0 {total_ms}", + ) + + return response + + endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ + wrapper.__name__ = func.__name__ + wrapper.__doc__ = func.__doc__ + wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint + + decorated = auth.require(auth.Requirements.committer)(wrapper) + _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["POST"]) + _register(func) + + return decorated + + return decorator + + +def committer_project_version(path: str) -> Callable[[web.CommitterProjectVersionHandler[Any]], web.RouteFunction[Any]]: + def decorator(func: web.CommitterProjectVersionHandler[Any]) -> web.RouteFunction[Any]: + async def wrapper(*_args: Any, **kwargs: Any) -> Any: + enhanced_session = await _authenticate() + project_name = await _validate_project(kwargs.pop("project_name")) + version_name = await _validate_version(project_name, kwargs.pop("version_name")) + unsafe_kwargs: dict[str, taint.UnsafeStr] = {k: taint.UnsafeStr(v) for k, v in kwargs.items()} + + start_time_ns = time.perf_counter_ns() + response = await func( + enhanced_session, project_name=project_name, version_name=version_name, **unsafe_kwargs + ) + end_time_ns = time.perf_counter_ns() + total_ns = end_time_ns - start_time_ns + total_ms = total_ns // 1_000_000 + + log.performance( + f"POST {path} {func.__name__} = 0 0 {total_ms}", + ) + + return response + + endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ + wrapper.__name__ = func.__name__ + wrapper.__doc__ = func.__doc__ + wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint + + decorated = auth.require(auth.Requirements.committer)(wrapper) + _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["POST"]) + _register(func) return decorated @@ -187,9 +292,7 @@ def public(path: str) -> Callable[[Callable[..., Awaitable[Any]]], web.RouteFunc wrapper.__name__ = func.__name__ _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper, methods=["POST"]) - - module_name = func.__module__.split(".")[-1] - _routes.append(f"post.{module_name}.{func.__name__}") + _register(func) return wrapper diff --git a/atr/cache.py b/atr/cache.py index f652a02d..b6be4a85 100644 --- a/atr/cache.py +++ b/atr/cache.py @@ -32,6 +32,8 @@ import atr.models.schema as schema # Fifth prime after 3600 ADMINS_POLL_INTERVAL_SECONDS: Final[int] = 3631 +PROJECT_VERSION_POLL_INTERVAL_SECONDS: Final[int] = 307 + class AdminsCache(schema.Strict): refreshed: datetime.datetime = schema.description("When the cache was last refreshed") @@ -92,6 +94,39 @@ async def admins_startup_load() -> None: log.warning(f"Failed to fetch admin users from LDAP at startup: {e}") +def project_version_get() -> dict[str, set[str]]: + if asfquart.APP is not None: + return asfquart.APP.extensions.get("project_versions", {}) + return {} + + +def project_version_has_project(project_name: str) -> bool: + return project_name in project_version_get() + + +def project_version_has_version(project_name: str, version_name: str) -> bool: + projects = project_version_get() + if project_name not in projects: + return False + return version_name in projects[project_name] + + +async def project_version_refresh_loop() -> None: + while True: + await asyncio.sleep(PROJECT_VERSION_POLL_INTERVAL_SECONDS) + try: + await _project_version_refresh() + except Exception as e: + log.warning(f"Project/version cache refresh failed: {e}") + + +async def project_version_startup_load() -> None: + try: + await _project_version_refresh() + except Exception as e: + log.warning(f"Failed to populate project/version cache at startup: {e}") + + def _admins_path() -> pathlib.Path: return pathlib.Path(config.get().STATE_DIR) / "cache" / "admins.json" @@ -134,3 +169,29 @@ def _admins_update_app_extensions(admins: frozenset[str]) -> None: app = asfquart.APP app.extensions["admins"] = admins app.extensions["admins_refreshed"] = datetime.datetime.now(datetime.UTC) + + +async def _project_version_fetch_from_db() -> dict[str, set[str]]: + import atr.db as db + import atr.models.sql as sql + + projects: dict[str, set[str]] = {} + async with db.session() as data: + all_projects = await data.project(status=sql.ProjectStatus.ACTIVE, _committee=False).all() + for project in all_projects: + all_releases = await data.release(project_name=project.name, _project=False, _committee=False).all() + projects[project.name] = {release.version for release in all_releases} + return projects + + +async def _project_version_refresh() -> None: + projects = await _project_version_fetch_from_db() + _project_version_update_app_extensions(projects) + total_versions = sum(len(v) for v in projects.values()) + log.info(f"Project/version cache refreshed: {len(projects)} projects, {total_versions} versions") + + +def _project_version_update_app_extensions(projects: dict[str, set[str]]) -> None: + app = asfquart.APP + app.extensions["project_versions"] = projects + app.extensions["project_versions_refreshed"] = datetime.datetime.now(datetime.UTC) diff --git a/atr/get/announce.py b/atr/get/announce.py index 235c8463..ee5ea9d6 100644 --- a/atr/get/announce.py +++ b/atr/get/announce.py @@ -30,13 +30,20 @@ import atr.models.sql as sql import atr.post as post import atr.render as render import atr.shared as shared +import atr.taint as taint import atr.template as template import atr.util as util +import atr.validated as validated import atr.web as web [email protected]("/announce/<project_name>/<version_name>") -async def selected(session: web.Committer, project_name: str, version_name: str) -> str | web.WerkzeugResponse: [email protected]_project_version("/announce/<project_name>/<version_name>") +async def selected( + session: web.Committer, + project_name: str, + version_name: validated.VersionName, + **_kwargs: taint.UnsafeStr, +) -> str | web.WerkzeugResponse: """Allow the user to announce a release preview.""" await session.check_access(project_name) diff --git a/atr/get/checks.py b/atr/get/checks.py index 8faa5ce2..dfd8538b 100644 --- a/atr/get/checks.py +++ b/atr/get/checks.py @@ -40,6 +40,7 @@ import atr.render as render import atr.shared as shared import atr.shared.draft as draft import atr.storage as storage +import atr.taint as taint import atr.template as template import atr.util as util import atr.web as web @@ -134,12 +135,13 @@ async def selected(session: web.Committer | None, project_name: str, version_nam ) [email protected]("/checks/<project_name>/<version_name>/<revision_number>") [email protected]_project_version("/checks/<project_name>/<version_name>/<revision_number>") async def selected_revision( session: web.Committer, project_name: str, version_name: str, - revision_number: str, + revision_number: taint.UnsafeStr, + **_kwargs: taint.UnsafeStr, ) -> web.QuartResponse: """Return JSON with ongoing count and HTML fragments for dynamic updates.""" async with db.session() as data: diff --git a/atr/get/compose.py b/atr/get/compose.py index 6e400872..dcb51979 100644 --- a/atr/get/compose.py +++ b/atr/get/compose.py @@ -22,11 +22,18 @@ import atr.db as db import atr.mapping as mapping import atr.models.sql as sql import atr.shared as shared +import atr.taint as taint +import atr.validated as validated import atr.web as web [email protected]("/compose/<project_name>/<version_name>") -async def selected(session: web.Committer, project_name: str, version_name: str) -> web.WerkzeugResponse | str: [email protected]_project_version("/compose/<project_name>/<version_name>") +async def selected( + session: web.Committer, + project_name: validated.ProjectName, + version_name: validated.VersionName, + **_kwargs: taint.UnsafeStr, +) -> web.WerkzeugResponse | str: """Show the contents of the release candidate draft.""" await session.check_access(project_name) diff --git a/atr/get/distribution.py b/atr/get/distribution.py index 4bbf9225..deccad44 100644 --- a/atr/get/distribution.py +++ b/atr/get/distribution.py @@ -33,12 +33,12 @@ import atr.web as web from atr.tasks import gha [email protected]("/distribution/automate/<project>/<version>") [email protected]_project_version("/distribution/automate/<project>/<version>") async def automate(session: web.Committer, project: str, version: str) -> str: return await _automate_form_page(project, version, staging=False) [email protected]("/distributions/list/<project_name>/<version_name>") [email protected]_project_version("/distributions/list/<project_name>/<version_name>") async def list_get(session: web.Committer, project_name: str, version_name: str) -> str: distributions, tasks = await _get_page_data(project_name, version_name) @@ -122,17 +122,17 @@ async def list_get(session: web.Committer, project_name: str, version_name: str) return await template.blank(title, content=block.collect()) [email protected]("/distribution/record/<project>/<version>") [email protected]_project_version("/distribution/record/<project>/<version>") async def record(session: web.Committer, project: str, version: str) -> str: return await _record_form_page(project, version, staging=False) [email protected]("/distribution/stage/automate/<project>/<version>") [email protected]_project_version("/distribution/stage/automate/<project>/<version>") async def stage_automate(session: web.Committer, project: str, version: str) -> str: return await _automate_form_page(project, version, staging=True) [email protected]("/distribution/stage/record/<project>/<version>") [email protected]_project_version("/distribution/stage/record/<project>/<version>") async def stage_record(session: web.Committer, project: str, version: str) -> str: return await _record_form_page(project, version, staging=True) diff --git a/atr/get/download.py b/atr/get/download.py index bb8d5145..22a3b909 100644 --- a/atr/get/download.py +++ b/atr/get/download.py @@ -37,7 +37,7 @@ import atr.util as util import atr.web as web [email protected]("/download/all/<project_name>/<version_name>") [email protected]_project_version("/download/all/<project_name>/<version_name>") async def all_selected(session: web.Committer, project_name: str, version_name: str) -> web.WerkzeugResponse | str: """Display download commands for a release.""" import atr.get.root as root @@ -108,7 +108,7 @@ async def urls_selected(session: web.Committer | None, project_name: str, versio return web.TextResponse(f"Internal server error: {e}", status=500) [email protected]("/download/zip/<project_name>/<version_name>") [email protected]_project_version("/download/zip/<project_name>/<version_name>") async def zip_selected(session: web.Committer, project_name: str, version_name: str) -> web.Response: try: release = await session.release(project_name=project_name, version_name=version_name, phase=None) diff --git a/atr/get/draft.py b/atr/get/draft.py index 16a091ac..0777aa91 100644 --- a/atr/get/draft.py +++ b/atr/get/draft.py @@ -32,7 +32,7 @@ import atr.util as util import atr.web as web [email protected]("/draft/tools/<project_name>/<version_name>/<path:file_path>") [email protected]_project_version("/draft/tools/<project_name>/<version_name>/<path:file_path>") async def tools(session: web.Committer, project_name: str, version_name: str, file_path: str) -> str: """Show the tools for a specific file.""" await session.check_access(project_name) diff --git a/atr/get/file.py b/atr/get/file.py index 2c004b55..bcd427cb 100644 --- a/atr/get/file.py +++ b/atr/get/file.py @@ -33,7 +33,7 @@ import atr.web as web type Phase = Literal["COMPOSE", "VOTE", "FINISH"] [email protected]("/file/<project_name>/<version_name>") [email protected]_project_version("/file/<project_name>/<version_name>") async def selected(session: web.Committer, project_name: str, version_name: str) -> str: """View all the files in a release (any phase).""" await session.check_access(project_name) @@ -125,7 +125,7 @@ async def selected(session: web.Committer, project_name: str, version_name: str) return await template.blank(f"Files in {release.short_display_name}", content=block.collect()) [email protected]("/file/<project_name>/<version_name>/<path:file_path>") [email protected]_project_version("/file/<project_name>/<version_name>/<path:file_path>") async def selected_path(session: web.Committer, project_name: str, version_name: str, file_path: str) -> str: """View the content of a specific file in a release (any phase).""" await session.check_access(project_name) diff --git a/atr/get/finish.py b/atr/get/finish.py index ac95b2c1..4269ecdc 100644 --- a/atr/get/finish.py +++ b/atr/get/finish.py @@ -56,7 +56,7 @@ class RCTagAnalysisResult: total_paths: int [email protected]("/finish/<project_name>/<version_name>") [email protected]_project_version("/finish/<project_name>/<version_name>") async def selected( session: web.Committer, project_name: str, version_name: str ) -> tuple[web.QuartResponse, int] | web.WerkzeugResponse | str: diff --git a/atr/get/ignores.py b/atr/get/ignores.py index 5c741d37..063b0ed8 100644 --- a/atr/get/ignores.py +++ b/atr/get/ignores.py @@ -27,7 +27,7 @@ import atr.util as util import atr.web as web [email protected]("/ignores/<project_name>") [email protected]_project("/ignores/<project_name>") async def ignores(session: web.Committer, project_name: str) -> str | web.WerkzeugResponse: async with storage.read() as read: ragp = read.as_general_public() diff --git a/atr/get/manual.py b/atr/get/manual.py index ee88a1ea..df67fc20 100644 --- a/atr/get/manual.py +++ b/atr/get/manual.py @@ -31,7 +31,7 @@ import atr.util as util import atr.web as web [email protected]("/manual/resolve/<project_name>/<version_name>") [email protected]_project_version("/manual/resolve/<project_name>/<version_name>") async def resolve_selected(session: web.Committer, project_name: str, version_name: str) -> str: """Get the manual vote resolution page.""" await session.check_access(project_name) @@ -55,7 +55,7 @@ async def resolve_selected(session: web.Committer, project_name: str, version_na ) [email protected]("/manual/start/<project_name>/<version_name>/<revision>") [email protected]_project_version("/manual/start/<project_name>/<version_name>/<revision>") async def start_selected_revision( session: web.Committer, project_name: str, version_name: str, revision: str ) -> web.WerkzeugResponse | str: diff --git a/atr/get/release.py b/atr/get/release.py index 10a7061b..81a8efc0 100644 --- a/atr/get/release.py +++ b/atr/get/release.py @@ -77,7 +77,7 @@ async def releases(session: web.Committer | None) -> str: ) [email protected]("/release/select/<project_name>") [email protected]_project("/release/select/<project_name>") async def select(session: web.Committer, project_name: str) -> str: """Show releases in progress for a project.""" await session.check_access(project_name) diff --git a/atr/get/report.py b/atr/get/report.py index a3d194c9..6cba3443 100644 --- a/atr/get/report.py +++ b/atr/get/report.py @@ -30,7 +30,7 @@ import atr.util as util import atr.web as web [email protected]("/report/<project_name>/<version_name>/<path:rel_path>") [email protected]_project_version("/report/<project_name>/<version_name>/<path:rel_path>") async def selected_path(session: web.Committer, project_name: str, version_name: str, rel_path: str) -> str: """Show the report for a specific file.""" await session.check_access(project_name) diff --git a/atr/get/result.py b/atr/get/result.py index 5fccc29c..c651fdfe 100644 --- a/atr/get/result.py +++ b/atr/get/result.py @@ -25,7 +25,7 @@ import atr.models.sql as sql import atr.web as web [email protected]("/result/data/<project_name>/<version_name>/<int:check_id>") [email protected]_project_version("/result/data/<project_name>/<version_name>/<int:check_id>") async def data( session: web.Committer, project_name: str, diff --git a/atr/get/revisions.py b/atr/get/revisions.py index 680227a1..46114ac5 100644 --- a/atr/get/revisions.py +++ b/atr/get/revisions.py @@ -47,7 +47,7 @@ class FilesDiff(schema.Strict): modified: list[pathlib.Path] [email protected]("/revisions/<project_name>/<version_name>") [email protected]_project_version("/revisions/<project_name>/<version_name>") async def selected(session: web.Committer, project_name: str, version_name: str) -> str: """Show the revision history for a release candidate draft or release preview.""" await session.check_access(project_name) diff --git a/atr/get/sbom.py b/atr/get/sbom.py index 9a16a479..f4c35be2 100644 --- a/atr/get/sbom.py +++ b/atr/get/sbom.py @@ -43,7 +43,7 @@ if TYPE_CHECKING: from collections.abc import Sequence [email protected]("/sbom/report/<project>/<version>/<path:file_path>") [email protected]_project_version("/sbom/report/<project>/<version>/<path:file_path>") async def report(session: web.Committer, project: str, version: str, file_path: str) -> str: await session.check_access(project) diff --git a/atr/get/test.py b/atr/get/test.py index ad70cd64..7e20be3c 100644 --- a/atr/get/test.py +++ b/atr/get/test.py @@ -76,7 +76,7 @@ async def test_login(session: web.Committer | None) -> web.WerkzeugResponse: return await web.redirect(root.index) [email protected]("/test/merge/<project_name>/<version_name>") [email protected]_project_version("/test/merge/<project_name>/<version_name>") async def test_merge(session: web.Committer, project_name: str, version_name: str) -> web.WerkzeugResponse: if not config.get().ALLOW_TESTS: raise base.ASFQuartException("Test routes not enabled", errorcode=404) diff --git a/atr/get/upload.py b/atr/get/upload.py index 5e57258d..911acd07 100644 --- a/atr/get/upload.py +++ b/atr/get/upload.py @@ -35,7 +35,7 @@ import atr.util as util import atr.web as web [email protected]("/upload/<project_name>/<version_name>") [email protected]_project_version("/upload/<project_name>/<version_name>") async def selected(session: web.Committer, project_name: str, version_name: str) -> str: await session.check_access(project_name) diff --git a/atr/get/voting.py b/atr/get/voting.py index 665cff66..98205d2c 100644 --- a/atr/get/voting.py +++ b/atr/get/voting.py @@ -38,7 +38,7 @@ import atr.util as util import atr.web as web [email protected]("/voting/<project_name>/<version_name>/<revision>") [email protected]_project_version("/voting/<project_name>/<version_name>/<revision>") async def selected_revision( session: web.Committer, project_name: str, version_name: str, revision: str ) -> web.WerkzeugResponse | str: diff --git a/atr/server.py b/atr/server.py index 6dcee179..fe73716a 100644 --- a/atr/server.py +++ b/atr/server.py @@ -277,6 +277,10 @@ def _app_setup_lifecycle(app: base.QuartApp, app_config: type[config.AppConfig]) admins_task = asyncio.create_task(cache.admins_refresh_loop()) app.extensions["admins_task"] = admins_task + await cache.project_version_startup_load() + project_version_task = asyncio.create_task(cache.project_version_refresh_loop()) + app.extensions["project_version_task"] = project_version_task + worker_manager = manager.get_worker_manager() await worker_manager.start() diff --git a/atr/taint.py b/atr/taint.py new file mode 100644 index 00000000..f6461bd8 --- /dev/null +++ b/atr/taint.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +class UnsafeStr: + """A raw string from URL routing that has not been validated. + + This is NOT a subclass of str, so pyright will refuse to pass it to any + function that expects a str. The only way to get a str out is to validate + it explicitly. + """ + + __slots__ = ("_value",) + + def __init__(self, value: str) -> None: + self._value = value + + def __repr__(self) -> str: + return f"UnsafeStr({self._value!r})" diff --git a/atr/validated.py b/atr/validated.py new file mode 100644 index 00000000..a938ea73 --- /dev/null +++ b/atr/validated.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import NewType + +# Validated route parameter types. These are NewType subtypes of str, so they +# work anywhere str does (DB queries, path construction, f-strings) but pyright +# treats them as distinct nominal types. You cannot pass a ProjectName where a +# VersionName is expected, or vice versa. +ProjectName = NewType("ProjectName", str) +VersionName = NewType("VersionName", str) diff --git a/atr/web.py b/atr/web.py index 4d7d8c9c..b8585fff 100644 --- a/atr/web.py +++ b/atr/web.py @@ -42,6 +42,9 @@ if TYPE_CHECKING: import pydantic import werkzeug.wrappers.response as response + import atr.taint as taint + import atr.validated as validated + R = TypeVar("R", covariant=True) type WerkzeugResponse = response.Response @@ -49,6 +52,32 @@ type QuartResponse = quart.Response type Response = WerkzeugResponse | QuartResponse +class CommitterProjectHandler(Protocol[R]): + """Protocol for @committer_project decorated functions.""" + + __name__: str + __doc__: str | None + + def __call__( + self, session: Committer, project_name: validated.ProjectName, **kwargs: taint.UnsafeStr + ) -> Awaitable[R]: ... + + +class CommitterProjectVersionHandler(Protocol[R]): + """Protocol for @committer_project_version decorated functions.""" + + __name__: str + __doc__: str | None + + def __call__( + self, + session: Committer, + project_name: validated.ProjectName, + version_name: validated.VersionName, + **kwargs: taint.UnsafeStr, + ) -> Awaitable[R]: ... + + class CommitterRouteFunction(Protocol[R]): """Protocol for @committer_get decorated functions.""" --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
