This is an automated email from the ASF dual-hosted git repository.
sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git
The following commit(s) were added to refs/heads/main by this push:
new 69375f7 Add decorators for database sessions
69375f7 is described below
commit 69375f7c8c49f96ef27878211fd6d89b9b745537
Author: Sean B. Palmer <[email protected]>
AuthorDate: Thu Jul 3 18:45:47 2025 +0100
Add decorators for database sessions
---
atr/db/__init__.py | 46 ++++++++++++++++++++++++++++++++++++++++++++--
atr/routes/tokens.py | 36 ++++++++++++++++++------------------
2 files changed, 62 insertions(+), 20 deletions(-)
diff --git a/atr/db/__init__.py b/atr/db/__init__.py
index 0d6ad66..2b7f9ee 100644
--- a/atr/db/__init__.py
+++ b/atr/db/__init__.py
@@ -18,9 +18,11 @@
from __future__ import annotations
import contextlib
+import functools
import logging
import os
-from typing import TYPE_CHECKING, Any, Final, TypeGuard, TypeVar
+from collections.abc import Awaitable, Callable
+from typing import TYPE_CHECKING, Any, Concatenate, Final, TypeGuard, TypeVar
import alembic.command as command
import alembic.config as alembic_config
@@ -39,7 +41,7 @@ import atr.util as util
if TYPE_CHECKING:
import datetime
- from collections.abc import Iterator, Sequence
+ from collections.abc import Awaitable, Callable, Iterator, Sequence
import asfquart.base as base
@@ -51,6 +53,7 @@ _global_atr_sessionmaker:
sqlalchemy.ext.asyncio.async_sessionmaker | None = Non
T = TypeVar("T")
+R = TypeVar("R")
class NotSet:
@@ -375,6 +378,22 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession):
return Query(self, query)
+ async def query_all(self, stmt: sql.Select[Any]) -> list[Any]:
+ result = await self.execute(stmt)
+ return list(result.scalars().all())
+
+ async def query_first(self, stmt: sql.Select[Any]) -> Any | None:
+ result = await self.execute(stmt)
+ return result.scalars().first()
+
+ async def query_one(self, stmt: sql.Select[Any]) -> Any:
+ result = await self.execute(stmt)
+ return result.scalars().one()
+
+ async def query_one_or_none(self, stmt: sql.Select[Any]) -> Any | None:
+ result = await self.execute(stmt)
+ return result.scalars().one_or_none()
+
def release(
self,
name: Opt[str] = NOT_SET,
@@ -800,6 +819,29 @@ def session(log_queries: bool | None = None) -> Session:
return session_instance
+def session_commit_function[**P, R](
+ func: Callable[Concatenate[Session, P], Awaitable[R]],
+) -> Callable[P, Awaitable[R]]:
+ @functools.wraps(func)
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ async with session() as data:
+ async with data.begin():
+ return await func(data, *args, **kwargs)
+
+ return wrapper
+
+
+def session_function[**P, R](
+ func: Callable[Concatenate[Session, P], Awaitable[R]],
+) -> Callable[P, Awaitable[R]]:
+ @functools.wraps(func)
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+ async with session() as data:
+ return await func(data, *args, **kwargs)
+
+ return wrapper
+
+
async def shutdown_database() -> None:
if _global_atr_engine:
_LOGGER.info("Closing database")
diff --git a/atr/routes/tokens.py b/atr/routes/tokens.py
index d9f83f1..998aec4 100644
--- a/atr/routes/tokens.py
+++ b/atr/routes/tokens.py
@@ -186,27 +186,27 @@ async def _create_token(uid: str, label: str | None) ->
str:
return plaintext
-async def _delete_token(uid: str, token_id: int) -> None:
- async with db.session() as data:
- async with data.begin():
- stmt = sqlmodel.select(models.PersonalAccessToken).where(
- models.PersonalAccessToken.id == token_id,
- models.PersonalAccessToken.asfuid == uid,
- )
- pat = (await data.execute(stmt)).scalar_one_or_none()
- if pat:
- await data.delete(pat)
[email protected]_commit_function
+async def _delete_token(data: db.Session, uid: str, token_id: int) -> None:
+ pat = await data.query_one_or_none(
+ sqlmodel.select(models.PersonalAccessToken).where(
+ models.PersonalAccessToken.id == token_id,
+ models.PersonalAccessToken.asfuid == uid,
+ )
+ )
+ if pat:
+ await data.delete(pat)
-async def _fetch_tokens(uid: str) -> list[models.PersonalAccessToken]:
[email protected]_function
+async def _fetch_tokens(data: db.Session, uid: str) ->
list[models.PersonalAccessToken]:
via = models.validate_instrumented_attribute
- async with db.session() as data:
- stmt = (
- sqlmodel.select(models.PersonalAccessToken)
- .where(models.PersonalAccessToken.asfuid == uid)
- .order_by(via(models.PersonalAccessToken.created))
- )
- return list((await data.execute(stmt)).scalars())
+ stmt = (
+ sqlmodel.select(models.PersonalAccessToken)
+ .where(models.PersonalAccessToken.asfuid == uid)
+ .order_by(via(models.PersonalAccessToken.created))
+ )
+ return await data.query_all(stmt)
async def _handle_post(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]