This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 490f6a5b122 Fix mypy warnings for SQLA2 migration (#56989)
490f6a5b122 is described below
commit 490f6a5b1222993ebb9abc57ee7d33cd5e3fc118
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Oct 27 14:14:03 2025 +0100
Fix mypy warnings for SQLA2 migration (#56989)
* Fix mypy warnings for SQLA2 migration
* Revert soucing from TaskInstanceState from common.compat.sdk
---
.../providers/edge3/executors/edge_executor.py | 10 ++++----
.../src/airflow/providers/edge3/models/edge_job.py | 2 +-
.../airflow/providers/edge3/models/edge_worker.py | 28 ++++++++++++++++------
.../edge3/plugins/edge_executor_plugin.py | 3 +--
.../edge3/worker_api/routes/_v2_compat.py | 1 +
.../providers/edge3/worker_api/routes/jobs.py | 9 ++++---
.../providers/edge3/worker_api/routes/ui.py | 11 ++++++---
.../providers/edge3/worker_api/routes/worker.py | 10 +++++---
8 files changed, 49 insertions(+), 25 deletions(-)
diff --git
a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
index 6382756b14a..ea818b6fb70 100644
--- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
+++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
@@ -30,7 +30,7 @@ from sqlalchemy.orm import Session
from airflow.cli.cli_config import GroupCommand
from airflow.configuration import conf
from airflow.executors.base_executor import BaseExecutor
-from airflow.models.taskinstance import TaskInstance, TaskInstanceState
+from airflow.models.taskinstance import TaskInstance
from airflow.providers.common.compat.sdk import timezone
from airflow.providers.edge3.cli.edge_command import EDGE_COMMANDS
from airflow.providers.edge3.models.edge_job import EdgeJobModel
@@ -40,6 +40,7 @@ from airflow.providers.edge3.version_compat import
AIRFLOW_V_3_0_PLUS
from airflow.stats import Stats
from airflow.utils.db import DBLocks, create_global_lock
from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
import argparse
@@ -68,7 +69,8 @@ class EdgeExecutor(BaseExecutor):
"""
Check if already existing table matches the newest table schema.
- workaround till Airflow 3.0.0, then it is possible to use alembic also
for provider distributions.
+ workaround till support for Airflow 2.x is dropped,
+ then it is possible to use alembic also for provider distributions.
"""
inspector = inspect(engine)
edge_job_columns = None
@@ -78,7 +80,7 @@ class EdgeExecutor(BaseExecutor):
edge_job_columns = [column["name"] for column in edge_job_schema]
for column in edge_job_schema:
if column["name"] == "command":
- edge_job_command_len = column["type"].length
+ edge_job_command_len = column["type"].length # type:
ignore[attr-defined]
# version 0.6.0rc1 added new column concurrency_slots
if edge_job_columns and "concurrency_slots" not in edge_job_columns:
@@ -284,7 +286,7 @@ class EdgeExecutor(BaseExecutor):
map_index=job.map_index,
session=session,
)
- job.state = ti.state if ti else TaskInstanceState.REMOVED
+ job.state = ti.state if ti and ti.state else
TaskInstanceState.REMOVED
if job.state != TaskInstanceState.RUNNING:
# Edge worker does not backport emitted Airflow metrics, so
export some metrics
diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_job.py
b/providers/edge3/src/airflow/providers/edge3/models/edge_job.py
index e836716fe2e..ccf21de848f 100644
--- a/providers/edge3/src/airflow/providers/edge3/models/edge_job.py
+++ b/providers/edge3/src/airflow/providers/edge3/models/edge_job.py
@@ -94,4 +94,4 @@ class EdgeJobModel(Base, LoggingMixin):
@property
def last_update_t(self) -> float:
- return self.last_update.timestamp()
+ return self.last_update.timestamp() if self.last_update else
datetime.now().timestamp()
diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
index 9eda47517ed..7d4d53d39f1 100644
--- a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
@@ -230,7 +230,9 @@ def request_maintenance(
) -> None:
"""Write maintenance request to the db."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
+ worker: EdgeWorkerModel | None = session.scalar(query)
+ if not worker:
+ raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
worker.state = EdgeWorkerState.MAINTENANCE_REQUEST
worker.maintenance_comment = maintenance_comment
@@ -239,7 +241,9 @@ def request_maintenance(
def exit_maintenance(worker_name: str, session: Session = NEW_SESSION) -> None:
"""Write maintenance exit to the db."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
+ worker: EdgeWorkerModel | None = session.scalar(query)
+ if not worker:
+ raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
worker.state = EdgeWorkerState.MAINTENANCE_EXIT
worker.maintenance_comment = None
@@ -248,7 +252,9 @@ def exit_maintenance(worker_name: str, session: Session =
NEW_SESSION) -> None:
def remove_worker(worker_name: str, session: Session = NEW_SESSION) -> None:
"""Remove a worker that is offline or just gone from DB."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
+ worker: EdgeWorkerModel | None = session.scalar(query)
+ if not worker:
+ raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
if worker.state in (
EdgeWorkerState.OFFLINE,
EdgeWorkerState.OFFLINE_MAINTENANCE,
@@ -267,7 +273,9 @@ def change_maintenance_comment(
) -> None:
"""Write maintenance comment in the db."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
+ worker: EdgeWorkerModel | None = session.scalar(query)
+ if not worker:
+ raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
if worker.state in (
EdgeWorkerState.MAINTENANCE_MODE,
EdgeWorkerState.MAINTENANCE_PENDING,
@@ -285,7 +293,9 @@ def change_maintenance_comment(
def request_shutdown(worker_name: str, session: Session = NEW_SESSION) -> None:
"""Request to shutdown the edge worker."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
+ worker: EdgeWorkerModel | None = session.scalar(query)
+ if not worker:
+ raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
if worker.state not in (
EdgeWorkerState.OFFLINE,
EdgeWorkerState.OFFLINE_MAINTENANCE,
@@ -298,7 +308,9 @@ def request_shutdown(worker_name: str, session: Session =
NEW_SESSION) -> None:
def add_worker_queues(worker_name: str, queues: list[str], session: Session =
NEW_SESSION) -> None:
"""Add queues to an edge worker."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
+ worker: EdgeWorkerModel | None = session.scalar(query)
+ if not worker:
+ raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
if worker.state in (
EdgeWorkerState.OFFLINE,
EdgeWorkerState.OFFLINE_MAINTENANCE,
@@ -314,7 +326,9 @@ def add_worker_queues(worker_name: str, queues: list[str],
session: Session = NE
def remove_worker_queues(worker_name: str, queues: list[str], session: Session
= NEW_SESSION) -> None:
"""Remove queues from an edge worker."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
+ worker: EdgeWorkerModel | None = session.scalar(query)
+ if not worker:
+ raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
if worker.state in (
EdgeWorkerState.OFFLINE,
EdgeWorkerState.OFFLINE_MAINTENANCE,
diff --git
a/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py
b/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py
index d37f30805cd..110e9efb7fc 100644
---
a/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py
+++
b/providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py
@@ -66,8 +66,7 @@ else:
from sqlalchemy import select
from airflow.auth.managers.models.resource_details import AccessView
- from airflow.models.taskinstance import TaskInstanceState
- from airflow.utils.state import State
+ from airflow.utils.state import State, TaskInstanceState
from airflow.utils.yaml import safe_load
from airflow.www.auth import has_access_view
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py
index 14889bd94ad..046dbea7b63 100644
---
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py
+++
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/_v2_compat.py
@@ -68,6 +68,7 @@ else:
HTTP_204_NO_CONTENT = 204
HTTP_400_BAD_REQUEST = 400
HTTP_403_FORBIDDEN = 403
+ HTTP_404_NOT_FOUND = 404
HTTP_500_INTERNAL_SERVER_ERROR = 500
class HTTPException(ProblemException): # type: ignore[no-redef]
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
index 4f1804fa9b6..a162ffc9db7 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
@@ -39,7 +39,6 @@ from airflow.providers.edge3.worker_api.routes._v2_compat
import (
status,
)
from airflow.stats import Stats
-from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import TaskInstanceState
jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")
@@ -78,8 +77,8 @@ def fetch(
if body.queues:
query = query.where(EdgeJobModel.queue.in_(body.queues))
query = query.limit(1)
- query = with_row_locks(query, of=EdgeJobModel, session=session,
skip_locked=True)
- job: EdgeJobModel = session.scalar(query)
+ query = query.with_for_update(skip_locked=True)
+ job: EdgeJobModel | None = session.scalar(query)
if not job:
return None
job.state = TaskInstanceState.RUNNING
@@ -148,7 +147,7 @@ def state(
)
Stats.incr("edge_worker.ti.finish", tags=tags)
- query = (
+ query2 = (
update(EdgeJobModel)
.where(
EdgeJobModel.dag_id == dag_id,
@@ -159,4 +158,4 @@ def state(
)
.values(state=state, last_update=timezone.utcnow())
)
- session.execute(query)
+ session.execute(query2)
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py
index 11ea9b71e90..14b44deaff7 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py
@@ -18,6 +18,7 @@
from __future__ import annotations
from datetime import datetime
+from typing import TYPE_CHECKING
from fastapi import Depends, HTTPException, status
from sqlalchemy import select
@@ -44,6 +45,10 @@ from airflow.providers.edge3.worker_api.datamodels_ui import
(
Worker,
WorkerCollectionResponse,
)
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+ from sqlalchemy.engine import ScalarResult
ui_router = AirflowRouter(tags=["UI"])
@@ -59,7 +64,7 @@ def worker(
) -> WorkerCollectionResponse:
"""Return Edge Workers."""
query = select(EdgeWorkerModel).order_by(EdgeWorkerModel.worker_name)
- workers: list[EdgeWorkerModel] = session.scalars(query)
+ workers: ScalarResult[EdgeWorkerModel] = session.scalars(query)
result = [
Worker(
@@ -91,7 +96,7 @@ def jobs(
) -> JobCollectionResponse:
"""Return Edge Jobs."""
query = select(EdgeJobModel).order_by(EdgeJobModel.queued_dttm)
- jobs: list[EdgeJobModel] = session.scalars(query)
+ jobs: ScalarResult[EdgeJobModel] = session.scalars(query)
result = [
Job(
@@ -100,7 +105,7 @@ def jobs(
run_id=j.run_id,
map_index=j.map_index,
try_number=j.try_number,
- state=j.state,
+ state=TaskInstanceState(j.state),
queue=j.queue,
queued_dttm=j.queued_dttm,
edge_worker=j.edge_worker,
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
index d99ae0b270b..1a696509d1e 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
@@ -172,7 +172,7 @@ def register(
"""Register a new worker to the backend."""
_assert_version(body.sysinfo)
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
+ worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
worker = EdgeWorkerModel(worker_name=worker_name, state=body.state,
queues=body.queues)
worker.state = redefine_state(worker.state, body.state)
@@ -194,7 +194,9 @@ def set_state(
) -> WorkerSetStateReturn:
"""Set state of worker and returns the current assigned queues."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
+ worker: EdgeWorkerModel | None = session.scalar(query)
+ if not worker:
+ raise HTTPException(status.HTTP_404_NOT_FOUND, "Worker not found")
worker.state = redefine_state(worker.state, body.state)
worker.maintenance_comment = redefine_maintenance_comments(
worker.maintenance_comment, body.maintenance_comments
@@ -229,7 +231,9 @@ def update_queues(
session: SessionDep,
) -> None:
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
+ worker: EdgeWorkerModel | None = session.scalar(query)
+ if not worker:
+ raise HTTPException(status.HTTP_404_NOT_FOUND, "Worker not found")
if body.new_queues:
worker.add_queues(body.new_queues)
if body.remove_queues: