This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d98e2b0520 Migrate Edge calls for Worker to FastAPI part 4 - Cleanup 
(#44434)
0d98e2b0520 is described below

commit 0d98e2b052066c92b88a7b7d16449f4dc36d1b2a
Author: Jens Scheffler <[email protected]>
AuthorDate: Sun Dec 1 12:50:06 2024 +0100

    Migrate Edge calls for Worker to FastAPI part 4 - Cleanup (#44434)
    
    * Remove internal API bindings after migration to FastAPI
    
    * Move import to function preventing module import errors
---
 providers/src/airflow/providers/edge/CHANGELOG.rst |   8 ++
 providers/src/airflow/providers/edge/__init__.py   |   2 +-
 .../src/airflow/providers/edge/cli/api_client.py   |   4 +-
 .../src/airflow/providers/edge/models/edge_job.py  | 105 +--------------
 .../src/airflow/providers/edge/models/edge_logs.py |  81 ------------
 .../airflow/providers/edge/models/edge_worker.py   |  75 +----------
 .../providers/edge/openapi/edge_worker_api_v1.yaml |   2 +-
 providers/src/airflow/providers/edge/provider.yaml |   2 +-
 .../providers/edge/worker_api/routes/_v2_routes.py | 147 ++-------------------
 .../providers/edge/worker_api/routes/jobs.py       |   2 +-
 providers/tests/edge/models/test_edge_job.py       | 121 -----------------
 11 files changed, 26 insertions(+), 523 deletions(-)

diff --git a/providers/src/airflow/providers/edge/CHANGELOG.rst 
b/providers/src/airflow/providers/edge/CHANGELOG.rst
index 05c3d728246..ecf411a1d50 100644
--- a/providers/src/airflow/providers/edge/CHANGELOG.rst
+++ b/providers/src/airflow/providers/edge/CHANGELOG.rst
@@ -27,6 +27,14 @@
 Changelog
 ---------
 
+0.9.0pre0
+.........
+
+Misc
+~~~~
+
+* ``Remove dependency to Internal API after migration to FastAPI.``
+
 0.8.2pre0
 .........
 
diff --git a/providers/src/airflow/providers/edge/__init__.py 
b/providers/src/airflow/providers/edge/__init__.py
index d826c633ead..5c207bef66a 100644
--- a/providers/src/airflow/providers/edge/__init__.py
+++ b/providers/src/airflow/providers/edge/__init__.py
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
 
 __all__ = ["__version__"]
 
-__version__ = "0.8.2pre0"
+__version__ = "0.9.0pre0"
 
 if 
packaging.version.parse(packaging.version.parse(airflow_version).base_version) 
< packaging.version.parse(
     "2.10.0"
diff --git a/providers/src/airflow/providers/edge/cli/api_client.py 
b/providers/src/airflow/providers/edge/cli/api_client.py
index c0a0144f5fe..942577e86d2 100644
--- a/providers/src/airflow/providers/edge/cli/api_client.py
+++ b/providers/src/airflow/providers/edge/cli/api_client.py
@@ -110,7 +110,7 @@ def worker_register(
 def worker_set_state(
     hostname: str, state: EdgeWorkerState, jobs_active: int, queues: list[str] 
| None, sysinfo: dict
 ) -> list[str] | None:
-    """Register worker with the Edge API."""
+    """Update the state of the worker in the central site and thereby 
implicitly heartbeat."""
     return _make_generic_request(
         "PATCH",
         f"worker/{quote(hostname)}",
@@ -123,7 +123,7 @@ def worker_set_state(
 def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int) 
-> EdgeJobFetched | None:
     """Fetch a job to execute on the edge worker."""
     result = _make_generic_request(
-        "GET",
+        "POST",
         f"jobs/fetch/{quote(hostname)}",
         WorkerQueuesBody(queues=queues, 
free_concurrency=free_concurrency).model_dump_json(
             exclude_unset=True
diff --git a/providers/src/airflow/providers/edge/models/edge_job.py 
b/providers/src/airflow/providers/edge/models/edge_job.py
index c591b6d3052..818e7bf7f9f 100644
--- a/providers/src/airflow/providers/edge/models/edge_job.py
+++ b/providers/src/airflow/providers/edge/models/edge_job.py
@@ -16,32 +16,21 @@
 # under the License.
 from __future__ import annotations
 
-from ast import literal_eval
 from datetime import datetime
-from typing import TYPE_CHECKING, Optional
 
-from pydantic import BaseModel, ConfigDict
 from sqlalchemy import (
     Column,
     Index,
     Integer,
     String,
-    select,
     text,
 )
 
-from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.models.base import Base, StringID
 from airflow.models.taskinstancekey import TaskInstanceKey
-from airflow.serialization.serialized_objects import 
add_pydantic_class_type_mapping
 from airflow.utils import timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
-from airflow.utils.state import TaskInstanceState
-
-if TYPE_CHECKING:
-    from sqlalchemy.orm.session import Session
+from airflow.utils.sqlalchemy import UtcDateTime
 
 
 class EdgeJobModel(Base, LoggingMixin):
@@ -103,95 +92,3 @@ class EdgeJobModel(Base, LoggingMixin):
     @property
     def last_update_t(self) -> float:
         return self.last_update.timestamp()
-
-
-class EdgeJob(BaseModel, LoggingMixin):
-    """Accessor for edge jobs as logical model."""
-
-    dag_id: str
-    task_id: str
-    run_id: str
-    map_index: int
-    try_number: int
-    state: TaskInstanceState
-    queue: str
-    concurrency_slots: int
-    command: list[str]
-    queued_dttm: datetime
-    edge_worker: Optional[str]  # noqa: UP007 - prevent Sphinx failing
-    last_update: Optional[datetime]  # noqa: UP007 - prevent Sphinx failing
-    model_config = ConfigDict(from_attributes=True, 
arbitrary_types_allowed=True)
-
-    @property
-    def key(self) -> TaskInstanceKey:
-        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, 
self.try_number, self.map_index)
-
-    @staticmethod
-    @internal_api_call
-    @provide_session
-    def reserve_task(
-        worker_name: str,
-        free_concurrency: int,
-        queues: list[str] | None = None,
-        session: Session = NEW_SESSION,
-    ) -> EdgeJob | None:
-        query = (
-            select(EdgeJobModel)
-            .where(
-                EdgeJobModel.state == TaskInstanceState.QUEUED,
-                EdgeJobModel.concurrency_slots <= free_concurrency,
-            )
-            .order_by(EdgeJobModel.queued_dttm)
-        )
-        if queues:
-            query = query.where(EdgeJobModel.queue.in_(queues))
-        query = query.limit(1)
-        query = with_row_locks(query, of=EdgeJobModel, session=session, 
skip_locked=True)
-        job: EdgeJobModel = session.scalar(query)
-        if not job:
-            return None
-        job.state = TaskInstanceState.RUNNING
-        job.edge_worker = worker_name
-        job.last_update = timezone.utcnow()
-        session.commit()
-        return EdgeJob(
-            dag_id=job.dag_id,
-            task_id=job.task_id,
-            run_id=job.run_id,
-            map_index=job.map_index,
-            try_number=job.try_number,
-            state=job.state,
-            queue=job.queue,
-            concurrency_slots=job.concurrency_slots,
-            command=literal_eval(job.command),
-            queued_dttm=job.queued_dttm,
-            edge_worker=job.edge_worker,
-            last_update=job.last_update,
-        )
-
-    @staticmethod
-    @internal_api_call
-    @provide_session
-    def set_state(task: TaskInstanceKey | tuple, state: TaskInstanceState, 
session: Session = NEW_SESSION):
-        if isinstance(task, tuple):
-            task = TaskInstanceKey(*task)
-        query = select(EdgeJobModel).where(
-            EdgeJobModel.dag_id == task.dag_id,
-            EdgeJobModel.task_id == task.task_id,
-            EdgeJobModel.run_id == task.run_id,
-            EdgeJobModel.map_index == task.map_index,
-            EdgeJobModel.try_number == task.try_number,
-        )
-        job: EdgeJobModel = session.scalar(query)
-        if job:
-            job.state = state
-            job.last_update = timezone.utcnow()
-            session.commit()
-
-    def __hash__(self):
-        return 
f"{self.dag_id}|{self.task_id}|{self.run_id}|{self.map_index}|{self.try_number}".__hash__()
-
-
-EdgeJob.model_rebuild()
-
-add_pydantic_class_type_mapping("edge_job", EdgeJobModel, EdgeJob)
diff --git a/providers/src/airflow/providers/edge/models/edge_logs.py 
b/providers/src/airflow/providers/edge/models/edge_logs.py
index 49a340540b3..c2f90282228 100644
--- a/providers/src/airflow/providers/edge/models/edge_logs.py
+++ b/providers/src/airflow/providers/edge/models/edge_logs.py
@@ -17,11 +17,7 @@
 from __future__ import annotations
 
 from datetime import datetime
-from functools import lru_cache
-from pathlib import Path
-from typing import TYPE_CHECKING
 
-from pydantic import BaseModel, ConfigDict
 from sqlalchemy import (
     Column,
     Integer,
@@ -30,19 +26,10 @@ from sqlalchemy import (
 )
 from sqlalchemy.dialects.mysql import MEDIUMTEXT
 
-from airflow.api_internal.internal_api_call import internal_api_call
-from airflow.configuration import conf
 from airflow.models.base import Base, StringID
-from airflow.models.taskinstance import TaskInstance
-from airflow.models.taskinstancekey import TaskInstanceKey
-from airflow.serialization.serialized_objects import 
add_pydantic_class_type_mapping
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
 
-if TYPE_CHECKING:
-    from sqlalchemy.orm.session import Session
-
 
 class EdgeLogsModel(Base, LoggingMixin):
     """
@@ -84,71 +71,3 @@ class EdgeLogsModel(Base, LoggingMixin):
         self.log_chunk_time = log_chunk_time
         self.log_chunk_data = log_chunk_data
         super().__init__()
-
-
-class EdgeLogs(BaseModel, LoggingMixin):
-    """Deprecated Internal API for Edge Worker instances as logical model."""
-
-    dag_id: str
-    task_id: str
-    run_id: str
-    map_index: int
-    try_number: int
-    log_chunk_time: datetime
-    log_chunk_data: str
-    model_config = ConfigDict(from_attributes=True, 
arbitrary_types_allowed=True)
-
-    @staticmethod
-    @internal_api_call
-    @provide_session
-    def push_logs(
-        task: TaskInstanceKey | tuple,
-        log_chunk_time: datetime,
-        log_chunk_data: str,
-        session: Session = NEW_SESSION,
-    ) -> None:
-        """Push an incremental log chunk from Edge Worker to central site."""
-        if isinstance(task, tuple):
-            task = TaskInstanceKey(*task)
-        log_chunk = EdgeLogsModel(
-            dag_id=task.dag_id,
-            task_id=task.task_id,
-            run_id=task.run_id,
-            map_index=task.map_index,
-            try_number=task.try_number,
-            log_chunk_time=log_chunk_time,
-            log_chunk_data=log_chunk_data,
-        )
-        session.add(log_chunk)
-        # Write logs to local file to make them accessible
-        logfile_path = EdgeLogs.logfile_path(task)
-        if not logfile_path.exists():
-            new_folder_permissions = int(
-                conf.get("logging", 
"file_task_handler_new_folder_permissions", fallback="0o775"), 8
-            )
-            logfile_path.parent.mkdir(parents=True, exist_ok=True, 
mode=new_folder_permissions)
-        with logfile_path.open("a") as logfile:
-            logfile.write(log_chunk_data)
-
-    @staticmethod
-    @lru_cache
-    def logfile_path(task: TaskInstanceKey) -> Path:
-        """Elaborate the path and filename to expect from task execution."""
-        from airflow.utils.log.file_task_handler import FileTaskHandler
-
-        ti = TaskInstance.get_task_instance(
-            dag_id=task.dag_id,
-            run_id=task.run_id,
-            task_id=task.task_id,
-            map_index=task.map_index,
-        )
-        if TYPE_CHECKING:
-            assert ti
-            assert isinstance(ti, TaskInstance)
-        base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT 
AVAILABLE")
-        return Path(base_log_folder, 
FileTaskHandler(base_log_folder)._render_filename(ti, task.try_number))
-
-
-EdgeLogs.model_rebuild()
-
-add_pydantic_class_type_mapping("edge_logs", EdgeLogsModel, EdgeLogs)
diff --git a/providers/src/airflow/providers/edge/models/edge_worker.py 
b/providers/src/airflow/providers/edge/models/edge_worker.py
index 656d7539d07..2f8115507d8 100644
--- a/providers/src/airflow/providers/edge/models/edge_worker.py
+++ b/providers/src/airflow/providers/edge/models/edge_worker.py
@@ -20,24 +20,16 @@ import ast
 import json
 from datetime import datetime
 from enum import Enum
-from typing import TYPE_CHECKING, Optional
 
-from pydantic import BaseModel, ConfigDict
-from sqlalchemy import Column, Integer, String, select
+from sqlalchemy import Column, Integer, String
 
-from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.exceptions import AirflowException
 from airflow.models.base import Base
-from airflow.serialization.serialized_objects import 
add_pydantic_class_type_mapping
 from airflow.stats import Stats
 from airflow.utils import timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
 
-if TYPE_CHECKING:
-    from sqlalchemy.orm.session import Session
-
 
 class EdgeWorkerVersionException(AirflowException):
     """Signal a version mismatch between core and Edge Site."""
@@ -170,68 +162,3 @@ def reset_metrics(worker_name: str) -> None:
         free_concurrency=-1,
         queues=None,
     )
-
-
-class EdgeWorker(BaseModel, LoggingMixin):
-    """Deprecated Edge Worker internal API, keeping for one minor for graceful 
migration."""
-
-    worker_name: str
-    state: EdgeWorkerState
-    queues: Optional[list[str]]  # noqa: UP007 - prevent Sphinx failing
-    first_online: datetime
-    last_update: Optional[datetime] = None  # noqa: UP007 - prevent Sphinx 
failing
-    jobs_active: int
-    jobs_taken: int
-    jobs_success: int
-    jobs_failed: int
-    sysinfo: str
-    model_config = ConfigDict(from_attributes=True, 
arbitrary_types_allowed=True)
-
-    @staticmethod
-    @internal_api_call
-    @provide_session
-    def set_state(
-        worker_name: str,
-        state: EdgeWorkerState,
-        jobs_active: int,
-        sysinfo: dict[str, str],
-        session: Session = NEW_SESSION,
-    ) -> list[str] | None:
-        """Set state of worker and returns the current assigned queues."""
-        query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
-        worker: EdgeWorkerModel = session.scalar(query)
-        worker.state = state
-        worker.jobs_active = jobs_active
-        worker.sysinfo = json.dumps(sysinfo)
-        worker.last_update = timezone.utcnow()
-        session.commit()
-        Stats.incr(f"edge_worker.heartbeat_count.{worker_name}", 1, 1)
-        Stats.incr("edge_worker.heartbeat_count", 1, 1, tags={"worker_name": 
worker_name})
-        set_metrics(
-            worker_name=worker_name,
-            state=state,
-            jobs_active=jobs_active,
-            concurrency=int(sysinfo["concurrency"]),
-            free_concurrency=int(sysinfo["free_concurrency"]),
-            queues=worker.queues,
-        )
-        raise EdgeWorkerVersionException(
-            "Edge Worker runs on an old version. Rejecting access due to 
difference."
-        )
-
-    @staticmethod
-    @internal_api_call
-    def register_worker(
-        worker_name: str,
-        state: EdgeWorkerState,
-        queues: list[str] | None,
-        sysinfo: dict[str, str],
-    ) -> EdgeWorker:
-        raise EdgeWorkerVersionException(
-            "Edge Worker runs on an old version. Rejecting access due to 
difference."
-        )
-
-
-EdgeWorker.model_rebuild()
-
-add_pydantic_class_type_mapping("edge_worker", EdgeWorkerModel, EdgeWorker)
diff --git 
a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml 
b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
index ef1f24a2288..8e361ec0982 100644
--- a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
+++ b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
@@ -179,7 +179,7 @@ paths:
       tags:
       - Worker
   /jobs/fetch/{worker_name}:
-    get:
+    post:
       description: Fetch a job to execute on the edge worker.
       x-openapi-router-controller: 
airflow.providers.edge.worker_api.routes._v2_routes
       operationId: job_fetch_v2
diff --git a/providers/src/airflow/providers/edge/provider.yaml 
b/providers/src/airflow/providers/edge/provider.yaml
index 229f1ad68e4..1377279ab76 100644
--- a/providers/src/airflow/providers/edge/provider.yaml
+++ b/providers/src/airflow/providers/edge/provider.yaml
@@ -27,7 +27,7 @@ source-date-epoch: 1729683247
 
 # note that those versions are maintained by release manager - do not update 
them manually
 versions:
-  - 0.8.2pre0
+  - 0.9.0pre0
 
 dependencies:
   - apache-airflow>=2.10.0
diff --git 
a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py 
b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
index b00f5cf41b9..2b68879531d 100644
--- a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
+++ b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
@@ -20,8 +20,7 @@ from __future__ import annotations
 
 import json
 import logging
-from functools import cache
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Any
 from uuid import uuid4
 
 from flask import Response, request
@@ -50,137 +49,6 @@ if TYPE_CHECKING:
 log = logging.getLogger(__name__)
 
 
-@cache
-def _initialize_method_map() -> dict[str, Callable]:
-    # Note: This is a copy of the (removed) AIP-44 implementation from
-    #       airflow/api_internal/endpoints/rpc_api_endpoint.py
-    #       for compatibility with Airflow 2.10-line.
-    #       Methods are potentially not existing more on main branch for 
Airflow 3.
-    from airflow.api.common.trigger_dag import trigger_dag
-    from airflow.cli.commands.task_command import _get_ti_db_access  # type: 
ignore[attr-defined]
-    from airflow.dag_processing.manager import DagFileProcessorManager
-    from airflow.dag_processing.processor import DagFileProcessor
-
-    # Airflow 2.10 compatibility
-    from airflow.datasets import expand_alias_to_datasets  # type: 
ignore[attr-defined]
-    from airflow.datasets.manager import DatasetManager  # type: 
ignore[attr-defined]
-    from airflow.jobs.job import Job, most_recent_job
-    from airflow.models import Trigger, Variable, XCom
-    from airflow.models.dag import DAG, DagModel
-    from airflow.models.dagcode import DagCode
-    from airflow.models.dagrun import DagRun
-    from airflow.models.dagwarning import DagWarning
-    from airflow.models.renderedtifields import RenderedTaskInstanceFields
-    from airflow.models.serialized_dag import SerializedDagModel
-    from airflow.models.skipmixin import SkipMixin
-    from airflow.models.taskinstance import (
-        TaskInstance,
-        _add_log,
-        _defer_task,
-        _get_template_context,
-        _handle_failure,
-        _handle_reschedule,
-        _record_task_map_for_downstreams,
-        _update_rtif,
-        _xcom_pull,
-    )
-    from airflow.models.xcom_arg import _get_task_map_length
-    from airflow.providers.edge.models.edge_job import EdgeJob
-    from airflow.providers.edge.models.edge_logs import EdgeLogs
-    from airflow.providers.edge.models.edge_worker import EdgeWorker
-    from airflow.secrets.metastore import MetastoreBackend
-    from airflow.sensors.base import _orig_start_date
-    from airflow.utils.cli_action_loggers import _default_action_log_internal  
# type: ignore[attr-defined]
-    from airflow.utils.log.file_task_handler import FileTaskHandler
-
-    functions: list[Callable] = [
-        _default_action_log_internal,
-        _defer_task,
-        _get_template_context,
-        _get_ti_db_access,
-        _get_task_map_length,
-        _update_rtif,
-        _orig_start_date,
-        _handle_failure,
-        _handle_reschedule,
-        _add_log,
-        _xcom_pull,
-        _record_task_map_for_downstreams,
-        trigger_dag,
-        DagCode.remove_deleted_code,
-        DagModel.deactivate_deleted_dags,
-        DagModel.get_paused_dag_ids,
-        DagModel.get_current,
-        DagFileProcessor._execute_task_callbacks,
-        DagFileProcessor.execute_callbacks,
-        DagFileProcessor.execute_callbacks_without_dag,
-        # Airflow 2.10 compatibility
-        DagFileProcessor.manage_slas,  # type: ignore[attr-defined]
-        DagFileProcessor.save_dag_to_db,
-        DagFileProcessor.update_import_errors,
-        DagFileProcessor._validate_task_pools_and_update_dag_warnings,
-        DagFileProcessorManager._fetch_callbacks,
-        DagFileProcessorManager._get_priority_filelocs,
-        DagFileProcessorManager.clear_nonexistent_import_errors,
-        DagFileProcessorManager.deactivate_stale_dags,
-        DagWarning.purge_inactive_dag_warnings,
-        expand_alias_to_datasets,
-        DatasetManager.register_dataset_change,
-        FileTaskHandler._render_filename_db_access,  # type: 
ignore[attr-defined]
-        Job._add_to_db,
-        Job._fetch_from_db,
-        Job._kill,
-        Job._update_heartbeat,
-        Job._update_in_db,
-        most_recent_job,
-        # Airflow 2.10 compatibility
-        MetastoreBackend._fetch_connection,  # type: ignore[attr-defined]
-        MetastoreBackend._fetch_variable,  # type: ignore[attr-defined]
-        XCom.get_value,
-        XCom.get_one,
-        # XCom.get_many, # Not supported because it returns query
-        XCom.clear,
-        XCom.set,
-        Variable._set,
-        Variable._update,
-        Variable._delete,
-        DAG.fetch_callback,
-        DAG.fetch_dagrun,
-        DagRun.fetch_task_instances,
-        DagRun.get_previous_dagrun,
-        DagRun.get_previous_scheduled_dagrun,
-        DagRun.get_task_instances,
-        DagRun.fetch_task_instance,
-        DagRun._get_log_template,
-        RenderedTaskInstanceFields._update_runtime_evaluated_template_fields,
-        SerializedDagModel.get_serialized_dag,
-        SerializedDagModel.remove_deleted_dags,
-        SkipMixin._skip,
-        SkipMixin._skip_all_except,
-        TaskInstance._check_and_change_state_before_execution,
-        TaskInstance.get_task_instance,
-        TaskInstance._get_dagrun,
-        TaskInstance._set_state,
-        TaskInstance.save_to_db,
-        TaskInstance._clear_xcom_data,
-        Trigger.from_object,
-        Trigger.bulk_fetch,
-        Trigger.clean_unused,
-        Trigger.submit_event,
-        Trigger.submit_failure,
-        Trigger.ids_for_triggerer,
-        Trigger.assign_unassigned,
-        # Additional things from EdgeExecutor
-        # These are removed in follow-up PRs as being in transition to FastAPI
-        EdgeJob.reserve_task,
-        EdgeJob.set_state,
-        EdgeLogs.push_logs,
-        EdgeWorker.register_worker,
-        EdgeWorker.set_state,
-    ]
-    return {f"{func.__module__}.{func.__qualname__}": func for func in 
functions}
-
-
 def error_response(message: str, status: int):
     """Log the error and return the response as JSON object."""
     error_id = uuid4()
@@ -195,6 +63,11 @@ def rpcapi_v2(body: dict[str, Any]) -> APIResponse:
     # Note: Except the method map this _was_ a 100% copy of internal API module
     #       
airflow.api_internal.endpoints.rpc_api_endpoint.internal_airflow_api()
     # As of rework for FastAPI in Airflow 3.0, this is updated and to be 
removed in the future.
+    from airflow.api_internal.endpoints.rpc_api_endpoint import (  # type: 
ignore[attr-defined]
+        # Note: This is just for compatibility with Airflow 2.10, not working 
for Airflow 3 / main as removed
+        initialize_method_map,
+    )
+
     try:
         if request.headers.get("Content-Type", "") != "application/json":
             raise HTTPException(status.HTTP_403_FORBIDDEN, "Expected 
Content-Type: application/json")
@@ -207,7 +80,7 @@ def rpcapi_v2(body: dict[str, Any]) -> APIResponse:
             raise error_response("Expected jsonrpc 2.0 request.", 
status.HTTP_400_BAD_REQUEST)
 
         log.debug("Got request for %s", request_obj.method)
-        methods_map = _initialize_method_map()
+        methods_map = initialize_method_map()
         if request_obj.method not in methods_map:
             raise error_response(f"Unrecognized method: 
{request_obj.method}.", status.HTTP_400_BAD_REQUEST)
 
@@ -276,15 +149,15 @@ def set_state_v2(worker_name: str, body: dict[str, Any], 
session=NEW_SESSION) ->
 
 
 @provide_session
-def job_fetch_v2(worker_name: str, body: dict[str, Any] | None = None, 
session=NEW_SESSION) -> Any:
+def job_fetch_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) 
-> Any:
     """Handle Edge Worker API `/edge_worker/v1/jobs/fetch/{worker_name}` 
endpoint for Airflow 2.10."""
     from flask import request
 
     try:
         auth = request.headers.get("Authorization", "")
         jwt_token_authorization(request.path, auth)
-        queues = body["queues"] if body else None
-        free_concurrency = body["free_concurrency"] if body else 1
+        queues = body.get("queues")
+        free_concurrency = body.get("free_concurrency", 1)
         request_obj = WorkerQueuesBody(queues=queues, 
free_concurrency=free_concurrency)
         job: EdgeJobFetched | None = fetch(worker_name, request_obj, session)
         return job.model_dump() if job is not None else None
diff --git a/providers/src/airflow/providers/edge/worker_api/routes/jobs.py 
b/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
index 289fc3eed99..99d6607b572 100644
--- a/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
+++ b/providers/src/airflow/providers/edge/worker_api/routes/jobs.py
@@ -44,7 +44,7 @@ from airflow.utils.state import TaskInstanceState
 jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")
 
 
-@jobs_router.get(
+@jobs_router.post(
     "/fetch/{worker_name}",
     dependencies=[Depends(jwt_token_authorization_rest)],
     responses=create_openapi_http_exception_doc(
diff --git a/providers/tests/edge/models/test_edge_job.py 
b/providers/tests/edge/models/test_edge_job.py
deleted file mode 100644
index bca1025a823..00000000000
--- a/providers/tests/edge/models/test_edge_job.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-import pytest
-
-from airflow.providers.edge.models.edge_job import EdgeJob, EdgeJobModel
-from airflow.utils import timezone
-from airflow.utils.state import TaskInstanceState
-
-if TYPE_CHECKING:
-    from sqlalchemy.orm import Session
-
-pytestmark = pytest.mark.db_test
-pytest.importorskip("pydantic", minversion="2.0.0")
-
-
-class TestEdgeJob:
-    @pytest.fixture(autouse=True)
-    def setup_test_cases(self, session: Session):
-        session.query(EdgeJobModel).delete()
-
-    def test_reserve_task_no_job(self):
-        job = EdgeJob.reserve_task("worker", free_concurrency=10)
-        assert job is None
-
-    @pytest.mark.parametrize(
-        "concurrency_slots, free_concurrency, expected_job",
-        [
-            pytest.param(10, 9, False, id="less_free_concurrency"),
-            pytest.param(10, 10, True, id="equal_free_concurrency"),
-            pytest.param(10, 11, True, id="more_free_concurrency"),
-        ],
-    )
-    def test_reserve_task_has_one(self, concurrency_slots, free_concurrency, 
expected_job, session: Session):
-        rjm = EdgeJobModel(
-            dag_id="test_dag",
-            task_id="test_task",
-            run_id="test_run",
-            map_index=-1,
-            try_number=1,
-            state=TaskInstanceState.QUEUED,
-            queue="default",
-            concurrency_slots=concurrency_slots,
-            command=str(["hello", "world"]),
-            queued_dttm=timezone.utcnow(),
-        )
-        session.add(rjm)
-        session.commit()
-
-        job = EdgeJob.reserve_task("worker", free_concurrency=free_concurrency)
-        if expected_job:
-            assert job
-            assert job.edge_worker == "worker"
-            assert job.queue == "default"
-            assert job.dag_id == "test_dag"
-            assert job.task_id == "test_task"
-            assert job.run_id == "test_run"
-            assert job.concurrency_slots == concurrency_slots
-        else:
-            assert job is None
-
-        jobs: list[EdgeJobModel] = session.query(EdgeJobModel).all()
-
-        assert len(jobs) == 1
-        assert jobs[0].queue == "default"
-        assert jobs[0].dag_id == "test_dag"
-        assert jobs[0].task_id == "test_task"
-        assert jobs[0].run_id == "test_run"
-        assert jobs[0].concurrency_slots == concurrency_slots
-
-        if expected_job:
-            assert jobs[0].state == TaskInstanceState.RUNNING
-            assert jobs[0].edge_worker == "worker"
-        else:
-            assert jobs[0].state == TaskInstanceState.QUEUED
-            assert jobs[0].edge_worker is None
-
-    def test_set_state(self, session: Session):
-        rjm = EdgeJobModel(
-            dag_id="test_dag",
-            task_id="test_task",
-            run_id="test_run",
-            map_index=-1,
-            try_number=1,
-            state=TaskInstanceState.RUNNING,
-            queue="default",
-            concurrency_slots=5,
-            command=str(["hello", "world"]),
-            queued_dttm=timezone.utcnow(),
-        )
-        session.add(rjm)
-        session.commit()
-
-        EdgeJob.set_state(rjm.key, TaskInstanceState.FAILED)
-
-        jobs: list[EdgeJobModel] = session.query(EdgeJobModel).all()
-        assert len(jobs) == 1
-        assert jobs[0].state == TaskInstanceState.FAILED
-        assert jobs[0].last_update
-        assert jobs[0].queue == "default"
-        assert jobs[0].dag_id == "test_dag"
-        assert jobs[0].task_id == "test_task"
-        assert jobs[0].run_id == "test_run"
-        assert jobs[0].concurrency_slots == 5

Reply via email to