This is an automated email from the ASF dual-hosted git repository.
vatsrahul1001 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6877aea6353 Revert airflowctl dependency from airflow-core (#68856)
6877aea6353 is described below
commit 6877aea6353cd0d54a622a6d7dcb4157888db060
Author: Bugra Ozturk <[email protected]>
AuthorDate: Wed Jun 24 05:24:06 2026 +0200
Revert airflowctl dependency from airflow-core (#68856)
* Not warn user for CLI Deprecations
* Update significant to reflect direction
* Rename significant to reflect warning changes
* Remove not used method docs and remove significant file
* Update documentation to explain CLI CTL relation better
* Update documentation to explain CLI CTL relation better
* Revert routing airflow CLI commands through the airflowctl client
Restores the previous in-process client for the airflow dags/pools/assets
commands, reverting #68175 (and removing the auth-manager get_cli_user token
support it added). These commands talk to the metadata DB through the local
client again, so they no longer require a reachable API server. The
maintainer-only deprecated_for_airflowctl marker and its command
* Revert installing airflowctl into airflow core
* Add back partition date to test model
---
airflow-core/pyproject.toml | 2 -
.../src/airflow/api/client/__init__.py | 22 +-
.../src/airflow/api/client/local_client.py | 107 ++++++
.../api_fastapi/auth/managers/base_auth_manager.py | 16 -
.../auth/managers/simple/simple_auth_manager.py | 3 -
airflow-core/src/airflow/cli/api_client.py | 129 --------
.../src/airflow/cli/commands/asset_command.py | 66 ++--
.../src/airflow/cli/commands/dag_command.py | 60 ++--
.../src/airflow/cli/commands/pool_command.py | 99 +++---
.../tests/unit/cli/commands/test_asset_command.py | 146 ++++++---
.../tests/unit/cli/commands/test_dag_command.py | 364 ++++++++++++---------
.../tests/unit/cli/commands/test_pool_command.py | 350 +++++++++++---------
airflow-core/tests/unit/cli/conftest.py | 20 --
airflow-core/tests/unit/cli/test_api_client.py | 140 --------
.../providers/fab/auth_manager/fab_auth_manager.py | 24 --
.../unit/fab/auth_manager/test_fab_auth_manager.py | 26 --
.../keycloak/auth_manager/keycloak_auth_manager.py | 30 +-
.../auth_manager/test_keycloak_auth_manager.py | 21 +-
scripts/ci/prek/known_airflow_exceptions.txt | 2 +-
uv.lock | 2 -
20 files changed, 740 insertions(+), 889 deletions(-)
diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml
index 31a88734c88..f126c18bd3a 100644
--- a/airflow-core/pyproject.toml
+++ b/airflow-core/pyproject.toml
@@ -154,7 +154,6 @@ dependencies = [
"universal-pathlib>=0.3.8",
"uuid6>=2024.7.10",
"apache-airflow-task-sdk<1.4.0,>=1.3.0",
- "apache-airflow-ctl<0.1.6,>=0.1.5",
# pre-installed providers
"apache-airflow-providers-common-compat>=1.7.4",
"apache-airflow-providers-common-io>=1.6.3",
@@ -328,7 +327,6 @@ required-version = ">=0.11.8"
[tool.uv.sources]
apache-airflow-core = {workspace = true}
-apache-airflow-ctl = {workspace = true}
apache-airflow-devel-common = { workspace = true }
[tool.airflow]
diff --git
a/airflow-e2e-tests/tests/airflow_e2e_tests/basic_tests/test_airflowctl_imports.py
b/airflow-core/src/airflow/api/client/__init__.py
similarity index 62%
rename from
airflow-e2e-tests/tests/airflow_e2e_tests/basic_tests/test_airflowctl_imports.py
rename to airflow-core/src/airflow/api/client/__init__.py
index f7957f1a576..f0d236b9019 100644
---
a/airflow-e2e-tests/tests/airflow_e2e_tests/basic_tests/test_airflowctl_imports.py
+++ b/airflow-core/src/airflow/api/client/__init__.py
@@ -1,3 +1,4 @@
+#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -14,25 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""API Client that allows interacting with Airflow API."""
from __future__ import annotations
-import subprocess
-import sys
+from airflow.api.client.local_client import Client
-def test_airflowctl_is_importable():
- # checks if airflowctl imports correctly
- result = subprocess.run(
- [
- sys.executable,
- "-c",
- "import airflowctl; print('airflowctl imported successfully')",
- ],
- capture_output=True,
- text=True,
- check=False,
- )
- assert result.returncode == 0, (
- f"airflowctl import failed!\nstdout: {result.stdout}\nstderr:
{result.stderr}"
- )
+def get_current_api_client() -> Client:
+ return Client()
diff --git a/airflow-core/src/airflow/api/client/local_client.py
b/airflow-core/src/airflow/api/client/local_client.py
new file mode 100644
index 00000000000..057d6d99c7c
--- /dev/null
+++ b/airflow-core/src/airflow/api/client/local_client.py
@@ -0,0 +1,107 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Local client API."""
+
+from __future__ import annotations
+
+import httpx
+
+from airflow.api.common import delete_dag, trigger_dag
+from airflow.exceptions import AirflowBadRequest, PoolNotFound
+from airflow.models.pool import Pool
+from airflow.utils.types import DagRunTriggeredByType
+
+
+class Client:
+ """Local API client implementation."""
+
+ def __init__(self, auth=None, session: httpx.Client | None = None):
+ self._session: httpx.Client = session or httpx.Client()
+ if auth:
+ self._session.auth = auth
+
+ def trigger_dag(
+ self,
+ dag_id,
+ run_id=None,
+ conf=None,
+ logical_date=None,
+ triggering_user_name=None,
+ replace_microseconds=True,
+ ) -> dict | None:
+ dag_run = trigger_dag.trigger_dag(
+ dag_id=dag_id,
+ triggered_by=DagRunTriggeredByType.CLI,
+ triggering_user_name=triggering_user_name,
+ run_id=run_id,
+ conf=conf,
+ logical_date=logical_date,
+ replace_microseconds=replace_microseconds,
+ )
+ if dag_run:
+ return {
+ "conf": dag_run.conf,
+ "dag_id": dag_run.dag_id,
+ "dag_run_id": dag_run.run_id,
+ "data_interval_start": dag_run.data_interval_start,
+ "data_interval_end": dag_run.data_interval_end,
+ "end_date": dag_run.end_date,
+ "last_scheduling_decision": dag_run.last_scheduling_decision,
+ "logical_date": dag_run.logical_date,
+ "run_type": dag_run.run_type,
+ "start_date": dag_run.start_date,
+ "state": dag_run.state,
+ "triggering_user_name": dag_run.triggering_user_name,
+ }
+ return dag_run
+
+ def delete_dag(self, dag_id):
+ count = delete_dag.delete_dag(dag_id)
+ return f"Removed {count} record(s)"
+
+ def get_pool(self, name):
+ pool = Pool.get_pool(pool_name=name)
+ if not pool:
+ raise PoolNotFound(f"Pool {name} not found")
+ return pool.pool, pool.slots, pool.description, pool.include_deferred,
pool.team_name
+
+ def get_pools(self):
+ return [(p.pool, p.slots, p.description, p.include_deferred,
p.team_name) for p in Pool.get_pools()]
+
+ def create_pool(self, name, slots, description, include_deferred,
team_name=None):
+ if not (name and name.strip()):
+ raise AirflowBadRequest("Pool name shouldn't be empty")
+ pool_name_length = Pool.pool.property.columns[0].type.length
+ if len(name) > pool_name_length:
+ raise AirflowBadRequest(f"Pool name cannot be more than
{pool_name_length} characters")
+ try:
+ slots = int(slots)
+ except ValueError:
+ raise AirflowBadRequest(f"Invalid value for `slots`: {slots}")
+ pool = Pool.create_or_update_pool(
+ name=name,
+ slots=slots,
+ description=description,
+ include_deferred=include_deferred,
+ team_name=team_name,
+ )
+ return pool.pool, pool.slots, pool.description, pool.team_name
+
+ def delete_pool(self, name):
+ pool = Pool.delete_pool(name=name)
+ return pool.pool, pool.slots, pool.description
diff --git
a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
index 7f62d8af27c..4ba08e5b447 100644
--- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
+++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
@@ -180,22 +180,6 @@ class BaseAuthManager(Generic[T], LoggingMixin,
metaclass=ABCMeta):
self.serialize_user(user)
)
- def get_cli_user(self) -> T:
- """
- Return the user the local CLI acts as when calling the API server.
-
- The Airflow CLI mints a short-lived JWT for this user (via
:meth:`generate_jwt`)
- so it can talk to the API server without persisting any credentials. A
generic
- auth manager cannot know which user is authorized for local CLI
access, so the
- default raises. Auth managers that support local CLI usage should
override this
- to return an administrative user. Otherwise, operators must provide a
token via
- the ``AIRFLOW_CLI_TOKEN`` environment variable.
- """
- raise NotImplementedError(
- f"{type(self).__name__} does not support minting a local CLI
token. "
- "Set the AIRFLOW_CLI_TOKEN environment variable with a valid API
token instead."
- )
-
@abstractmethod
def get_url_login(self, **kwargs) -> str:
"""Return the login page url."""
diff --git
a/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
index 0deaadf4034..0559a388156 100644
---
a/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
+++
b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/simple_auth_manager.py
@@ -238,9 +238,6 @@ class
SimpleAuthManager(BaseAuthManager[SimpleAuthManagerUser]):
def serialize_user(self, user: SimpleAuthManagerUser) -> dict[str, Any]:
return {"sub": user.username, "role": user.role, "teams": user.teams}
- def get_cli_user(self) -> SimpleAuthManagerUser:
- return SimpleAuthManagerUser(username="cli",
role=SimpleAuthManagerRole.ADMIN.name)
-
def is_authorized_configuration(
self,
*,
diff --git a/airflow-core/src/airflow/cli/api_client.py
b/airflow-core/src/airflow/cli/api_client.py
deleted file mode 100644
index d1ff5e2ddd9..00000000000
--- a/airflow-core/src/airflow/cli/api_client.py
+++ /dev/null
@@ -1,129 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Provide the :mod:`airflowctl` HTTP API client to the local Airflow CLI.
-
-The local CLI talks to the API server through the same typed client that
``airflowctl``
-uses, but without the keyring-backed credential store. For each invocation it
mints a
-short-lived JWT **in memory** (via the active auth manager) and builds a
client with it;
-nothing is persisted. Set the ``AIRFLOW_CLI_TOKEN`` environment variable to
supply a token
-explicitly (required for auth managers whose tokens cannot be minted locally,
such as
-Keycloak, or when targeting a remote API server).
-"""
-
-from __future__ import annotations
-
-import atexit
-import os
-from collections.abc import Callable
-from functools import wraps
-from typing import TYPE_CHECKING, TypeVar
-
-import httpx
-
-# Re-exported so command modules import the client surface from a single place.
-from airflowctl.api.client import NEW_API_CLIENT, Client, ClientKind
-
-from airflow.configuration import conf
-from airflow.typing_compat import ParamSpec
-
-if TYPE_CHECKING:
- from airflow.api_fastapi.auth.managers.base_auth_manager import
BaseAuthManager
-
-__all__ = [
- "NEW_API_CLIENT",
- "Client",
- "ClientKind",
- "get_cli_api_client",
- "provide_api_client",
-]
-
-PS = ParamSpec("PS")
-RT = TypeVar("RT")
-
-# Validity of the in-memory CLI token. It only needs to outlive a single CLI
command
-# (including the client's request retries) and is never persisted or logged.
-_CLI_TOKEN_VALID_FOR_SECONDS = 300
-
-_api_client: Client | None = None
-
-
-def _resolve_base_url() -> str:
- """Resolve the API server base URL from configuration."""
- base_url = conf.get("api", "base_url", fallback=None)
- if base_url:
- return base_url
- host = conf.get("api", "host", fallback="localhost") or "localhost"
- port = conf.get("api", "port", fallback="8080") or "8080"
- return f"http://{host}:{port}"
-
-
-def _mint_cli_token() -> str:
- """
- Return a token for the CLI to authenticate against the API server.
-
- Prefers an explicit ``AIRFLOW_CLI_TOKEN`` (the universal override),
otherwise mints a
- short-lived JWT through the active auth manager. The token lives only in
this process.
- """
- if token := os.environ.get("AIRFLOW_CLI_TOKEN"):
- return token
-
- from airflow.api_fastapi.app import get_auth_manager, init_auth_manager
-
- # The CLI runs outside the API server, so the auth manager singleton is
usually not
- # initialized yet; initialize it on demand. ``init_auth_manager`` reuses
the cached
- # instance when one already exists, so this is safe to call here.
- try:
- auth_manager: BaseAuthManager = get_auth_manager()
- except RuntimeError:
- auth_manager = init_auth_manager()
- return auth_manager.generate_jwt(
- auth_manager.get_cli_user(),
- expiration_time_in_seconds=_CLI_TOKEN_VALID_FOR_SECONDS,
- )
-
-
-def get_cli_api_client() -> Client:
- """Return the process-wide singleton airflowctl client for the local
CLI."""
- global _api_client
- if _api_client is None:
- _api_client = Client(
- base_url=_resolve_base_url(),
- token=_mint_cli_token(),
- kind=ClientKind.CLI,
- limits=httpx.Limits(max_keepalive_connections=1,
max_connections=1),
- )
- atexit.register(_api_client.close)
- return _api_client
-
-
-def provide_api_client(func: Callable[PS, RT]) -> Callable[PS, RT]:
- """
- Provide the CLI API client to the decorated command function.
-
- Injects ``api_client=get_cli_api_client()`` when the caller does not pass
one. Tests
- (or callers that already hold a client) pass ``api_client=`` explicitly to
bypass it.
- """
-
- @wraps(func)
- def wrapper(*args, **kwargs) -> RT:
- if "api_client" not in kwargs:
- kwargs["api_client"] = get_cli_api_client()
- return func(*args, **kwargs)
-
- return wrapper
diff --git a/airflow-core/src/airflow/cli/commands/asset_command.py
b/airflow-core/src/airflow/cli/commands/asset_command.py
index 29c1025958a..3897e42b4fd 100644
--- a/airflow-core/src/airflow/cli/commands/asset_command.py
+++ b/airflow-core/src/airflow/cli/commands/asset_command.py
@@ -17,17 +17,22 @@
from __future__ import annotations
+import logging
import typing
from sqlalchemy import select
+from airflow.api.common.trigger_dag import trigger_dag
from airflow.api_fastapi.core_api.datamodels.assets import AssetAliasResponse,
AssetResponse
-from airflow.cli.api_client import NEW_API_CLIENT, Client, provide_api_client
+from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
from airflow.cli.simple_table import AirflowConsole
from airflow.cli.utils import deprecated_for_airflowctl
-from airflow.models.asset import AssetAliasModel, AssetModel
+from airflow.exceptions import AirflowConfigException
+from airflow.models.asset import AssetAliasModel, AssetModel,
TaskOutletAssetReference
from airflow.utils import cli as cli_utils
+from airflow.utils.platform import getuser
from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.types import DagRunTriggeredByType, DagRunType
if typing.TYPE_CHECKING:
from typing import Any
@@ -36,6 +41,8 @@ if typing.TYPE_CHECKING:
from airflow.api_fastapi.core_api.base import BaseModel
+log = logging.getLogger(__name__)
+
def _list_asset_aliases(args, *, session: Session) -> tuple[Any,
type[BaseModel]]:
aliases =
session.scalars(select(AssetAliasModel).order_by(AssetAliasModel.name))
@@ -43,13 +50,7 @@ def _list_asset_aliases(args, *, session: Session) ->
tuple[Any, type[BaseModel]
def _list_assets(args, *, session: Session) -> tuple[Any, type[BaseModel]]:
- assets =
session.scalars(select(AssetModel).order_by(AssetModel.name)).all()
- for asset in assets:
- for watcher in asset.watchers:
- # ``AssetWatcherModel`` has no ``created_date`` column; like the
public API
- # serializer, derive it from the watcher's trigger so
``AssetResponse`` validation
- # succeeds. Set on the instance so ``model_validate`` reads it via
``from_attributes``.
- watcher.created_date = watcher.trigger.created_date
+ assets = session.scalars(select(AssetModel).order_by(AssetModel.name))
return assets, AssetResponse
@@ -123,39 +124,50 @@ def asset_details(args, *, session: Session =
NEW_SESSION) -> None:
AirflowConsole().print_as(data=data, output=args.output)
-@cli_utils.action_cli
@deprecated_for_airflowctl("airflowctl assets materialize")
-@provide_api_client
-def asset_materialize(args, api_client: Client = NEW_API_CLIENT) -> None:
+@cli_utils.action_cli
+@provide_session
+def asset_materialize(args, *, session: Session = NEW_SESSION) -> None:
"""
Materialize the specified asset.
This is done by finding the DAG with the asset defined as outlet, and
create
- a run for that DAG. Resolving the DAG and creating the run is handled by
the API
- server; the asset is identified here by its name and/or URI.
+ a run for that DAG.
"""
if not args.name and not args.uri:
raise SystemExit("Either --name or --uri is required")
+ stmt =
select(TaskOutletAssetReference.dag_id).join(TaskOutletAssetReference.asset)
select_message_parts = []
if args.name:
+ stmt = stmt.where(AssetModel.name == args.name)
select_message_parts.append(f"name {args.name}")
if args.uri:
+ stmt = stmt.where(AssetModel.uri == args.uri)
select_message_parts.append(f"URI {args.uri}")
+ dag_id_it =
iter(session.scalars(stmt.group_by(TaskOutletAssetReference.dag_id).limit(2)))
select_message = " and ".join(select_message_parts)
- matches = [
- asset
- for asset in api_client.assets.list().assets
- if (not args.name or asset.name == args.name) and (not args.uri or
asset.uri == args.uri)
- ]
- if not matches:
+ if (dag_id := next(dag_id_it, None)) is None:
raise SystemExit(f"Asset with {select_message} does not exist.")
- if len(matches) > 1:
- raise SystemExit(f"More than one asset exists with {select_message}.")
-
- dag_run = api_client.assets.materialize(asset_id=str(matches[0].id))
- AirflowConsole().print_as(
- data=[dag_run.model_dump(mode="json")],
- output=args.output,
+ if next(dag_id_it, None) is not None:
+ raise SystemExit(f"More than one DAG materializes asset with
{select_message}.")
+
+ try:
+ user = getuser()
+ except AirflowConfigException as e:
+ log.warning("Failed to get user name from os: %s, not setting the
triggering user", e)
+ user = None
+ dagrun = trigger_dag(
+ dag_id=dag_id,
+ triggered_by=DagRunTriggeredByType.CLI,
+ run_type=DagRunType.ASSET_MATERIALIZATION,
+ triggering_user_name=user,
+ session=session,
)
+ if dagrun is not None:
+ data = [DAGRunResponse.model_validate(dagrun).model_dump(mode="json")]
+ else:
+ data = []
+
+ AirflowConsole().print_as(data=data, output=args.output)
diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py
b/airflow-core/src/airflow/cli/commands/dag_command.py
index 39bdc7a0470..756048af14d 100644
--- a/airflow-core/src/airflow/cli/commands/dag_command.py
+++ b/airflow-core/src/airflow/cli/commands/dag_command.py
@@ -33,15 +33,14 @@ from typing import TYPE_CHECKING, cast
from sqlalchemy import func, select
from airflow._shared.timezones import timezone
-from airflow.api_fastapi.core_api.datamodels.dag_run import
TriggerDAGRunPostBody
+from airflow.api.client import get_current_api_client
from airflow.api_fastapi.core_api.datamodels.dags import DAGResponse
-from airflow.cli.api_client import NEW_API_CLIENT, Client, provide_api_client
from airflow.cli.simple_table import AirflowConsole
from airflow.cli.utils import deprecated_for_airflowctl,
fetch_dag_run_from_run_id_or_logical_date_string
from airflow.dag_processing.bundles.base import unpack_bundle_version
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.dag_processing.dagbag import BundleDagBag, DagBag, sync_bag_to_db
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowConfigException, AirflowException
from airflow.jobs.job import Job
from airflow.models import DagModel, DagRun, TaskInstance
from airflow.models.errors import ParseImportError
@@ -57,6 +56,7 @@ from airflow.utils.cli import (
)
from airflow.utils.dot_renderer import render_dag, render_dag_dependencies
from airflow.utils.helpers import ask_yesno, chunks
+from airflow.utils.platform import getuser
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState, TaskInstanceState
@@ -80,42 +80,50 @@ log = logging.getLogger(__name__)
_RUN_CHUNK_SIZE = 500
-@cli_utils.action_cli
@deprecated_for_airflowctl("airflowctl dags trigger")
+@cli_utils.action_cli
@providers_configuration_loaded
-@provide_api_client
-def dag_trigger(args, api_client: Client = NEW_API_CLIENT) -> None:
+def dag_trigger(args) -> None:
"""Create a dag run for the specified dag."""
- run_conf = json.loads(args.conf) if args.conf is not None else None
- if run_conf is not None and not isinstance(run_conf, dict):
- raise ValueError("DagRun conf must be a JSON object or null")
- # The core_api request models are the source of truth; they are
wire-compatible with
- # the airflowctl client's generated models (the API server uses
populate_by_name).
- trigger_body = TriggerDAGRunPostBody(
- dag_run_id=args.run_id,
- conf=run_conf,
- logical_date=args.logical_date,
- )
- dag_run = api_client.dags.trigger(dag_id=args.dag_id,
trigger_dag_run=trigger_body) # type: ignore[arg-type]
- AirflowConsole().print_as(
- data=[dag_run.model_dump(mode="json")],
- output=args.output,
- )
+ api_client = get_current_api_client()
+ try:
+ user = getuser()
+ except AirflowConfigException as e:
+ log.warning("Failed to get user name from os: %s, not setting the
triggering user", e)
+ user = None
+ try:
+ message = api_client.trigger_dag(
+ dag_id=args.dag_id,
+ run_id=args.run_id,
+ conf=args.conf,
+ logical_date=args.logical_date,
+ triggering_user_name=user,
+ replace_microseconds=args.replace_microseconds,
+ )
+ AirflowConsole().print_as(
+ data=[message] if message is not None else [],
+ output=args.output,
+ )
+ except OSError as err:
+ raise AirflowException(err)
-@cli_utils.action_cli
@deprecated_for_airflowctl("airflowctl dags delete")
+@cli_utils.action_cli
@providers_configuration_loaded
-@provide_api_client
-def dag_delete(args, api_client: Client = NEW_API_CLIENT) -> None:
+def dag_delete(args) -> None:
"""Delete all DB records related to the specified dag."""
+ api_client = get_current_api_client()
if (
args.yes
or input("This will drop all existing records related to the specified
DAG. Proceed? (y/n)").upper()
== "Y"
):
- api_client.dags.delete(dag_id=args.dag_id)
- print(f"Removed DAG {args.dag_id}")
+ try:
+ message = api_client.delete_dag(dag_id=args.dag_id)
+ print(message)
+ except OSError as err:
+ raise AirflowException(err)
else:
print("Cancelled")
diff --git a/airflow-core/src/airflow/cli/commands/pool_command.py
b/airflow-core/src/airflow/cli/commands/pool_command.py
index a51351f73bd..0d1f087e377 100644
--- a/airflow-core/src/airflow/cli/commands/pool_command.py
+++ b/airflow-core/src/airflow/cli/commands/pool_command.py
@@ -23,12 +23,10 @@ import json
import os
from json import JSONDecodeError
-from airflowctl.api.operations import ServerResponseError
-
-from airflow.api_fastapi.core_api.datamodels.pools import PoolBody
-from airflow.cli.api_client import NEW_API_CLIENT, Client, provide_api_client
+from airflow.api.client import get_current_api_client
from airflow.cli.simple_table import AirflowConsole
from airflow.cli.utils import deprecated_for_airflowctl
+from airflow.exceptions import PoolNotFound
from airflow.utils import cli as cli_utils
from airflow.utils.cli import suppress_logs_and_warning
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
@@ -39,11 +37,11 @@ def _show_pools(pools, output):
data=pools,
output=output,
mapper=lambda x: {
- "pool": x.name,
- "slots": x.slots,
- "description": x.description,
- "include_deferred": x.include_deferred,
- "team_name": x.team_name,
+ "pool": x[0],
+ "slots": x[1],
+ "description": x[2],
+ "include_deferred": x[3],
+ "team_name": x[4],
},
)
@@ -51,66 +49,59 @@ def _show_pools(pools, output):
@deprecated_for_airflowctl("airflowctl pools list")
@suppress_logs_and_warning
@providers_configuration_loaded
-@provide_api_client
-def pool_list(args, api_client: Client = NEW_API_CLIENT):
+def pool_list(args):
"""Display info of all the pools."""
- pools = api_client.pools.list().pools
+ api_client = get_current_api_client()
+ pools = api_client.get_pools()
_show_pools(pools=pools, output=args.output)
@deprecated_for_airflowctl("airflowctl pools get")
@suppress_logs_and_warning
@providers_configuration_loaded
-@provide_api_client
-def pool_get(args, api_client: Client = NEW_API_CLIENT):
+def pool_get(args):
"""Display pool info by a given name."""
+ api_client = get_current_api_client()
try:
- pools = [api_client.pools.get(pool_name=args.pool)]
+ pools = [api_client.get_pool(name=args.pool)]
_show_pools(pools=pools, output=args.output)
- except ServerResponseError as e:
- if e.response.status_code == 404:
- raise SystemExit(f"Pool {args.pool} does not exist")
- raise
+ except PoolNotFound:
+ raise SystemExit(f"Pool {args.pool} does not exist")
-@cli_utils.action_cli
@deprecated_for_airflowctl("airflowctl pools create")
+@cli_utils.action_cli
@suppress_logs_and_warning
@providers_configuration_loaded
-@provide_api_client
-def pool_set(args, api_client: Client = NEW_API_CLIENT):
+def pool_set(args):
"""Create new pool with a given name and slots."""
- # core_api PoolBody is the source of truth and is wire-compatible with the
airflowctl
- # client's generated model (the API server uses populate_by_name).
- pool_body = PoolBody(
+ api_client = get_current_api_client()
+ api_client.create_pool(
name=args.pool,
slots=args.slots,
description=args.description,
include_deferred=args.include_deferred,
team_name=args.team_name,
)
- api_client.pools.create(pool=pool_body) # type: ignore[arg-type]
print(f"Pool {args.pool} created")
-@cli_utils.action_cli
@deprecated_for_airflowctl("airflowctl pools delete")
+@cli_utils.action_cli
@suppress_logs_and_warning
@providers_configuration_loaded
-@provide_api_client
-def pool_delete(args, api_client: Client = NEW_API_CLIENT):
+def pool_delete(args):
"""Delete pool by a given name."""
+ api_client = get_current_api_client()
try:
- api_client.pools.delete(pool=args.pool)
+ api_client.delete_pool(name=args.pool)
print(f"Pool {args.pool} deleted")
- except ServerResponseError as e:
- if e.response.status_code == 404:
- raise SystemExit(f"Pool {args.pool} does not exist")
- raise
+ except PoolNotFound:
+ raise SystemExit(f"Pool {args.pool} does not exist")
-@cli_utils.action_cli
@deprecated_for_airflowctl("airflowctl pools import")
+@cli_utils.action_cli
@suppress_logs_and_warning
@providers_configuration_loaded
def pool_import(args):
@@ -131,9 +122,10 @@ def pool_export(args):
print(f"Exported {len(pools)} pools to {args.file}")
-@provide_api_client
-def pool_import_helper(filepath, api_client: Client = NEW_API_CLIENT):
+def pool_import_helper(filepath):
"""Help import pools from the json file."""
+ api_client = get_current_api_client()
+
with open(filepath) as poolfile:
data = poolfile.read()
try:
@@ -144,33 +136,34 @@ def pool_import_helper(filepath, api_client: Client =
NEW_API_CLIENT):
failed = []
for k, v in pools_json.items():
if isinstance(v, dict) and "slots" in v and "description" in v:
- pool_body = PoolBody(
- name=k,
- slots=v["slots"],
- description=v["description"],
- include_deferred=v.get("include_deferred", False),
- team_name=v.get("team_name"),
+ pools.append(
+ api_client.create_pool(
+ name=k,
+ slots=v["slots"],
+ description=v["description"],
+ include_deferred=v.get("include_deferred", False),
+ team_name=v.get("team_name"),
+ )
)
- pools.append(api_client.pools.create(pool=pool_body)) # type:
ignore[arg-type]
else:
failed.append(k)
return pools, failed
-@provide_api_client
-def pool_export_helper(filepath, api_client: Client = NEW_API_CLIENT):
+def pool_export_helper(filepath):
"""Help export all the pools to the json file."""
+ api_client = get_current_api_client()
pool_dict = {}
- pools = api_client.pools.list().pools
+ pools = api_client.get_pools()
for pool in pools:
entry = {
- "slots": pool.slots,
- "description": pool.description,
- "include_deferred": pool.include_deferred,
+ "slots": pool[1],
+ "description": pool[2],
+ "include_deferred": pool[3],
}
- if pool.team_name is not None:
- entry["team_name"] = pool.team_name
- pool_dict[pool.name] = entry
+ if pool[4] is not None:
+ entry["team_name"] = pool[4]
+ pool_dict[pool[0]] = entry
with open(filepath, "w") as poolfile:
poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4))
return pools
diff --git a/airflow-core/tests/unit/cli/commands/test_asset_command.py
b/airflow-core/tests/unit/cli/commands/test_asset_command.py
index d30e36eb04d..a7329a96251 100644
--- a/airflow-core/tests/unit/cli/commands/test_asset_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_asset_command.py
@@ -21,7 +21,7 @@ from __future__ import annotations
import json
import os
import typing
-from types import SimpleNamespace
+from unittest import mock
import pytest
@@ -37,9 +37,7 @@ if typing.TYPE_CHECKING:
pytestmark = [pytest.mark.db_test]
-# Not autouse: only the DB-backed tests below request it, so the mocked
(non-DB)
-# ``assets materialize`` tests stay free of any database access.
[email protected](scope="module")
[email protected](scope="module", autouse=True)
def prepare_examples():
with conf_vars({("core", "load_examples"): "True"}):
parse_and_sync_to_db(os.devnull)
@@ -48,12 +46,17 @@ def prepare_examples():
clear_db_dags()
[email protected](autouse=True)
+def clear_runs():
+ clear_db_runs()
+
+
@pytest.fixture(scope="module")
def parser() -> ArgumentParser:
return cli_parser.get_parser()
-def test_cli_assets_list(prepare_examples, parser: ArgumentParser,
stdout_capture) -> None:
+def test_cli_assets_list(parser: ArgumentParser, stdout_capture) -> None:
args = parser.parse_args(["assets", "list", "--output=json"])
with stdout_capture as capture:
asset_command.asset_list(args)
@@ -64,7 +67,7 @@ def test_cli_assets_list(prepare_examples, parser:
ArgumentParser, stdout_captur
assert any(asset["uri"] == "s3://dag1/output_1.txt" for asset in
asset_list), asset_list
-def test_cli_assets_alias_list(prepare_examples, parser: ArgumentParser,
stdout_capture) -> None:
+def test_cli_assets_alias_list(parser: ArgumentParser, stdout_capture) -> None:
args = parser.parse_args(["assets", "list", "--alias", "--output=json"])
with stdout_capture as capture:
asset_command.asset_list(args)
@@ -75,7 +78,7 @@ def test_cli_assets_alias_list(prepare_examples, parser:
ArgumentParser, stdout_
assert any(alias["name"] == "example-alias" for alias in alias_list),
alias_list
-def test_cli_assets_details(prepare_examples, parser: ArgumentParser,
stdout_capture) -> None:
+def test_cli_assets_details(parser: ArgumentParser, stdout_capture) -> None:
args = parser.parse_args(["assets", "details", "--name=asset1_producer",
"--output=json"])
with stdout_capture as capture:
asset_command.asset_details(args)
@@ -104,7 +107,7 @@ def test_cli_assets_details(prepare_examples, parser:
ArgumentParser, stdout_cap
}
-def test_cli_assets_alias_details(prepare_examples, parser: ArgumentParser,
stdout_capture) -> None:
+def test_cli_assets_alias_details(parser: ArgumentParser, stdout_capture) ->
None:
args = parser.parse_args(["assets", "details", "--alias",
"--name=example-alias", "--output=json"])
with stdout_capture as capture:
asset_command.asset_details(args)
@@ -121,46 +124,87 @@ def test_cli_assets_alias_details(prepare_examples,
parser: ArgumentParser, stdo
}
[email protected]_db_test_override
-class TestCliAssetsMaterialize:
- """`assets materialize` goes through the airflowctl client; mocked here
(no DB/server)."""
-
- def test_materialize(self, parser: ArgumentParser, mock_cli_api_client,
stdout_capture) -> None:
- mock_cli_api_client.assets.list.return_value.assets = [
- SimpleNamespace(id=7, name="asset1_producer",
uri="s3://bucket/asset1_producer"),
- SimpleNamespace(id=8, name="other", uri="s3://bucket/other"),
- ]
-
mock_cli_api_client.assets.materialize.return_value.model_dump.return_value = {
- "dag_id": "asset1_producer",
- "run_type": "asset_materialization",
- "state": "queued",
- }
- args = parser.parse_args(["assets", "materialize",
"--name=asset1_producer", "--output=json"])
- with stdout_capture as capture:
- asset_command.asset_materialize(args)
-
- run_list = json.loads(capture.getvalue())
- assert len(run_list) == 1
- assert run_list[0]["dag_id"] == "asset1_producer"
- # The asset is resolved to its id and materialization is delegated to
the API server.
-
mock_cli_api_client.assets.materialize.assert_called_once_with(asset_id="7")
-
- def test_materialize_requires_name_or_uri(self, parser: ArgumentParser,
mock_cli_api_client) -> None:
- with pytest.raises(SystemExit, match="Either --name or --uri is
required"):
- asset_command.asset_materialize(parser.parse_args(["assets",
"materialize"]))
- mock_cli_api_client.assets.materialize.assert_not_called()
-
- def test_materialize_missing(self, parser: ArgumentParser,
mock_cli_api_client) -> None:
- mock_cli_api_client.assets.list.return_value.assets = []
- with pytest.raises(SystemExit, match="Asset with name nope does not
exist"):
- asset_command.asset_materialize(parser.parse_args(["assets",
"materialize", "--name=nope"]))
- mock_cli_api_client.assets.materialize.assert_not_called()
-
- def test_materialize_ambiguous(self, parser: ArgumentParser,
mock_cli_api_client) -> None:
- mock_cli_api_client.assets.list.return_value.assets = [
- SimpleNamespace(id=1, name="dup", uri="s3://a"),
- SimpleNamespace(id=2, name="dup", uri="s3://b"),
- ]
- with pytest.raises(SystemExit, match="More than one asset exists with
name dup"):
- asset_command.asset_materialize(parser.parse_args(["assets",
"materialize", "--name=dup"]))
- mock_cli_api_client.assets.materialize.assert_not_called()
[email protected]("airflow.api_fastapi.core_api.datamodels.dag_versions.hasattr")
+def test_cli_assets_materialize(mock_hasattr, parser: ArgumentParser,
stdout_capture) -> None:
+ mock_hasattr.return_value = False
+ args = parser.parse_args(["assets", "materialize",
"--name=asset1_producer", "--output=json"])
+ with stdout_capture as capture:
+ asset_command.asset_materialize(args)
+
+ output = capture.getvalue()
+
+ # Check if output is empty first
+ assert output, "No output captured from asset_materialize command"
+
+ run_list = json.loads(output)
+ assert len(run_list) == 1
+
+ # No good way to statically compare these.
+ undeterministic: dict = {
+ "dag_run_id": None,
+ "dag_versions": [],
+ "data_interval_end": None,
+ "data_interval_start": None,
+ "logical_date": None,
+ "queued_at": None,
+ "run_after": "2025-02-12T19:27:59.066046Z",
+ }
+
+ assert run_list[0] | undeterministic == undeterministic | {
+ "conf": {},
+ "bundle_version": None,
+ "dag_display_name": "asset1_producer",
+ "dag_id": "asset1_producer",
+ "end_date": None,
+ "duration": None,
+ "last_scheduling_decision": None,
+ "note": None,
+ "partition_date": None,
+ "partition_key": None,
+ "run_type": "asset_materialization",
+ "start_date": None,
+ "state": "queued",
+ "triggered_by": "cli",
+ "triggering_user_name": "root",
+ "run_after": "2025-02-12T19:27:59.066046Z",
+ }
+
+
+def test_cli_assets_materialize_with_view_url_template(parser: ArgumentParser,
stdout_capture) -> None:
+ args = parser.parse_args(["assets", "materialize",
"--name=asset1_producer", "--output=json"])
+ with stdout_capture as capture:
+ asset_command.asset_materialize(args)
+
+ output = capture.getvalue()
+ run_list = json.loads(output)
+ assert len(run_list) == 1
+
+ # No good way to statically compare these.
+ undeterministic: dict = {
+ "dag_run_id": None,
+ "dag_versions": [],
+ "data_interval_end": None,
+ "data_interval_start": None,
+ "logical_date": None,
+ "queued_at": None,
+ "run_after": "2025-02-12T19:27:59.066046Z",
+ }
+
+ assert run_list[0] | undeterministic == undeterministic | {
+ "conf": {},
+ "bundle_version": None,
+ "dag_display_name": "asset1_producer",
+ "dag_id": "asset1_producer",
+ "end_date": None,
+ "duration": None,
+ "last_scheduling_decision": None,
+ "note": None,
+ "partition_date": None,
+ "partition_key": None,
+ "run_type": "asset_materialization",
+ "start_date": None,
+ "state": "queued",
+ "triggered_by": "cli",
+ "triggering_user_name": "root",
+ "run_after": "2025-02-12T19:27:59.066046Z",
+ }
diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py
b/airflow-core/tests/unit/cli/commands/test_dag_command.py
index 8ba0a64156a..059ce1d2f26 100644
--- a/airflow-core/tests/unit/cli/commands/test_dag_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py
@@ -25,14 +25,13 @@ from datetime import datetime, timedelta
from unittest import mock
from unittest.mock import MagicMock
-import httpx
import msgspec
import pendulum
import pytest
import time_machine
-from airflowctl.api.operations import ServerResponseError
-from sqlalchemy import select
+from sqlalchemy import func, select
+from airflow import settings
from airflow._shared.timezones import timezone
from airflow.cli import cli_parser
from airflow.cli.commands import dag_command
@@ -486,19 +485,21 @@ class TestCliDags:
assert str(path_to_parse) in log_output
assert "[0 100 * * *] is not acceptable, out of range" in log_output
- def test_cli_list_dag_runs(self, dag_maker):
- # Seed a run directly in the DB; ``dags trigger`` now goes through the
API server
- # (airflowctl client) and cannot be used as an in-process fixture here.
- with dag_maker("test_list_dag_runs", start_date=DEFAULT_DATE,
serialized=True):
- EmptyOperator(task_id="t1")
- dag_maker.create_dagrun(state=DagRunState.SUCCESS,
logical_date=DEFAULT_DATE)
- dag_maker.sync_dagbag_to_db()
-
+ def test_cli_list_dag_runs(self):
+ dag_command.dag_trigger(
+ self.parser.parse_args(
+ [
+ "dags",
+ "trigger",
+ "example_bash_operator",
+ ]
+ )
+ )
args = self.parser.parse_args(
[
"dags",
"list-runs",
- "test_list_dag_runs",
+ "example_bash_operator",
"--no-backfill",
"--start-date",
DEFAULT_DATE.isoformat(),
@@ -591,6 +592,206 @@ class TestCliDags:
out = temp_stdout.splitlines()[-1]
assert out == "No unpaused DAGs were found"
+ def test_trigger_dag(self):
+ dag_command.dag_trigger(
+ self.parser.parse_args(
+ [
+ "dags",
+ "trigger",
+ "example_bash_operator",
+ "--run-id=test_trigger_dag",
+ '--conf={"foo": "bar"}',
+ ],
+ ),
+ )
+ with create_session() as session:
+ dagrun = session.scalars(select(DagRun).where(DagRun.run_id ==
"test_trigger_dag")).one()
+
+ assert dagrun, "DagRun not created"
+ assert dagrun.run_type == DagRunType.MANUAL
+ assert dagrun.conf == {"foo": "bar"}
+
+ # logical_date is None as it's not provided
+ assert dagrun.logical_date is None
+
+ # data_interval is None as logical_date is None
+ assert dagrun.data_interval_start is None
+ assert dagrun.data_interval_end is None
+
+ def test_trigger_dag_empty_object_conf(self):
+ dag_command.dag_trigger(
+ self.parser.parse_args(
+ [
+ "dags",
+ "trigger",
+ "example_bash_operator",
+ "--run-id=test_trigger_dag_empty_object_conf",
+ "--conf={}",
+ ],
+ ),
+ )
+ with create_session() as session:
+ dagrun = session.scalars(
+ select(DagRun).where(DagRun.run_id ==
"test_trigger_dag_empty_object_conf")
+ ).one()
+
+ assert dagrun.conf == {}
+
+ def test_trigger_dag_json_null_conf(self):
+ dag_command.dag_trigger(
+ self.parser.parse_args(
+ [
+ "dags",
+ "trigger",
+ "example_bash_operator",
+ "--run-id=test_trigger_dag_json_null_conf",
+ "--conf=null",
+ ],
+ ),
+ )
+ with create_session() as session:
+ dagrun = session.scalars(
+ select(DagRun).where(DagRun.run_id ==
"test_trigger_dag_json_null_conf")
+ ).one()
+
+ assert dagrun.conf == {}
+
+ def test_trigger_dag_with_microseconds(self):
+ dag_command.dag_trigger(
+ self.parser.parse_args(
+ [
+ "dags",
+ "trigger",
+ "example_bash_operator",
+ "--run-id=test_trigger_dag_with_micro",
+ "--logical-date=2021-06-04T09:00:00.000001+08:00",
+ "--no-replace-microseconds",
+ ],
+ )
+ )
+
+ with create_session() as session:
+ dagrun = session.scalars(
+ select(DagRun).where(DagRun.run_id ==
"test_trigger_dag_with_micro")
+ ).one()
+
+ assert dagrun, "DagRun not created"
+ assert dagrun.run_type == DagRunType.MANUAL
+ assert dagrun.logical_date.isoformat(timespec="microseconds") ==
"2021-06-04T01:00:00.000001+00:00"
+
+ @pytest.mark.parametrize("conf", ["NOT JSON", ""])
+ def test_trigger_dag_invalid_conf(self, conf):
+ with pytest.raises(ValueError, match=r"Expecting value: line \d+
column \d+ \(char \d+\)"):
+ dag_command.dag_trigger(
+ self.parser.parse_args(
+ [
+ "dags",
+ "trigger",
+ "example_bash_operator",
+ "--run-id",
+ "trigger_dag_xxx",
+ "--conf",
+ conf,
+ ]
+ ),
+ )
+
+ @pytest.mark.parametrize("conf", ["[]", '"str"', "1", "false"])
+ def test_trigger_dag_rejects_non_object_conf(self, conf):
+ with pytest.raises(ValueError, match="DagRun conf must be a JSON
object or null"):
+ dag_command.dag_trigger(
+ self.parser.parse_args(
+ [
+ "dags",
+ "trigger",
+ "example_bash_operator",
+ "--run-id",
+ "trigger_dag_xxx",
+ "--conf",
+ conf,
+ ]
+ ),
+ )
+
+ def test_trigger_dag_output_as_json(self, stdout_capture):
+ args = self.parser.parse_args(
+ [
+ "dags",
+ "trigger",
+ "example_bash_operator",
+ "--run-id",
+ "trigger_dag_xxx",
+ "--conf",
+ '{"conf1": "val1", "conf2": "val2"}',
+ "--output=json",
+ ]
+ )
+ with stdout_capture as temp_stdout:
+ dag_command.dag_trigger(args)
+ # get the last line from the logs ignoring all logging lines
+ out = temp_stdout.getvalue().strip().splitlines()[-1]
+ parsed_out = json.loads(out)
+
+ assert len(parsed_out) == 1
+ assert parsed_out[0]["dag_id"] == "example_bash_operator"
+ assert parsed_out[0]["dag_run_id"] == "trigger_dag_xxx"
+ assert parsed_out[0]["conf"] == {"conf1": "val1", "conf2": "val2"}
+
+ def test_delete_dag(self):
+ DM = DagModel
+ key = "my_dag_id"
+ session = settings.Session()
+ session.add(DM(dag_id=key, bundle_name="dags-folder"))
+ session.commit()
+ dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key,
"--yes"]))
+ assert
session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0
+ with pytest.raises(AirflowException):
+ dag_command.dag_delete(
+ self.parser.parse_args(["dags", "delete",
"does_not_exist_dag", "--yes"]),
+ )
+
+ def test_dag_delete_when_backfill_and_dagrun_exist(self):
+ # Test to check that the DAG should be deleted even if
+ # there are backfill records associated with it.
+ from airflow.models.backfill import Backfill
+
+ DM = DagModel
+ key = "my_dag_id"
+ session = settings.Session()
+ session.add(DM(dag_id=key, bundle_name="dags-folder"))
+ _backfill = Backfill(dag_id=key, from_date=DEFAULT_DATE,
to_date=DEFAULT_DATE + timedelta(days=1))
+ session.add(_backfill)
+ # To create the backfill_id in DagRun
+ session.flush()
+ session.add(
+ DagRun(
+ dag_id=key,
+ run_id="backfill__" + key,
+ state=DagRunState.SUCCESS,
+ run_type="backfill",
+ backfill_id=_backfill.id,
+ )
+ )
+ session.commit()
+ dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key,
"--yes"]))
+ assert
session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0
+ with pytest.raises(AirflowException):
+ dag_command.dag_delete(
+ self.parser.parse_args(["dags", "delete",
"does_not_exist_dag", "--yes"]),
+ )
+
+ def test_delete_dag_existing_file(self, tmp_path):
+ # Test to check that the DAG should be deleted even if
+ # the file containing it is not deleted
+ path = tmp_path / "testfile"
+ DM = DagModel
+ key = "my_dag_id"
+ session = settings.Session()
+ session.add(DM(dag_id=key, bundle_name="dags-folder",
fileloc=os.fspath(path)))
+ session.commit()
+ dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key,
"--yes"]))
+ assert
session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0
+
def test_cli_list_jobs(self):
args = self.parser.parse_args(["dags", "list-jobs"])
dag_command.dag_list_jobs(args)
@@ -1908,142 +2109,3 @@ class TestDagDetailsIsBackfillable:
)
dag_details = dag_command._get_dagbag_dag_details(dag)
assert dag_details["is_backfillable"] is expected
-
-
-def _server_error(status_code: int) -> ServerResponseError:
- request = httpx.Request("DELETE", "http://testserver/api/v2/dags/foo")
- response = httpx.Response(status_code, request=request, json={"detail":
"boom"})
- return ServerResponseError(message="boom", request=request,
response=response)
-
-
[email protected]_db_test_override
-class TestCliDagsApiClientCommands:
- """Dag CLI commands that talk to the API server through the airflowctl
client.
-
- These are unit tests: the airflowctl client is mocked so no API server (or
- database) is required.
- """
-
- @classmethod
- def setup_class(cls):
- cls.parser = cli_parser.get_parser()
-
- @pytest.fixture(autouse=True)
- def _default_trigger_response(self, mock_cli_api_client):
- """Give the mocked ``dags.trigger`` a dict response so ``print_as``
can render it."""
- mock_cli_api_client.dags.trigger.return_value.model_dump.return_value
= {
- "dag_id": "example_bash_operator",
- "dag_run_id": "test_run",
- }
-
- def test_trigger_dag(self, mock_cli_api_client):
- dag_command.dag_trigger(
- self.parser.parse_args(
- [
- "dags",
- "trigger",
- "example_bash_operator",
- "--run-id=test_trigger_dag",
- '--conf={"foo": "bar"}',
- ]
- ),
- )
- mock_cli_api_client.dags.trigger.assert_called_once()
- call = mock_cli_api_client.dags.trigger.call_args
- assert call.kwargs["dag_id"] == "example_bash_operator"
- body = call.kwargs["trigger_dag_run"]
- assert body.dag_run_id == "test_trigger_dag"
- assert body.conf == {"foo": "bar"}
- # logical_date is None as it's not provided
- assert body.logical_date is None
-
- def test_trigger_dag_empty_object_conf(self, mock_cli_api_client):
- dag_command.dag_trigger(
- self.parser.parse_args(
- ["dags", "trigger", "example_bash_operator",
"--run-id=empty_conf", "--conf={}"]
- ),
- )
- body =
mock_cli_api_client.dags.trigger.call_args.kwargs["trigger_dag_run"]
- assert body.conf == {}
-
- def test_trigger_dag_json_null_conf(self, mock_cli_api_client):
- dag_command.dag_trigger(
- self.parser.parse_args(
- ["dags", "trigger", "example_bash_operator",
"--run-id=null_conf", "--conf=null"]
- ),
- )
- # ``null`` conf resolves to None on the client; the API server coerces
it to {}.
- body =
mock_cli_api_client.dags.trigger.call_args.kwargs["trigger_dag_run"]
- assert body.conf is None
-
- def test_trigger_dag_with_microseconds(self, mock_cli_api_client):
- dag_command.dag_trigger(
- self.parser.parse_args(
- [
- "dags",
- "trigger",
- "example_bash_operator",
- "--run-id=micro",
- "--logical-date=2021-06-04T09:00:00.000001+08:00",
- ]
- )
- )
- body =
mock_cli_api_client.dags.trigger.call_args.kwargs["trigger_dag_run"]
- assert body.logical_date.isoformat(timespec="microseconds") ==
"2021-06-04T09:00:00.000001+08:00"
-
- @pytest.mark.parametrize("conf", ["NOT JSON", ""])
- def test_trigger_dag_invalid_conf(self, mock_cli_api_client, conf):
- with pytest.raises(ValueError, match=r"Expecting value: line \d+
column \d+ \(char \d+\)"):
- dag_command.dag_trigger(
- self.parser.parse_args(
- ["dags", "trigger", "example_bash_operator", "--run-id",
"xxx", "--conf", conf]
- ),
- )
- mock_cli_api_client.dags.trigger.assert_not_called()
-
- @pytest.mark.parametrize("conf", ["[]", '"str"', "1", "false"])
- def test_trigger_dag_rejects_non_object_conf(self, mock_cli_api_client,
conf):
- with pytest.raises(ValueError, match="DagRun conf must be a JSON
object or null"):
- dag_command.dag_trigger(
- self.parser.parse_args(
- ["dags", "trigger", "example_bash_operator", "--run-id",
"xxx", "--conf", conf]
- ),
- )
- mock_cli_api_client.dags.trigger.assert_not_called()
-
- def test_trigger_dag_output_as_json(self, mock_cli_api_client,
stdout_capture):
- mock_cli_api_client.dags.trigger.return_value.model_dump.return_value
= {
- "dag_id": "example_bash_operator",
- "dag_run_id": "trigger_dag_xxx",
- "conf": {"conf1": "val1", "conf2": "val2"},
- }
- args = self.parser.parse_args(
- [
- "dags",
- "trigger",
- "example_bash_operator",
- "--run-id",
- "trigger_dag_xxx",
- "--conf",
- '{"conf1": "val1", "conf2": "val2"}',
- "--output=json",
- ]
- )
- with stdout_capture as temp_stdout:
- dag_command.dag_trigger(args)
- out = temp_stdout.getvalue().strip().splitlines()[-1]
- parsed_out = json.loads(out)
-
- assert len(parsed_out) == 1
- assert parsed_out[0]["dag_id"] == "example_bash_operator"
- assert parsed_out[0]["dag_run_id"] == "trigger_dag_xxx"
- assert parsed_out[0]["conf"] == {"conf1": "val1", "conf2": "val2"}
-
- def test_delete_dag(self, mock_cli_api_client):
- dag_command.dag_delete(self.parser.parse_args(["dags", "delete",
"my_dag_id", "--yes"]))
-
mock_cli_api_client.dags.delete.assert_called_once_with(dag_id="my_dag_id")
-
- def test_delete_dag_missing(self, mock_cli_api_client):
- mock_cli_api_client.dags.delete.side_effect = _server_error(404)
- with pytest.raises(ServerResponseError):
- dag_command.dag_delete(self.parser.parse_args(["dags", "delete",
"does_not_exist_dag", "--yes"]))
diff --git a/airflow-core/tests/unit/cli/commands/test_pool_command.py
b/airflow-core/tests/unit/cli/commands/test_pool_command.py
index 28e2d761812..ad6951567fb 100644
--- a/airflow-core/tests/unit/cli/commands/test_pool_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_pool_command.py
@@ -18,235 +18,281 @@
from __future__ import annotations
import json
-from types import SimpleNamespace
-import httpx
import pytest
-from airflowctl.api.operations import ServerResponseError
+from sqlalchemy import delete, func, select
+from airflow import models, settings
from airflow.cli import cli_parser
from airflow.cli.commands import pool_command
+from airflow.models import Pool
+from airflow.settings import Session
+from airflow.utils.db import add_default_pool_if_not_exists
-from tests_common.test_utils.config import conf_vars
-
-
-def _pool(name, slots, description="", include_deferred=False, team_name=None):
- """Build a stand-in for the airflowctl ``PoolResponse`` returned by the
client."""
- return SimpleNamespace(
- name=name,
- slots=slots,
- description=description,
- include_deferred=include_deferred,
- team_name=team_name,
- )
-
-
-def _server_error(status_code: int) -> ServerResponseError:
- request = httpx.Request("GET", "http://testserver/api/v2/pools/foo")
- response = httpx.Response(status_code, request=request, json={"detail":
"boom"})
- return ServerResponseError(message="boom", request=request,
response=response)
+pytestmark = pytest.mark.db_test
class TestCliPools:
@classmethod
def setup_class(cls):
+ cls.dagbag = models.DagBag()
cls.parser = cli_parser.get_parser()
-
- def test_pool_list(self, mock_cli_api_client, stdout_capture):
- mock_cli_api_client.pools.list.return_value.pools = [_pool("foo", 1,
"test")]
+ settings.configure_orm()
+ cls.session = Session
+ cls._cleanup()
+
+ def tearDown(self):
+ self._cleanup()
+
+ @staticmethod
+ def _cleanup(session=None):
+ if session is None:
+ session = Session()
+ session.execute(delete(Pool).where(Pool.pool !=
Pool.DEFAULT_POOL_NAME))
+ session.commit()
+ add_default_pool_if_not_exists()
+ session.close()
+
+ def test_pool_list(self, stdout_capture):
+ pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo",
"1", "test"]))
with stdout_capture as stdout:
pool_command.pool_list(self.parser.parse_args(["pools", "list"]))
assert "foo" in stdout.getvalue()
- mock_cli_api_client.pools.list.assert_called_once()
- def test_pool_list_with_args(self, mock_cli_api_client):
- mock_cli_api_client.pools.list.return_value.pools = [_pool("foo", 1,
"test")]
+ def test_pool_list_with_args(self):
pool_command.pool_list(self.parser.parse_args(["pools", "list",
"--output", "json"]))
- def test_pool_create(self, mock_cli_api_client):
+ def test_pool_create(self):
pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo",
"1", "test"]))
+ assert self.session.scalar(select(func.count()).select_from(Pool)) == 2
- mock_cli_api_client.pools.create.assert_called_once()
- body = mock_cli_api_client.pools.create.call_args.kwargs["pool"]
- # core_api PoolBody exposes the name via the ``pool`` attribute (alias
``name``).
- assert body.pool == "foo"
- assert body.slots == 1
- assert body.description == "test"
- assert body.include_deferred is False
+ def test_pool_update_deferred(self):
+ pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo",
"1", "test"]))
+ assert self.session.scalar(select(Pool).where(Pool.pool ==
"foo")).include_deferred is False
- def test_pool_create_include_deferred(self, mock_cli_api_client):
pool_command.pool_set(
self.parser.parse_args(["pools", "set", "foo", "1", "test",
"--include-deferred"])
)
+ assert self.session.scalar(select(Pool).where(Pool.pool ==
"foo")).include_deferred is True
- body = mock_cli_api_client.pools.create.call_args.kwargs["pool"]
- assert body.include_deferred is True
-
- def test_pool_get(self, mock_cli_api_client, stdout_capture):
- mock_cli_api_client.pools.get.return_value = _pool("foo", 1, "test")
- with stdout_capture as stdout:
- pool_command.pool_get(self.parser.parse_args(["pools", "get",
"foo"]))
-
- assert "foo" in stdout.getvalue()
- mock_cli_api_client.pools.get.assert_called_once_with(pool_name="foo")
-
- def test_pool_get_missing(self, mock_cli_api_client):
- mock_cli_api_client.pools.get.side_effect = _server_error(404)
- with pytest.raises(SystemExit, match="Pool foo does not exist"):
- pool_command.pool_get(self.parser.parse_args(["pools", "get",
"foo"]))
+ pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo",
"1", "test"]))
+ assert self.session.scalar(select(Pool).where(Pool.pool ==
"foo")).include_deferred is False
- def test_pool_get_other_error_reraised(self, mock_cli_api_client):
- mock_cli_api_client.pools.get.side_effect = _server_error(500)
- with pytest.raises(ServerResponseError):
- pool_command.pool_get(self.parser.parse_args(["pools", "get",
"foo"]))
+ def test_pool_get(self):
+ pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo",
"1", "test"]))
+ pool_command.pool_get(self.parser.parse_args(["pools", "get", "foo"]))
- def test_pool_delete(self, mock_cli_api_client):
+ def test_pool_delete(self):
+ pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo",
"1", "test"]))
pool_command.pool_delete(self.parser.parse_args(["pools", "delete",
"foo"]))
- mock_cli_api_client.pools.delete.assert_called_once_with(pool="foo")
-
- def test_pool_delete_missing(self, mock_cli_api_client):
- mock_cli_api_client.pools.delete.side_effect = _server_error(404)
- with pytest.raises(SystemExit, match="Pool foo does not exist"):
- pool_command.pool_delete(self.parser.parse_args(["pools",
"delete", "foo"]))
+ assert self.session.scalar(select(func.count()).select_from(Pool)) == 1
- def test_pool_import_nonexistent(self, mock_cli_api_client):
+ def test_pool_import_nonexistent(self):
with pytest.raises(SystemExit):
pool_command.pool_import(self.parser.parse_args(["pools",
"import", "nonexistent.json"]))
- def test_pool_import_invalid_json(self, mock_cli_api_client, tmp_path):
+ def test_pool_import_invalid_json(self, tmp_path):
invalid_pool_import_file_path = tmp_path / "pools_import_invalid.json"
- invalid_pool_import_file_path.write_text("not valid json")
+ with open(invalid_pool_import_file_path, mode="w") as file:
+ file.write("not valid json")
with pytest.raises(SystemExit):
pool_command.pool_import(
self.parser.parse_args(["pools", "import",
str(invalid_pool_import_file_path)])
)
- def test_pool_import_invalid_pools(self, mock_cli_api_client, tmp_path):
+ def test_pool_import_invalid_pools(self, tmp_path):
invalid_pool_import_file_path = tmp_path / "pools_import_invalid.json"
- # Missing ``slots`` makes the entry invalid.
pool_config_input = {"foo": {"description": "foo_test",
"include_deferred": False}}
- invalid_pool_import_file_path.write_text(json.dumps(pool_config_input))
+ with open(invalid_pool_import_file_path, mode="w") as file:
+ json.dump(pool_config_input, file)
with pytest.raises(SystemExit):
pool_command.pool_import(
self.parser.parse_args(["pools", "import",
str(invalid_pool_import_file_path)])
)
- def test_pool_import(self, mock_cli_api_client, tmp_path):
+ def test_pool_import_backwards_compatibility(self, tmp_path):
pool_import_file_path = tmp_path / "pools_import.json"
pool_config_input = {
- "foo": {"description": "foo_test", "slots": 1, "include_deferred":
True},
- # JSON before version 2.7.0 does not contain ``include_deferred``.
- "bar": {"description": "bar_test", "slots": 2},
+ # JSON before version 2.7.0 does not contain `include_deferred`
+ "foo": {"description": "foo_test", "slots": 1},
}
- pool_import_file_path.write_text(json.dumps(pool_config_input))
+ with open(pool_import_file_path, mode="w") as file:
+ json.dump(pool_config_input, file)
pool_command.pool_import(self.parser.parse_args(["pools", "import",
str(pool_import_file_path)]))
- assert mock_cli_api_client.pools.create.call_count == 2
- bodies = {
- call.kwargs["pool"].pool: call.kwargs["pool"]
- for call in mock_cli_api_client.pools.create.call_args_list
- }
- assert bodies["foo"].include_deferred is True
- # Missing ``include_deferred`` defaults to False (backwards
compatibility).
- assert bodies["bar"].include_deferred is False
+ assert self.session.scalar(select(Pool).where(Pool.pool ==
"foo")).include_deferred is False
- def test_pool_export(self, mock_cli_api_client, tmp_path):
+ def test_pool_import_export(self, tmp_path):
+ pool_import_file_path = tmp_path / "pools_import.json"
pool_export_file_path = tmp_path / "pools_export.json"
- mock_cli_api_client.pools.list.return_value.pools = [
- _pool("foo", 1, "foo_test", include_deferred=True),
- _pool("baz", 2, "baz_test", include_deferred=False),
- ]
+ pool_config_input = {
+ "foo": {"description": "foo_test", "slots": 1, "include_deferred":
True},
+ "default_pool": {
+ "description": "Default pool",
+ "slots": 128,
+ "include_deferred": False,
+ },
+ "baz": {"description": "baz_test", "slots": 2, "include_deferred":
False},
+ }
+ with open(pool_import_file_path, mode="w") as file:
+ json.dump(pool_config_input, file)
+
+ # Import json
+ pool_command.pool_import(self.parser.parse_args(["pools", "import",
str(pool_import_file_path)]))
+ # Export json
pool_command.pool_export(self.parser.parse_args(["pools", "export",
str(pool_export_file_path)]))
- exported = json.loads(pool_export_file_path.read_text())
- assert exported == {
- "foo": {"slots": 1, "description": "foo_test", "include_deferred":
True},
- "baz": {"slots": 2, "description": "baz_test", "include_deferred":
False},
- }
+ with open(pool_export_file_path) as file:
+ pool_config_output = json.load(file)
+ assert pool_config_input == pool_config_output, "Input and output
pool files are not same"
- def test_pool_set_with_team_name(self, mock_cli_api_client):
- """``--team-name`` is forwarded to the airflowctl client when
multi_team is enabled."""
- with conf_vars({("core", "multi_team"): "True"}):
- pool_command.pool_set(
- self.parser.parse_args(
- ["pools", "set", "team_pool", "5", "team pool",
"--team-name", "test_team"]
- )
- )
+ def test_pool_set_with_team_name(self):
+ """Test that pool_set with --team-name assigns the pool to the team
when multi_team is enabled."""
+ from airflow.models.team import Team
- body = mock_cli_api_client.pools.create.call_args.kwargs["pool"]
- assert body.team_name == "test_team"
+ from tests_common.test_utils.config import conf_vars
- def test_pool_set_team_name_rejected_when_multi_team_disabled(self,
mock_cli_api_client):
- """``PoolBody`` rejects a team_name (client-side) when multi_team is
disabled."""
- with conf_vars({("core", "multi_team"): "False"}):
- with pytest.raises(ValueError, match="team_name cannot be set when
multi_team mode is disabled"):
+ # Create the team first
+ team = Team(name="test_team")
+ self.session.add(team)
+ self.session.commit()
+
+ try:
+ with conf_vars({("core", "multi_team"): "True"}):
pool_command.pool_set(
self.parser.parse_args(
["pools", "set", "team_pool", "5", "team pool",
"--team-name", "test_team"]
)
)
- mock_cli_api_client.pools.create.assert_not_called()
- def test_pool_set_without_team_name(self, mock_cli_api_client):
- """Without ``--team-name`` the forwarded body has ``team_name`` as
None."""
+ pool = self.session.scalar(select(Pool).where(Pool.pool ==
"team_pool"))
+ assert pool is not None
+ assert pool.team_name == "test_team"
+ assert pool.slots == 5
+ finally:
+ self.session.execute(delete(Pool).where(Pool.pool == "team_pool"))
+ self.session.execute(delete(Team).where(Team.name == "test_team"))
+ self.session.commit()
+
+ def test_pool_set_team_name_rejected_when_multi_team_disabled(self):
+ """Test that pool_set with --team-name raises when multi_team is
disabled."""
+ from airflow.models.team import Team
+
+ from tests_common.test_utils.config import conf_vars
+
+ team = Team(name="test_team")
+ self.session.add(team)
+ self.session.commit()
+
+ try:
+ with conf_vars({("core", "multi_team"): "False"}):
+ with pytest.raises(
+ ValueError, match="team_name cannot be set when multi_team
mode is disabled"
+ ):
+ pool_command.pool_set(
+ self.parser.parse_args(
+ ["pools", "set", "team_pool", "5", "team pool",
"--team-name", "test_team"]
+ )
+ )
+ finally:
+ self.session.execute(delete(Pool).where(Pool.pool == "team_pool"))
+ self.session.execute(delete(Team).where(Team.name == "test_team"))
+ self.session.commit()
+
+ def test_pool_set_without_team_name(self):
+ """Test that pool_set without --team-name leaves team_name as None."""
pool_command.pool_set(self.parser.parse_args(["pools", "set",
"no_team_pool", "3", "no team"]))
- body = mock_cli_api_client.pools.create.call_args.kwargs["pool"]
- assert body.team_name is None
+ pool = self.session.scalar(select(Pool).where(Pool.pool ==
"no_team_pool"))
+ assert pool is not None
+ assert pool.team_name is None
- def test_pool_import_forwards_team_name(self, mock_cli_api_client,
tmp_path):
- """Import forwards each pool's ``team_name`` (or None) to the
airflowctl client."""
- pool_import_file_path = tmp_path / "pools_import_team.json"
- pool_import_file_path.write_text(
- json.dumps(
- {
- "team_pool_a": {
- "slots": 10,
- "description": "team pool",
- "include_deferred": False,
- "team_name": "import_team",
- },
- "global_pool": {"slots": 5, "description": "global pool",
"include_deferred": False},
- }
- )
- )
+ def test_pool_import_export_with_team_name(self, tmp_path):
+ """Test that import/export round-trips the team_name field."""
+ from airflow.models.team import Team
- with conf_vars({("core", "multi_team"): "True"}):
- pool_command.pool_import(self.parser.parse_args(["pools",
"import", str(pool_import_file_path)]))
+ from tests_common.test_utils.config import conf_vars
- bodies = {
- call.kwargs["pool"].pool: call.kwargs["pool"]
- for call in mock_cli_api_client.pools.create.call_args_list
- }
- assert bodies["team_pool_a"].team_name == "import_team"
- assert bodies["global_pool"].team_name is None
+ team = Team(name="import_team")
+ self.session.add(team)
+ self.session.commit()
- def test_pool_export_includes_team_name(self, mock_cli_api_client,
tmp_path):
- """Export writes ``team_name`` only for pools that have one."""
+ pool_import_file_path = tmp_path / "pools_import_team.json"
pool_export_file_path = tmp_path / "pools_export_team.json"
- mock_cli_api_client.pools.list.return_value.pools = [
- _pool("team_pool_a", 10, "team pool", team_name="import_team"),
- _pool("global_pool", 5, "global pool"),
- ]
+ pool_config_input = {
+ "team_pool_a": {
+ "slots": 10,
+ "description": "team pool",
+ "include_deferred": False,
+ "team_name": "import_team",
+ },
+ "global_pool": {
+ "slots": 5,
+ "description": "global pool",
+ "include_deferred": False,
+ },
+ }
- pool_command.pool_export(self.parser.parse_args(["pools", "export",
str(pool_export_file_path)]))
+ with open(pool_import_file_path, mode="w") as file:
+ json.dump(pool_config_input, file)
- exported = json.loads(pool_export_file_path.read_text())
- assert exported["team_pool_a"]["team_name"] == "import_team"
- assert "team_name" not in exported["global_pool"]
+ try:
+ with conf_vars({("core", "multi_team"): "True"}):
+ pool_command.pool_import(
+ self.parser.parse_args(["pools", "import",
str(pool_import_file_path)])
+ )
- def test_pool_list_shows_team_name(self, mock_cli_api_client,
stdout_capture):
- """Pool list output includes the team_name column."""
- mock_cli_api_client.pools.list.return_value.pools = [
- _pool("list_pool", 5, "desc", team_name="list_team")
- ]
+ # Verify team assignment
+ pool = self.session.scalar(select(Pool).where(Pool.pool ==
"team_pool_a"))
+ assert pool is not None
+ assert pool.team_name == "import_team"
- with stdout_capture as stdout:
- pool_command.pool_list(self.parser.parse_args(["pools", "list"]))
+ global_pool = self.session.scalar(select(Pool).where(Pool.pool ==
"global_pool"))
+ assert global_pool is not None
+ assert global_pool.team_name is None
+
+ # Export and verify
+ pool_command.pool_export(self.parser.parse_args(["pools",
"export", str(pool_export_file_path)]))
+
+ with open(pool_export_file_path) as file:
+ pool_config_output = json.load(file)
+
+ assert pool_config_output["team_pool_a"]["team_name"] ==
"import_team"
+ assert "team_name" not in pool_config_output["global_pool"]
+ finally:
+
self.session.execute(delete(Pool).where(Pool.pool.in_(["team_pool_a",
"global_pool"])))
+ self.session.execute(delete(Team).where(Team.name ==
"import_team"))
+ self.session.commit()
+
+ def test_pool_list_shows_team_name(self, stdout_capture):
+ """Test that pool list output includes the team_name column."""
+ from airflow.models.team import Team
+
+ from tests_common.test_utils.config import conf_vars
+
+ team = Team(name="list_team")
+ self.session.add(team)
+ self.session.commit()
+
+ try:
+ with conf_vars({("core", "multi_team"): "True"}):
+ pool_command.pool_set(
+ self.parser.parse_args(
+ ["pools", "set", "list_pool", "5", "desc",
"--team-name", "list_team"]
+ )
+ )
+
+ with stdout_capture as stdout:
+ pool_command.pool_list(self.parser.parse_args(["pools",
"list"]))
- assert "list_team" in stdout.getvalue()
+ output = stdout.getvalue()
+ assert "list_team" in output
+ finally:
+ self.session.execute(delete(Pool).where(Pool.pool == "list_pool"))
+ self.session.execute(delete(Team).where(Team.name == "list_team"))
+ self.session.commit()
diff --git a/airflow-core/tests/unit/cli/conftest.py
b/airflow-core/tests/unit/cli/conftest.py
index d9d2ae341eb..7676a103b53 100644
--- a/airflow-core/tests/unit/cli/conftest.py
+++ b/airflow-core/tests/unit/cli/conftest.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import sys
-from unittest import mock
import pytest
@@ -69,25 +68,6 @@ def parser():
# log messages
[email protected]
-def mock_cli_api_client():
- """Mock the CLI airflowctl client and neutralize ``action_cli``'s DB touch
points.
-
- CLI commands that go through the airflowctl client only need the mocked
client; the
- ``@action_cli`` audit logging and log-template sync would otherwise open a
database
- session. Patching them lets these command tests run without a database or
API server.
- """
- client = mock.MagicMock()
- with (
- mock.patch("airflow.cli.api_client.get_cli_api_client",
return_value=client),
- mock.patch("airflow.utils.cli_action_loggers.on_pre_execution"),
- mock.patch("airflow.utils.cli_action_loggers.on_post_execution"),
- mock.patch("airflow.utils.db.synchronize_log_template"),
- mock.patch("airflow.utils.db.check_and_run_migrations"),
- ):
- yield client
-
-
@pytest.fixture
def stdout_capture(request):
"""Fixture that captures stdout only."""
diff --git a/airflow-core/tests/unit/cli/test_api_client.py
b/airflow-core/tests/unit/cli/test_api_client.py
deleted file mode 100644
index 3ef813cebda..00000000000
--- a/airflow-core/tests/unit/cli/test_api_client.py
+++ /dev/null
@@ -1,140 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from unittest import mock
-
-import pytest
-
-from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
-from airflow.api_fastapi.auth.managers.simple.simple_auth_manager import
SimpleAuthManager
-from airflow.cli import api_client as cli_api_client
-
-from tests_common.test_utils.config import conf_vars
-
-
[email protected](autouse=True)
-def _reset_singleton():
- """Reset the process-wide client singleton around each test."""
- cli_api_client._api_client = None
- yield
- cli_api_client._api_client = None
-
-
-class TestResolveBaseUrl:
- @conf_vars({("api", "base_url"): "https://airflow.example.com:9999"})
- def test_explicit_base_url(self):
- assert cli_api_client._resolve_base_url() ==
"https://airflow.example.com:9999"
-
- @conf_vars({("api", "base_url"): "", ("api", "host"): "myhost", ("api",
"port"): "1234"})
- def test_host_port_fallback(self):
- assert cli_api_client._resolve_base_url() == "http://myhost:1234"
-
-
-class TestMintCliToken:
- def test_uses_env_token(self, monkeypatch):
- monkeypatch.setenv("AIRFLOW_CLI_TOKEN", "tok-123")
- with mock.patch("airflow.api_fastapi.app.get_auth_manager") as
get_auth_manager:
- assert cli_api_client._mint_cli_token() == "tok-123"
- # The auth manager is never consulted when a token is supplied
explicitly.
- get_auth_manager.assert_not_called()
-
- def test_mints_via_auth_manager(self, monkeypatch):
- monkeypatch.delenv("AIRFLOW_CLI_TOKEN", raising=False)
- auth_manager = mock.MagicMock()
- auth_manager.get_cli_user.return_value = "cli-user"
- auth_manager.generate_jwt.return_value = "signed-jwt"
- with mock.patch("airflow.api_fastapi.app.get_auth_manager",
return_value=auth_manager):
- assert cli_api_client._mint_cli_token() == "signed-jwt"
-
- auth_manager.generate_jwt.assert_called_once()
- assert auth_manager.generate_jwt.call_args.args[0] == "cli-user"
- # Token must be short-lived.
- assert
auth_manager.generate_jwt.call_args.kwargs["expiration_time_in_seconds"] > 0
-
- def test_initializes_auth_manager_when_not_initialized(self, monkeypatch):
- # In the CLI the auth manager singleton is usually not initialized
yet, so
- # ``get_auth_manager`` raises and we must initialize it on demand.
- monkeypatch.delenv("AIRFLOW_CLI_TOKEN", raising=False)
- auth_manager = mock.MagicMock()
- auth_manager.get_cli_user.return_value = "cli-user"
- auth_manager.generate_jwt.return_value = "signed-jwt"
- with (
- mock.patch(
- "airflow.api_fastapi.app.get_auth_manager",
- side_effect=RuntimeError("Auth Manager has not been
initialized yet."),
- ),
- mock.patch(
- "airflow.api_fastapi.app.init_auth_manager",
return_value=auth_manager
- ) as init_auth_manager,
- ):
- assert cli_api_client._mint_cli_token() == "signed-jwt"
-
- init_auth_manager.assert_called_once()
- auth_manager.generate_jwt.assert_called_once()
-
-
-class TestGetCliApiClient:
- def test_builds_singleton(self):
- with (
- mock.patch.object(cli_api_client, "_resolve_base_url",
return_value="http://h:8080"),
- mock.patch.object(cli_api_client, "_mint_cli_token",
return_value="tok"),
- mock.patch.object(cli_api_client, "Client") as client_cls,
- ):
- first = cli_api_client.get_cli_api_client()
- second = cli_api_client.get_cli_api_client()
-
- assert first is second
- client_cls.assert_called_once()
- kwargs = client_cls.call_args.kwargs
- assert kwargs["base_url"] == "http://h:8080"
- assert kwargs["token"] == "tok"
- assert kwargs["kind"] == cli_api_client.ClientKind.CLI
-
-
-class TestProvideApiClient:
- def test_injects_when_missing(self):
- with mock.patch.object(cli_api_client, "get_cli_api_client",
return_value="CLIENT"):
-
- @cli_api_client.provide_api_client
- def command(args, api_client=None):
- return api_client
-
- assert command("args") == "CLIENT"
-
- def test_uses_explicit_client(self):
- with mock.patch.object(cli_api_client, "get_cli_api_client") as
get_client:
-
- @cli_api_client.provide_api_client
- def command(args, api_client=None):
- return api_client
-
- assert command("args", api_client="EXPLICIT") == "EXPLICIT"
- get_client.assert_not_called()
-
-
-class TestGetCliUser:
- def test_base_default_raises(self):
- # The generic auth manager cannot know which user is authorized for
the CLI.
- with pytest.raises(NotImplementedError, match="AIRFLOW_CLI_TOKEN"):
- BaseAuthManager.get_cli_user(mock.Mock())
-
- def test_simple_auth_manager_returns_admin(self):
- user = SimpleAuthManager.get_cli_user(mock.Mock())
- assert user.get_id() == "cli"
- assert user.role == "ADMIN"
diff --git
a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
index 8ea13d44a7e..b6a79d032ae 100644
--- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -301,30 +301,6 @@ class FabAuthManager(BaseAuthManager[User]):
def serialize_user(self, user: User) -> dict[str, Any]:
return {"sub": str(user.id)}
- def get_cli_user(self) -> User:
- """
- Return an existing ``Admin`` user for the local CLI to mint a token
for.
-
- The Airflow CLI mints a short-lived, in-memory JWT for this user so it
can talk to
- the API server. FAB tokens reference a real database user, so we reuse
an existing
- ``Admin`` user rather than fabricating one. If none exists, the
operator must
- create one or provide a token via the ``AIRFLOW_CLI_TOKEN``
environment variable.
- """
- from airflow.utils.session import create_session
-
- with create_session() as session:
- user =
session.scalars(select(User).join(User.roles).where(Role.name ==
"Admin").limit(1)).first()
- if user is None:
- raise AirflowConfigException(
- "No user with the 'Admin' role exists in the FAB database.
Create one "
- "(e.g. `airflow fab create-user --role Admin ...`) or set
the "
- "AIRFLOW_CLI_TOKEN environment variable with a valid API
token."
- )
- # Detach so attributes stay accessible after the session closes
(and is not
- # expired on commit) while the CLI serializes the user to mint the
token.
- session.expunge(user)
- return user
-
def is_logged_in(self) -> bool:
"""Return whether the user is logged in."""
user = self.get_user()
diff --git a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
index 5b6d27cecaf..62f98567012 100644
--- a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
+++ b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
@@ -1332,29 +1332,3 @@ class TestFabAuthManagerSessionCleanupErrorHandling:
mock_session.remove.assert_called()
mock_log.warning.assert_called()
assert response is not None
-
-
-class TestFabGetCliUser:
- """``get_cli_user`` reuses an existing ``Admin`` user for the local CLI
token."""
-
- @mock.patch("airflow.utils.session.create_session")
- def test_returns_admin_user(self, mock_create_session, auth_manager):
- admin_user = MagicMock()
- session = MagicMock()
- session.scalars.return_value.first.return_value = admin_user
- mock_create_session.return_value.__enter__.return_value = session
-
- result = auth_manager.get_cli_user()
-
- assert result is admin_user
- # The user is detached so its attributes survive the session closing.
- session.expunge.assert_called_once_with(admin_user)
-
- @mock.patch("airflow.utils.session.create_session")
- def test_raises_when_no_admin_user(self, mock_create_session,
auth_manager):
- session = MagicMock()
- session.scalars.return_value.first.return_value = None
- mock_create_session.return_value.__enter__.return_value = session
-
- with pytest.raises(AirflowConfigException, match="Admin"):
- auth_manager.get_cli_user()
diff --git
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
index bc9ebc305c3..a8cd683ef46 100644
---
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
+++
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
@@ -34,7 +34,7 @@ from urllib3.util import Retry
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
-from airflow.exceptions import AirflowConfigException,
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowProviderDeprecationWarning
try:
from airflow.api_fastapi.auth.managers.base_auth_manager import
ExtendedResourceMethod
@@ -141,34 +141,6 @@ class
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
"refresh_token": user.refresh_token,
}
- def get_cli_user(self) -> KeycloakAuthManagerUser:
- """
- Return a service-account user for the local CLI to mint a token for.
-
- Keycloak tokens are issued by the external Keycloak server, so they
cannot be
- forged locally. The Keycloak client is already configured for Airflow
to talk to
- Keycloak, so we reuse it to obtain a service-account token through the
- ``client_credentials`` flow. The service account's effective
permissions are
- governed by the Keycloak deployment. If the client credentials are not
usable, the
- operator must provide a token via the ``AIRFLOW_CLI_TOKEN``
environment variable.
- """
- try:
- tokens =
self.get_keycloak_client().token(grant_type="client_credentials")
- except Exception as e:
- raise AirflowConfigException(
- "Could not obtain a Keycloak service-account token for the CLI
via the "
- "client_credentials flow. Set the AIRFLOW_CLI_TOKEN
environment variable "
- f"with a valid API token instead. Original error: {e}"
- ) from e
- return KeycloakAuthManagerUser(
- user_id="airflow-cli",
- name="airflow-cli",
- access_token=tokens["access_token"],
- # No refresh token is issued for the client_credentials flow (RFC
6749 ยง4.4.3),
- # which marks this as a service account in
refresh_user/refresh_tokens.
- refresh_token=tokens.get("refresh_token"),
- )
-
def get_url_login(self, **kwargs) -> str:
base_url = conf.get("api", "base_url", fallback="/")
return urljoin(base_url, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
diff --git
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
index e4b6a5c294d..9c8ed9dd5b6 100644
---
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
+++
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
@@ -45,7 +45,7 @@ if AIRFLOW_V_3_2_PLUS:
else:
TeamDetails = None # type: ignore[assignment,misc]
from airflow.api_fastapi.common.types import MenuItem
-from airflow.exceptions import AirflowConfigException,
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowProviderDeprecationWarning
try:
from airflow.providers.common.compat.sdk import AirflowException
@@ -121,25 +121,6 @@ def _clear_filter_cache():
class TestKeycloakAuthManager:
- @patch.object(KeycloakAuthManager, "get_keycloak_client")
- def test_get_cli_user(self, mock_get_keycloak_client, auth_manager):
- # client_credentials (service account) flow returns an access token
and no refresh token.
- mock_get_keycloak_client.return_value.token.return_value =
{"access_token": "svc-token"}
-
- user = auth_manager.get_cli_user()
-
- assert user.get_id() == "airflow-cli"
- assert user.access_token == "svc-token"
- assert user.refresh_token is None
-
mock_get_keycloak_client.return_value.token.assert_called_once_with(grant_type="client_credentials")
-
- @patch.object(KeycloakAuthManager, "get_keycloak_client")
- def test_get_cli_user_raises_when_credentials_unusable(self,
mock_get_keycloak_client, auth_manager):
- mock_get_keycloak_client.return_value.token.side_effect =
Exception("boom")
-
- with pytest.raises(AirflowConfigException, match="AIRFLOW_CLI_TOKEN"):
- auth_manager.get_cli_user()
-
def test_deserialize_user(self, auth_manager):
result = auth_manager.deserialize_user(
{
diff --git a/scripts/ci/prek/known_airflow_exceptions.txt
b/scripts/ci/prek/known_airflow_exceptions.txt
index 262f5d6ce54..fec615e061f 100644
--- a/scripts/ci/prek/known_airflow_exceptions.txt
+++ b/scripts/ci/prek/known_airflow_exceptions.txt
@@ -1,7 +1,7 @@
airflow-core/src/airflow/api/common/delete_dag.py::1
airflow-core/src/airflow/api_fastapi/core_api/app.py::1
airflow-core/src/airflow/cli/cli_parser.py::1
-airflow-core/src/airflow/cli/commands/dag_command.py::1
+airflow-core/src/airflow/cli/commands/dag_command.py::3
airflow-core/src/airflow/cli/commands/db_command.py::1
airflow-core/src/airflow/config_templates/airflow_local_settings.py::1
airflow-core/src/airflow/dag_processing/dagbag.py::1
diff --git a/uv.lock b/uv.lock
index 865064ec1fa..709e8ee883e 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1917,7 +1917,6 @@ dependencies = [
{ name = "a2wsgi" },
{ name = "aiosqlite" },
{ name = "alembic" },
- { name = "apache-airflow-ctl" },
{ name = "apache-airflow-providers-common-compat" },
{ name = "apache-airflow-providers-common-io" },
{ name = "apache-airflow-providers-common-sql" },
@@ -2046,7 +2045,6 @@ requires-dist = [
{ name = "aiosqlite", specifier = ">=0.20.0,<0.22.0" },
{ name = "alembic", specifier = ">=1.13.1,<2.0" },
{ name = "apache-airflow-core", extras = ["graphviz", "gunicorn",
"kerberos", "otel", "statsd"], marker = "extra == 'all'", editable =
"airflow-core" },
- { name = "apache-airflow-ctl", editable = "airflow-ctl" },
{ name = "apache-airflow-providers-common-compat", editable =
"providers/common/compat" },
{ name = "apache-airflow-providers-common-io", editable =
"providers/common/io" },
{ name = "apache-airflow-providers-common-sql", editable =
"providers/common/sql" },