This is an automated email from the ASF dual-hosted git repository.

arm pushed a commit to branch arm
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git

commit 32f4ee3b79dded6f95559d1435dbc8adaf3766d2
Author: Alastair McFarlane <[email protected]>
AuthorDate: Tue Feb 17 15:13:34 2026 +0000

    Check for running tasks as well as completed checks when using cache keys
---
 atr/db/__init__.py                              |  7 +++
 atr/models/sql.py                               |  1 +
 atr/tasks/__init__.py                           | 59 +++++++++++++++----------
 migrations/versions/0050_2026.02.17_7406bb29.py | 29 ++++++++++++
 4 files changed, 72 insertions(+), 24 deletions(-)

diff --git a/atr/db/__init__.py b/atr/db/__init__.py
index 2c6d579e..fb282628 100644
--- a/atr/db/__init__.py
+++ b/atr/db/__init__.py
@@ -670,8 +670,10 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession):
         self,
         id: Opt[int] = NOT_SET,
         status: Opt[sql.TaskStatus] = NOT_SET,
+        status_in: Opt[list[sql.TaskStatus]] = NOT_SET,
         task_type: Opt[str] = NOT_SET,
         task_args: Opt[Any] = NOT_SET,
+        inputs_hash: Opt[str] = NOT_SET,
         asf_uid: Opt[str] = NOT_SET,
         added: Opt[datetime.datetime] = NOT_SET,
         started: Opt[datetime.datetime | None] = NOT_SET,
@@ -687,14 +689,19 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession):
     ) -> Query[sql.Task]:
         query = sqlmodel.select(sql.Task)
 
+        via = sql.validate_instrumented_attribute
         if is_defined(id):
             query = query.where(sql.Task.id == id)
         if is_defined(status):
             query = query.where(sql.Task.status == status)
+        if is_defined(status_in):
+            query = query.where(via(sql.Task.status).in_(status_in))
         if is_defined(task_type):
             query = query.where(sql.Task.task_type == task_type)
         if is_defined(task_args):
             query = query.where(sql.Task.task_args == task_args)
+        if is_defined(inputs_hash):
+            query = query.where(sql.Task.inputs_hash == inputs_hash)
         if is_defined(asf_uid):
             query = query.where(sql.Task.asf_uid == asf_uid)
         if is_defined(added):
diff --git a/atr/models/sql.py b/atr/models/sql.py
index c03a3953..c1eda568 100644
--- a/atr/models/sql.py
+++ b/atr/models/sql.py
@@ -357,6 +357,7 @@ class Task(sqlmodel.SQLModel, table=True):
     status: TaskStatus = sqlmodel.Field(default=TaskStatus.QUEUED, index=True)
     task_type: TaskType
     task_args: Any = 
sqlmodel.Field(sa_column=sqlalchemy.Column(sqlalchemy.JSON))
+    inputs_hash: str | None = sqlmodel.Field(default=None, index=True, 
**example("blake3:7f83b1657ff1fc..."))
     asf_uid: str
     added: datetime.datetime = sqlmodel.Field(
         default_factory=lambda: datetime.datetime.now(datetime.UTC),
diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py
index 00beb6d7..42e03406 100644
--- a/atr/tasks/__init__.py
+++ b/atr/tasks/__init__.py
@@ -17,7 +17,6 @@
 
 import asyncio
 import datetime
-import logging
 import pathlib
 from collections.abc import Awaitable, Callable, Coroutine
 from typing import Any, Final
@@ -220,24 +219,24 @@ async def _draft_file_checks(
     # TODO: Should we check .json files for their content?
     # Ideally we would not have to do that
     if path.name.endswith(".cdx.json"):
-        data.add(
-            await queued(
-                asf_uid,
-                sql.TaskType.SBOM_TOOL_SCORE,
-                release,
-                revision_number,
-                caller_data,
-                path_str,
-                extra_args={
-                    "project_name": project_name,
-                    "version_name": release_version,
-                    "revision_number": revision_number,
-                    "previous_release_version": previous_version.version if 
previous_version else None,
-                    "file_path": path_str,
-                    "asf_uid": asf_uid,
-                },
-            )
+        cdx_task = await queued(
+            asf_uid,
+            sql.TaskType.SBOM_TOOL_SCORE,
+            release,
+            revision_number,
+            caller_data,
+            path_str,
+            extra_args={
+                "project_name": project_name,
+                "version_name": release_version,
+                "revision_number": revision_number,
+                "previous_release_version": previous_version.version if 
previous_version else None,
+                "file_path": path_str,
+                "asf_uid": asf_uid,
+            },
         )
+        if cdx_task:
+            data.add(cdx_task)
 
 
 async def keys_import_file(
@@ -299,17 +298,29 @@ async def queued(
     extra_args: dict[str, Any] | None = None,
     check_cache_key: dict[str, Any] | None = None,
 ) -> sql.Task | None:
+    # If there's a queued or running task for this same set of inputs and hash 
value, don't start a new one
+    # If there isn't one, but there is an existing check result, also don't 
run a new task, just use the existing one
     if check_cache_key is not None:
-        logging.info("cache key", check_cache_key)
         hash_val = hashes.compute_dict_hash(check_cache_key)
         if not data:
             raise RuntimeError("DB Session is required for check_cache_key")
-        existing = await data.check_result(inputs_hash=hash_val, 
release_name=release.name).all()
-        if existing:
-            await attestable.write_checks_data(
-                release.project.name, release.version, revision_number, [c.id 
for c in existing]
-            )
+        existing_task = await data.task(
+            inputs_hash=hash_val,
+            project_name=release.project_name,
+            version_name=release.version,
+            task_args=extra_args or {},
+            status_in=[sql.TaskStatus.QUEUED, sql.TaskStatus.ACTIVE],
+        ).all()
+        if existing_task:
             return None
+        else:
+            existing = await data.check_result(inputs_hash=hash_val, 
release_name=release.name).all()
+            if existing:
+                await attestable.write_checks_data(
+                    release.project.name, release.version, revision_number, 
[c.id for c in existing]
+                )
+                return None
+
     return sql.Task(
         status=sql.TaskStatus.QUEUED,
         task_type=task_type,
diff --git a/migrations/versions/0050_2026.02.17_7406bb29.py 
b/migrations/versions/0050_2026.02.17_7406bb29.py
new file mode 100644
index 00000000..0b635c1c
--- /dev/null
+++ b/migrations/versions/0050_2026.02.17_7406bb29.py
@@ -0,0 +1,29 @@
+"""Add inputs hash to task table
+
+Revision ID: 0050_2026.02.17_7406bb29
+Revises: 0049_2026.02.11_5b874ed2
+Create Date: 2026-02-17 14:34:59.166215+00:00
+"""
+
+from collections.abc import Sequence
+
+import sqlalchemy as sa
+from alembic import op
+
+# Revision identifiers, used by Alembic
+revision: str = "0050_2026.02.17_7406bb29"
+down_revision: str | None = "0049_2026.02.11_5b874ed2"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+    with op.batch_alter_table("task", schema=None) as batch_op:
+        batch_op.add_column(sa.Column("inputs_hash", sa.String(), 
nullable=True))
+        batch_op.create_index(batch_op.f("ix_task_inputs_hash"), 
["inputs_hash"], unique=False)
+
+
+def downgrade() -> None:
+    with op.batch_alter_table("task", schema=None) as batch_op:
+        batch_op.drop_index(batch_op.f("ix_task_inputs_hash"))
+        batch_op.drop_column("inputs_hash")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to