This is an automated email from the ASF dual-hosted git repository.
o-nikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f90ebd0835b Add --team-name support to pool CLI commands (#68110)
f90ebd0835b is described below
commit f90ebd0835bead9b61db675009dfe8d32b575740
Author: Niko Oliveira <[email protected]>
AuthorDate: Mon Jun 8 09:47:38 2026 -0700
Add --team-name support to pool CLI commands (#68110)
Add the ability to assign pools to teams via the CLI:
- airflow pools set --team-name <team>
- Pool import/export JSON includes team_name field
- Pool list/get output shows team_name column
Validates that multi_team mode is enabled when team_name is provided.
---
.../src/airflow/api/client/local_client.py | 14 +-
airflow-core/src/airflow/cli/cli_config.py | 4 +-
.../src/airflow/cli/commands/pool_command.py | 17 ++-
airflow-core/src/airflow/models/pool.py | 17 ++-
.../tests/unit/cli/commands/test_pool_command.py | 152 ++++++++++++++++++++-
5 files changed, 194 insertions(+), 10 deletions(-)
diff --git a/airflow-core/src/airflow/api/client/local_client.py
b/airflow-core/src/airflow/api/client/local_client.py
index 2eeb5d434b5..057d6d99c7c 100644
--- a/airflow-core/src/airflow/api/client/local_client.py
+++ b/airflow-core/src/airflow/api/client/local_client.py
@@ -78,12 +78,12 @@ class Client:
pool = Pool.get_pool(pool_name=name)
if not pool:
raise PoolNotFound(f"Pool {name} not found")
- return pool.pool, pool.slots, pool.description, pool.include_deferred
+ return pool.pool, pool.slots, pool.description, pool.include_deferred,
pool.team_name
def get_pools(self):
- return [(p.pool, p.slots, p.description, p.include_deferred) for p in
Pool.get_pools()]
+ return [(p.pool, p.slots, p.description, p.include_deferred,
p.team_name) for p in Pool.get_pools()]
- def create_pool(self, name, slots, description, include_deferred):
+ def create_pool(self, name, slots, description, include_deferred,
team_name=None):
if not (name and name.strip()):
raise AirflowBadRequest("Pool name shouldn't be empty")
pool_name_length = Pool.pool.property.columns[0].type.length
@@ -94,9 +94,13 @@ class Client:
except ValueError:
raise AirflowBadRequest(f"Invalid value for `slots`: {slots}")
pool = Pool.create_or_update_pool(
- name=name, slots=slots, description=description,
include_deferred=include_deferred
+ name=name,
+ slots=slots,
+ description=description,
+ include_deferred=include_deferred,
+ team_name=team_name,
)
- return pool.pool, pool.slots, pool.description
+ return pool.pool, pool.slots, pool.description, pool.team_name
def delete_pool(self, name):
pool = Pool.delete_pool(name=name)
diff --git a/airflow-core/src/airflow/cli/cli_config.py
b/airflow-core/src/airflow/cli/cli_config.py
index e9b026dd296..99d7faf30bc 100644
--- a/airflow-core/src/airflow/cli/cli_config.py
+++ b/airflow-core/src/airflow/cli/cli_config.py
@@ -596,6 +596,7 @@ ARG_POOL_DESCRIPTION = Arg(("description",), help="Pool
description")
ARG_POOL_INCLUDE_DEFERRED = Arg(
("--include-deferred",), help="Include deferred tasks in calculations for
Pool", action="store_true"
)
+ARG_POOL_TEAM_NAME = Arg(("--team-name",), help="Team name to assign the pool
to (requires multi_team mode)")
ARG_POOL_IMPORT = Arg(
("file",),
metavar="FILEPATH",
@@ -605,7 +606,7 @@ ARG_POOL_IMPORT = Arg(
"""
{
"pool_1": {"slots": 5, "description": "", "include_deferred":
true},
- "pool_2": {"slots": 10, "description": "test",
"include_deferred": false}
+ "pool_2": {"slots": 10, "description": "test",
"include_deferred": false, "team_name": "my_team"}
}"""
),
" " * 4,
@@ -1575,6 +1576,7 @@ POOLS_COMMANDS = (
ARG_POOL_SLOTS,
ARG_POOL_DESCRIPTION,
ARG_POOL_INCLUDE_DEFERRED,
+ ARG_POOL_TEAM_NAME,
ARG_OUTPUT,
ARG_VERBOSE,
),
diff --git a/airflow-core/src/airflow/cli/commands/pool_command.py
b/airflow-core/src/airflow/cli/commands/pool_command.py
index 4856a6a4bc3..c2e624d23a6 100644
--- a/airflow-core/src/airflow/cli/commands/pool_command.py
+++ b/airflow-core/src/airflow/cli/commands/pool_command.py
@@ -40,6 +40,7 @@ def _show_pools(pools, output):
"slots": x[1],
"description": x[2],
"include_deferred": x[3],
+ "team_name": x[4],
},
)
@@ -72,7 +73,11 @@ def pool_set(args):
"""Create new pool with a given name and slots."""
api_client = get_current_api_client()
api_client.create_pool(
- name=args.pool, slots=args.slots, description=args.description,
include_deferred=args.include_deferred
+ name=args.pool,
+ slots=args.slots,
+ description=args.description,
+ include_deferred=args.include_deferred,
+ team_name=args.team_name,
)
print(f"Pool {args.pool} created")
@@ -130,6 +135,7 @@ def pool_import_helper(filepath):
slots=v["slots"],
description=v["description"],
include_deferred=v.get("include_deferred", False),
+ team_name=v.get("team_name"),
)
)
else:
@@ -143,7 +149,14 @@ def pool_export_helper(filepath):
pool_dict = {}
pools = api_client.get_pools()
for pool in pools:
- pool_dict[pool[0]] = {"slots": pool[1], "description": pool[2],
"include_deferred": pool[3]}
+ entry = {
+ "slots": pool[1],
+ "description": pool[2],
+ "include_deferred": pool[3],
+ }
+ if pool[4] is not None:
+ entry["team_name"] = pool[4]
+ pool_dict[pool[0]] = entry
with open(filepath, "w") as poolfile:
poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4))
return pools
diff --git a/airflow-core/src/airflow/models/pool.py
b/airflow-core/src/airflow/models/pool.py
index 08eff959ea2..d6a4915ea2b 100644
--- a/airflow-core/src/airflow/models/pool.py
+++ b/airflow-core/src/airflow/models/pool.py
@@ -129,20 +129,35 @@ class Pool(Base):
description: str,
include_deferred: bool,
*,
+ team_name: str | None = None,
session: Session = NEW_SESSION,
) -> Pool:
"""Create a pool with given parameters or update it if it already
exists."""
+ from airflow.configuration import conf
+
if not name:
raise ValueError("Pool name must not be empty")
+ if team_name and not conf.getboolean("core", "multi_team"):
+ raise ValueError(
+ "team_name cannot be set when multi_team mode is disabled.
Please contact your administrator."
+ )
+
pool = session.scalar(select(Pool).filter_by(pool=name))
if pool is None:
- pool = Pool(pool=name, slots=slots, description=description,
include_deferred=include_deferred)
+ pool = Pool(
+ pool=name,
+ slots=slots,
+ description=description,
+ include_deferred=include_deferred,
+ team_name=team_name,
+ )
session.add(pool)
else:
pool.slots = slots
pool.description = description
pool.include_deferred = include_deferred
+ pool.team_name = team_name
return pool
diff --git a/airflow-core/tests/unit/cli/commands/test_pool_command.py
b/airflow-core/tests/unit/cli/commands/test_pool_command.py
index 8fea33d7a7f..6b5dc2a5417 100644
--- a/airflow-core/tests/unit/cli/commands/test_pool_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_pool_command.py
@@ -131,7 +131,11 @@ class TestCliPools:
pool_export_file_path = tmp_path / "pools_export.json"
pool_config_input = {
"foo": {"description": "foo_test", "slots": 1, "include_deferred":
True},
- "default_pool": {"description": "Default pool", "slots": 128,
"include_deferred": False},
+ "default_pool": {
+ "description": "Default pool",
+ "slots": 128,
+ "include_deferred": False,
+ },
"baz": {"description": "baz_test", "slots": 2, "include_deferred":
False},
}
with open(pool_import_file_path, mode="w") as file:
@@ -146,3 +150,149 @@ class TestCliPools:
with open(pool_export_file_path) as file:
pool_config_output = json.load(file)
assert pool_config_input == pool_config_output, "Input and output
pool files are not same"
+
+ def test_pool_set_with_team_name(self):
+ """Test that pool_set with --team-name assigns the pool to the team
when multi_team is enabled."""
+ from airflow.models.team import Team
+
+ from tests_common.test_utils.config import conf_vars
+
+ # Create the team first
+ team = Team(name="test_team")
+ self.session.add(team)
+ self.session.commit()
+
+ try:
+ with conf_vars({("core", "multi_team"): "True"}):
+ pool_command.pool_set(
+ self.parser.parse_args(
+ ["pools", "set", "team_pool", "5", "team pool",
"--team-name", "test_team"]
+ )
+ )
+
+ pool = self.session.scalar(select(Pool).where(Pool.pool ==
"team_pool"))
+ assert pool is not None
+ assert pool.team_name == "test_team"
+ assert pool.slots == 5
+ finally:
+ self.session.execute(delete(Pool).where(Pool.pool == "team_pool"))
+ self.session.execute(delete(Team).where(Team.name == "test_team"))
+ self.session.commit()
+
+ def test_pool_set_team_name_rejected_when_multi_team_disabled(self):
+ """Test that pool_set with --team-name raises when multi_team is
disabled."""
+ from airflow.models.team import Team
+
+ from tests_common.test_utils.config import conf_vars
+
+ team = Team(name="test_team")
+ self.session.add(team)
+ self.session.commit()
+
+ try:
+ with conf_vars({("core", "multi_team"): "False"}):
+ with pytest.raises(
+ ValueError, match="team_name cannot be set when multi_team
mode is disabled"
+ ):
+ pool_command.pool_set(
+ self.parser.parse_args(
+ ["pools", "set", "team_pool", "5", "team pool",
"--team-name", "test_team"]
+ )
+ )
+ finally:
+ self.session.execute(delete(Pool).where(Pool.pool == "team_pool"))
+ self.session.execute(delete(Team).where(Team.name == "test_team"))
+ self.session.commit()
+
+ def test_pool_set_without_team_name(self):
+ """Test that pool_set without --team-name leaves team_name as None."""
+ pool_command.pool_set(self.parser.parse_args(["pools", "set",
"no_team_pool", "3", "no team"]))
+
+ pool = self.session.scalar(select(Pool).where(Pool.pool ==
"no_team_pool"))
+ assert pool is not None
+ assert pool.team_name is None
+
+ def test_pool_import_export_with_team_name(self, tmp_path):
+ """Test that import/export round-trips the team_name field."""
+ from airflow.models.team import Team
+
+ from tests_common.test_utils.config import conf_vars
+
+ team = Team(name="import_team")
+ self.session.add(team)
+ self.session.commit()
+
+ pool_import_file_path = tmp_path / "pools_import_team.json"
+ pool_export_file_path = tmp_path / "pools_export_team.json"
+ pool_config_input = {
+ "team_pool_a": {
+ "slots": 10,
+ "description": "team pool",
+ "include_deferred": False,
+ "team_name": "import_team",
+ },
+ "global_pool": {
+ "slots": 5,
+ "description": "global pool",
+ "include_deferred": False,
+ },
+ }
+
+ with open(pool_import_file_path, mode="w") as file:
+ json.dump(pool_config_input, file)
+
+ try:
+ with conf_vars({("core", "multi_team"): "True"}):
+ pool_command.pool_import(
+ self.parser.parse_args(["pools", "import",
str(pool_import_file_path)])
+ )
+
+ # Verify team assignment
+ pool = self.session.scalar(select(Pool).where(Pool.pool ==
"team_pool_a"))
+ assert pool is not None
+ assert pool.team_name == "import_team"
+
+ global_pool = self.session.scalar(select(Pool).where(Pool.pool ==
"global_pool"))
+ assert global_pool is not None
+ assert global_pool.team_name is None
+
+ # Export and verify
+ pool_command.pool_export(self.parser.parse_args(["pools",
"export", str(pool_export_file_path)]))
+
+ with open(pool_export_file_path) as file:
+ pool_config_output = json.load(file)
+
+ assert pool_config_output["team_pool_a"]["team_name"] ==
"import_team"
+ assert "team_name" not in pool_config_output["global_pool"]
+ finally:
+
self.session.execute(delete(Pool).where(Pool.pool.in_(["team_pool_a",
"global_pool"])))
+ self.session.execute(delete(Team).where(Team.name ==
"import_team"))
+ self.session.commit()
+
+ def test_pool_list_shows_team_name(self, stdout_capture):
+ """Test that pool list output includes the team_name column."""
+ from airflow.models.team import Team
+
+ from tests_common.test_utils.config import conf_vars
+
+ team = Team(name="list_team")
+ self.session.add(team)
+ self.session.commit()
+
+ try:
+ with conf_vars({("core", "multi_team"): "True"}):
+ pool_command.pool_set(
+ self.parser.parse_args(
+ ["pools", "set", "list_pool", "5", "desc",
"--team-name", "list_team"]
+ )
+ )
+
+ with stdout_capture as stdout:
+ pool_command.pool_list(self.parser.parse_args(["pools",
"list"]))
+
+ output = stdout.getvalue()
+ assert "list_team" in output
+ finally:
+ self.session.execute(delete(Pool).where(Pool.pool == "list_pool"))
+ self.session.execute(delete(Team).where(Team.name == "list_team"))
+ self.session.commit()