This is an automated email from the ASF dual-hosted git repository.

tn pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d6c3f8  further cleanup database initialization, add shutdown and 
reliably shutdown the database for app and workers
3d6c3f8 is described below

commit 3d6c3f85c1b7c2c8c2d0dae20f78160ea51697af
Author: Thomas Neidhart <[email protected]>
AuthorDate: Tue Apr 1 15:41:39 2025 +0200

    further cleanup database initialization, add shutdown and reliably shutdown 
the database for app and workers
---
 atr/db/__init__.py | 88 ++++++++++++++++++++++++++++--------------------------
 atr/manager.py     | 46 ++++++++++++++--------------
 atr/routes/dev.py  |  8 ++---
 atr/server.py      |  2 ++
 atr/worker.py      | 39 ++++++++++++++----------
 5 files changed, 98 insertions(+), 85 deletions(-)

diff --git a/atr/db/__init__.py b/atr/db/__init__.py
index 92ab0d4..010c0d7 100644
--- a/atr/db/__init__.py
+++ b/atr/db/__init__.py
@@ -21,7 +21,6 @@ import logging
 import os
 from typing import TYPE_CHECKING, Any, Final, Generic, TypeGuard, TypeVar
 
-import quart
 import sqlalchemy
 import sqlalchemy.ext.asyncio
 import sqlalchemy.orm as orm
@@ -41,9 +40,8 @@ if TYPE_CHECKING:
 
 _LOGGER: Final = logging.getLogger(__name__)
 
-_global_async_sessionmaker: sqlalchemy.ext.asyncio.async_sessionmaker | None = 
None
+_global_atr_engine: sqlalchemy.ext.asyncio.AsyncEngine | None = None
 _global_atr_sessionmaker: sqlalchemy.ext.asyncio.async_sessionmaker | None = 
None
-_global_sync_engine: sqlalchemy.Engine | None = None
 
 
 T = TypeVar("T")
@@ -481,25 +479,16 @@ def init_database(app: base.QuartApp) -> None:
 
     @app.before_serving
     async def create() -> None:
+        global _global_atr_engine, _global_atr_sessionmaker
+
         app_config = config.get()
-        engine = create_async_engine(app_config)
+        engine = await create_async_engine(app_config)
+        _global_atr_engine = engine
 
-        app.extensions["async_session"] = 
sqlalchemy.ext.asyncio.async_sessionmaker(
-            bind=engine, class_=sqlalchemy.ext.asyncio.AsyncSession, 
expire_on_commit=False
-        )
-        app.extensions["atr_db_session"] = 
sqlalchemy.ext.asyncio.async_sessionmaker(
+        _global_atr_sessionmaker = sqlalchemy.ext.asyncio.async_sessionmaker(
             bind=engine, class_=Session, expire_on_commit=False
         )
 
-        # Set SQLite pragmas for better performance
-        # Use 64 MB for the cache_size, and 5000ms for busy_timeout
-        async with engine.begin() as conn:
-            await conn.execute(sql.text("PRAGMA journal_mode=WAL"))
-            await conn.execute(sql.text("PRAGMA synchronous=NORMAL"))
-            await conn.execute(sql.text("PRAGMA cache_size=-64000"))
-            await conn.execute(sql.text("PRAGMA foreign_keys=ON"))
-            await conn.execute(sql.text("PRAGMA busy_timeout=5000"))
-
         # Run any pending migrations
         # In dev we'd do this first:
         # poetry run alembic revision --autogenerate -m "description"
@@ -519,17 +508,26 @@ def init_database(app: base.QuartApp) -> None:
             await conn.run_sync(sqlmodel.SQLModel.metadata.create_all)
 
 
-def init_database_for_worker() -> None:
-    global _global_async_sessionmaker
+async def init_database_for_worker() -> None:
+    global _global_atr_engine, _global_atr_sessionmaker
 
     _LOGGER.info(f"Creating database for worker {os.getpid()}")
-    engine = create_async_engine(config.get())
-    _global_async_sessionmaker = sqlalchemy.ext.asyncio.async_sessionmaker(
-        bind=engine, class_=sqlalchemy.ext.asyncio.AsyncSession, 
expire_on_commit=False
+    engine = await create_async_engine(config.get())
+    _global_atr_engine = engine
+    _global_atr_sessionmaker = sqlalchemy.ext.asyncio.async_sessionmaker(
+        bind=engine, class_=Session, expire_on_commit=False
     )
 
 
-def create_async_engine(app_config: type[config.AppConfig]) -> 
sqlalchemy.ext.asyncio.AsyncEngine:
+async def shutdown_database() -> None:
+    if _global_atr_engine:
+        _LOGGER.info("Closing database")
+        await _global_atr_engine.dispose()
+    else:
+        _LOGGER.info("No database to close")
+
+
+async def create_async_engine(app_config: type[config.AppConfig]) -> 
sqlalchemy.ext.asyncio.AsyncEngine:
     sqlite_url = f"sqlite+aiosqlite://{app_config.SQLITE_DB_PATH}"
     # Use aiosqlite for async SQLite access
     engine = sqlalchemy.ext.asyncio.create_async_engine(
@@ -540,18 +538,16 @@ def create_async_engine(app_config: 
type[config.AppConfig]) -> sqlalchemy.ext.as
         },
     )
 
-    return engine
-
+    # Set SQLite pragmas for better performance
+    # Use 64 MB for the cache_size, and 5000ms for busy_timeout
+    async with engine.begin() as conn:
+        await conn.execute(sql.text("PRAGMA journal_mode=WAL"))
+        await conn.execute(sql.text("PRAGMA synchronous=NORMAL"))
+        await conn.execute(sql.text("PRAGMA cache_size=-64000"))
+        await conn.execute(sql.text("PRAGMA foreign_keys=ON"))
+        await conn.execute(sql.text("PRAGMA busy_timeout=5000"))
 
-def create_async_db_session() -> sqlalchemy.ext.asyncio.AsyncSession:
-    """Create a new asynchronous database session."""
-    if quart.has_app_context():
-        extensions = quart.current_app.extensions
-        return util.validate_as_type(extensions["async_session"](), 
sqlalchemy.ext.asyncio.AsyncSession)
-    else:
-        if _global_async_sessionmaker is None:
-            raise RuntimeError("Global async_sessionmaker not initialized, run 
init_database() first.")
-        return util.validate_as_type(_global_async_sessionmaker(), 
sqlalchemy.ext.asyncio.AsyncSession)
+    return engine
 
 
 async def recent_tasks(data: Session, release_name: str, file_path: str, 
modified: int) -> dict[str, models.Task]:
@@ -598,17 +594,23 @@ def select_in_load_nested(parent: Any, *descendants: Any) 
-> orm.strategy_option
 
 def session() -> Session:
     """Create a new asynchronous database session."""
-    global _global_atr_sessionmaker
 
-    if quart.has_app_context():
-        extensions = quart.current_app.extensions
-        return util.validate_as_type(extensions["atr_db_session"](), Session)
+    # FIXME: occasionally you see this in the console output
+    # <sys>:0: SAWarning: The garbage collector is trying to clean up 
non-checked-in connection <AdaptedConnection
+    # <Connection(Thread-291, started daemon 138838634661440)>>, which will be 
dropped, as it cannot be safely
+    # terminated. Please ensure that SQLAlchemy pooled connections are 
returned to the pool explicitly, either by
+    # calling ``close()`` or by using appropriate context managers to manage 
their lifecycle.
+
+    # Not fully clear where this is coming from, but we could experiment by 
returning a session like that:
+    # async def session() -> AsyncGenerator[Session, None]:
+    #     async with _global_atr_sessionmaker() as session:
+    #         yield session
+
+    # from FastAPI documentation: 
https://fastapi-users.github.io/fastapi-users/latest/configuration/databases/sqlalchemy/
+
+    if _global_atr_sessionmaker is None:
+        raise RuntimeError("database not initialized")
     else:
-        if _global_atr_sessionmaker is None:
-            engine = create_async_engine(config.get())
-            _global_atr_sessionmaker = 
sqlalchemy.ext.asyncio.async_sessionmaker(
-                bind=engine, class_=Session, expire_on_commit=False
-            )
         return util.validate_as_type(_global_atr_sessionmaker(), Session)
 
 
diff --git a/atr/manager.py b/atr/manager.py
index c664e76..77e1973 100644
--- a/atr/manager.py
+++ b/atr/manager.py
@@ -206,18 +206,19 @@ class WorkerManager:
         """Check worker processes and restart if needed."""
         exited_workers = []
 
-        # Check each worker first
-        for pid, worker in list(self.workers.items()):
-            # Check if process is running
-            if not await worker.is_running():
-                exited_workers.append(pid)
-                _LOGGER.info(f"Worker {pid} has exited")
-                continue
-
-            # Check if worker has been processing its task for too long
-            # This also stops tasks if they have indeed been running for too 
long
-            if await self.check_task_duration(pid, worker):
-                exited_workers.append(pid)
+        async with db.session() as data:
+            # Check each worker first
+            for pid, worker in list(self.workers.items()):
+                # Check if process is running
+                if not await worker.is_running():
+                    exited_workers.append(pid)
+                    _LOGGER.info(f"Worker {pid} has exited")
+                    continue
+
+                # Check if worker has been processing its task for too long
+                # This also stops tasks if they have indeed been running for 
too long
+                if await self.check_task_duration(data, pid, worker):
+                    exited_workers.append(pid)
 
         # Remove exited workers
         for pid in exited_workers:
@@ -265,24 +266,23 @@ class WorkerManager:
         except Exception as e:
             _LOGGER.error(f"Error stopping long-running worker {pid}: {e}")
 
-    async def check_task_duration(self, pid: int, worker: WorkerProcess) -> 
bool:
+    async def check_task_duration(self, data: db.Session, pid: int, worker: 
WorkerProcess) -> bool:
         """
         Check if a worker has been processing its task for too long.
         Returns True if the worker has been terminated.
         """
         try:
-            async with db.session() as data:
-                async with data.begin():
-                    task = await data.task(pid=pid, 
status=models.TaskStatus.ACTIVE).get()
-                    if not task or not task.started:
-                        return False
+            async with data.begin():
+                task = await data.task(pid=pid, 
status=models.TaskStatus.ACTIVE).get()
+                if not task or not task.started:
+                    return False
 
-                    task_duration = (datetime.datetime.now(datetime.UTC) - 
task.started).total_seconds()
-                    if task_duration > self.max_task_seconds:
-                        await self.terminate_long_running_task(task, worker, 
task.id, pid)
-                        return True
+                task_duration = (datetime.datetime.now(datetime.UTC) - 
task.started).total_seconds()
+                if task_duration > self.max_task_seconds:
+                    await self.terminate_long_running_task(task, worker, 
task.id, pid)
+                    return True
 
-                    return False
+                return False
         except Exception as e:
             _LOGGER.error(f"Error checking task duration for worker {pid}: 
{e}")
             # TODO: Return False here to avoid over-reporting errors
diff --git a/atr/routes/dev.py b/atr/routes/dev.py
index 32b94af..423266f 100644
--- a/atr/routes/dev.py
+++ b/atr/routes/dev.py
@@ -54,16 +54,16 @@ async def send_email(session: routes.CommitterSession) -> 
quart.ResponseReturnVa
             )
 
         # Create a task for mail testing
-        async with db.create_async_db_session() as db_session:
-            async with db_session.begin():
+        async with db.session() as data:
+            async with data.begin():
                 task = models.Task(
                     status=models.TaskStatus.QUEUED,
                     task_type="mailtest_send",
                     task_args=[name, email, token],
                 )
-                db_session.add(task)
+                data.add(task)
                 # Flush to get the task ID
-                await db_session.flush()
+                await data.flush()
 
         return await quart.render_template(
             "dev-send-email.html",
diff --git a/atr/server.py b/atr/server.py
index e62d89b..be9b6d1 100644
--- a/atr/server.py
+++ b/atr/server.py
@@ -170,6 +170,8 @@ def app_setup_lifecycle(app: base.QuartApp) -> None:
         if ssh_server:
             await ssh.server_stop(ssh_server)
 
+        await db.shutdown_database()
+
         app.background_tasks.clear()
 
 
diff --git a/atr/worker.py b/atr/worker.py
index 04f6cbd..65791a4 100644
--- a/atr/worker.py
+++ b/atr/worker.py
@@ -29,7 +29,6 @@ import logging
 import os
 import resource
 import signal
-import sys
 import traceback
 from typing import TYPE_CHECKING, Any, Final
 
@@ -66,9 +65,6 @@ def main() -> None:
     """Main entry point."""
     import atr.config as config
 
-    signal.signal(signal.SIGTERM, _worker_signal_handle)
-    signal.signal(signal.SIGINT, _worker_signal_handle)
-
     conf = config.get()
     if os.path.isdir(conf.STATE_DIR):
         os.chdir(conf.STATE_DIR)
@@ -76,10 +72,32 @@ def main() -> None:
     _setup_logging()
 
     _LOGGER.info(f"Starting worker process with pid {os.getpid()}")
-    db.init_database_for_worker()
+
+    tasks: list[asyncio.Task] = []
+
+    async def _handle_signal(signum: int) -> None:
+        _LOGGER.info(f"Received signal {signum}, shutting down...")
+
+        await db.shutdown_database()
+
+        for t in tasks:
+            t.cancel()
+
+        _LOGGER.debug("Cancelled all running tasks")
+        asyncio.get_event_loop().stop()
+        _LOGGER.debug("Stopped event loop")
+
+    for s in (signal.SIGTERM, signal.SIGINT):
+        signal.signal(s, lambda signum, frame: 
asyncio.create_task(_handle_signal(signum)))
 
     _worker_resources_limit_set()
-    asyncio.run(_worker_loop_run())
+
+    async def _start() -> None:
+        await asyncio.create_task(db.init_database_for_worker())
+        tasks.append(asyncio.create_task(_worker_loop_run()))
+        await asyncio.gather(*tasks)
+
+    asyncio.run(_start())
 
 
 def _setup_logging() -> None:
@@ -284,17 +302,8 @@ def _worker_resources_limit_set() -> None:
         _LOGGER.warning(f"Could not set memory limit: {e}")
 
 
-def _worker_signal_handle(signum: int, frame: object) -> None:
-    """Handle termination signals gracefully."""
-    # For RLIMIT_AS we'll generally get a SIGKILL
-    # For RLIMIT_CPU we'll get a SIGXCPU, which we can catch
-    _LOGGER.info(f"Received signal {signum}, shutting down...")
-    sys.exit(0)
-
-
 if __name__ == "__main__":
     _LOGGER.info("Starting ATR worker...")
-    print("Starting ATR worker...")
     try:
         main()
     except Exception as e:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to