This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 621792df82b Add `LLMBranchOperator` and `@task.llm_branch` to 
`common.ai` provider (#62740)
621792df82b is described below

commit 621792df82b69c20cc21338fa818a3c358ec3154
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Mar 3 00:16:59 2026 +0000

    Add `LLMBranchOperator` and `@task.llm_branch` to `common.ai` provider 
(#62740)
    
    - Add type: ignore[misc] for dynamic Enum() construction (mypy
      requires a literal second arg)
    - Add explicit type annotation for branches variable to avoid
      incompatible assignment error
    - Match do_branch return type (str | Iterable[str] | None)
---
 providers/common/ai/docs/index.rst                 |   1 +
 providers/common/ai/docs/operators/llm_branch.rst  |  97 ++++++++++++
 providers/common/ai/provider.yaml                  |   4 +
 providers/common/ai/pyproject.toml                 |   4 +
 .../providers/common/ai/decorators/llm_branch.py   | 135 +++++++++++++++++
 .../common/ai/example_dags/example_llm_branch.py   | 152 +++++++++++++++++++
 .../providers/common/ai/get_provider_info.py       |   6 +
 .../providers/common/ai/operators/llm_branch.py    |  94 ++++++++++++
 .../unit/common/ai/decorators/test_llm_branch.py   | 102 +++++++++++++
 .../unit/common/ai/operators/test_llm_branch.py    | 162 +++++++++++++++++++++
 10 files changed, 757 insertions(+)

diff --git a/providers/common/ai/docs/index.rst 
b/providers/common/ai/docs/index.rst
index fffcfc494ba..06c600a7805 100644
--- a/providers/common/ai/docs/index.rst
+++ b/providers/common/ai/docs/index.rst
@@ -122,6 +122,7 @@ Dependent package
 
==================================================================================================================
  =================
 `apache-airflow-providers-common-compat 
<https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_  
``common.compat``
 `apache-airflow-providers-common-sql 
<https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_        
``common.sql``
+`apache-airflow-providers-standard 
<https://airflow.apache.org/docs/apache-airflow-providers-standard>`_           
 ``standard``
 
==================================================================================================================
  =================
 
 Downloading official packages
diff --git a/providers/common/ai/docs/operators/llm_branch.rst 
b/providers/common/ai/docs/operators/llm_branch.rst
new file mode 100644
index 00000000000..9d1bc059a5e
--- /dev/null
+++ b/providers/common/ai/docs/operators/llm_branch.rst
@@ -0,0 +1,97 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. _howto/operator:llm_branch:
+
+``LLMBranchOperator``
+=====================
+
+Use 
:class:`~airflow.providers.common.ai.operators.llm_branch.LLMBranchOperator`
+for LLM-driven branching — where the LLM decides which downstream task(s) to
+execute.
+
+The operator discovers downstream tasks automatically from the DAG topology
+and presents them to the LLM as a constrained enum via pydantic-ai structured
+output. No text parsing or manual validation is needed.
+
+.. seealso::
+    :ref:`Connection configuration <howto/connection:pydantic_ai>`
+
+Basic Usage
+-----------
+
+Connect the operator to downstream tasks. The LLM chooses which branch to
+execute based on the prompt:
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
+    :language: python
+    :start-after: [START howto_operator_llm_branch_basic]
+    :end-before: [END howto_operator_llm_branch_basic]
+
+Multiple Branches
+-----------------
+
+Set ``allow_multiple_branches=True`` to let the LLM select more than one
+downstream task. All selected branches run; unselected branches are skipped:
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
+    :language: python
+    :start-after: [START howto_operator_llm_branch_multi]
+    :end-before: [END howto_operator_llm_branch_multi]
+
+TaskFlow Decorator
+------------------
+
+The ``@task.llm_branch`` decorator wraps ``LLMBranchOperator``. The function
+returns the prompt string; all other parameters are passed to the operator:
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
+    :language: python
+    :start-after: [START howto_decorator_llm_branch]
+    :end-before: [END howto_decorator_llm_branch]
+
+With multiple branches:
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
+    :language: python
+    :start-after: [START howto_decorator_llm_branch_multi]
+    :end-before: [END howto_decorator_llm_branch_multi]
+
+How It Works
+------------
+
+At execution time, the operator:
+
+1. Reads ``self.downstream_task_ids`` from the DAG topology.
+2. Creates a dynamic ``Enum`` with one member per downstream task ID.
+3. Passes that enum as ``output_type`` to ``pydantic-ai``, constraining the 
LLM to
+   valid task IDs only.
+4. Converts the LLM's structured output to task ID string(s) and calls
+   ``do_branch()`` to skip non-selected downstream tasks.
+
+Parameters
+----------
+
+- ``prompt``: The prompt to send to the LLM (operator) or the return value of 
the
+  decorated function (decorator).
+- ``llm_conn_id``: Airflow connection ID for the LLM provider.
+- ``model_id``: Model identifier (e.g. ``"openai:gpt-5"``). Overrides the 
connection's extra field.
+- ``system_prompt``: System-level instructions for the agent. Supports Jinja 
templating.
+- ``allow_multiple_branches``: When ``False`` (default) the LLM returns a 
single
+  task ID. When ``True`` the LLM may return one or more task IDs.
+- ``agent_params``: Additional keyword arguments passed to the pydantic-ai 
``Agent``
+  constructor (e.g. ``retries``, ``model_settings``). Supports Jinja 
templating.
diff --git a/providers/common/ai/provider.yaml 
b/providers/common/ai/provider.yaml
index 7ef0acd7c3c..7e51945470c 100644
--- a/providers/common/ai/provider.yaml
+++ b/providers/common/ai/provider.yaml
@@ -33,6 +33,7 @@ integrations:
     external-doc-url: 
https://airflow.apache.org/docs/apache-airflow-providers-common-ai/
     how-to-guide:
       - /docs/apache-airflow-providers-common-ai/operators/llm.rst
+      - /docs/apache-airflow-providers-common-ai/operators/llm_branch.rst
       - /docs/apache-airflow-providers-common-ai/operators/llm_sql.rst
     tags: [ai]
   - integration-name: Pydantic AI
@@ -62,10 +63,13 @@ operators:
   - integration-name: Common AI
     python-modules:
       - airflow.providers.common.ai.operators.llm
+      - airflow.providers.common.ai.operators.llm_branch
       - airflow.providers.common.ai.operators.llm_sql
 
 task-decorators:
   - class-name: airflow.providers.common.ai.decorators.llm.llm_task
     name: llm
+  - class-name: 
airflow.providers.common.ai.decorators.llm_branch.llm_branch_task
+    name: llm_branch
   - class-name: airflow.providers.common.ai.decorators.llm_sql.llm_sql_task
     name: llm_sql
diff --git a/providers/common/ai/pyproject.toml 
b/providers/common/ai/pyproject.toml
index b8726b0248b..1770af6848b 100644
--- a/providers/common/ai/pyproject.toml
+++ b/providers/common/ai/pyproject.toml
@@ -79,6 +79,9 @@ dependencies = [
 "common.sql" = [
     "apache-airflow-providers-common-sql"
 ]
+"standard" = [
+    "apache-airflow-providers-standard"
+]
 
 [dependency-groups]
 dev = [
@@ -87,6 +90,7 @@ dev = [
     "apache-airflow-devel-common",
     "apache-airflow-providers-common-compat",
     "apache-airflow-providers-common-sql",
+    "apache-airflow-providers-standard",
     # Additional devel dependencies (do not remove this line and add extra 
development dependencies)
     "sqlglot>=26.0.0",
     "apache-airflow-providers-common-sql[datafusion]"
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_branch.py 
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_branch.py
new file mode 100644
index 00000000000..2dc9194638a
--- /dev/null
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_branch.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+TaskFlow decorator for LLM-driven branching.
+
+The user writes a function that **returns the prompt string**. The decorator
+discovers downstream tasks from the DAG topology and asks the LLM to choose
+which branch(es) to execute using pydantic-ai structured output.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable, Collection, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar
+
+from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
+from airflow.providers.common.compat.sdk import (
+    DecoratedOperator,
+    TaskDecorator,
+    context_merge,
+    task_decorator_factory,
+)
+from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
+from airflow.utils.operator_helpers import determine_kwargs
+
+if TYPE_CHECKING:
+    from airflow.sdk import Context
+
+
+class _LLMBranchDecoratedOperator(DecoratedOperator, LLMBranchOperator):
+    """
+    Wraps a callable that returns a prompt for LLM-driven branching.
+
+    The user function is called at execution time to produce the prompt string.
+    All other parameters (``llm_conn_id``, ``system_prompt``, 
``allow_multiple_branches``,
+    etc.) are passed through to
+    
:class:`~airflow.providers.common.ai.operators.llm_branch.LLMBranchOperator`.
+
+    :param python_callable: A reference to a callable that returns the prompt 
string.
+    :param op_args: Positional arguments for the callable.
+    :param op_kwargs: Keyword arguments for the callable.
+    """
+
+    template_fields: Sequence[str] = (
+        *DecoratedOperator.template_fields,
+        *LLMBranchOperator.template_fields,
+    )
+    template_fields_renderers: ClassVar[dict[str, str]] = {
+        **DecoratedOperator.template_fields_renderers,
+    }
+
+    custom_operator_name: str = "@task.llm_branch"
+
+    def __init__(
+        self,
+        *,
+        python_callable: Callable,
+        op_args: Collection[Any] | None = None,
+        op_kwargs: Mapping[str, Any] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(
+            python_callable=python_callable,
+            op_args=op_args,
+            op_kwargs=op_kwargs,
+            prompt=SET_DURING_EXECUTION,
+            **kwargs,
+        )
+
+    def execute(self, context: Context) -> Any:
+        context_merge(context, self.op_kwargs)
+        kwargs = determine_kwargs(self.python_callable, self.op_args, context)
+
+        self.prompt = self.python_callable(*self.op_args, **kwargs)
+
+        if not isinstance(self.prompt, str) or not self.prompt.strip():
+            raise TypeError(
+                "The returned value from the @task.llm_branch callable must be 
a non-empty string."
+            )
+
+        self.render_template_fields(context)
+        return LLMBranchOperator.execute(self, context)
+
+
+def llm_branch_task(
+    python_callable: Callable | None = None,
+    **kwargs,
+) -> TaskDecorator:
+    """
+    Wrap a function that returns a prompt into an LLM-driven branching task.
+
+    The function body constructs the prompt. The decorator discovers downstream
+    tasks from the DAG topology and asks the LLM to choose which branch(es)
+    to execute.
+
+    Usage::
+
+        @task.llm_branch(
+            llm_conn_id="openai_default",
+            system_prompt="Route support tickets to the right team.",
+        )
+        def route_ticket(message: str):
+            return f"Route this ticket: {message}"
+
+    With multiple branches::
+
+        @task.llm_branch(
+            llm_conn_id="openai_default",
+            system_prompt="Select all applicable categories.",
+            allow_multiple_branches=True,
+        )
+        def classify(text: str):
+            return f"Classify this text: {text}"
+
+    :param python_callable: Function to decorate.
+    """
+    return task_decorator_factory(
+        python_callable=python_callable,
+        decorated_operator_class=_LLMBranchDecoratedOperator,
+        **kwargs,
+    )
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
new file mode 100644
index 00000000000..c76b68999e7
--- /dev/null
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_branch.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example DAGs demonstrating LLMBranchOperator and @task.llm_branch usage."""
+
+from __future__ import annotations
+
+from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
+from airflow.providers.common.compat.sdk import dag, task
+
+
+# [START howto_operator_llm_branch_basic]
+@dag
+def example_llm_branch_operator():
+    route = LLMBranchOperator(
+        task_id="route_ticket",
+        prompt="User says: 'My password reset email never arrived.'",
+        llm_conn_id="pydantic_ai_default",
+        system_prompt="Route support tickets to the right team.",
+    )
+
+    @task
+    def handle_billing():
+        return "Handling billing issue"
+
+    @task
+    def handle_auth():
+        return "Handling auth issue"
+
+    @task
+    def handle_general():
+        return "Handling general issue"
+
+    route >> [handle_billing(), handle_auth(), handle_general()]
+
+
+# [END howto_operator_llm_branch_basic]
+
+example_llm_branch_operator()
+
+
+# [START howto_operator_llm_branch_multi]
+@dag
+def example_llm_branch_multi():
+    route = LLMBranchOperator(
+        task_id="classify",
+        prompt="This product is great but shipping was slow and the box was 
damaged.",
+        llm_conn_id="pydantic_ai_default",
+        system_prompt="Select all applicable categories for this customer 
review.",
+        allow_multiple_branches=True,
+    )
+
+    @task
+    def handle_positive():
+        return "Processing positive feedback"
+
+    @task
+    def handle_shipping():
+        return "Escalating shipping issue"
+
+    @task
+    def handle_packaging():
+        return "Escalating packaging issue"
+
+    route >> [handle_positive(), handle_shipping(), handle_packaging()]
+
+
+# [END howto_operator_llm_branch_multi]
+
+example_llm_branch_multi()
+
+
+# [START howto_decorator_llm_branch]
+@dag
+def example_llm_branch_decorator():
+    @task.llm_branch(
+        llm_conn_id="pydantic_ai_default",
+        system_prompt="Route support tickets to the right team.",
+    )
+    def route_ticket(message: str):
+        return f"Route this support ticket: {message}"
+
+    @task
+    def handle_billing():
+        return "Handling billing issue"
+
+    @task
+    def handle_auth():
+        return "Handling auth issue"
+
+    @task
+    def handle_general():
+        return "Handling general issue"
+
+    route_ticket("I was charged twice for my subscription.") >> [
+        handle_billing(),
+        handle_auth(),
+        handle_general(),
+    ]
+
+
+# [END howto_decorator_llm_branch]
+
+example_llm_branch_decorator()
+
+
+# [START howto_decorator_llm_branch_multi]
+@dag
+def example_llm_branch_decorator_multi():
+    @task.llm_branch(
+        llm_conn_id="pydantic_ai_default",
+        system_prompt="Select all applicable categories for this customer 
review.",
+        allow_multiple_branches=True,
+    )
+    def classify_review(review: str):
+        return f"Classify this review: {review}"
+
+    @task
+    def handle_positive():
+        return "Processing positive feedback"
+
+    @task
+    def handle_shipping():
+        return "Escalating shipping issue"
+
+    @task
+    def handle_packaging():
+        return "Escalating packaging issue"
+
+    classify_review("Great product but shipping was slow.") >> [
+        handle_positive(),
+        handle_shipping(),
+        handle_packaging(),
+    ]
+
+
+# [END howto_decorator_llm_branch_multi]
+
+example_llm_branch_decorator_multi()
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py 
b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
index 091f4b849ac..26d285e6a70 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
@@ -32,6 +32,7 @@ def get_provider_info():
                 "external-doc-url": 
"https://airflow.apache.org/docs/apache-airflow-providers-common-ai/";,
                 "how-to-guide": [
                     
"/docs/apache-airflow-providers-common-ai/operators/llm.rst",
+                    
"/docs/apache-airflow-providers-common-ai/operators/llm_branch.rst",
                     
"/docs/apache-airflow-providers-common-ai/operators/llm_sql.rst",
                 ],
                 "tags": ["ai"],
@@ -67,12 +68,17 @@ def get_provider_info():
                 "integration-name": "Common AI",
                 "python-modules": [
                     "airflow.providers.common.ai.operators.llm",
+                    "airflow.providers.common.ai.operators.llm_branch",
                     "airflow.providers.common.ai.operators.llm_sql",
                 ],
             }
         ],
         "task-decorators": [
             {"class-name": 
"airflow.providers.common.ai.decorators.llm.llm_task", "name": "llm"},
+            {
+                "class-name": 
"airflow.providers.common.ai.decorators.llm_branch.llm_branch_task",
+                "name": "llm_branch",
+            },
             {"class-name": 
"airflow.providers.common.ai.decorators.llm_sql.llm_sql_task", "name": 
"llm_sql"},
         ],
     }
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_branch.py 
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_branch.py
new file mode 100644
index 00000000000..b7f3028ec92
--- /dev/null
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_branch.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""LLM-driven branching operator."""
+
+from __future__ import annotations
+
+from collections.abc import Iterable, Sequence
+from enum import Enum
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.standard.operators.branch import BranchMixIn
+
+if TYPE_CHECKING:
+    from airflow.sdk import Context
+
+
+class LLMBranchOperator(LLMOperator, BranchMixIn):
+    """
+    Ask an LLM to choose which downstream task(s) to execute.
+
+    Downstream task IDs are discovered automatically from the DAG topology
+    and presented to the LLM as a constrained enum via pydantic-ai structured
+    output. No text parsing or manual validation is needed.
+
+    :param prompt: The prompt to send to the LLM.
+    :param llm_conn_id: Connection ID for the LLM provider.
+    :param model_id: Model identifier (e.g. ``"openai:gpt-5"``).
+        Overrides the model stored in the connection's extra field.
+    :param system_prompt: System-level instructions for the LLM agent.
+    :param allow_multiple_branches: When ``False`` (default) the LLM returns a
+        single task ID. When ``True`` the LLM may return one or more task IDs.
+    :param agent_params: Additional keyword arguments passed to the pydantic-ai
+        ``Agent`` constructor (e.g. ``retries``, ``model_settings``, 
``tools``).
+    """
+
+    inherits_from_skipmixin = True
+
+    template_fields: Sequence[str] = LLMOperator.template_fields
+
+    def __init__(
+        self,
+        *,
+        allow_multiple_branches: bool = False,
+        **kwargs: Any,
+    ) -> None:
+        kwargs.pop("output_type", None)
+        super().__init__(**kwargs)
+        self.allow_multiple_branches = allow_multiple_branches
+
+    def execute(self, context: Context) -> str | Iterable[str] | None:
+        if not self.downstream_task_ids:
+            raise ValueError(
+                f"{self.task_id!r} has no downstream tasks. "
+                "LLMBranchOperator requires at least one downstream task to 
branch into."
+            )
+
+        downstream_tasks_enum = Enum(  # type: ignore[misc]
+            "DownstreamTasks",
+            {task_id: task_id for task_id in self.downstream_task_ids},
+        )
+        output_type = list[downstream_tasks_enum] if 
self.allow_multiple_branches else downstream_tasks_enum
+
+        agent = self.llm_hook.create_agent(
+            output_type=output_type,
+            instructions=self.system_prompt,
+            **self.agent_params,
+        )
+        result = agent.run_sync(self.prompt)
+        output = result.output
+
+        branches: str | list[str]
+        if isinstance(output, list):
+            branches = [item.value for item in output]
+        elif isinstance(output, Enum):
+            branches = output.value
+        else:
+            branches = str(output)
+
+        return self.do_branch(context, branches)
diff --git 
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py 
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
new file mode 100644
index 00000000000..66620426a3b
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from enum import Enum
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.decorators.llm_branch import 
_LLMBranchDecoratedOperator
+from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
+
+
+class TestLLMBranchDecoratedOperator:
+    def test_custom_operator_name(self):
+        assert _LLMBranchDecoratedOperator.custom_operator_name == 
"@task.llm_branch"
+
+    @patch.object(LLMBranchOperator, "do_branch")
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_execute_calls_callable_and_branches(self, mock_hook_cls, 
mock_do_branch):
+        """The callable's return value becomes the LLM prompt, LLM output goes 
through do_branch."""
+        downstream_enum = Enum("DownstreamTasks", {"positive": "positive", 
"negative": "negative"})
+
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = downstream_enum.positive
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_do_branch.return_value = "positive"
+
+        def my_prompt():
+            return "Route this review"
+
+        op = _LLMBranchDecoratedOperator(
+            task_id="test",
+            python_callable=my_prompt,
+            llm_conn_id="my_llm",
+        )
+        op.downstream_task_ids = {"positive", "negative"}
+
+        result = op.execute(context={})
+
+        assert result == "positive"
+        assert op.prompt == "Route this review"
+        mock_agent.run_sync.assert_called_once_with("Route this review")
+        mock_do_branch.assert_called_once()
+
+    @pytest.mark.parametrize(
+        "return_value",
+        [42, "", "   ", None],
+        ids=["non-string", "empty", "whitespace", "none"],
+    )
+    def test_execute_raises_on_invalid_prompt(self, return_value):
+        """TypeError when the callable returns a non-string or blank string."""
+        op = _LLMBranchDecoratedOperator(
+            task_id="test",
+            python_callable=lambda: return_value,
+            llm_conn_id="my_llm",
+        )
+        with pytest.raises(TypeError, match="non-empty string"):
+            op.execute(context={})
+
+    @patch.object(LLMBranchOperator, "do_branch")
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls, 
mock_do_branch):
+        """op_kwargs are resolved by the callable to build the prompt."""
+        downstream_enum = Enum("DownstreamTasks", {"task_a": "task_a"})
+
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = downstream_enum.task_a
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        def my_prompt(ticket_type):
+            return f"Route this {ticket_type} ticket"
+
+        op = _LLMBranchDecoratedOperator(
+            task_id="test",
+            python_callable=my_prompt,
+            llm_conn_id="my_llm",
+            op_kwargs={"ticket_type": "billing"},
+        )
+        op.downstream_task_ids = {"task_a"}
+
+        op.execute(context={"task_instance": MagicMock()})
+
+        assert op.prompt == "Route this billing ticket"
diff --git 
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
new file mode 100644
index 00000000000..d94fc552178
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from enum import Enum
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
+
+
+class TestLLMBranchOperator:
+    def test_inherits_from_skipmixin_is_true(self):
+        assert LLMBranchOperator.inherits_from_skipmixin is True
+
+    def test_template_fields(self):
+        assert set(LLMBranchOperator.template_fields) == 
set(LLMOperator.template_fields)
+
+    def test_output_type_ignored(self):
+        """Passing output_type= doesn't break anything; it's silently 
dropped."""
+        op = LLMBranchOperator(
+            task_id="test",
+            prompt="pick a branch",
+            llm_conn_id="my_llm",
+            output_type=int,
+        )
+        # output_type is overridden to str (the LLMOperator default) since
+        # the real output_type is built dynamically from downstream_task_ids
+        assert op.output_type is str
+
+    @patch.object(LLMBranchOperator, "do_branch")
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_execute_single_branch(self, mock_hook_cls, mock_do_branch):
+        """LLM returns a single enum member → do_branch receives a string."""
+        downstream_enum = Enum("DownstreamTasks", {"task_a": "task_a", 
"task_b": "task_b"})
+
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = downstream_enum.task_a
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_do_branch.return_value = "task_a"
+
+        op = LLMBranchOperator(
+            task_id="test",
+            prompt="Pick a branch",
+            llm_conn_id="my_llm",
+        )
+        op.downstream_task_ids = {"task_a", "task_b"}
+
+        ctx = MagicMock()
+        result = op.execute(ctx)
+
+        assert result == "task_a"
+        mock_do_branch.assert_called_once_with(ctx, "task_a")
+        mock_agent.run_sync.assert_called_once_with("Pick a branch")
+
+    @patch.object(LLMBranchOperator, "do_branch")
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_execute_multi_branch(self, mock_hook_cls, mock_do_branch):
+        """allow_multiple_branches=True → LLM returns list of enums → 
do_branch receives list."""
+        downstream_enum = Enum(
+            "DownstreamTasks", {"task_a": "task_a", "task_b": "task_b", 
"task_c": "task_c"}
+        )
+
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = [downstream_enum.task_a, downstream_enum.task_c]
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+        mock_do_branch.return_value = ["task_a", "task_c"]
+
+        op = LLMBranchOperator(
+            task_id="test",
+            prompt="Pick branches",
+            llm_conn_id="my_llm",
+            allow_multiple_branches=True,
+        )
+        op.downstream_task_ids = {"task_a", "task_b", "task_c"}
+
+        ctx = MagicMock()
+        result = op.execute(ctx)
+
+        assert result == ["task_a", "task_c"]
+        mock_do_branch.assert_called_once_with(ctx, ["task_a", "task_c"])
+
+    @patch.object(LLMBranchOperator, "do_branch")
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_system_prompt_forwarded(self, mock_hook_cls, mock_do_branch):
+        """system_prompt is passed to create_agent(instructions=...)."""
+        downstream_enum = Enum("DownstreamTasks", {"task_a": "task_a"})
+
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = downstream_enum.task_a
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        op = LLMBranchOperator(
+            task_id="test",
+            prompt="Pick",
+            llm_conn_id="my_llm",
+            system_prompt="Route tickets to the right team.",
+        )
+        op.downstream_task_ids = {"task_a"}
+
+        op.execute(MagicMock())
+
+        call_kwargs = mock_hook_cls.return_value.create_agent.call_args
+        assert call_kwargs.kwargs["instructions"] == "Route tickets to the 
right team."
+
+    @patch.object(LLMBranchOperator, "do_branch")
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_downstream_task_ids_used_for_enum(self, mock_hook_cls, 
mock_do_branch):
+        """The dynamic enum is built from self.downstream_task_ids."""
+        downstream_enum = Enum(
+            "DownstreamTasks", {"billing": "billing", "auth": "auth", 
"general": "general"}
+        )
+
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = downstream_enum.billing
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        op = LLMBranchOperator(
+            task_id="test",
+            prompt="Pick",
+            llm_conn_id="my_llm",
+        )
+        op.downstream_task_ids = {"billing", "auth", "general"}
+
+        op.execute(MagicMock())
+
+        output_type = 
mock_hook_cls.return_value.create_agent.call_args.kwargs["output_type"]
+        assert {m.value for m in output_type} == {"billing", "auth", "general"}
+
+    def test_execute_raises_on_no_downstream_tasks(self):
+        """ValueError when the operator has no downstream tasks."""
+        op = LLMBranchOperator(
+            task_id="test",
+            prompt="Pick",
+            llm_conn_id="my_llm",
+        )
+        with pytest.raises(ValueError, match="no downstream tasks"):
+            op.execute(MagicMock())


Reply via email to