This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new 69375f7  Add decorators for database sessions
69375f7 is described below

commit 69375f7c8c49f96ef27878211fd6d89b9b745537
Author: Sean B. Palmer <[email protected]>
AuthorDate: Thu Jul 3 18:45:47 2025 +0100

    Add decorators for database sessions
---
 atr/db/__init__.py   | 46 ++++++++++++++++++++++++++++++++++++++++++++--
 atr/routes/tokens.py | 36 ++++++++++++++++++------------------
 2 files changed, 62 insertions(+), 20 deletions(-)

diff --git a/atr/db/__init__.py b/atr/db/__init__.py
index 0d6ad66..2b7f9ee 100644
--- a/atr/db/__init__.py
+++ b/atr/db/__init__.py
@@ -18,9 +18,11 @@
 from __future__ import annotations
 
 import contextlib
+import functools
 import logging
 import os
-from typing import TYPE_CHECKING, Any, Final, TypeGuard, TypeVar
+from collections.abc import Awaitable, Callable
+from typing import TYPE_CHECKING, Any, Concatenate, Final, TypeGuard, TypeVar
 
 import alembic.command as command
 import alembic.config as alembic_config
@@ -39,7 +41,7 @@ import atr.util as util
 
 if TYPE_CHECKING:
     import datetime
-    from collections.abc import Iterator, Sequence
+    from collections.abc import Awaitable, Callable, Iterator, Sequence
 
     import asfquart.base as base
 
@@ -51,6 +53,7 @@ _global_atr_sessionmaker: 
sqlalchemy.ext.asyncio.async_sessionmaker | None = Non
 
 
 T = TypeVar("T")
+R = TypeVar("R")
 
 
 class NotSet:
@@ -375,6 +378,22 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession):
 
         return Query(self, query)
 
+    async def query_all(self, stmt: sql.Select[Any]) -> list[Any]:
+        result = await self.execute(stmt)
+        return list(result.scalars().all())
+
+    async def query_first(self, stmt: sql.Select[Any]) -> Any | None:
+        result = await self.execute(stmt)
+        return result.scalars().first()
+
+    async def query_one(self, stmt: sql.Select[Any]) -> Any:
+        result = await self.execute(stmt)
+        return result.scalars().one()
+
+    async def query_one_or_none(self, stmt: sql.Select[Any]) -> Any | None:
+        result = await self.execute(stmt)
+        return result.scalars().one_or_none()
+
     def release(
         self,
         name: Opt[str] = NOT_SET,
@@ -800,6 +819,29 @@ def session(log_queries: bool | None = None) -> Session:
     return session_instance
 
 
+def session_commit_function[**P, R](
+    func: Callable[Concatenate[Session, P], Awaitable[R]],
+) -> Callable[P, Awaitable[R]]:
+    @functools.wraps(func)
+    async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+        async with session() as data:
+            async with data.begin():
+                return await func(data, *args, **kwargs)
+
+    return wrapper
+
+
+def session_function[**P, R](
+    func: Callable[Concatenate[Session, P], Awaitable[R]],
+) -> Callable[P, Awaitable[R]]:
+    @functools.wraps(func)
+    async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+        async with session() as data:
+            return await func(data, *args, **kwargs)
+
+    return wrapper
+
+
 async def shutdown_database() -> None:
     if _global_atr_engine:
         _LOGGER.info("Closing database")
diff --git a/atr/routes/tokens.py b/atr/routes/tokens.py
index d9f83f1..998aec4 100644
--- a/atr/routes/tokens.py
+++ b/atr/routes/tokens.py
@@ -186,27 +186,27 @@ async def _create_token(uid: str, label: str | None) -> 
str:
     return plaintext
 
 
-async def _delete_token(uid: str, token_id: int) -> None:
-    async with db.session() as data:
-        async with data.begin():
-            stmt = sqlmodel.select(models.PersonalAccessToken).where(
-                models.PersonalAccessToken.id == token_id,
-                models.PersonalAccessToken.asfuid == uid,
-            )
-            pat = (await data.execute(stmt)).scalar_one_or_none()
-            if pat:
-                await data.delete(pat)
[email protected]_commit_function
+async def _delete_token(data: db.Session, uid: str, token_id: int) -> None:
+    pat = await data.query_one_or_none(
+        sqlmodel.select(models.PersonalAccessToken).where(
+            models.PersonalAccessToken.id == token_id,
+            models.PersonalAccessToken.asfuid == uid,
+        )
+    )
+    if pat:
+        await data.delete(pat)
 
 
-async def _fetch_tokens(uid: str) -> list[models.PersonalAccessToken]:
[email protected]_function
+async def _fetch_tokens(data: db.Session, uid: str) -> 
list[models.PersonalAccessToken]:
     via = models.validate_instrumented_attribute
-    async with db.session() as data:
-        stmt = (
-            sqlmodel.select(models.PersonalAccessToken)
-            .where(models.PersonalAccessToken.asfuid == uid)
-            .order_by(via(models.PersonalAccessToken.created))
-        )
-        return list((await data.execute(stmt)).scalars())
+    stmt = (
+        sqlmodel.select(models.PersonalAccessToken)
+        .where(models.PersonalAccessToken.asfuid == uid)
+        .order_by(via(models.PersonalAccessToken.created))
+    )
+    return await data.query_all(stmt)
 
 
 async def _handle_post(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to