This is an automated email from the ASF dual-hosted git repository.

o-nikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f90ebd0835b Add --team-name support to pool CLI commands (#68110)
f90ebd0835b is described below

commit f90ebd0835bead9b61db675009dfe8d32b575740
Author: Niko Oliveira <[email protected]>
AuthorDate: Mon Jun 8 09:47:38 2026 -0700

    Add --team-name support to pool CLI commands (#68110)
    
    Add the ability to assign pools to teams via the CLI:
    - airflow pools set --team-name <team>
    - Pool import/export JSON includes team_name field
    - Pool list/get output shows team_name column
    
    Validates that multi_team mode is enabled when team_name is provided.
---
 .../src/airflow/api/client/local_client.py         |  14 +-
 airflow-core/src/airflow/cli/cli_config.py         |   4 +-
 .../src/airflow/cli/commands/pool_command.py       |  17 ++-
 airflow-core/src/airflow/models/pool.py            |  17 ++-
 .../tests/unit/cli/commands/test_pool_command.py   | 152 ++++++++++++++++++++-
 5 files changed, 194 insertions(+), 10 deletions(-)

diff --git a/airflow-core/src/airflow/api/client/local_client.py 
b/airflow-core/src/airflow/api/client/local_client.py
index 2eeb5d434b5..057d6d99c7c 100644
--- a/airflow-core/src/airflow/api/client/local_client.py
+++ b/airflow-core/src/airflow/api/client/local_client.py
@@ -78,12 +78,12 @@ class Client:
         pool = Pool.get_pool(pool_name=name)
         if not pool:
             raise PoolNotFound(f"Pool {name} not found")
-        return pool.pool, pool.slots, pool.description, pool.include_deferred
+        return pool.pool, pool.slots, pool.description, pool.include_deferred, 
pool.team_name
 
     def get_pools(self):
-        return [(p.pool, p.slots, p.description, p.include_deferred) for p in 
Pool.get_pools()]
+        return [(p.pool, p.slots, p.description, p.include_deferred, 
p.team_name) for p in Pool.get_pools()]
 
-    def create_pool(self, name, slots, description, include_deferred):
+    def create_pool(self, name, slots, description, include_deferred, 
team_name=None):
         if not (name and name.strip()):
             raise AirflowBadRequest("Pool name shouldn't be empty")
         pool_name_length = Pool.pool.property.columns[0].type.length
@@ -94,9 +94,13 @@ class Client:
         except ValueError:
             raise AirflowBadRequest(f"Invalid value for `slots`: {slots}")
         pool = Pool.create_or_update_pool(
-            name=name, slots=slots, description=description, 
include_deferred=include_deferred
+            name=name,
+            slots=slots,
+            description=description,
+            include_deferred=include_deferred,
+            team_name=team_name,
         )
-        return pool.pool, pool.slots, pool.description
+        return pool.pool, pool.slots, pool.description, pool.team_name
 
     def delete_pool(self, name):
         pool = Pool.delete_pool(name=name)
diff --git a/airflow-core/src/airflow/cli/cli_config.py 
b/airflow-core/src/airflow/cli/cli_config.py
index e9b026dd296..99d7faf30bc 100644
--- a/airflow-core/src/airflow/cli/cli_config.py
+++ b/airflow-core/src/airflow/cli/cli_config.py
@@ -596,6 +596,7 @@ ARG_POOL_DESCRIPTION = Arg(("description",), help="Pool 
description")
 ARG_POOL_INCLUDE_DEFERRED = Arg(
     ("--include-deferred",), help="Include deferred tasks in calculations for 
Pool", action="store_true"
 )
+ARG_POOL_TEAM_NAME = Arg(("--team-name",), help="Team name to assign the pool 
to (requires multi_team mode)")
 ARG_POOL_IMPORT = Arg(
     ("file",),
     metavar="FILEPATH",
@@ -605,7 +606,7 @@ ARG_POOL_IMPORT = Arg(
             """
             {
                 "pool_1": {"slots": 5, "description": "", "include_deferred": 
true},
-                "pool_2": {"slots": 10, "description": "test", 
"include_deferred": false}
+                "pool_2": {"slots": 10, "description": "test", 
"include_deferred": false, "team_name": "my_team"}
             }"""
         ),
         " " * 4,
@@ -1575,6 +1576,7 @@ POOLS_COMMANDS = (
             ARG_POOL_SLOTS,
             ARG_POOL_DESCRIPTION,
             ARG_POOL_INCLUDE_DEFERRED,
+            ARG_POOL_TEAM_NAME,
             ARG_OUTPUT,
             ARG_VERBOSE,
         ),
diff --git a/airflow-core/src/airflow/cli/commands/pool_command.py 
b/airflow-core/src/airflow/cli/commands/pool_command.py
index 4856a6a4bc3..c2e624d23a6 100644
--- a/airflow-core/src/airflow/cli/commands/pool_command.py
+++ b/airflow-core/src/airflow/cli/commands/pool_command.py
@@ -40,6 +40,7 @@ def _show_pools(pools, output):
             "slots": x[1],
             "description": x[2],
             "include_deferred": x[3],
+            "team_name": x[4],
         },
     )
 
@@ -72,7 +73,11 @@ def pool_set(args):
     """Create new pool with a given name and slots."""
     api_client = get_current_api_client()
     api_client.create_pool(
-        name=args.pool, slots=args.slots, description=args.description, 
include_deferred=args.include_deferred
+        name=args.pool,
+        slots=args.slots,
+        description=args.description,
+        include_deferred=args.include_deferred,
+        team_name=args.team_name,
     )
     print(f"Pool {args.pool} created")
 
@@ -130,6 +135,7 @@ def pool_import_helper(filepath):
                     slots=v["slots"],
                     description=v["description"],
                     include_deferred=v.get("include_deferred", False),
+                    team_name=v.get("team_name"),
                 )
             )
         else:
@@ -143,7 +149,14 @@ def pool_export_helper(filepath):
     pool_dict = {}
     pools = api_client.get_pools()
     for pool in pools:
-        pool_dict[pool[0]] = {"slots": pool[1], "description": pool[2], 
"include_deferred": pool[3]}
+        entry = {
+            "slots": pool[1],
+            "description": pool[2],
+            "include_deferred": pool[3],
+        }
+        if pool[4] is not None:
+            entry["team_name"] = pool[4]
+        pool_dict[pool[0]] = entry
     with open(filepath, "w") as poolfile:
         poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4))
     return pools
diff --git a/airflow-core/src/airflow/models/pool.py 
b/airflow-core/src/airflow/models/pool.py
index 08eff959ea2..d6a4915ea2b 100644
--- a/airflow-core/src/airflow/models/pool.py
+++ b/airflow-core/src/airflow/models/pool.py
@@ -129,20 +129,35 @@ class Pool(Base):
         description: str,
         include_deferred: bool,
         *,
+        team_name: str | None = None,
         session: Session = NEW_SESSION,
     ) -> Pool:
         """Create a pool with given parameters or update it if it already 
exists."""
+        from airflow.configuration import conf
+
         if not name:
             raise ValueError("Pool name must not be empty")
 
+        if team_name and not conf.getboolean("core", "multi_team"):
+            raise ValueError(
+                "team_name cannot be set when multi_team mode is disabled. 
Please contact your administrator."
+            )
+
         pool = session.scalar(select(Pool).filter_by(pool=name))
         if pool is None:
-            pool = Pool(pool=name, slots=slots, description=description, 
include_deferred=include_deferred)
+            pool = Pool(
+                pool=name,
+                slots=slots,
+                description=description,
+                include_deferred=include_deferred,
+                team_name=team_name,
+            )
             session.add(pool)
         else:
             pool.slots = slots
             pool.description = description
             pool.include_deferred = include_deferred
+            pool.team_name = team_name
 
         return pool
 
diff --git a/airflow-core/tests/unit/cli/commands/test_pool_command.py 
b/airflow-core/tests/unit/cli/commands/test_pool_command.py
index 8fea33d7a7f..6b5dc2a5417 100644
--- a/airflow-core/tests/unit/cli/commands/test_pool_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_pool_command.py
@@ -131,7 +131,11 @@ class TestCliPools:
         pool_export_file_path = tmp_path / "pools_export.json"
         pool_config_input = {
             "foo": {"description": "foo_test", "slots": 1, "include_deferred": 
True},
-            "default_pool": {"description": "Default pool", "slots": 128, 
"include_deferred": False},
+            "default_pool": {
+                "description": "Default pool",
+                "slots": 128,
+                "include_deferred": False,
+            },
             "baz": {"description": "baz_test", "slots": 2, "include_deferred": 
False},
         }
         with open(pool_import_file_path, mode="w") as file:
@@ -146,3 +150,149 @@ class TestCliPools:
         with open(pool_export_file_path) as file:
             pool_config_output = json.load(file)
             assert pool_config_input == pool_config_output, "Input and output 
pool files are not same"
+
+    def test_pool_set_with_team_name(self):
+        """Test that pool_set with --team-name assigns the pool to the team 
when multi_team is enabled."""
+        from airflow.models.team import Team
+
+        from tests_common.test_utils.config import conf_vars
+
+        # Create the team first
+        team = Team(name="test_team")
+        self.session.add(team)
+        self.session.commit()
+
+        try:
+            with conf_vars({("core", "multi_team"): "True"}):
+                pool_command.pool_set(
+                    self.parser.parse_args(
+                        ["pools", "set", "team_pool", "5", "team pool", 
"--team-name", "test_team"]
+                    )
+                )
+
+            pool = self.session.scalar(select(Pool).where(Pool.pool == 
"team_pool"))
+            assert pool is not None
+            assert pool.team_name == "test_team"
+            assert pool.slots == 5
+        finally:
+            self.session.execute(delete(Pool).where(Pool.pool == "team_pool"))
+            self.session.execute(delete(Team).where(Team.name == "test_team"))
+            self.session.commit()
+
+    def test_pool_set_team_name_rejected_when_multi_team_disabled(self):
+        """Test that pool_set with --team-name raises when multi_team is 
disabled."""
+        from airflow.models.team import Team
+
+        from tests_common.test_utils.config import conf_vars
+
+        team = Team(name="test_team")
+        self.session.add(team)
+        self.session.commit()
+
+        try:
+            with conf_vars({("core", "multi_team"): "False"}):
+                with pytest.raises(
+                    ValueError, match="team_name cannot be set when multi_team 
mode is disabled"
+                ):
+                    pool_command.pool_set(
+                        self.parser.parse_args(
+                            ["pools", "set", "team_pool", "5", "team pool", 
"--team-name", "test_team"]
+                        )
+                    )
+        finally:
+            self.session.execute(delete(Pool).where(Pool.pool == "team_pool"))
+            self.session.execute(delete(Team).where(Team.name == "test_team"))
+            self.session.commit()
+
+    def test_pool_set_without_team_name(self):
+        """Test that pool_set without --team-name leaves team_name as None."""
+        pool_command.pool_set(self.parser.parse_args(["pools", "set", 
"no_team_pool", "3", "no team"]))
+
+        pool = self.session.scalar(select(Pool).where(Pool.pool == 
"no_team_pool"))
+        assert pool is not None
+        assert pool.team_name is None
+
+    def test_pool_import_export_with_team_name(self, tmp_path):
+        """Test that import/export round-trips the team_name field."""
+        from airflow.models.team import Team
+
+        from tests_common.test_utils.config import conf_vars
+
+        team = Team(name="import_team")
+        self.session.add(team)
+        self.session.commit()
+
+        pool_import_file_path = tmp_path / "pools_import_team.json"
+        pool_export_file_path = tmp_path / "pools_export_team.json"
+        pool_config_input = {
+            "team_pool_a": {
+                "slots": 10,
+                "description": "team pool",
+                "include_deferred": False,
+                "team_name": "import_team",
+            },
+            "global_pool": {
+                "slots": 5,
+                "description": "global pool",
+                "include_deferred": False,
+            },
+        }
+
+        with open(pool_import_file_path, mode="w") as file:
+            json.dump(pool_config_input, file)
+
+        try:
+            with conf_vars({("core", "multi_team"): "True"}):
+                pool_command.pool_import(
+                    self.parser.parse_args(["pools", "import", 
str(pool_import_file_path)])
+                )
+
+            # Verify team assignment
+            pool = self.session.scalar(select(Pool).where(Pool.pool == 
"team_pool_a"))
+            assert pool is not None
+            assert pool.team_name == "import_team"
+
+            global_pool = self.session.scalar(select(Pool).where(Pool.pool == 
"global_pool"))
+            assert global_pool is not None
+            assert global_pool.team_name is None
+
+            # Export and verify
+            pool_command.pool_export(self.parser.parse_args(["pools", 
"export", str(pool_export_file_path)]))
+
+            with open(pool_export_file_path) as file:
+                pool_config_output = json.load(file)
+
+            assert pool_config_output["team_pool_a"]["team_name"] == 
"import_team"
+            assert "team_name" not in pool_config_output["global_pool"]
+        finally:
+            
self.session.execute(delete(Pool).where(Pool.pool.in_(["team_pool_a", 
"global_pool"])))
+            self.session.execute(delete(Team).where(Team.name == 
"import_team"))
+            self.session.commit()
+
+    def test_pool_list_shows_team_name(self, stdout_capture):
+        """Test that pool list output includes the team_name column."""
+        from airflow.models.team import Team
+
+        from tests_common.test_utils.config import conf_vars
+
+        team = Team(name="list_team")
+        self.session.add(team)
+        self.session.commit()
+
+        try:
+            with conf_vars({("core", "multi_team"): "True"}):
+                pool_command.pool_set(
+                    self.parser.parse_args(
+                        ["pools", "set", "list_pool", "5", "desc", 
"--team-name", "list_team"]
+                    )
+                )
+
+            with stdout_capture as stdout:
+                pool_command.pool_list(self.parser.parse_args(["pools", 
"list"]))
+
+            output = stdout.getvalue()
+            assert "list_team" in output
+        finally:
+            self.session.execute(delete(Pool).where(Pool.pool == "list_pool"))
+            self.session.execute(delete(Team).where(Team.name == "list_team"))
+            self.session.commit()

Reply via email to