This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0d98e2b0520 Migrate Edge calls for Worker to FastAPI part 4 - Cleanup
(#44434)
0d98e2b0520 is described below
commit 0d98e2b052066c92b88a7b7d16449f4dc36d1b2a
Author: Jens Scheffler <[email protected]>
AuthorDate: Sun Dec 1 12:50:06 2024 +0100
Migrate Edge calls for Worker to FastAPI part 4 - Cleanup (#44434)
* Remove internal API bindings after migration to FastAPI
* Move import to function preventing module import errors
---
providers/src/airflow/providers/edge/CHANGELOG.rst | 8 ++
providers/src/airflow/providers/edge/__init__.py | 2 +-
.../src/airflow/providers/edge/cli/api_client.py | 4 +-
.../src/airflow/providers/edge/models/edge_job.py | 105 +--------------
.../src/airflow/providers/edge/models/edge_logs.py | 81 ------------
.../airflow/providers/edge/models/edge_worker.py | 75 +----------
.../providers/edge/openapi/edge_worker_api_v1.yaml | 2 +-
providers/src/airflow/providers/edge/provider.yaml | 2 +-
.../providers/edge/worker_api/routes/_v2_routes.py | 147 ++-------------------
.../providers/edge/worker_api/routes/jobs.py | 2 +-
providers/tests/edge/models/test_edge_job.py | 121 -----------------
11 files changed, 26 insertions(+), 523 deletions(-)
diff --git a/providers/src/airflow/providers/edge/CHANGELOG.rst
b/providers/src/airflow/providers/edge/CHANGELOG.rst
index 05c3d728246..ecf411a1d50 100644
--- a/providers/src/airflow/providers/edge/CHANGELOG.rst
+++ b/providers/src/airflow/providers/edge/CHANGELOG.rst
@@ -27,6 +27,14 @@
Changelog
---------
+0.9.0pre0
+.........
+
+Misc
+~~~~
+
+* ``Remove dependency to Internal API after migration to FastAPI.``
+
0.8.2pre0
.........
diff --git a/providers/src/airflow/providers/edge/__init__.py
b/providers/src/airflow/providers/edge/__init__.py
index d826c633ead..5c207bef66a 100644
--- a/providers/src/airflow/providers/edge/__init__.py
+++ b/providers/src/airflow/providers/edge/__init__.py
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
__all__ = ["__version__"]
-__version__ = "0.8.2pre0"
+__version__ = "0.9.0pre0"
if
packaging.version.parse(packaging.version.parse(airflow_version).base_version)
< packaging.version.parse(
"2.10.0"
diff --git a/providers/src/airflow/providers/edge/cli/api_client.py
b/providers/src/airflow/providers/edge/cli/api_client.py
index c0a0144f5fe..942577e86d2 100644
--- a/providers/src/airflow/providers/edge/cli/api_client.py
+++ b/providers/src/airflow/providers/edge/cli/api_client.py
@@ -110,7 +110,7 @@ def worker_register(
def worker_set_state(
hostname: str, state: EdgeWorkerState, jobs_active: int, queues: list[str]
| None, sysinfo: dict
) -> list[str] | None:
- """Register worker with the Edge API."""
+ """Update the state of the worker in the central site and thereby
implicitly heartbeat."""
return _make_generic_request(
"PATCH",
f"worker/{quote(hostname)}",
@@ -123,7 +123,7 @@ def worker_set_state(
def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int)
-> EdgeJobFetched | None:
"""Fetch a job to execute on the edge worker."""
result = _make_generic_request(
- "GET",
+ "POST",
f"jobs/fetch/{quote(hostname)}",
WorkerQueuesBody(queues=queues,
free_concurrency=free_concurrency).model_dump_json(
exclude_unset=True
diff --git a/providers/src/airflow/providers/edge/models/edge_job.py
b/providers/src/airflow/providers/edge/models/edge_job.py
index c591b6d3052..818e7bf7f9f 100644
--- a/providers/src/airflow/providers/edge/models/edge_job.py
+++ b/providers/src/airflow/providers/edge/models/edge_job.py
@@ -16,32 +16,21 @@
# under the License.
from __future__ import annotations
-from ast import literal_eval
from datetime import datetime
-from typing import TYPE_CHECKING, Optional
-from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
Column,
Index,
Integer,
String,
- select,
text,
)
-from airflow.api_internal.internal_api_call import internal_api_call
from airflow.models.base import Base, StringID
from airflow.models.taskinstancekey import TaskInstanceKey
-from airflow.serialization.serialized_objects import
add_pydantic_class_type_mapping
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
-from airflow.utils.state import TaskInstanceState
-
-if TYPE_CHECKING:
- from sqlalchemy.orm.session import Session
+from airflow.utils.sqlalchemy import UtcDateTime
class EdgeJobModel(Base, LoggingMixin):
@@ -103,95 +92,3 @@ class EdgeJobModel(Base, LoggingMixin):
@property
def last_update_t(self) -> float:
return self.last_update.timestamp()
-
-
-class EdgeJob(BaseModel, LoggingMixin):
- """Accessor for edge jobs as logical model."""
-
- dag_id: str
- task_id: str
- run_id: str
- map_index: int
- try_number: int
- state: TaskInstanceState
- queue: str
- concurrency_slots: int
- command: list[str]
- queued_dttm: datetime
- edge_worker: Optional[str] # noqa: UP007 - prevent Sphinx failing
- last_update: Optional[datetime] # noqa: UP007 - prevent Sphinx failing
- model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
-
- @property
- def key(self) -> TaskInstanceKey:
- return TaskInstanceKey(self.dag_id, self.task_id, self.run_id,
self.try_number, self.map_index)
-
- @staticmethod
- @internal_api_call
- @provide_session
- def reserve_task(
- worker_name: str,
- free_concurrency: int,
- queues: list[str] | None = None,
- session: Session = NEW_SESSION,
- ) -> EdgeJob | None:
- query = (
- select(EdgeJobModel)
- .where(
- EdgeJobModel.state == TaskInstanceState.QUEUED,
- EdgeJobModel.concurrency_slots <= free_concurrency,
- )
- .order_by(EdgeJobModel.queued_dttm)
- )
- if queues:
- query = query.where(EdgeJobModel.queue.in_(queues))
- query = query.limit(1)
- query = with_row_locks(query, of=EdgeJobModel, session=session,
skip_locked=True)
- job: EdgeJobModel = session.scalar(query)
- if not job:
- return None
- job.state = TaskInstanceState.RUNNING
- job.edge_worker = worker_name
- job.last_update = timezone.utcnow()
- session.commit()
- return EdgeJob(
- dag_id=job.dag_id,
- task_id=job.task_id,
- run_id=job.run_id,
- map_index=job.map_index,
- try_number=job.try_number,
- state=job.state,
- queue=job.queue,
- concurrency_slots=job.concurrency_slots,
- command=literal_eval(job.command),
- queued_dttm=job.queued_dttm,
- edge_worker=job.edge_worker,
- last_update=job.last_update,
- )
-
- @staticmethod
- @internal_api_call
- @provide_session
- def set_state(task: TaskInstanceKey | tuple, state: TaskInstanceState,
session: Session = NEW_SESSION):
- if isinstance(task, tuple):
- task = TaskInstanceKey(*task)
- query = select(EdgeJobModel).where(
- EdgeJobModel.dag_id == task.dag_id,
- EdgeJobModel.task_id == task.task_id,
- EdgeJobModel.run_id == task.run_id,
- EdgeJobModel.map_index == task.map_index,
- EdgeJobModel.try_number == task.try_number,
- )
- job: EdgeJobModel = session.scalar(query)
- if job:
- job.state = state
- job.last_update = timezone.utcnow()
- session.commit()
-
- def __hash__(self):
- return
f"{self.dag_id}|{self.task_id}|{self.run_id}|{self.map_index}|{self.try_number}".__hash__()
-
-
-EdgeJob.model_rebuild()
-
-add_pydantic_class_type_mapping("edge_job", EdgeJobModel, EdgeJob)
diff --git a/providers/src/airflow/providers/edge/models/edge_logs.py
b/providers/src/airflow/providers/edge/models/edge_logs.py
index 49a340540b3..c2f90282228 100644
--- a/providers/src/airflow/providers/edge/models/edge_logs.py
+++ b/providers/src/airflow/providers/edge/models/edge_logs.py
@@ -17,11 +17,7 @@
from __future__ import annotations
from datetime import datetime
-from functools import lru_cache
-from pathlib import Path
-from typing import TYPE_CHECKING
-from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
Column,
Integer,
@@ -30,19 +26,10 @@ from sqlalchemy import (
)
from sqlalchemy.dialects.mysql import MEDIUMTEXT
-from airflow.api_internal.internal_api_call import internal_api_call
-from airflow.configuration import conf
from airflow.models.base import Base, StringID
-from airflow.models.taskinstance import TaskInstance
-from airflow.models.taskinstancekey import TaskInstanceKey
-from airflow.serialization.serialized_objects import
add_pydantic_class_type_mapping
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
-if TYPE_CHECKING:
- from sqlalchemy.orm.session import Session
-
class EdgeLogsModel(Base, LoggingMixin):
"""
@@ -84,71 +71,3 @@ class EdgeLogsModel(Base, LoggingMixin):
self.log_chunk_time = log_chunk_time
self.log_chunk_data = log_chunk_data
super().__init__()
-
-
-class EdgeLogs(BaseModel, LoggingMixin):
- """Deprecated Internal API for Edge Worker instances as logical model."""
-
- dag_id: str
- task_id: str
- run_id: str
- map_index: int
- try_number: int
- log_chunk_time: datetime
- log_chunk_data: str
- model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
-
- @staticmethod
- @internal_api_call
- @provide_session
- def push_logs(
- task: TaskInstanceKey | tuple,
- log_chunk_time: datetime,
- log_chunk_data: str,
- session: Session = NEW_SESSION,
- ) -> None:
- """Push an incremental log chunk from Edge Worker to central site."""
- if isinstance(task, tuple):
- task = TaskInstanceKey(*task)
- log_chunk = EdgeLogsModel(
- dag_id=task.dag_id,
- task_id=task.task_id,
- run_id=task.run_id,
- map_index=task.map_index,
- try_number=task.try_number,
- log_chunk_time=log_chunk_time,
- log_chunk_data=log_chunk_data,
- )
- session.add(log_chunk)
- # Write logs to local file to make them accessible
- logfile_path = EdgeLogs.logfile_path(task)
- if not logfile_path.exists():
- new_folder_permissions = int(
- conf.get("logging",
"file_task_handler_new_folder_permissions", fallback="0o775"), 8
- )
- logfile_path.parent.mkdir(parents=True, exist_ok=True,
mode=new_folder_permissions)
- with logfile_path.open("a") as logfile:
- logfile.write(log_chunk_data)
-
- @staticmethod
- @lru_cache
- def logfile_path(task: TaskInstanceKey) -> Path:
- """Elaborate the path and filename to expect from task execution."""
- from airflow.utils.log.file_task_handler import FileTaskHandler
-
- ti = TaskInstance.get_task_instance(
- dag_id=task.dag_id,
- run_id=task.run_id,
- task_id=task.task_id,
- map_index=task.map_index,
- )
- if TYPE_CHECKING:
- assert ti
- assert isinstance(ti, TaskInstance)
- base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT
AVAILABLE")
- return Path(base_log_folder,
FileTaskHandler(base_log_folder)._render_filename(ti, task.try_number))
-
-
-EdgeLogs.model_rebuild()
-
-add_pydantic_class_type_mapping("edge_logs", EdgeLogsModel, EdgeLogs)
diff --git a/providers/src/airflow/providers/edge/models/edge_worker.py
b/providers/src/airflow/providers/edge/models/edge_worker.py
index 656d7539d07..2f8115507d8 100644
--- a/providers/src/airflow/providers/edge/models/edge_worker.py
+++ b/providers/src/airflow/providers/edge/models/edge_worker.py
@@ -20,24 +20,16 @@ import ast
import json
from datetime import datetime
from enum import Enum
-from typing import TYPE_CHECKING, Optional
-from pydantic import BaseModel, ConfigDict
-from sqlalchemy import Column, Integer, String, select
+from sqlalchemy import Column, Integer, String
-from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import AirflowException
from airflow.models.base import Base
-from airflow.serialization.serialized_objects import
add_pydantic_class_type_mapping
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
-if TYPE_CHECKING:
- from sqlalchemy.orm.session import Session
-
class EdgeWorkerVersionException(AirflowException):
"""Signal a version mismatch between core and Edge Site."""
@@ -170,68 +162,3 @@ def reset_metrics(worker_name: str) -> None:
free_concurrency=-1,
queues=None,
)
-
-
-class EdgeWorker(BaseModel, LoggingMixin):
- """Deprecated Edge Worker internal API, keeping for one minor for graceful
migration."""
-
- worker_name: str
- state: EdgeWorkerState
- queues: Optional[list[str]] # noqa: UP007 - prevent Sphinx failing
- first_online: datetime
- last_update: Optional[datetime] = None # noqa: UP007 - prevent Sphinx
failing
- jobs_active: int
- jobs_taken: int
- jobs_success: int
- jobs_failed: int
- sysinfo: str
- model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
-
- @staticmethod
- @internal_api_call
- @provide_session
- def set_state(
- worker_name: str,
- state: EdgeWorkerState,
- jobs_active: int,
- sysinfo: dict[str, str],
- session: Session = NEW_SESSION,
- ) -> list[str] | None:
- """Set state of worker and returns the current assigned queues."""
- query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel = session.scalar(query)
- worker.state = state
- worker.jobs_active = jobs_active
- worker.sysinfo = json.dumps(sysinfo)
- worker.last_update = timezone.utcnow()
- session.commit()
- Stats.incr(f"edge_worker.heartbeat_count.{worker_name}", 1, 1)
- Stats.incr("edge_worker.heartbeat_count", 1, 1, tags={"worker_name":
worker_name})
- set_metrics(
- worker_name=worker_name,
- state=state,
- jobs_active=jobs_active,
- concurrency=int(sysinfo["concurrency"]),
- free_concurrency=int(sysinfo["free_concurrency"]),
- queues=worker.queues,
- )
- raise EdgeWorkerVersionException(
- "Edge Worker runs on an old version. Rejecting access due to
difference."
- )
-
- @staticmethod
- @internal_api_call
- def register_worker(
- worker_name: str,
- state: EdgeWorkerState,
- queues: list[str] | None,
- sysinfo: dict[str, str],
- ) -> EdgeWorker:
- raise EdgeWorkerVersionException(
- "Edge Worker runs on an old version. Rejecting access due to
difference."
- )
-
-
-EdgeWorker.model_rebuild()
-
-add_pydantic_class_type_mapping("edge_worker", EdgeWorkerModel, EdgeWorker)
diff --git
a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
index ef1f24a2288..8e361ec0982 100644
--- a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
+++ b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
@@ -179,7 +179,7 @@ paths:
tags:
- Worker
/jobs/fetch/{worker_name}:
- get:
+ post:
description: Fetch a job to execute on the edge worker.
x-openapi-router-controller:
airflow.providers.edge.worker_api.routes._v2_routes
operationId: job_fetch_v2
diff --git a/providers/src/airflow/providers/edge/provider.yaml
b/providers/src/airflow/providers/edge/provider.yaml
index 229f1ad68e4..1377279ab76 100644
--- a/providers/src/airflow/providers/edge/provider.yaml
+++ b/providers/src/airflow/providers/edge/provider.yaml
@@ -27,7 +27,7 @@ source-date-epoch: 1729683247
# note that those versions are maintained by release manager - do not update
them manually
versions:
- - 0.8.2pre0
+ - 0.9.0pre0
dependencies:
- apache-airflow>=2.10.0
diff --git
a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
index b00f5cf41b9..2b68879531d 100644
--- a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
+++ b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
@@ -20,8 +20,7 @@ from __future__ import annotations
import json
import logging
-from functools import cache
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Any
from uuid import uuid4
from flask import Response, request
@@ -50,137 +49,6 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
-@cache
-def _initialize_method_map() -> dict[str, Callable]:
- # Note: This is a copy of the (removed) AIP-44 implementation from
- # airflow/api_internal/endpoints/rpc_api_endpoint.py
- # for compatibility with Airflow 2.10-line.
- # Methods are potentially not existing more on main branch for
Airflow 3.
- from airflow.api.common.trigger_dag import trigger_dag
- from airflow.cli.commands.task_command import _get_ti_db_access # type:
ignore[attr-defined]
- from airflow.dag_processing.manager import DagFileProcessorManager
- from airflow.dag_processing.processor import DagFileProcessor
-
- # Airflow 2.10 compatibility
- from airflow.datasets import expand_alias_to_datasets # type:
ignore[attr-defined]
- from airflow.datasets.manager import DatasetManager # type:
ignore[attr-defined]
- from airflow.jobs.job import Job, most_recent_job
- from airflow.models import Trigger, Variable, XCom
- from airflow.models.dag import DAG, DagModel
- from airflow.models.dagcode import DagCode
- from airflow.models.dagrun import DagRun
- from airflow.models.dagwarning import DagWarning
- from airflow.models.renderedtifields import RenderedTaskInstanceFields
- from airflow.models.serialized_dag import SerializedDagModel
- from airflow.models.skipmixin import SkipMixin
- from airflow.models.taskinstance import (
- TaskInstance,
- _add_log,
- _defer_task,
- _get_template_context,
- _handle_failure,
- _handle_reschedule,
- _record_task_map_for_downstreams,
- _update_rtif,
- _xcom_pull,
- )
- from airflow.models.xcom_arg import _get_task_map_length
- from airflow.providers.edge.models.edge_job import EdgeJob
- from airflow.providers.edge.models.edge_logs import EdgeLogs
- from airflow.providers.edge.models.edge_worker import EdgeWorker
- from airflow.secrets.metastore import MetastoreBackend
- from airflow.sensors.base import _orig_start_date
- from airflow.utils.cli_action_loggers import _default_action_log_internal
# type: ignore[attr-defined]
- from airflow.utils.log.file_task_handler import FileTaskHandler
-
- functions: list[Callable] = [
- _default_action_log_internal,
- _defer_task,
- _get_template_context,
- _get_ti_db_access,
- _get_task_map_length,
- _update_rtif,
- _orig_start_date,
- _handle_failure,
- _handle_reschedule,
- _add_log,
- _xcom_pull,
- _record_task_map_for_downstreams,
- trigger_dag,
- DagCode.remove_deleted_code,
- DagModel.deactivate_deleted_dags,
- DagModel.get_paused_dag_ids,
- DagModel.get_current,
- DagFileProcessor._execute_task_callbacks,
- DagFileProcessor.execute_callbacks,
- DagFileProcessor.execute_callbacks_without_dag,
- # Airflow 2.10 compatibility
- DagFileProcessor.manage_slas, # type: ignore[attr-defined]
- DagFileProcessor.save_dag_to_db,
- DagFileProcessor.update_import_errors,
- DagFileProcessor._validate_task_pools_and_update_dag_warnings,
- DagFileProcessorManager._fetch_callbacks,
- DagFileProcessorManager._get_priority_filelocs,
- DagFileProcessorManager.clear_nonexistent_import_errors,
- DagFileProcessorManager.deactivate_stale_dags,
- DagWarning.purge_inactive_dag_warnings,
- expand_alias_to_datasets,
- DatasetManager.register_dataset_change,
- FileTaskHandler._render_filename_db_access, # type:
ignore[attr-defined]
- Job._add_to_db,
- Job._fetch_from_db,
- Job._kill,
- Job._update_heartbeat,
- Job._update_in_db,
- most_recent_job,
- # Airflow 2.10 compatibility
- MetastoreBackend._fetch_connection, # type: ignore[attr-defined]
- MetastoreBackend._fetch_variable, # type: ignore[attr-defined]
- XCom.get_value,
- XCom.get_one,
- # XCom.get_many, # Not supported because it returns query
- XCom.clear,
- XCom.set,
- Variable._set,
- Variable._update,
- Variable._delete,
- DAG.fetch_callback,
- DAG.fetch_dagrun,
- DagRun.fetch_task_instances,
- DagRun.get_previous_dagrun,
- DagRun.get_previous_scheduled_dagrun,
- DagRun.get_task_instances,
- DagRun.fetch_task_instance,
- DagRun._get_log_template,
- RenderedTaskInstanceFields._update_runtime_evaluated_template_fields,
- SerializedDagModel.get_serialized_dag,
- SerializedDagModel.remove_deleted_dags,
- SkipMixin._skip,
- SkipMixin._skip_all_except,
- TaskInstance._check_and_change_state_before_execution,
- TaskInstance.get_task_instance,
- TaskInstance._get_dagrun,
- TaskInstance._set_state,
- TaskInstance.save_to_db,
- TaskInstance._clear_xcom_data,
- Trigger.from_object,
- Trigger.bulk_fetch,
- Trigger.clean_unused,
- Trigger.submit_event,
- Trigger.submit_failure,
- Trigger.ids_for_triggerer,
- Trigger.assign_unassigned,
- # Additional things from EdgeExecutor
- # These are removed in follow-up PRs as being in transition to FastAPI
- EdgeJob.reserve_task,
- EdgeJob.set_state,
- EdgeLogs.push_logs,
- EdgeWorker.register_worker,
- EdgeWorker.set_state,
- ]
- return {f"{func.__module__}.{func.__qualname__}": func for func in
functions}
-
-
def error_response(message: str, status: int):
"""Log the error and return the response as JSON object."""
error_id = uuid4()
@@ -195,6 +63,11 @@ def rpcapi_v2(body: dict[str, Any]) -> APIResponse:
# Note: Except the method map this _was_ a 100% copy of internal API module
#
airflow.api_internal.endpoints.rpc_api_endpoint.internal_airflow_api()
# As of rework for FastAPI in Airflow 3.0, this is updated and to be
removed in the future.
+ from airflow.api_internal.endpoints.rpc_api_endpoint import ( # type:
ignore[attr-defined]
+ # Note: This is just for compatibility with Airflow 2.10, not working
for Airflow 3 / main as removed
+ initialize_method_map,
+ )
+
try:
if request.headers.get("Content-Type", "") != "application/json":
raise HTTPException(status.HTTP_403_FORBIDDEN, "Expected
Content-Type: application/json")
@@ -207,7 +80,7 @@ def rpcapi_v2(body: dict[str, Any]) -> APIResponse:
raise error_response("Expected jsonrpc 2.0 request.",
status.HTTP_400_BAD_REQUEST)
log.debug("Got request for %s", request_obj.method)
- methods_map = _initialize_method_map()
+ methods_map = initialize_method_map()
if request_obj.method not in methods_map:
raise error_response(f"Unrecognized method:
{request_obj.method}.", status.HTTP_400_BAD_REQUEST)
@@ -276,15 +149,15 @@ def set_state_v2(worker_name: str, body: dict[str, Any],
session=NEW_SESSION) ->
@provide_session
-def job_fetch_v2(worker_name: str, body: dict[str, Any] | None = None,
session=NEW_SESSION) -> Any:
+def job_fetch_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION)
-> Any:
"""Handle Edge Worker API `/edge_worker/v1/jobs/fetch/{worker_name}`
endpoint for Airflow 2.10."""
from flask import request
try:
auth = request.headers.get("Authorization", "")
jwt_token_authorization(request.path, auth)
- queues = body["queues"] if body else None
- free_concurrency = body["free_concurrency"] if body else 1
+ queues = body.get("queues")
+ free_concurrency = body.get("free_concurrency", 1)
request_obj = WorkerQueuesBody(queues=queues,
free_concurrency=free_concurrency)
job: EdgeJobFetched | None = fetch(worker_name, request_obj, session)
return job.model_dump() if job is not None else None
diff --git a/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
b/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
index 289fc3eed99..99d6607b572 100644
--- a/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
+++ b/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
@@ -44,7 +44,7 @@ from airflow.utils.state import TaskInstanceState
jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")
-@jobs_router.get(
+@jobs_router.post(
"/fetch/{worker_name}",
dependencies=[Depends(jwt_token_authorization_rest)],
responses=create_openapi_http_exception_doc(
diff --git a/providers/tests/edge/models/test_edge_job.py
b/providers/tests/edge/models/test_edge_job.py
deleted file mode 100644
index bca1025a823..00000000000
--- a/providers/tests/edge/models/test_edge_job.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-import pytest
-
-from airflow.providers.edge.models.edge_job import EdgeJob, EdgeJobModel
-from airflow.utils import timezone
-from airflow.utils.state import TaskInstanceState
-
-if TYPE_CHECKING:
- from sqlalchemy.orm import Session
-
-pytestmark = pytest.mark.db_test
-pytest.importorskip("pydantic", minversion="2.0.0")
-
-
-class TestEdgeJob:
- @pytest.fixture(autouse=True)
- def setup_test_cases(self, session: Session):
- session.query(EdgeJobModel).delete()
-
- def test_reserve_task_no_job(self):
- job = EdgeJob.reserve_task("worker", free_concurrency=10)
- assert job is None
-
- @pytest.mark.parametrize(
- "concurrency_slots, free_concurrency, expected_job",
- [
- pytest.param(10, 9, False, id="less_free_concurrency"),
- pytest.param(10, 10, True, id="equal_free_concurrency"),
- pytest.param(10, 11, True, id="more_free_concurrency"),
- ],
- )
- def test_reserve_task_has_one(self, concurrency_slots, free_concurrency,
expected_job, session: Session):
- rjm = EdgeJobModel(
- dag_id="test_dag",
- task_id="test_task",
- run_id="test_run",
- map_index=-1,
- try_number=1,
- state=TaskInstanceState.QUEUED,
- queue="default",
- concurrency_slots=concurrency_slots,
- command=str(["hello", "world"]),
- queued_dttm=timezone.utcnow(),
- )
- session.add(rjm)
- session.commit()
-
- job = EdgeJob.reserve_task("worker", free_concurrency=free_concurrency)
- if expected_job:
- assert job
- assert job.edge_worker == "worker"
- assert job.queue == "default"
- assert job.dag_id == "test_dag"
- assert job.task_id == "test_task"
- assert job.run_id == "test_run"
- assert job.concurrency_slots == concurrency_slots
- else:
- assert job is None
-
- jobs: list[EdgeJobModel] = session.query(EdgeJobModel).all()
-
- assert len(jobs) == 1
- assert jobs[0].queue == "default"
- assert jobs[0].dag_id == "test_dag"
- assert jobs[0].task_id == "test_task"
- assert jobs[0].run_id == "test_run"
- assert jobs[0].concurrency_slots == concurrency_slots
-
- if expected_job:
- assert jobs[0].state == TaskInstanceState.RUNNING
- assert jobs[0].edge_worker == "worker"
- else:
- assert jobs[0].state == TaskInstanceState.QUEUED
- assert jobs[0].edge_worker is None
-
- def test_set_state(self, session: Session):
- rjm = EdgeJobModel(
- dag_id="test_dag",
- task_id="test_task",
- run_id="test_run",
- map_index=-1,
- try_number=1,
- state=TaskInstanceState.RUNNING,
- queue="default",
- concurrency_slots=5,
- command=str(["hello", "world"]),
- queued_dttm=timezone.utcnow(),
- )
- session.add(rjm)
- session.commit()
-
- EdgeJob.set_state(rjm.key, TaskInstanceState.FAILED)
-
- jobs: list[EdgeJobModel] = session.query(EdgeJobModel).all()
- assert len(jobs) == 1
- assert jobs[0].state == TaskInstanceState.FAILED
- assert jobs[0].last_update
- assert jobs[0].queue == "default"
- assert jobs[0].dag_id == "test_dag"
- assert jobs[0].task_id == "test_task"
- assert jobs[0].run_id == "test_run"
- assert jobs[0].concurrency_slots == 5