This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ee6d8dd7aae AIP 67 - Multi-Team: Update Edge Executor to support multi
team (#61646)
ee6d8dd7aae is described below
commit ee6d8dd7aae24fc0dfce0c1c98f391c8d4b72130
Author: Jeongwoo Do <[email protected]>
AuthorDate: Tue Mar 31 05:58:01 2026 +0900
AIP 67 - Multi-Team: Update Edge Executor to support multi team (#61646)
* Update Edge Executor to support multi team
* fix logics for multi-team
* add the contents of multi-team in docs
* fix query logic
* fix lint
* fix for test
* fix docs and db manager logic
* fix logic
* fix migration order
* fix mypy test error
* fix migration order
* fix for compatiblity
* update docs
* fix logic
* fix migration version
* fix
---
providers/edge3/docs/architecture.rst | 3 +-
providers/edge3/docs/deployment.rst | 15 +-
providers/edge3/docs/edge_executor.rst | 60 ++++++
providers/edge3/docs/migrations-ref.rst | 20 +-
.../src/airflow/providers/edge3/cli/api_client.py | 27 ++-
.../src/airflow/providers/edge3/cli/definition.py | 8 +
.../airflow/providers/edge3/cli/edge_command.py | 3 +-
.../src/airflow/providers/edge3/cli/worker.py | 75 +++++---
.../providers/edge3/executors/edge_executor.py | 41 ++--
.../versions/0004_3_4_0_add_team_name_column.py | 60 ++++++
.../edge3/src/airflow/providers/edge3/models/db.py | 1 +
.../src/airflow/providers/edge3/models/edge_job.py | 3 +
.../airflow/providers/edge3/models/edge_worker.py | 27 ++-
.../src/airflow/providers/edge3/version_compat.py | 2 +
.../providers/edge3/worker_api/datamodels.py | 7 +
.../providers/edge3/worker_api/routes/jobs.py | 1 +
.../providers/edge3/worker_api/routes/ui.py | 23 +--
.../providers/edge3/worker_api/routes/worker.py | 5 +-
.../edge3/worker_api/v2-edge-generated.yaml | 21 +++
.../edge3/tests/unit/edge3/cli/test_definition.py | 19 +-
.../edge3/tests/unit/edge3/cli/test_worker.py | 43 +++--
.../unit/edge3/executors/test_edge_executor.py | 207 +++++++++++++++++++++
providers/edge3/tests/unit/edge3/models/test_db.py | 8 +-
.../unit/edge3/worker_api/routes/test_jobs.py | 103 +++++++++-
.../unit/edge3/worker_api/routes/test_worker.py | 76 +++++++-
25 files changed, 754 insertions(+), 104 deletions(-)
diff --git a/providers/edge3/docs/architecture.rst
b/providers/edge3/docs/architecture.rst
index 5a1f7bb070a..e9e586183cf 100644
--- a/providers/edge3/docs/architecture.rst
+++ b/providers/edge3/docs/architecture.rst
@@ -150,7 +150,8 @@ The current version of the EdgeExecutor is released with
known limitations. It w
The following features are known missing and will be implemented in increments:
-- API token per worker: Today there is a global API token available only
+- Per-worker or per-team authentication tokens: Today a single shared secret
is used for all
+ workers and teams, so multi-team isolation is logical only (see
:ref:`edge_executor:multi_team`).
- Edge Worker Plugin
- Overview about queues / jobs per queue
diff --git a/providers/edge3/docs/deployment.rst
b/providers/edge3/docs/deployment.rst
index 6c78033c0a4..c843349c50e 100644
--- a/providers/edge3/docs/deployment.rst
+++ b/providers/edge3/docs/deployment.rst
@@ -106,7 +106,11 @@ run on the central Airflow instance:
airflow db migrate
To kick off a worker, you need to setup Airflow and kick off the worker
-subcommand
+subcommand.
+
+If your Airflow deployment uses Multi-Team mode, assign the worker to its team
with
+the ``--team-name`` option so it only picks up jobs for that team. See
+:ref:`edge_executor:multi_team` for setup details and security considerations.
.. code-block:: bash
@@ -126,6 +130,12 @@ subcommand
2025-09-27T12:28:33.171525Z [info ] No new job to process
+To start a worker assigned to a specific team:
+
+.. code-block:: bash
+
+ airflow edge worker --team-name team_a -q remote,wisconsin_site
+
You can also start this worker in the background by running
it as a daemonized process. Additionally, you can redirect stdout
and stderr to their respective files.
@@ -245,3 +255,6 @@ instance. The commands are:
- ``airflow edge add-worker-queues``: Add queues to an edge worker
- ``airflow edge remove-worker-queues``: Remove queues from an edge worker
- ``airflow edge set-worker-concurrency``: Set the concurrency of a running
remote edge worker
+
+Workers are identified by hostname. See the :doc:`cli-ref` for the full list of
+arguments.
diff --git a/providers/edge3/docs/edge_executor.rst
b/providers/edge3/docs/edge_executor.rst
index 64380f0bf38..73182dd16c1 100644
--- a/providers/edge3/docs/edge_executor.rst
+++ b/providers/edge3/docs/edge_executor.rst
@@ -59,6 +59,63 @@ When using EdgeExecutor in addition to other executors and
EdgeExecutor not bein
as the executor at task or Dag level in addition to the queues you are
targeting.
For more details on multiple executors please see
:ref:`apache-airflow:using-multiple-executors-concurrently`.
+.. _edge_executor:multi_team:
+
+Multi-Team Support
+------------------
+
+When multiple teams share a single Airflow deployment, each team may need its
own
+set of edge workers — for example, separate on-premise sites, different
geographic
+regions, or isolated execution environments. The EdgeExecutor integrates with
+Airflow's Multi-Team mode so that each team's edge jobs and workers are kept
separate.
+
+To use multi-team with the EdgeExecutor, first enable Multi-Team mode in your
+Airflow deployment and create the teams you need. Then configure the
+EdgeExecutor for each team in your ``airflow.cfg``:
+
+.. code-block:: ini
+
+ [core]
+ multi_team = True
+ executor = EdgeExecutor;team_a=EdgeExecutor;team_b=EdgeExecutor
+
+With this configuration, the scheduler runs a dedicated EdgeExecutor instance
per
+team. Each instance only schedules and monitors jobs belonging to its own
team, and
+each worker only picks up jobs assigned to its team.
+
+**Starting a worker for a specific team:**
+
+.. code-block:: bash
+
+ airflow edge worker --team-name team_a -q queue1,queue2
+
+When ``--team-name`` is omitted, the worker operates without team isolation —
the
+same behavior as a single-team deployment. Existing workers continue to work
without
+any changes.
+
+**Per-team configuration overrides:**
+
+Each team's EdgeExecutor can have its own settings. Use environment variables
with
+the ``AIRFLOW__<TEAM_NAME>___<SECTION>__<KEY>`` pattern (triple underscore
between
+team name and section):
+
+.. code-block:: bash
+
+ # Set a longer heartbeat interval for team_a's edge workers
+ export AIRFLOW__TEAM_A___EDGE__HEARTBEAT_INTERVAL=30
+
+ # Point team_b's workers to a different API endpoint
+ export
AIRFLOW__TEAM_B___EDGE__API_URL=https://team-b-api.example.com/edge_worker/v1/rpcapi
+
+.. warning::
+
+ **Security limitation:** Multi-team in the EdgeExecutor provides **logical
+ isolation only**. Worker management CLI commands (maintenance, shutdown,
remove,
+ etc.) operate without team distinction — any administrator can manage any
+ worker regardless of its team. Treat multi-team as an organizational
separation
+ for trusted administrators, not as a security boundary. Per-team
authentication
+ tokens are planned for a future release.
+
.. _edge_executor:concurrency_slots:
Concurrency slot handling
@@ -115,3 +172,6 @@ Current Limitations Edge Executor
- Performance: No extensive performance assessment and scaling tests have
been made. The edge executor package is
optimized for stability. This will be incrementally improved in future
releases. Setups have reported stable
operation with ~80 workers until now. Note that executed tasks require
more api-server / webserver API capacity.
+ - Multi-team isolation is logical only — all teams share a single
authentication secret. A worker
+ administrator could change the team name and access another team's jobs.
See
+ :ref:`edge_executor:multi_team` for details and planned improvements.
diff --git a/providers/edge3/docs/migrations-ref.rst
b/providers/edge3/docs/migrations-ref.rst
index a89544f56ff..f1b0c08573f 100644
--- a/providers/edge3/docs/migrations-ref.rst
+++ b/providers/edge3/docs/migrations-ref.rst
@@ -31,15 +31,17 @@ Here's the list of all the Database Migrations that are
executed via when you ru
.. All table elements are scraped from migration files
.. Beginning of auto-generated table
-+-------------------------+------------------+-----------------+----------------------------------------------+
-| Revision ID | Revises ID | Edge3 Version | Description
|
-+=========================+==================+=================+==============================================+
-| ``8c275b6fbaa8`` (head) | ``b3c4d5e6f7a8`` | ``3.2.0`` | Fix migration
file/ORM inconsistencies. |
-+-------------------------+------------------+-----------------+----------------------------------------------+
-| ``b3c4d5e6f7a8`` | ``9d34dfc2de06`` | ``3.2.0`` | Add
concurrency column to edge_worker table. |
-+-------------------------+------------------+-----------------+----------------------------------------------+
-| ``9d34dfc2de06`` (base) | ``None`` | ``3.0.0`` | Create Edge
tables if missing. |
-+-------------------------+------------------+-----------------+----------------------------------------------+
++-------------------------+------------------+-----------------+----------------------------------------------------------+
+| Revision ID | Revises ID | Edge3 Version | Description
|
++=========================+==================+=================+==========================================================+
+| ``a09c3ee8e1d3`` (head) | ``8c275b6fbaa8`` | ``3.4.0`` | Add team_name
column to edge_job and edge_worker tables. |
++-------------------------+------------------+-----------------+----------------------------------------------------------+
+| ``8c275b6fbaa8`` | ``b3c4d5e6f7a8`` | ``3.2.0`` | Fix migration
file/ORM inconsistencies. |
++-------------------------+------------------+-----------------+----------------------------------------------------------+
+| ``b3c4d5e6f7a8`` | ``9d34dfc2de06`` | ``3.2.0`` | Add
concurrency column to edge_worker table. |
++-------------------------+------------------+-----------------+----------------------------------------------------------+
+| ``9d34dfc2de06`` (base) | ``None`` | ``3.0.0`` | Create Edge
tables if missing. |
++-------------------------+------------------+-----------------+----------------------------------------------------------+
.. End of auto-generated table
diff --git a/providers/edge3/src/airflow/providers/edge3/cli/api_client.py
b/providers/edge3/src/airflow/providers/edge3/cli/api_client.py
index a5124800ba5..1fd53245e19 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/api_client.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/api_client.py
@@ -112,16 +112,20 @@ async def _make_generic_request(method: str, rest_path:
str, data: str | None =
async def worker_register(
- hostname: str, state: EdgeWorkerState, queues: list[str] | None, sysinfo:
dict
+ hostname: str,
+ state: EdgeWorkerState,
+ queues: list[str] | None,
+ sysinfo: dict,
+ team_name: str | None = None,
) -> WorkerRegistrationReturn:
"""Register worker with the Edge API."""
try:
result = await _make_generic_request(
"POST",
f"worker/{quote(hostname)}",
- WorkerStateBody(state=state, jobs_active=0, queues=queues,
sysinfo=sysinfo).model_dump_json(
- exclude_unset=True
- ),
+ WorkerStateBody(
+ state=state, jobs_active=0, queues=queues, sysinfo=sysinfo,
team_name=team_name
+ ).model_dump_json(exclude_unset=True),
)
except ClientResponseError as e:
if e.status == HTTPStatus.BAD_REQUEST:
@@ -142,6 +146,7 @@ async def worker_set_state(
queues: list[str] | None,
sysinfo: dict,
maintenance_comments: str | None = None,
+ team_name: str | None = None,
) -> WorkerSetStateReturn:
"""Update the state of the worker in the central site and thereby
implicitly heartbeat."""
try:
@@ -154,6 +159,7 @@ async def worker_set_state(
queues=queues,
sysinfo=sysinfo,
maintenance_comments=maintenance_comments,
+ team_name=team_name,
).model_dump_json(exclude_unset=True),
)
except ClientResponseError as e:
@@ -163,14 +169,19 @@ async def worker_set_state(
return WorkerSetStateReturn(**result)
-async def jobs_fetch(hostname: str, queues: list[str] | None,
free_concurrency: int) -> EdgeJobFetched | None:
+async def jobs_fetch(
+ hostname: str,
+ queues: list[str] | None,
+ free_concurrency: int,
+ team_name: str | None = None,
+) -> EdgeJobFetched | None:
"""Fetch a job to execute on the edge worker."""
result = await _make_generic_request(
"POST",
f"jobs/fetch/{quote(hostname)}",
- WorkerQueuesBody(queues=queues,
free_concurrency=free_concurrency).model_dump_json(
- exclude_unset=True
- ),
+ WorkerQueuesBody(
+ queues=queues, free_concurrency=free_concurrency,
team_name=team_name
+ ).model_dump_json(exclude_unset=True),
)
if result:
return EdgeJobFetched(**result)
diff --git a/providers/edge3/src/airflow/providers/edge3/cli/definition.py
b/providers/edge3/src/airflow/providers/edge3/cli/definition.py
index ea8dd30de2e..037cc5d7884 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/definition.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/definition.py
@@ -35,6 +35,13 @@ ARG_QUEUES = Arg(
("-q", "--queues"),
help="Comma delimited list of queues to serve, serve all queues if not
provided.",
)
+ARG_TEAM_NAME = Arg(
+ (
+ "-t",
+ "--team-name",
+ ),
+ help="Team name for multi-team setups. If not provided, worker operates
without team isolation.",
+)
ARG_EDGE_HOSTNAME = Arg(
("-H", "--edge-hostname"),
help="Set the hostname of worker if you have multiple workers on a single
machine",
@@ -121,6 +128,7 @@ EDGE_COMMANDS: list[ActionCommand] = [
args=(
ARG_CONCURRENCY,
ARG_QUEUES,
+ ARG_TEAM_NAME,
ARG_EDGE_HOSTNAME,
ARG_PID,
ARG_VERBOSE,
diff --git a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
index 07aa8cf5566..3b3fb6de23d 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
@@ -95,9 +95,8 @@ def _launch_worker(args):
hostname=args.edge_hostname or getfqdn(),
queues=args.queues.split(",") if args.queues else None,
concurrency=args.concurrency,
- job_poll_interval=conf.getint("edge", "job_poll_interval"),
- heartbeat_interval=conf.getint("edge", "heartbeat_interval"),
daemon=args.daemon,
+ team_name=getattr(args, "team_name", None),
)
asyncio.run(edge_worker.start())
diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py
b/providers/edge3/src/airflow/providers/edge3/cli/worker.py
index f0cffa96c05..de0ebc1432a 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py
@@ -23,7 +23,7 @@ import sys
import traceback
from asyncio import Task, create_task, get_running_loop, sleep
from datetime import datetime
-from functools import cache
+from functools import cached_property
from http import HTTPStatus
from multiprocessing import Process, Queue
from pathlib import Path
@@ -56,16 +56,15 @@ from airflow.providers.edge3.models.edge_worker import (
EdgeWorkerState,
EdgeWorkerVersionException,
)
+from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS
from airflow.utils.net import getfqdn
from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
+ from airflow.configuration import AirflowConfigParser
from airflow.executors.workloads import ExecuteTask
logger = logging.getLogger(__name__)
-base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT
AVAILABLE")
-push_logs = conf.getboolean("edge", "push_logs")
-push_log_chunk_size = conf.getint("edge", "push_log_chunk_size")
if sys.platform == "darwin":
setproctitle = lambda title: logger.debug("Mac OS detected, skipping
setproctitle")
@@ -78,18 +77,6 @@ def _edge_hostname() -> str:
return os.environ.get("HOSTNAME", getfqdn())
-@cache
-def _execution_api_server_url() -> str:
- """Get the execution api server url from config or environment."""
- execution_api_server_url = conf.get("core", "execution_api_server_url",
fallback="")
- if not execution_api_server_url:
- # Derive execution api url from edge api url as fallback
- api_url = conf.get("edge", "api_url")
- execution_api_server_url = api_url.replace("edge_worker/v1/rpcapi",
"execution")
- logger.info("Using execution api server url: %s", execution_api_server_url)
- return execution_api_server_url
-
-
class EdgeWorker:
"""Runner instance which executes the Edge Worker."""
@@ -109,17 +96,47 @@ class EdgeWorker:
hostname: str,
queues: list[str] | None,
concurrency: int,
- job_poll_interval: int,
- heartbeat_interval: int,
daemon: bool = False,
+ team_name: str | None = None,
):
self.pid_file_path = pid_file_path
- self.job_poll_interval = job_poll_interval
- self.hb_interval = heartbeat_interval
self.hostname = hostname
self.queues = queues
self.concurrency = concurrency
self.daemon = daemon
+ self.team_name = team_name
+
+ if TYPE_CHECKING:
+ self.conf: ExecutorConf | AirflowConfigParser
+
+ if AIRFLOW_V_3_2_PLUS:
+ from airflow.executors.base_executor import ExecutorConf
+
+ self.conf = ExecutorConf(team_name)
+
+ else:
+ self.conf = conf
+
+ self.job_poll_interval = self.conf.getint("edge", "job_poll_interval")
+ self.hb_interval = self.conf.getint("edge", "heartbeat_interval")
+ self.base_log_folder: str = (
+ self.conf.get("logging", "base_log_folder", fallback="NOT
AVAILABLE") or ""
+ )
+ self.push_logs = self.conf.getboolean("edge", "push_logs")
+ self.push_log_chunk_size = self.conf.getint("edge",
"push_log_chunk_size")
+
+ @cached_property
+ def _execution_api_server_url(self) -> str | None:
+ """Get the execution api server url from config or environment."""
+ execution_api_server_url = self.conf.get("core",
"execution_api_server_url", fallback="")
+ if not execution_api_server_url:
+ # Derive execution api url from edge api url as fallback
+ api_url = self.conf.get("edge", "api_url")
+ execution_api_server_url = (
+ api_url.replace("edge_worker/v1/rpcapi", "execution") if
api_url is not None else None
+ )
+ logger.info("Using execution api server url: %s",
execution_api_server_url)
+ return execution_api_server_url
@property
def free_concurrency(self) -> int:
@@ -217,7 +234,7 @@ class EdgeWorker:
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
- server=_execution_api_server_url(),
+ server=self._execution_api_server_url,
log_path=workload.log_path,
)
return 0
@@ -239,7 +256,7 @@ class EdgeWorker:
async def _push_logs_in_chunks(self, job: Job):
aio_logfile = anyio.Path(job.logfile)
- if push_logs and await aio_logfile.exists() and (await
aio_logfile.stat()).st_size > job.logsize:
+ if self.push_logs and await aio_logfile.exists() and (await
aio_logfile.stat()).st_size > job.logsize:
async with aio_open(job.logfile, mode="rb") as logf:
await logf.seek(job.logsize, os.SEEK_SET)
read_data = await logf.read()
@@ -248,8 +265,8 @@ class EdgeWorker:
# replace null with question mark to fix issue during DB push
log_data =
read_data.decode(errors="backslashreplace").replace("\x00", "\ufffd")
while True:
- chunk_data = log_data[:push_log_chunk_size]
- log_data = log_data[push_log_chunk_size:]
+ chunk_data = log_data[: self.push_log_chunk_size]
+ log_data = log_data[self.push_log_chunk_size :]
if not chunk_data:
break
@@ -262,7 +279,9 @@ class EdgeWorker:
async def start(self):
"""Start the execution in a loop until terminated."""
try:
- await worker_register(self.hostname, EdgeWorkerState.STARTING,
self.queues, self._get_sysinfo())
+ await worker_register(
+ self.hostname, EdgeWorkerState.STARTING, self.queues,
self._get_sysinfo(), self.team_name
+ )
except EdgeWorkerVersionException as e:
logger.info("Version mismatch of Edge worker and Core. Shutting
down worker.")
raise SystemExit(str(e))
@@ -296,6 +315,7 @@ class EdgeWorker:
0,
self.queues,
self._get_sysinfo(),
+ team_name=self.team_name,
)
except EdgeWorkerVersionException:
logger.info("Version mismatch of Edge worker and Core.
Quitting worker anyway.")
@@ -333,7 +353,7 @@ class EdgeWorker:
async def fetch_and_run_job(self) -> None:
"""Fetch, start and monitor a new job."""
logger.debug("Attempting to fetch a new job...")
- edge_job = await jobs_fetch(self.hostname, self.queues,
self.free_concurrency)
+ edge_job = await jobs_fetch(self.hostname, self.queues,
self.free_concurrency, self.team_name)
if not edge_job:
logger.info(
"No new job to process%s",
@@ -347,7 +367,7 @@ class EdgeWorker:
process, results_queue = self._launch_job(workload)
if TYPE_CHECKING:
assert workload.log_path # We need to assume this is defined in
here
- logfile = Path(base_log_folder, workload.log_path)
+ logfile = Path(self.base_log_folder, workload.log_path)
job = Job(edge_job, process, logfile)
self.jobs.append(job)
await jobs_set_state(edge_job.key, TaskInstanceState.RUNNING)
@@ -398,6 +418,7 @@ class EdgeWorker:
self.queues,
sysinfo,
new_maintenance_comments,
+ team_name=self.team_name,
)
self.queues = worker_info.queues
if worker_info.concurrency is not None and worker_info.concurrency
!= self.concurrency:
diff --git
a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
index bef49278c6a..22c93528fb2 100644
--- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
+++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py
@@ -27,7 +27,7 @@ from sqlalchemy import delete, select
from airflow.executors import workloads
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstance
-from airflow.providers.common.compat.sdk import Stats, conf, timezone
+from airflow.providers.common.compat.sdk import Stats, timezone
from airflow.providers.edge3.models.db import EdgeDBManager,
check_db_manager_config
from airflow.providers.edge3.models.edge_job import EdgeJobModel
from airflow.providers.edge3.models.edge_logs import EdgeLogsModel
@@ -47,17 +47,31 @@ if TYPE_CHECKING:
# Task tuple to send to be executed
TaskTuple = tuple[TaskInstanceKey, CommandType, str | None, Any | None]
-PARALLELISM: int = conf.getint("core", "PARALLELISM")
-DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue")
-
class EdgeExecutor(BaseExecutor):
"""Implementation of the EdgeExecutor to distribute work to Edge Workers
via HTTP."""
- def __init__(self, parallelism: int = PARALLELISM):
- super().__init__(parallelism=parallelism)
+ supports_multi_team: bool = True
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
self.last_reported_state: dict[TaskInstanceKey, TaskInstanceState] = {}
+ # Check if self has the ExecutorConf set on the self.conf attribute
with all required methods.
+ # In Airflow 2.x, ExecutorConf exists but lacks methods like getint,
getboolean, getsection, etc.
+ # In such cases, fall back to the global configuration object.
+ # This allows the changes to be backwards compatible with older
versions of Airflow.
+ # Can be removed when minimum supported provider version is equal to
the version of core airflow
+ # which introduces multi-team configuration (3.2+).
+ if not hasattr(self, "conf") or not hasattr(self.conf, "getint"):
+ from airflow.configuration import conf as global_conf
+
+ self.conf = global_conf
+ # Also set team_name to None if it doesn't exist, since the Celery app
creation expects it to be
+ # there (even if it's None)
+ if not hasattr(self, "team_name"):
+ self.team_name = None
+
@provide_session
def start(self, session: Session = NEW_SESSION):
"""If EdgeExecutor provider is loaded first time, ensure table
exists."""
@@ -109,6 +123,7 @@ class EdgeExecutor(BaseExecutor):
existing_job.queue = task_instance.queue
existing_job.concurrency_slots = task_instance.pool_slots
existing_job.command = workload.model_dump_json()
+ existing_job.team_name = self.team_name
else:
session.add(
EdgeJobModel(
@@ -121,17 +136,19 @@ class EdgeExecutor(BaseExecutor):
queue=task_instance.queue,
concurrency_slots=task_instance.pool_slots,
command=workload.model_dump_json(),
+ team_name=self.team_name,
)
)
def _check_worker_liveness(self, session: Session) -> bool:
"""Reset worker state if heartbeat timed out."""
changed = False
- heartbeat_interval: int = conf.getint("edge", "heartbeat_interval")
+ heartbeat_interval: int = self.conf.getint("edge",
"heartbeat_interval")
lifeless_workers: Sequence[EdgeWorkerModel] = session.scalars(
select(EdgeWorkerModel)
.with_for_update(skip_locked=True)
.where(
+ EdgeWorkerModel.team_name == self.team_name,
EdgeWorkerModel.state.not_in(
[
EdgeWorkerState.UNKNOWN,
@@ -162,11 +179,12 @@ class EdgeExecutor(BaseExecutor):
def _update_orphaned_jobs(self, session: Session) -> bool:
"""Update status ob jobs when workers die and don't update anymore."""
- heartbeat_interval: int = conf.getint("scheduler",
"task_instance_heartbeat_timeout")
+ heartbeat_interval: int = self.conf.getint("scheduler",
"task_instance_heartbeat_timeout")
lifeless_jobs: Sequence[EdgeJobModel] = session.scalars(
select(EdgeJobModel)
.with_for_update(skip_locked=True)
.where(
+ EdgeJobModel.team_name == self.team_name,
EdgeJobModel.state == TaskInstanceState.RUNNING,
EdgeJobModel.last_update < (timezone.utcnow() -
timedelta(seconds=heartbeat_interval)),
)
@@ -202,12 +220,13 @@ class EdgeExecutor(BaseExecutor):
def _purge_jobs(self, session: Session) -> bool:
"""Clean finished jobs."""
purged_marker = False
- job_success_purge = conf.getint("edge", "job_success_purge")
- job_fail_purge = conf.getint("edge", "job_fail_purge")
+ job_success_purge = self.conf.getint("edge", "job_success_purge")
+ job_fail_purge = self.conf.getint("edge", "job_fail_purge")
jobs: Sequence[EdgeJobModel] = session.scalars(
select(EdgeJobModel)
.with_for_update(skip_locked=True)
.where(
+ EdgeJobModel.team_name == self.team_name,
EdgeJobModel.state.in_(
[
TaskInstanceState.RUNNING,
@@ -217,7 +236,7 @@ class EdgeExecutor(BaseExecutor):
TaskInstanceState.RESTARTING,
TaskInstanceState.UP_FOR_RETRY,
]
- )
+ ),
)
).all()
diff --git
a/providers/edge3/src/airflow/providers/edge3/migrations/versions/0004_3_4_0_add_team_name_column.py
b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0004_3_4_0_add_team_name_column.py
new file mode 100644
index 00000000000..19a9fd40587
--- /dev/null
+++
b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0004_3_4_0_add_team_name_column.py
@@ -0,0 +1,60 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Add team_name column to edge_job and edge_worker tables.
+
+Revision ID: a09c3ee8e1d3
+Revises: 8c275b6fbaa8
+Create Date: 2026-02-07 00:00:00.000000
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "a09c3ee8e1d3"
+down_revision = "8c275b6fbaa8"
+branch_labels = None
+depends_on = None
+edge3_version = "3.4.0"
+
+
+def upgrade() -> None:
+ conn = op.get_bind()
+ inspector = sa.inspect(conn)
+
+ edge_job_cols = {c["name"] for c in inspector.get_columns("edge_job")}
+ if "team_name" not in edge_job_cols:
+ with op.batch_alter_table("edge_job") as batch_op:
+ batch_op.add_column(sa.Column("team_name", sa.String(length=64),
nullable=True))
+
+ edge_worker_cols = {c["name"] for c in
inspector.get_columns("edge_worker")}
+ if "team_name" not in edge_worker_cols:
+ with op.batch_alter_table("edge_worker") as batch_op:
+ batch_op.add_column(sa.Column("team_name", sa.String(length=64),
nullable=True))
+
+
+def downgrade() -> None:
+ with op.batch_alter_table("edge_worker") as batch_op:
+ batch_op.drop_column("team_name")
+
+ with op.batch_alter_table("edge_job") as batch_op:
+ batch_op.drop_column("team_name")
diff --git a/providers/edge3/src/airflow/providers/edge3/models/db.py
b/providers/edge3/src/airflow/providers/edge3/models/db.py
index 1637cca2a13..f564b1d117b 100644
--- a/providers/edge3/src/airflow/providers/edge3/models/db.py
+++ b/providers/edge3/src/airflow/providers/edge3/models/db.py
@@ -45,6 +45,7 @@ PACKAGE_DIR = Path(__file__).parents[1]
_REVISION_HEADS_MAP: dict[str, str] = {
"3.0.0": "9d34dfc2de06",
"3.2.0": "8c275b6fbaa8",
+ "3.4.0": "a09c3ee8e1d3",
}
diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_job.py
b/providers/edge3/src/airflow/providers/edge3/models/edge_job.py
index d98f3076a32..79576031112 100644
--- a/providers/edge3/src/airflow/providers/edge3/models/edge_job.py
+++ b/providers/edge3/src/airflow/providers/edge3/models/edge_job.py
@@ -56,6 +56,7 @@ class EdgeJobModel(Base, LoggingMixin):
queued_dttm: Mapped[datetime | None] = mapped_column(UtcDateTime)
edge_worker: Mapped[str | None] = mapped_column(String(64))
last_update: Mapped[datetime | None] = mapped_column(UtcDateTime)
+ team_name: Mapped[str | None] = mapped_column(String(64), nullable=True)
def __init__(
self,
@@ -71,6 +72,7 @@ class EdgeJobModel(Base, LoggingMixin):
queued_dttm: datetime | None = None,
edge_worker: str | None = None,
last_update: datetime | None = None,
+ team_name: str | None = None,
):
self.dag_id = dag_id
self.task_id = task_id
@@ -84,6 +86,7 @@ class EdgeJobModel(Base, LoggingMixin):
self.queued_dttm = queued_dttm or timezone.utcnow()
self.edge_worker = edge_worker
self.last_update = last_update
+ self.team_name = team_name
super().__init__()
__table_args__ = (Index("rj_order", state, queued_dttm, queue),)
diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
index 5b037f2903c..cf72300681c 100644
--- a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
@@ -103,6 +103,7 @@ class EdgeWorkerModel(Base, LoggingMixin):
jobs_success: Mapped[int] = mapped_column(Integer, default=0)
jobs_failed: Mapped[int] = mapped_column(Integer, default=0)
sysinfo: Mapped[str | None] = mapped_column(String(256))
+ team_name: Mapped[str | None] = mapped_column(String(64), nullable=True)
concurrency: Mapped[int | None] = mapped_column(Integer, nullable=True)
def __init__(
@@ -113,6 +114,7 @@ class EdgeWorkerModel(Base, LoggingMixin):
first_online: datetime | None = None,
last_update: datetime | None = None,
maintenance_comment: str | None = None,
+ team_name: str | None = None,
):
self.worker_name = worker_name
self.state = EdgeWorkerState(state)
@@ -120,6 +122,7 @@ class EdgeWorkerModel(Base, LoggingMixin):
self.first_online = first_online or timezone.utcnow()
self.last_update = last_update
self.maintenance_comment = maintenance_comment
+ self.team_name = team_name
super().__init__()
@property
@@ -257,10 +260,16 @@ def reset_metrics(worker_name: str) -> None:
)
+def get_query_filter_by_worker_name(worker_name: str):
+ return select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+
+
@providers_configuration_loaded
@provide_session
def _fetch_edge_hosts_from_db(
- hostname: str | None = None, states: list | None = None, session: Session
= NEW_SESSION
+ hostname: str | None = None,
+ states: list | None = None,
+ session: Session = NEW_SESSION,
) -> Sequence[EdgeWorkerModel]:
query = select(EdgeWorkerModel)
if states:
@@ -282,8 +291,8 @@ def request_maintenance(
worker_name: str, maintenance_comment: str | None, session: Session =
NEW_SESSION
) -> None:
"""Write maintenance request to the db."""
- query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
- worker: EdgeWorkerModel | None = session.scalar(query)
+ query = get_query_filter_by_worker_name(worker_name=worker_name)
+ worker = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
worker.state = EdgeWorkerState.MAINTENANCE_REQUEST
@@ -293,7 +302,7 @@ def request_maintenance(
@provide_session
def exit_maintenance(worker_name: str, session: Session = NEW_SESSION) -> None:
"""Write maintenance exit to the db."""
- query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ query = get_query_filter_by_worker_name(worker_name)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
@@ -304,7 +313,7 @@ def exit_maintenance(worker_name: str, session: Session =
NEW_SESSION) -> None:
@provide_session
def remove_worker(worker_name: str, session: Session = NEW_SESSION) -> None:
"""Remove a worker that is offline or just gone from DB."""
- query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ query = get_query_filter_by_worker_name(worker_name)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
@@ -325,7 +334,7 @@ def change_maintenance_comment(
worker_name: str, maintenance_comment: str | None, session: Session =
NEW_SESSION
) -> None:
"""Write maintenance comment in the db."""
- query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ query = get_query_filter_by_worker_name(worker_name)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
@@ -345,7 +354,7 @@ def change_maintenance_comment(
@provide_session
def request_shutdown(worker_name: str, session: Session = NEW_SESSION) -> None:
"""Request to shutdown the edge worker."""
- query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ query = get_query_filter_by_worker_name(worker_name)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
@@ -360,7 +369,7 @@ def request_shutdown(worker_name: str, session: Session =
NEW_SESSION) -> None:
@provide_session
def add_worker_queues(worker_name: str, queues: list[str], session: Session =
NEW_SESSION) -> None:
"""Add queues to an edge worker."""
- query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ query = get_query_filter_by_worker_name(worker_name)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
@@ -378,7 +387,7 @@ def add_worker_queues(worker_name: str, queues: list[str],
session: Session = NE
@provide_session
def remove_worker_queues(worker_name: str, queues: list[str], session: Session
= NEW_SESSION) -> None:
"""Remove queues from an edge worker."""
- query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ query = get_query_filter_by_worker_name(worker_name)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
diff --git a/providers/edge3/src/airflow/providers/edge3/version_compat.py
b/providers/edge3/src/airflow/providers/edge3/version_compat.py
index 27070ab292b..61b31ae45a2 100644
--- a/providers/edge3/src/airflow/providers/edge3/version_compat.py
+++ b/providers/edge3/src/airflow/providers/edge3/version_compat.py
@@ -33,7 +33,9 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
+AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)
__all__ = [
"AIRFLOW_V_3_1_PLUS",
+ "AIRFLOW_V_3_2_PLUS",
]
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
index fc780b87662..1a523550545 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
@@ -121,6 +121,13 @@ class WorkerQueuesBase(BaseModel):
description="List of queues the worker is pulling jobs from. If
not provided, worker pulls from all queues.",
),
]
+ team_name: Annotated[
+ str | None,
+ Field(
+ None,
+ description="Team name for multi-team setups. If not provided,
worker operates without team isolation.",
+ ),
+ ] = None
class WorkerQueuesBody(WorkerQueuesBase):
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
index 1191d068014..e05159c2731 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py
@@ -80,6 +80,7 @@ def fetch(
)
if body.queues:
query = query.where(EdgeJobModel.queue.in_(body.queues))
+ query = query.where(EdgeJobModel.team_name == body.team_name)
query = query.limit(1)
query = query.with_for_update(skip_locked=True)
job: EdgeJobModel | None = session.scalar(query)
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py
index 996a3a26128..67963260140 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/ui.py
@@ -34,6 +34,7 @@ from airflow.providers.edge3.models.edge_worker import (
add_worker_queues,
change_maintenance_comment,
exit_maintenance,
+ get_query_filter_by_worker_name,
remove_worker,
remove_worker_queues,
request_maintenance,
@@ -164,15 +165,13 @@ def request_worker_maintenance(
user: GetUserDep,
) -> None:
"""Put a worker into maintenance mode."""
- # Check if worker exists first
- worker_query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name
== worker_name)
+ worker_query = get_query_filter_by_worker_name(worker_name)
worker = session.scalar(worker_query)
if not worker:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f"Worker
{worker_name} not found")
if not maintenance_request.maintenance_comment:
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="Maintenance
comment is required")
- # Format the comment with timestamp and username (username will be added
by plugin layer)
formatted_comment = f"[{datetime.now().strftime('%Y-%m-%d %H:%M')}] -
{user.get_name()} put node into maintenance mode\nComment:
{maintenance_request.maintenance_comment}"
try:
@@ -194,15 +193,13 @@ def update_worker_maintenance(
user: GetUserDep,
) -> None:
"""Update maintenance comments for a worker."""
- # Check if worker exists first
- worker_query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name
== worker_name)
+ worker_query = get_query_filter_by_worker_name(worker_name)
worker = session.scalar(worker_query)
if not worker:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f"Worker
{worker_name} not found")
if not maintenance_request.maintenance_comment:
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="Maintenance
comment is required")
- # Format the comment with timestamp and username (username will be added
by plugin layer)
first_line = worker.maintenance_comment.split("\n", 1)[0] if
worker.maintenance_comment else ""
formatted_comment = f"{first_line}\n[{datetime.now().strftime('%Y-%m-%d
%H:%M')}] - {user.get_name()} updated
comment:\n{maintenance_request.maintenance_comment}"
@@ -223,8 +220,7 @@ def exit_worker_maintenance(
session: SessionDep,
) -> None:
"""Exit a worker from maintenance mode."""
- # Check if worker exists first
- worker_query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name
== worker_name)
+ worker_query = get_query_filter_by_worker_name(worker_name)
worker = session.scalar(worker_query)
if not worker:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f"Worker
{worker_name} not found")
@@ -246,8 +242,7 @@ def request_worker_shutdown(
session: SessionDep,
) -> None:
"""Request shutdown of a worker."""
- # Check if worker exists first
- worker_query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name
== worker_name)
+ worker_query = get_query_filter_by_worker_name(worker_name)
worker = session.scalar(worker_query)
if not worker:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f"Worker
{worker_name} not found")
@@ -269,8 +264,7 @@ def delete_worker(
session: SessionDep,
) -> None:
"""Delete a worker record from the system."""
- # Check if worker exists first
- worker_query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name
== worker_name)
+ worker_query = get_query_filter_by_worker_name(worker_name)
worker = session.scalar(worker_query)
if not worker:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f"Worker
{worker_name} not found")
@@ -294,7 +288,7 @@ def add_worker_queue(
) -> None:
"""Add a queue to a worker."""
# Check if worker exists first
- worker_query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name
== worker_name)
+ worker_query = get_query_filter_by_worker_name(worker_name)
worker = session.scalar(worker_query)
if not worker:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f"Worker
{worker_name} not found")
@@ -317,8 +311,7 @@ def remove_worker_queue(
session: SessionDep,
) -> None:
"""Remove a queue from a worker."""
- # Check if worker exists first
- worker_query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name
== worker_name)
+ worker_query = get_query_filter_by_worker_name(worker_name)
worker = session.scalar(worker_query)
if not worker:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f"Worker
{worker_name} not found")
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
index 34368c98c4a..b66be595c0c 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
@@ -173,7 +173,9 @@ def register(
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
- worker = EdgeWorkerModel(worker_name=worker_name, state=body.state,
queues=body.queues)
+ worker = EdgeWorkerModel(
+ worker_name=worker_name, state=body.state, queues=body.queues,
team_name=body.team_name
+ )
else:
# Prevent duplicate workers unless the existing worker is in offline
or unknown state
allowed_states_for_reuse = {
@@ -194,6 +196,7 @@ def register(
worker.queues = body.queues
worker.sysinfo = json.dumps(body.sysinfo)
worker.last_update = timezone.utcnow()
+ worker.team_name = body.team_name
session.add(worker)
return WorkerRegistrationReturn(last_update=worker.last_update)
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
index 2904a2e0d3d..e6babb7d3dc 100644
---
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
+++
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
@@ -1316,6 +1316,13 @@ components:
title: Queues
description: List of queues the worker is pulling jobs from. If not
provided,
worker pulls from all queues.
+ team_name:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Team Name
+ description: Team name for multi-team setups. If not provided,
worker operates
+ without team isolation.
state:
$ref: '#/components/schemas/EdgeWorkerState'
description: State of the worker from the view of the worker.
@@ -1419,6 +1426,13 @@ components:
title: Queues
description: List of queues the worker is pulling jobs from. If not
provided,
worker pulls from all queues.
+ team_name:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Team Name
+ description: Team name for multi-team setups. If not provided,
worker operates
+ without team isolation.
free_concurrency:
type: integer
title: Free Concurrency
@@ -1484,6 +1498,13 @@ components:
title: Queues
description: List of queues the worker is pulling jobs from. If not
provided,
worker pulls from all queues.
+ team_name:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Team Name
+ description: Team name for multi-team setups. If not provided,
worker operates
+ without team isolation.
state:
$ref: '#/components/schemas/EdgeWorkerState'
description: State of the worker from the view of the worker.
diff --git a/providers/edge3/tests/unit/edge3/cli/test_definition.py
b/providers/edge3/tests/unit/edge3/cli/test_definition.py
index 0c1dac8e6d3..f64b7ca281a 100644
--- a/providers/edge3/tests/unit/edge3/cli/test_definition.py
+++ b/providers/edge3/tests/unit/edge3/cli/test_definition.py
@@ -93,11 +93,20 @@ class TestEdgeCliDefinition:
"4",
"--edge-hostname",
"edge-worker-1",
+ "--team-name",
+ "team_x",
]
args = self.arg_parser.parse_args(params)
assert args.queues == "queue1,queue2"
assert args.concurrency == 4
assert args.edge_hostname == "edge-worker-1"
+ assert args.team_name == "team_x"
+
+ def test_worker_command_args_without_team_name(self):
+ """Test worker command without --team-name defaults to None."""
+ params = ["edge", "worker"]
+ args = self.arg_parser.parse_args(params)
+ assert args.team_name is None
def test_status_command_args(self):
"""Test status command with pid argument."""
@@ -135,7 +144,15 @@ class TestEdgeCliDefinition:
def test_list_workers_command_args(self):
"""Test list-workers command with output format and state filter."""
- params = ["edge", "list-workers", "--output", "json", "--state",
"running", "maintenance"]
+ params = [
+ "edge",
+ "list-workers",
+ "--output",
+ "json",
+ "--state",
+ "running",
+ "maintenance",
+ ]
args = self.arg_parser.parse_args(params)
assert args.output == "json"
assert args.state == ["running", "maintenance"]
diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py
b/providers/edge3/tests/unit/edge3/cli/test_worker.py
index b2b97034b36..d3b721c1e32 100644
--- a/providers/edge3/tests/unit/edge3/cli/test_worker.py
+++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py
@@ -34,9 +34,9 @@ from yarl import URL
from airflow.cli import cli_parser
from airflow.providers.common.compat.sdk import timezone
-from airflow.providers.edge3.cli import edge_command, worker as worker_module
+from airflow.providers.edge3.cli import edge_command
from airflow.providers.edge3.cli.dataclasses import Job
-from airflow.providers.edge3.cli.worker import EdgeWorker,
_execution_api_server_url
+from airflow.providers.edge3.cli.worker import EdgeWorker
from airflow.providers.edge3.models.edge_worker import (
EdgeWorkerModel,
EdgeWorkerState,
@@ -106,6 +106,11 @@ class TestEdgeWorker:
importlib.reload(cli_parser)
self.parser = cli_parser.get_parser()
+ @pytest.fixture
+ def cli_worker_with_team(self, tmp_path: Path) -> EdgeWorker:
+ test_worker = EdgeWorker(str(tmp_path / "mock.pid"), "mock", None, 8,
team_name="team_a")
+ return test_worker
+
@pytest.fixture
def mock_joblist(self, tmp_path: Path) -> list[Job]:
logfile = tmp_path / "file.log"
@@ -130,7 +135,7 @@ class TestEdgeWorker:
@pytest.fixture
def worker_with_job(self, tmp_path: Path, mock_joblist: list[Job]) ->
EdgeWorker:
- test_worker = EdgeWorker(str(tmp_path / "mock.pid"), "mock", None, 8,
5, 5)
+ test_worker = EdgeWorker(str(tmp_path / "mock.pid"), "mock", None, 8)
EdgeWorker.jobs = mock_joblist
return test_worker
@@ -143,6 +148,13 @@ class TestEdgeWorker:
)
return test_edgeworker
+ def test_worker_with_team_name(self, cli_worker_with_team: EdgeWorker):
+ assert cli_worker_with_team.team_name == "team_a"
+
+ def test_worker_without_team_name(self, tmp_path: Path):
+ worker = EdgeWorker(str(tmp_path / "mock.pid"), "mock", None, 8)
+ assert worker.team_name is None
+
@pytest.mark.parametrize(
("configs", "expected_url"),
[
@@ -171,24 +183,21 @@ class TestEdgeWorker:
self,
configs,
expected_url,
+ tmp_path,
):
with conf_vars(configs):
- _execution_api_server_url.cache_clear()
- url = _execution_api_server_url()
+ test_worker = EdgeWorker(str(tmp_path / "mock.pid"), "mock", None,
8)
+ url = test_worker._execution_api_server_url
assert url == expected_url
- @patch(
- "airflow.providers.edge3.cli.worker._execution_api_server_url",
- return_value="https://mock-execution-api",
- )
@patch("airflow.sdk.execution_time.supervisor.supervise")
@pytest.mark.asyncio
async def test_supervise_launch(
self,
mock_supervise,
- mock_execution_api_url,
worker_with_job: EdgeWorker,
):
+ worker_with_job.__dict__["_execution_api_server_url"] =
"https://mock-server/execution"
edge_job = worker_with_job.jobs.pop().edge_job
q = mock.MagicMock()
result = worker_with_job._run_job_via_supervisor(edge_job.command, q)
@@ -196,16 +205,11 @@ class TestEdgeWorker:
assert result == 0
q.put.assert_not_called()
- @patch(
- "airflow.providers.edge3.cli.worker._execution_api_server_url",
- return_value="https://mock-execution-api",
- )
@patch("airflow.sdk.execution_time.supervisor.supervise")
@pytest.mark.asyncio
async def test_supervise_launch_fail(
self,
mock_supervise,
- mock_execution_api_url,
worker_with_job: EdgeWorker,
):
mock_supervise.side_effect = Exception("Supervise failed")
@@ -268,6 +272,8 @@ class TestEdgeWorker:
await worker_with_job.fetch_and_run_job()
mock_jobs_fetch.assert_called_once()
+ fetch_args = mock_jobs_fetch.call_args
+ assert fetch_args.args[3] is None # team_name should be None
mock_launch_job.assert_called_once()
assert mock_jobs_set_state.call_count == 2
mock_push_log_chunks.assert_called_once()
@@ -349,11 +355,13 @@ class TestEdgeWorker:
@time_machine.travel(datetime.now(), tick=False)
@patch("airflow.providers.edge3.cli.worker.logs_push")
- @patch.object(worker_module, "push_log_chunk_size", 4)
@pytest.mark.asyncio
async def test_check_running_jobs_log_push_chunks(self, mock_logs_push,
worker_with_job: EdgeWorker):
+ worker_with_job.push_log_chunk_size = 4
+
job = EdgeWorker.jobs[0]
job.logfile.write_bytes("log1log2ülog3".encode("latin-1"))
+
with conf_vars({("edge", "api_url"):
"https://invalid-api-test-endpoint"}):
await worker_with_job._push_logs_in_chunks(job)
assert len(EdgeWorker.jobs) == 1
@@ -401,6 +409,7 @@ class TestEdgeWorker:
with conf_vars({("edge", "api_url"):
"https://invalid-api-test-endpoint"}):
await worker_with_job.heartbeat()
assert mock_set_state.call_args.args[1] == expected_state
+ assert mock_set_state.call_args.kwargs.get("team_name") is None
queue_list = worker_with_job.queues or []
assert len(queue_list) == 2
assert "queue1" in (queue_list)
@@ -486,6 +495,8 @@ class TestEdgeWorker:
await worker_with_job.start()
mock_register.assert_called_once()
+ register_args = mock_register.call_args.args
+ assert register_args[4] is None # team_name should be None
mock_loop.assert_called_once()
assert mock_set_state.call_count == 1
diff --git a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py
b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py
index 4bc0fb2f07c..2ef41ece0a5 100644
--- a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py
+++ b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py
@@ -16,7 +16,9 @@
# under the License.
from __future__ import annotations
+import os
from datetime import datetime, timedelta
+from unittest import mock
from unittest.mock import MagicMock, patch
import pytest
@@ -32,6 +34,7 @@ from airflow.utils.session import create_session
from airflow.utils.state import TaskInstanceState
from tests_common.test_utils.config import conf_vars
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS
pytestmark = pytest.mark.db_test
@@ -346,3 +349,207 @@ class TestEdgeExecutor:
# Verify nothing breaks
assert key not in executor.running
assert key not in executor.queued_tasks
+
+
+class TestEdgeExecutorMultiTeam:
+ """Tests for multi-team (AIP-67) support in EdgeExecutor."""
+
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self):
+ with create_session() as session:
+ session.execute(delete(EdgeJobModel))
+ session.execute(delete(EdgeWorkerModel))
+ session.commit()
+
+ def test_global_executor_without_team_name(self):
+ """Test that global executor (no team) works correctly."""
+ executor = EdgeExecutor()
+
+ assert hasattr(executor, "conf")
+ assert executor.team_name is None
+ if AIRFLOW_V_3_2_PLUS:
+ assert executor.conf.team_name is None
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="The tests should be
skipped for Airflow < 3.2")
+ def test_executor_with_team_name(self):
+ """Test that executor with team_name has correct conf setup."""
+ team_name = "test_team"
+ executor = EdgeExecutor(team_name=team_name)
+
+ assert hasattr(executor, "conf")
+ assert executor.team_name == team_name
+ assert executor.conf.team_name == team_name
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="The tests should be
skipped for Airflow < 3.2")
+ def test_multiple_team_executors_isolation(self):
+ """Test that multiple team executors can coexist with isolated
resources."""
+ team_a_executor = EdgeExecutor(parallelism=2, team_name="team_a")
+ team_b_executor = EdgeExecutor(parallelism=3, team_name="team_b")
+
+ assert team_a_executor.running is not team_b_executor.running
+ assert team_a_executor.queued_tasks is not team_b_executor.queued_tasks
+ assert team_a_executor.last_reported_state is not
team_b_executor.last_reported_state
+
+ if AIRFLOW_V_3_2_PLUS:
+ assert team_a_executor.conf.team_name == "team_a"
+ assert team_b_executor.conf.team_name == "team_b"
+
+ assert team_a_executor.parallelism == 2
+ assert team_b_executor.parallelism == 3
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="The tests should be
skipped for Airflow < 3.2")
+ def test_team_config_used_in_check_worker_liveness(self):
+ """Test that _check_worker_liveness reads config from self.conf, not
global conf."""
+ team_name = "test_team"
+ executor = EdgeExecutor(team_name=team_name)
+
+ team_env_key_prefix = f"AIRFLOW__{team_name.upper()}___EDGE__"
+ test_key_values = [
+ "heartbeat_interval",
+ "task_instance_heartbeat_timeout",
+ "job_success_purge",
+ "job_fail_purge",
+ ]
+ for test_key_value in test_key_values:
+ with mock.patch.dict(os.environ,
{f"{team_env_key_prefix}{test_key_value.upper()}": "100"}):
+ value = executor.conf.getint("edge", test_key_value)
+ assert value == 100
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="The tests should be
skipped for Airflow < 3.2")
+ def test_purge_jobs_filters_by_team_name(self):
+ """Test that _purge_jobs only purges jobs belonging to its team."""
+ executor_a = EdgeExecutor(team_name="team_a")
+
+ delta_to_purge = timedelta(minutes=conf.getint("edge",
"job_fail_purge") + 1)
+
+ with create_session() as session:
+ for team in ["team_a", "team_b"]:
+ session.add(
+ EdgeJobModel(
+ dag_id="test_dag",
+ task_id=f"task_{team}",
+ run_id="test_run",
+ map_index=-1,
+ try_number=1,
+ state=TaskInstanceState.FAILED,
+ queue="default",
+ command="mock",
+ concurrency_slots=1,
+ last_update=timezone.utcnow() - delta_to_purge,
+ team_name=team,
+ )
+ )
+ session.commit()
+
+ with create_session() as session:
+ executor_a._purge_jobs(session)
+ session.commit()
+
+ with create_session() as session:
+ remaining_jobs = session.scalars(select(EdgeJobModel)).all()
+ assert len(remaining_jobs) == 1
+ assert remaining_jobs[0].team_name == "team_b"
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="The tests should be
skipped for Airflow < 3.2")
+ def test_update_orphaned_jobs_filters_by_team_name(self):
+ """Test that _update_orphaned_jobs only checks jobs belonging to its
team."""
+ executor_a = EdgeExecutor(team_name="team_a")
+
+ heartbeat_timeout = conf.getint("scheduler",
"task_instance_heartbeat_timeout")
+ delta_to_orphaned = timedelta(seconds=heartbeat_timeout + 1)
+
+ with create_session() as session:
+ for team in ["team_a", "team_b"]:
+ session.add(
+ EdgeJobModel(
+ dag_id="test_dag",
+ task_id=f"task_{team}",
+ run_id="test_run",
+ map_index=-1,
+ try_number=1,
+ state=TaskInstanceState.RUNNING,
+ queue="default",
+ command="mock",
+ concurrency_slots=1,
+ last_update=timezone.utcnow() - delta_to_orphaned,
+ team_name=team,
+ )
+ )
+ session.commit()
+
+ with create_session() as session:
+ executor_a._update_orphaned_jobs(session)
+ session.commit()
+
+ with create_session() as session:
+ jobs = session.scalars(select(EdgeJobModel)).all()
+ jobs_by_team = {job.team_name: job for job in jobs}
+ assert jobs_by_team["team_a"].state != TaskInstanceState.RUNNING
+ assert jobs_by_team["team_b"].state == TaskInstanceState.RUNNING
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="The tests should be
skipped for Airflow < 3.2")
+ def test_check_worker_liveness_filters_by_team_name(self):
+ """Test that _check_worker_liveness only resets workers belonging to
its team."""
+ executor_a = EdgeExecutor(team_name="team_a")
+
+ with create_session() as session:
+ for worker_name, team in [
+ ("worker_team_a", "team_a"),
+ ("worker_team_b", "team_b"),
+ ]:
+ session.add(
+ EdgeWorkerModel(
+ worker_name=worker_name,
+ state=EdgeWorkerState.IDLE,
+ last_update=datetime(2023, 1, 1, 0, 0, 0,
tzinfo=timezone.utc),
+ queues=["default"],
+ first_online=timezone.utcnow(),
+ team_name=team,
+ )
+ )
+ session.commit()
+
+ with time_machine.travel(datetime(2023, 1, 1, 1, 0, 0,
tzinfo=timezone.utc), tick=False):
+ with conf_vars({("edge", "heartbeat_interval"): "10"}):
+ with create_session() as session:
+ executor_a._check_worker_liveness(session)
+ session.commit()
+
+ with create_session() as session:
+ workers = {w.worker_name: w for w in
session.scalars(select(EdgeWorkerModel)).all()}
+ assert workers["worker_team_a"].state == EdgeWorkerState.UNKNOWN
+ assert workers["worker_team_b"].state == EdgeWorkerState.IDLE
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="The tests should be
skipped for Airflow < 3.2")
+ def test_no_team_executor_processes_all_jobs(self):
+ """Test that an executor without team_name processes only job has no
team_name."""
+ executor = EdgeExecutor()
+
+ delta_to_purge = timedelta(minutes=conf.getint("edge",
"job_fail_purge") + 1)
+
+ with create_session() as session:
+ for team in ["team_a", "team_b", None]:
+ session.add(
+ EdgeJobModel(
+ dag_id="test_dag",
+ task_id=f"task_{team}",
+ run_id="test_run",
+ map_index=-1,
+ try_number=1,
+ state=TaskInstanceState.FAILED,
+ queue="default",
+ command="mock",
+ concurrency_slots=1,
+ last_update=timezone.utcnow() - delta_to_purge,
+ team_name=team,
+ )
+ )
+ session.commit()
+
+ with create_session() as session:
+ executor._purge_jobs(session)
+ session.commit()
+
+ with create_session() as session:
+ remaining_jobs = session.scalars(select(EdgeJobModel)).all()
+ assert len(remaining_jobs) == 2
diff --git a/providers/edge3/tests/unit/edge3/models/test_db.py
b/providers/edge3/tests/unit/edge3/models/test_db.py
index 360e092788d..3f852185a7d 100644
--- a/providers/edge3/tests/unit/edge3/models/test_db.py
+++ b/providers/edge3/tests/unit/edge3/models/test_db.py
@@ -212,9 +212,13 @@ class TestEdgeDBManager:
assert "3.0.0" in _REVISION_HEADS_MAP
assert _REVISION_HEADS_MAP["3.0.0"] == "9d34dfc2de06"
+
assert "3.2.0" in _REVISION_HEADS_MAP
assert _REVISION_HEADS_MAP["3.2.0"] == "8c275b6fbaa8"
+ assert "3.4.0" in _REVISION_HEADS_MAP
+ assert _REVISION_HEADS_MAP["3.4.0"] == "a09c3ee8e1d3"
+
def
test_initdb_stamps_and_upgrades_when_tables_exist_without_version(self,
session):
"""Test that initdb runs incremental migrations when tables exist but
alembic version table does not."""
from sqlalchemy import inspect, text
@@ -244,8 +248,9 @@ class TestEdgeDBManager:
version = conn.execute(text("SELECT version_num FROM
alembic_version_edge3")).scalar()
columns = {col["name"] for col in
inspect(conn).get_columns("edge_worker")}
- assert version == "8c275b6fbaa8"
+ assert version == "a09c3ee8e1d3"
assert "concurrency" in columns
+ assert "team_name" in columns
def test_migration_adds_concurrency_column(self, session):
"""Test that upgrading from 3.0.0 actually adds the concurrency
column."""
@@ -281,6 +286,7 @@ class TestEdgeDBManager:
columns = {col["name"] for col in
inspector.get_columns("edge_worker")}
assert "concurrency" in columns, "Migration 0002 should have added the
concurrency column"
+ assert "team_name" in columns, "Migration 0003 should have added the
team_name column"
def test_drop_tables_handles_missing_tables(self, session):
"""Test that drop_tables handles missing tables gracefully."""
diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
index 667ec184c0a..dde158ebc6e 100644
--- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
+++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import json
from typing import TYPE_CHECKING
from unittest.mock import patch
@@ -23,7 +24,8 @@ import pytest
from sqlalchemy import delete, select
from airflow.providers.edge3.models.edge_job import EdgeJobModel
-from airflow.providers.edge3.worker_api.routes.jobs import state
+from airflow.providers.edge3.worker_api.datamodels import WorkerQueuesBody
+from airflow.providers.edge3.worker_api.routes.jobs import fetch, state
from airflow.utils.session import create_session
from airflow.utils.state import TaskInstanceState
@@ -49,6 +51,29 @@ TASK_ID = "my_task"
RUN_ID = "manual__2024-11-24T21:03:01+01:00"
QUEUE = "test"
+MOCK_COMMAND_STR = json.dumps(
+ {
+ "token": "mock",
+ "type": "ExecuteTask",
+ "ti": {
+ "id": "4d828a62-a417-4936-a7a6-2b3fabacecab",
+ "task_id": "mock",
+ "dag_id": "mock",
+ "run_id": "mock",
+ "try_number": 1,
+ "dag_version_id": "01234567-89ab-cdef-0123-456789abcdef",
+ "pool_slots": 1,
+ "queue": "default",
+ "priority_weight": 1,
+ "start_date": "2023-01-01T00:00:00+00:00",
+ "map_index": -1,
+ },
+ "dag_rel_path": "mock.py",
+ "log_path": "mock.log",
+ "bundle_info": {"name": "hello", "version": "abc"},
+ }
+)
+
class TestJobsApiRoutes:
@pytest.fixture(autouse=True)
@@ -108,3 +133,79 @@ class TestJobsApiRoutes:
db_job: EdgeJobModel | None = session.scalar(select(EdgeJobModel))
assert db_job is not None
assert db_job.state == TaskInstanceState.SUCCESS
+
+ def test_fetch_filters_by_team_name(self, session: Session):
+ with create_session() as session:
+ job_team_a = EdgeJobModel(
+ dag_id="dag_a",
+ task_id="task_a",
+ run_id=RUN_ID,
+ try_number=1,
+ map_index=-1,
+ state=TaskInstanceState.QUEUED,
+ queue=QUEUE,
+ concurrency_slots=1,
+ command=MOCK_COMMAND_STR,
+ team_name="team_a",
+ )
+ job_team_b = EdgeJobModel(
+ dag_id="dag_b",
+ task_id="task_b",
+ run_id=RUN_ID,
+ try_number=1,
+ map_index=-1,
+ state=TaskInstanceState.QUEUED,
+ queue=QUEUE,
+ concurrency_slots=1,
+ command=MOCK_COMMAND_STR,
+ team_name="team_b",
+ )
+ session.add_all([job_team_a, job_team_b])
+ session.commit()
+
+ body = WorkerQueuesBody(free_concurrency=1, queues=[QUEUE],
team_name="team_a")
+ result = fetch("worker1", body, session)
+ assert result is not None
+ assert result.dag_id == "dag_a"
+ assert result.task_id == "task_a"
+
+ def test_fetch_without_team_name_returns_any_team(self, session: Session):
+ """When team_name is None, no team filter is applied so any queued job
can be returned."""
+ with create_session() as session:
+ job_team_a = EdgeJobModel(
+ dag_id="dag_a",
+ task_id="task_a",
+ run_id=RUN_ID,
+ try_number=1,
+ map_index=-1,
+ state=TaskInstanceState.QUEUED,
+ queue=QUEUE,
+ concurrency_slots=1,
+ command=MOCK_COMMAND_STR,
+ team_name="team_a",
+ )
+ job_no_team = EdgeJobModel(
+ dag_id="dag_b",
+ task_id="task_b",
+ run_id=RUN_ID,
+ try_number=1,
+ map_index=-1,
+ state=TaskInstanceState.QUEUED,
+ queue=QUEUE,
+ concurrency_slots=1,
+ command=MOCK_COMMAND_STR,
+ team_name=None,
+ )
+ session.add_all([job_team_a, job_no_team])
+ session.commit()
+
+ body1 = WorkerQueuesBody(free_concurrency=2, queues=[QUEUE],
team_name="team_a")
+ result1 = fetch("worker1", body1, session)
+ assert result1 is not None
+ body2 = WorkerQueuesBody(free_concurrency=2, queues=[QUEUE],
team_name=None)
+ result2 = fetch("worker1", body2, session)
+ assert result2 is not None
+ result3 = fetch("worker1", body2, session)
+ assert result3 is None
+ fetched_dag_ids = {result1.dag_id, result2.dag_id}
+ assert fetched_dag_ids == {"dag_a", "dag_b"}
diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py
b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py
index fc5cd111a1d..080b019948e 100644
--- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py
+++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py
@@ -48,7 +48,7 @@ pytestmark = pytest.mark.db_test
class TestWorkerApiRoutes:
@pytest.fixture
def cli_worker(self, tmp_path: Path) -> EdgeWorker:
- test_worker = EdgeWorker(str(tmp_path / "mock.pid"), "mock", None, 8,
5, 5)
+ test_worker = EdgeWorker(str(tmp_path / "mock.pid"), "mock", None, 8)
return test_worker
@pytest.fixture(autouse=True)
@@ -96,11 +96,85 @@ class TestWorkerApiRoutes:
worker: Sequence[EdgeWorkerModel] =
session.scalars(select(EdgeWorkerModel)).all()
assert len(worker) == 1
assert worker[0].worker_name == "test_worker"
+ assert worker[0].team_name is None
if input_queues:
assert worker[0].queues == input_queues
else:
assert worker[0].queues is None
+ def test_register_with_team_name(self, session: Session, cli_worker:
EdgeWorker):
+ body = WorkerStateBody(
+ state=EdgeWorkerState.STARTING,
+ jobs_active=0,
+ queues=["default"],
+ sysinfo=cli_worker._get_sysinfo(),
+ team_name="team_a",
+ )
+ register("test_worker", body, session)
+ session.commit()
+
+ worker: Sequence[EdgeWorkerModel] =
session.scalars(select(EdgeWorkerModel)).all()
+ assert len(worker) == 1
+ assert worker[0].worker_name == "test_worker"
+ assert worker[0].team_name == "team_a"
+
+ def test_register_same_name_different_team_rejects_when_active(
+ self, session: Session, cli_worker: EdgeWorker
+ ):
+ """A physical worker (hostname) can only have one identity.
Registering the same
+ worker_name with a different team_name while the existing one is
active should be rejected."""
+ existing_worker = EdgeWorkerModel(
+ worker_name="test_worker",
+ state=EdgeWorkerState.RUNNING,
+ queues=["default"],
+ first_online=timezone.utcnow(),
+ team_name="team_a",
+ )
+ session.add(existing_worker)
+ session.commit()
+
+ body = WorkerStateBody(
+ state=EdgeWorkerState.STARTING,
+ jobs_active=0,
+ queues=["default"],
+ sysinfo=cli_worker._get_sysinfo(),
+ team_name="team_b",
+ )
+ with pytest.raises(HTTPException) as exc_info:
+ register("test_worker", body, session)
+ assert exc_info.value.status_code == 409
+
+ def test_register_same_name_different_team_reuses_when_offline(
+ self, session: Session, cli_worker: EdgeWorker
+ ):
+ """When an existing worker with the same name is offline,
re-registering with a
+ different team_name should succeed and update the team_name."""
+ existing_worker = EdgeWorkerModel(
+ worker_name="test_worker",
+ state=EdgeWorkerState.OFFLINE,
+ queues=["default"],
+ first_online=timezone.utcnow(),
+ team_name="team_a",
+ )
+ session.add(existing_worker)
+ session.commit()
+
+ body = WorkerStateBody(
+ state=EdgeWorkerState.STARTING,
+ jobs_active=0,
+ queues=["default"],
+ sysinfo=cli_worker._get_sysinfo(),
+ team_name="team_b",
+ )
+ register("test_worker", body, session)
+ session.commit()
+
+ worker = session.execute(
+ select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
"test_worker")
+ ).scalar_one_or_none()
+ assert worker is not None
+ assert worker.team_name == "team_b"
+
@pytest.mark.parametrize(
("existing_state", "should_raise"),
[