This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 490f6a5b122 Fix mypy warnings for SQLA2 migration (#56989)
490f6a5b122 is described below

commit 490f6a5b1222993ebb9abc57ee7d33cd5e3fc118
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Oct 27 14:14:03 2025 +0100

    Fix mypy warnings for SQLA2 migration (#56989)
    
    * Fix mypy warnings for SQLA2 migration
    
    * Revert soucing from TaskInstanceState from common.compat.sdk
---
 .../providers/edge3/executors/edge_executor.py     | 10 ++++----
 .../src/airflow/providers/edge3/models/edge_job.py |  2 +-
 .../airflow/providers/edge3/models/edge_worker.py  | 28 ++++++++++++++++------
 .../edge3/plugins/edge_executor_plugin.py          |  3 +--
 .../edge3/worker_api/routes/_v2_compat.py          |  1 +
 .../providers/edge3/worker_api/routes/jobs.py      |  9 ++++---
 .../providers/edge3/worker_api/routes/ui.py        | 11 ++++++---
 .../providers/edge3/worker_api/routes/worker.py    | 10 +++++---
 8 files changed, 49 insertions(+), 25 deletions(-)

diff --git 
a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py 
b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
index 6382756b14a..ea818b6fb70 100644
--- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
+++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
@@ -30,7 +30,7 @@ from sqlalchemy.orm import Session
 from airflow.cli.cli_config import GroupCommand
 from airflow.configuration import conf
 from airflow.executors.base_executor import BaseExecutor
-from airflow.models.taskinstance import TaskInstance, TaskInstanceState
+from airflow.models.taskinstance import TaskInstance
 from airflow.providers.common.compat.sdk import timezone
 from airflow.providers.edge3.cli.edge_command import EDGE_COMMANDS
 from airflow.providers.edge3.models.edge_job import EdgeJobModel
@@ -40,6 +40,7 @@ from airflow.providers.edge3.version_compat import 
AIRFLOW_V_3_0_PLUS
 from airflow.stats import Stats
 from airflow.utils.db import DBLocks, create_global_lock
 from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
     import argparse
@@ -68,7 +69,8 @@ class EdgeExecutor(BaseExecutor):
         """
         Check if already existing table matches the newest table schema.
 
-        workaround till Airflow 3.0.0, then it is possible to use alembic also 
for provider distributions.
+        workaround till support for Airflow 2.x is dropped,
+        then it is possible to use alembic also for provider distributions.
         """
         inspector = inspect(engine)
         edge_job_columns = None
@@ -78,7 +80,7 @@ class EdgeExecutor(BaseExecutor):
             edge_job_columns = [column["name"] for column in edge_job_schema]
             for column in edge_job_schema:
                 if column["name"] == "command":
-                    edge_job_command_len = column["type"].length
+                    edge_job_command_len = column["type"].length  # type: 
ignore[attr-defined]
 
         # version 0.6.0rc1 added new column concurrency_slots
         if edge_job_columns and "concurrency_slots" not in edge_job_columns:
@@ -284,7 +286,7 @@ class EdgeExecutor(BaseExecutor):
                 map_index=job.map_index,
                 session=session,
             )
-            job.state = ti.state if ti else TaskInstanceState.REMOVED
+            job.state = ti.state if ti and ti.state else 
TaskInstanceState.REMOVED
 
             if job.state != TaskInstanceState.RUNNING:
                 # Edge worker does not backport emitted Airflow metrics, so 
export some metrics
diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_job.py 
b/providers/edge3/src/airflow/providers/edge3/models/edge_job.py
index e836716fe2e..ccf21de848f 100644
--- a/providers/edge3/src/airflow/providers/edge3/models/edge_job.py
+++ b/providers/edge3/src/airflow/providers/edge3/models/edge_job.py
@@ -94,4 +94,4 @@ class EdgeJobModel(Base, LoggingMixin):
 
     @property
     def last_update_t(self) -> float:
-        return self.last_update.timestamp()
+        return self.last_update.timestamp() if self.last_update else 
datetime.now().timestamp()
diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py 
b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
index 9eda47517ed..7d4d53d39f1 100644
--- a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
@@ -230,7 +230,9 @@ def request_maintenance(
 ) -> None:
     """Write maintenance request to the db."""
     query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
-    worker: EdgeWorkerModel = session.scalar(query)
+    worker: EdgeWorkerModel | None = session.scalar(query)
+    if not worker:
+        raise ValueError(f"Edge Worker {worker_name} not found in list of 
registered workers")
     worker.state = EdgeWorkerState.MAINTENANCE_REQUEST
     worker.maintenance_comment = maintenance_comment
 
@@ -239,7 +241,9 @@ def request_maintenance(
 def exit_maintenance(worker_name: str, session: Session = NEW_SESSION) -> None:
     """Write maintenance exit to the db."""
     query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
-    worker: EdgeWorkerModel = session.scalar(query)
+    worker: EdgeWorkerModel | None = session.scalar(query)
+    if not worker:
+        raise ValueError(f"Edge Worker {worker_name} not found in list of 
registered workers")
     worker.state = EdgeWorkerState.MAINTENANCE_EXIT
     worker.maintenance_comment = None
 
@@ -248,7 +252,9 @@ def exit_maintenance(worker_name: str, session: Session = 
NEW_SESSION) -> None:
 def remove_worker(worker_name: str, session: Session = NEW_SESSION) -> None:
     """Remove a worker that is offline or just gone from DB."""
     query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
-    worker: EdgeWorkerModel = session.scalar(query)
+    worker: EdgeWorkerModel | None = session.scalar(query)
+    if not worker:
+        raise ValueError(f"Edge Worker {worker_name} not found in list of 
registered workers")
     if worker.state in (
         EdgeWorkerState.OFFLINE,
         EdgeWorkerState.OFFLINE_MAINTENANCE,
@@ -267,7 +273,9 @@ def change_maintenance_comment(
 ) -> None:
     """Write maintenance comment in the db."""
     query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
-    worker: EdgeWorkerModel = session.scalar(query)
+    worker: EdgeWorkerModel | None = session.scalar(query)
+    if not worker:
+        raise ValueError(f"Edge Worker {worker_name} not found in list of 
registered workers")
     if worker.state in (
         EdgeWorkerState.MAINTENANCE_MODE,
         EdgeWorkerState.MAINTENANCE_PENDING,
@@ -285,7 +293,9 @@ def change_maintenance_comment(
 def request_shutdown(worker_name: str, session: Session = NEW_SESSION) -> None:
     """Request to shutdown the edge worker."""
     query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
-    worker: EdgeWorkerModel = session.scalar(query)
+    worker: EdgeWorkerModel | None = session.scalar(query)
+    if not worker:
+        raise ValueError(f"Edge Worker {worker_name} not found in list of 
registered workers")
     if worker.state not in (
         EdgeWorkerState.OFFLINE,
         EdgeWorkerState.OFFLINE_MAINTENANCE,
@@ -298,7 +308,9 @@ def request_shutdown(worker_name: str, session: Session = 
NEW_SESSION) -> None:
 def add_worker_queues(worker_name: str, queues: list[str], session: Session = 
NEW_SESSION) -> None:
     """Add queues to an edge worker."""
     query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
-    worker: EdgeWorkerModel = session.scalar(query)
+    worker: EdgeWorkerModel | None = session.scalar(query)
+    if not worker:
+        raise ValueError(f"Edge Worker {worker_name} not found in list of 
registered workers")
     if worker.state in (
         EdgeWorkerState.OFFLINE,
         EdgeWorkerState.OFFLINE_MAINTENANCE,
@@ -314,7 +326,9 @@ def add_worker_queues(worker_name: str, queues: list[str], 
session: Session = NE
 def remove_worker_queues(worker_name: str, queues: list[str], session: Session 
= NEW_SESSION) -> None:
     """Remove queues from an edge worker."""
     query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
-    worker: EdgeWorkerModel = session.scalar(query)
+    worker: EdgeWorkerModel | None = session.scalar(query)
+    if not worker:
+        raise ValueError(f"Edge Worker {worker_name} not found in list of 
registered workers")
     if worker.state in (
         EdgeWorkerState.OFFLINE,
         EdgeWorkerState.OFFLINE_MAINTENANCE,
diff --git 
a/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py 
b/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py
index d37f30805cd..110e9efb7fc 100644
--- 
a/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py
+++ 
b/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py
@@ -66,8 +66,7 @@ else:
     from sqlalchemy import select
 
     from airflow.auth.managers.models.resource_details import AccessView
-    from airflow.models.taskinstance import TaskInstanceState
-    from airflow.utils.state import State
+    from airflow.utils.state import State, TaskInstanceState
     from airflow.utils.yaml import safe_load
     from airflow.www.auth import has_access_view
 
diff --git 
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py 
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py
index 14889bd94ad..046dbea7b63 100644
--- 
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py
+++ 
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py
@@ -68,6 +68,7 @@ else:
         HTTP_204_NO_CONTENT = 204
         HTTP_400_BAD_REQUEST = 400
         HTTP_403_FORBIDDEN = 403
+        HTTP_404_NOT_FOUND = 404
         HTTP_500_INTERNAL_SERVER_ERROR = 500
 
     class HTTPException(ProblemException):  # type: ignore[no-redef]
diff --git 
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py 
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
index 4f1804fa9b6..a162ffc9db7 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
@@ -39,7 +39,6 @@ from airflow.providers.edge3.worker_api.routes._v2_compat 
import (
     status,
 )
 from airflow.stats import Stats
-from airflow.utils.sqlalchemy import with_row_locks
 from airflow.utils.state import TaskInstanceState
 
 jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")
@@ -78,8 +77,8 @@ def fetch(
     if body.queues:
         query = query.where(EdgeJobModel.queue.in_(body.queues))
     query = query.limit(1)
-    query = with_row_locks(query, of=EdgeJobModel, session=session, 
skip_locked=True)
-    job: EdgeJobModel = session.scalar(query)
+    query = query.with_for_update(skip_locked=True)
+    job: EdgeJobModel | None = session.scalar(query)
     if not job:
         return None
     job.state = TaskInstanceState.RUNNING
@@ -148,7 +147,7 @@ def state(
             )
             Stats.incr("edge_worker.ti.finish", tags=tags)
 
-    query = (
+    query2 = (
         update(EdgeJobModel)
         .where(
             EdgeJobModel.dag_id == dag_id,
@@ -159,4 +158,4 @@ def state(
         )
         .values(state=state, last_update=timezone.utcnow())
     )
-    session.execute(query)
+    session.execute(query2)
diff --git 
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py 
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py
index 11ea9b71e90..14b44deaff7 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from datetime import datetime
+from typing import TYPE_CHECKING
 
 from fastapi import Depends, HTTPException, status
 from sqlalchemy import select
@@ -44,6 +45,10 @@ from airflow.providers.edge3.worker_api.datamodels_ui import 
(
     Worker,
     WorkerCollectionResponse,
 )
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+    from sqlalchemy.engine import ScalarResult
 
 ui_router = AirflowRouter(tags=["UI"])
 
@@ -59,7 +64,7 @@ def worker(
 ) -> WorkerCollectionResponse:
     """Return Edge Workers."""
     query = select(EdgeWorkerModel).order_by(EdgeWorkerModel.worker_name)
-    workers: list[EdgeWorkerModel] = session.scalars(query)
+    workers: ScalarResult[EdgeWorkerModel] = session.scalars(query)
 
     result = [
         Worker(
@@ -91,7 +96,7 @@ def jobs(
 ) -> JobCollectionResponse:
     """Return Edge Jobs."""
     query = select(EdgeJobModel).order_by(EdgeJobModel.queued_dttm)
-    jobs: list[EdgeJobModel] = session.scalars(query)
+    jobs: ScalarResult[EdgeJobModel] = session.scalars(query)
 
     result = [
         Job(
@@ -100,7 +105,7 @@ def jobs(
             run_id=j.run_id,
             map_index=j.map_index,
             try_number=j.try_number,
-            state=j.state,
+            state=TaskInstanceState(j.state),
             queue=j.queue,
             queued_dttm=j.queued_dttm,
             edge_worker=j.edge_worker,
diff --git 
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py 
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
index d99ae0b270b..1a696509d1e 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
@@ -172,7 +172,7 @@ def register(
     """Register a new worker to the backend."""
     _assert_version(body.sysinfo)
     query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
-    worker: EdgeWorkerModel = session.scalar(query)
+    worker: EdgeWorkerModel | None = session.scalar(query)
     if not worker:
         worker = EdgeWorkerModel(worker_name=worker_name, state=body.state, 
queues=body.queues)
     worker.state = redefine_state(worker.state, body.state)
@@ -194,7 +194,9 @@ def set_state(
 ) -> WorkerSetStateReturn:
     """Set state of worker and returns the current assigned queues."""
     query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
-    worker: EdgeWorkerModel = session.scalar(query)
+    worker: EdgeWorkerModel | None = session.scalar(query)
+    if not worker:
+        raise HTTPException(status.HTTP_404_NOT_FOUND, "Worker not found")
     worker.state = redefine_state(worker.state, body.state)
     worker.maintenance_comment = redefine_maintenance_comments(
         worker.maintenance_comment, body.maintenance_comments
@@ -229,7 +231,9 @@ def update_queues(
     session: SessionDep,
 ) -> None:
     query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
-    worker: EdgeWorkerModel = session.scalar(query)
+    worker: EdgeWorkerModel | None = session.scalar(query)
+    if not worker:
+        raise HTTPException(status.HTTP_404_NOT_FOUND, "Worker not found")
     if body.new_queues:
         worker.add_queues(body.new_queues)
     if body.remove_queues:

Reply via email to