This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 318bb25ffc7 AIP-99: Add `AgentOperator` and `@task.agent` for agentic 
LLM workflows (#62825)
318bb25ffc7 is described below

commit 318bb25ffc78aa2fbe921042417ece99b83f3665
Author: Kaxil Naik <[email protected]>
AuthorDate: Wed Mar 4 12:52:49 2026 +0000

    AIP-99: Add `AgentOperator` and `@task.agent` for agentic LLM workflows 
(#62825)
    
    Docs for HookToolset (generic hook→tools adapter) and SQLToolset
    (curated 4-tool DB toolset). Includes defense layers table,
    allowed_tables limitation, HookToolset guidelines, recommended
    configurations, and production checklist.
    
    AgentOperator runs a pydantic-ai Agent with tools and multi-turn
    reasoning. The operator builds the agent from an Airflow connection
    (llm_conn_id) and optional toolsets (HookToolset, SQLToolset, etc.),
    keeping credentials in the secret backend.
    
    @task.agent decorator wraps AgentOperator — the decorated function
    returns the prompt string, all other params are passed through.
    
    Includes docs with security section (defense layers, allowed_tables
    limitation, HookToolset guidelines, production checklist), example
    DAGs, and unit tests.
    
    Help users choose between LLMOperator, LLMBranchOperator,
    LLMSQLQueryOperator, and AgentOperator with a comparison table
    and short descriptions of when to use each.
    
    closes #62826
---
 docs/spelling_wordlist.txt                         |   1 +
 providers/common/ai/docs/operators/agent.rst       | 138 ++++++++++++++++
 providers/common/ai/docs/operators/index.rst       |  40 +++++
 providers/common/ai/provider.yaml                  |   4 +
 .../providers/common/ai/decorators/agent.py        | 123 ++++++++++++++
 .../common/ai/example_dags/example_agent.py        | 178 +++++++++++++++++++++
 .../providers/common/ai/get_provider_info.py       |   3 +
 .../airflow/providers/common/ai/operators/agent.py | 108 +++++++++++++
 .../tests/unit/common/ai/decorators/test_agent.py  | 131 +++++++++++++++
 .../tests/unit/common/ai/operators/test_agent.py   | 138 ++++++++++++++++
 10 files changed, 864 insertions(+)

diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 81fd9cc567e..2cbae4da50e 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1760,6 +1760,7 @@ stackdriver
 stacklevel
 stacktrace
 starttls
+stateful
 StatefulSet
 StatefulSets
 statics
diff --git a/providers/common/ai/docs/operators/agent.rst 
b/providers/common/ai/docs/operators/agent.rst
new file mode 100644
index 00000000000..bafd2114589
--- /dev/null
+++ b/providers/common/ai/docs/operators/agent.rst
@@ -0,0 +1,138 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. _howto/operator:agent:
+
+``AgentOperator`` & ``@task.agent``
+===================================
+
+Use :class:`~airflow.providers.common.ai.operators.agent.AgentOperator` or
+the ``@task.agent`` decorator to run an LLM agent with **tools** — the agent
+reasons about the prompt, calls tools (database queries, API calls, etc.) in
+a multi-turn loop, and returns a final answer.
+
+This is different from
+:class:`~airflow.providers.common.ai.operators.llm.LLMOperator`, which sends
+a single prompt and returns the output. ``AgentOperator`` manages a stateful
+tool-call loop where the LLM decides which tools to call and when to stop.
+
+.. seealso::
+    :ref:`Connection configuration <howto/connection:pydantic_ai>`
+
+
+SQL Agent
+---------
+
+The most common pattern: give an agent access to a database so it can answer
+questions by writing and executing SQL.
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+    :language: python
+    :start-after: [START howto_operator_agent_sql]
+    :end-before: [END howto_operator_agent_sql]
+
+The ``SQLToolset`` provides four tools to the agent:
+
+.. list-table::
+   :header-rows: 1
+   :widths: 20 50
+
+   * - Tool
+     - Description
+   * - ``list_tables``
+     - Lists available table names (filtered by ``allowed_tables`` if set)
+   * - ``get_schema``
+     - Returns column names and types for a table
+   * - ``query``
+     - Executes a SQL query and returns rows as JSON
+   * - ``check_query``
+     - Validates SQL syntax without executing it
+
+
+Hook-based Tools
+----------------
+
+Wrap any Airflow Hook's methods as agent tools using ``HookToolset``. Only
+methods you explicitly list are exposed — there is no auto-discovery.
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+    :language: python
+    :start-after: [START howto_operator_agent_hook]
+    :end-before: [END howto_operator_agent_hook]
+
+
+TaskFlow Decorator
+------------------
+
+The ``@task.agent`` decorator wraps ``AgentOperator``. The function returns
+the prompt string; all other parameters are passed to the operator.
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+    :language: python
+    :start-after: [START howto_decorator_agent]
+    :end-before: [END howto_decorator_agent]
+
+
+Structured Output
+-----------------
+
+Set ``output_type`` to a Pydantic ``BaseModel`` subclass to get structured
+data back. The result is serialized via ``model_dump()`` for XCom.
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+    :language: python
+    :start-after: [START howto_decorator_agent_structured]
+    :end-before: [END howto_decorator_agent_structured]
+
+
+Chaining with Downstream Tasks
+-------------------------------
+
+The agent's output is pushed to XCom like any other operator, so downstream
+tasks can consume it.
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+    :language: python
+    :start-after: [START howto_agent_chain]
+    :end-before: [END howto_agent_chain]
+
+
+Parameters
+----------
+
+- ``prompt``: The prompt to send to the agent (operator) or the return value
+  of the decorated function (decorator).
+- ``llm_conn_id``: Airflow connection ID for the LLM provider.
+- ``model_id``: Model identifier (e.g. ``"openai:gpt-5"``). Overrides the
+  connection's extra field.
+- ``system_prompt``: System-level instructions for the agent. Supports Jinja
+  templating.
+- ``output_type``: Expected output type (default: ``str``). Set to a Pydantic
+  ``BaseModel`` for structured output.
+- ``toolsets``: List of pydantic-ai toolsets (``SQLToolset``, ``HookToolset``,
+  etc.).
+- ``agent_params``: Additional keyword arguments passed to the pydantic-ai
+  ``Agent`` constructor (e.g. ``retries``, ``model_settings``).
+
+
+Security
+--------
+
+.. seealso::
+    :ref:`Toolsets — Security <howto/toolsets>` for defense layers,
+    ``allowed_tables`` limitations, ``HookToolset`` guidelines, recommended
+    configurations, and the production checklist.
diff --git a/providers/common/ai/docs/operators/index.rst 
b/providers/common/ai/docs/operators/index.rst
index 5ca15266335..e961931925a 100644
--- a/providers/common/ai/docs/operators/index.rst
+++ b/providers/common/ai/docs/operators/index.rst
@@ -18,6 +18,46 @@
 Common AI Operators
 ===================
 
+Choosing the right operator
+---------------------------
+
+The common-ai provider ships four operators (and matching ``@task`` 
decorators). Use this table
+to pick the one that fits your use case:
+
+.. list-table::
+   :header-rows: 1
+   :widths: 40 30 30
+
+   * - Need
+     - Operator
+     - Decorator
+   * - Single prompt → text or structured output
+     - :class:`~airflow.providers.common.ai.operators.llm.LLMOperator`
+     - ``@task.llm``
+   * - LLM picks which downstream task runs
+     - 
:class:`~airflow.providers.common.ai.operators.llm_branch.LLMBranchOperator`
+     - ``@task.llm_branch``
+   * - Natural-language → SQL generation (no execution)
+     - 
:class:`~airflow.providers.common.ai.operators.llm_sql.LLMSQLQueryOperator`
+     - ``@task.llm_sql``
+   * - Multi-turn reasoning with tools (DB queries, API calls, etc.)
+     - :class:`~airflow.providers.common.ai.operators.agent.AgentOperator`
+     - ``@task.agent``
+
+**LLMOperator / @task.llm** — stateless, single-turn calls. Use this for 
classification,
+summarization, extraction, or any prompt that produces one response. Supports 
structured output
+via a ``response_format`` Pydantic model.
+
+**AgentOperator / @task.agent** — multi-turn tool-calling loop. The model 
decides which tools to
+invoke and when to stop. Use this when the LLM needs to take actions (query 
databases, call APIs,
+read files) to produce its answer. You configure available tools through 
``toolsets``.
+
+AgentOperator *works* without toolsets — pydantic-ai supports tool-less agents 
for multi-turn
+reasoning — but if you don't need tools, ``LLMOperator`` is simpler and more 
explicit.
+
+Operator guides
+---------------
+
 .. toctree::
     :maxdepth: 1
     :glob:
diff --git a/providers/common/ai/provider.yaml 
b/providers/common/ai/provider.yaml
index 7e2cc85bf19..24507cd9277 100644
--- a/providers/common/ai/provider.yaml
+++ b/providers/common/ai/provider.yaml
@@ -32,6 +32,7 @@ integrations:
   - integration-name: Common AI
     external-doc-url: 
https://airflow.apache.org/docs/apache-airflow-providers-common-ai/
     how-to-guide:
+      - /docs/apache-airflow-providers-common-ai/operators/agent.rst
       - /docs/apache-airflow-providers-common-ai/operators/llm.rst
       - /docs/apache-airflow-providers-common-ai/operators/llm_branch.rst
       - /docs/apache-airflow-providers-common-ai/operators/llm_sql.rst
@@ -70,12 +71,15 @@ connection-types:
 operators:
   - integration-name: Common AI
     python-modules:
+      - airflow.providers.common.ai.operators.agent
       - airflow.providers.common.ai.operators.llm
       - airflow.providers.common.ai.operators.llm_branch
       - airflow.providers.common.ai.operators.llm_sql
       - airflow.providers.common.ai.operators.llm_schema_compare
 
 task-decorators:
+  - class-name: airflow.providers.common.ai.decorators.agent.agent_task
+    name: agent
   - class-name: airflow.providers.common.ai.decorators.llm.llm_task
     name: llm
   - class-name: 
airflow.providers.common.ai.decorators.llm_branch.llm_branch_task
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py 
b/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
new file mode 100644
index 00000000000..40c55f630c0
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+TaskFlow decorator for agentic LLM workflows.
+
+The user writes a function that **returns the prompt string**. The decorator
+handles hook creation, agent configuration with toolsets, multi-turn reasoning,
+and output serialization.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable, Collection, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar
+
+from airflow.providers.common.ai.operators.agent import AgentOperator
+from airflow.providers.common.compat.sdk import (
+    DecoratedOperator,
+    TaskDecorator,
+    context_merge,
+    determine_kwargs,
+    task_decorator_factory,
+)
+from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
+
+if TYPE_CHECKING:
+    from airflow.sdk import Context
+
+
+class _AgentDecoratedOperator(DecoratedOperator, AgentOperator):
+    """
+    Wraps a callable that returns a prompt for an agentic LLM workflow.
+
+    The user function is called at execution time to produce the prompt string.
+    All other parameters (``llm_conn_id``, ``toolsets``, ``system_prompt``, 
etc.)
+    are passed through to 
:class:`~airflow.providers.common.ai.operators.agent.AgentOperator`.
+
+    :param python_callable: A reference to a callable that returns the prompt 
string.
+    :param op_args: Positional arguments for the callable.
+    :param op_kwargs: Keyword arguments for the callable.
+    """
+
+    template_fields: Sequence[str] = (
+        *DecoratedOperator.template_fields,
+        *AgentOperator.template_fields,
+    )
+    template_fields_renderers: ClassVar[dict[str, str]] = {
+        **DecoratedOperator.template_fields_renderers,
+    }
+
+    custom_operator_name: str = "@task.agent"
+
+    def __init__(
+        self,
+        *,
+        python_callable: Callable,
+        op_args: Collection[Any] | None = None,
+        op_kwargs: Mapping[str, Any] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(
+            python_callable=python_callable,
+            op_args=op_args,
+            op_kwargs=op_kwargs,
+            prompt=SET_DURING_EXECUTION,
+            **kwargs,
+        )
+
+    def execute(self, context: Context) -> Any:
+        context_merge(context, self.op_kwargs)
+        kwargs = determine_kwargs(self.python_callable, self.op_args, context)
+
+        self.prompt = self.python_callable(*self.op_args, **kwargs)
+
+        if not isinstance(self.prompt, str) or not self.prompt.strip():
+            raise TypeError("The returned value from the @task.agent callable 
must be a non-empty string.")
+
+        self.render_template_fields(context)
+        return AgentOperator.execute(self, context)
+
+
+def agent_task(
+    python_callable: Callable | None = None,
+    **kwargs,
+) -> TaskDecorator:
+    """
+    Wrap a function that returns a prompt into an agentic LLM task.
+
+    The function body constructs the prompt (can use Airflow context, XCom, 
etc.).
+    The decorator handles hook creation, agent configuration with toolsets,
+    multi-turn reasoning, and output serialization.
+
+    Usage::
+
+        @task.agent(
+            llm_conn_id="pydantic_ai_default",
+            system_prompt="You are a data analyst.",
+            toolsets=[SQLToolset(db_conn_id="postgres_default")],
+        )
+        def analyze(question: str):
+            return f"Answer: {question}"
+
+    :param python_callable: Function to decorate.
+    """
+    return task_decorator_factory(
+        python_callable=python_callable,
+        decorated_operator_class=_AgentDecoratedOperator,
+        **kwargs,
+    )
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
new file mode 100644
index 00000000000..985d1019818
--- /dev/null
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example DAGs demonstrating AgentOperator, @task.agent, and toolsets."""
+
+from __future__ import annotations
+
+from airflow.providers.common.ai.operators.agent import AgentOperator
+from airflow.providers.common.ai.toolsets.hook import HookToolset
+from airflow.providers.common.ai.toolsets.sql import SQLToolset
+from airflow.providers.common.compat.sdk import dag, task
+
+# ---------------------------------------------------------------------------
+# 1. SQL Agent: answer a question using database tools
+# ---------------------------------------------------------------------------
+
+
+# [START howto_operator_agent_sql]
+@dag
+def example_agent_operator_sql():
+    AgentOperator(
+        task_id="analyst",
+        prompt="What are the top 5 customers by order count?",
+        llm_conn_id="pydantic_ai_default",
+        system_prompt=(
+            "You are a SQL analyst. Use the available tools to explore "
+            "the schema and answer the question with data."
+        ),
+        toolsets=[
+            SQLToolset(
+                db_conn_id="postgres_default",
+                allowed_tables=["customers", "orders"],
+                max_rows=20,
+            )
+        ],
+    )
+
+
+# [END howto_operator_agent_sql]
+
+example_agent_operator_sql()
+
+
+# ---------------------------------------------------------------------------
+# 2. Hook-based tools: wrap an existing hook for the agent
+# ---------------------------------------------------------------------------
+
+
+# [START howto_operator_agent_hook]
+@dag
+def example_agent_operator_hook():
+    from airflow.providers.http.hooks.http import HttpHook
+
+    http_hook = HttpHook(http_conn_id="my_api")
+
+    AgentOperator(
+        task_id="api_explorer",
+        prompt="What endpoints are available and what does /status return?",
+        llm_conn_id="pydantic_ai_default",
+        system_prompt="You are an API explorer. Use the tools to discover and 
call endpoints.",
+        toolsets=[
+            HookToolset(
+                http_hook,
+                allowed_methods=["run"],
+                tool_name_prefix="http_",
+            )
+        ],
+    )
+
+
+# [END howto_operator_agent_hook]
+
+example_agent_operator_hook()
+
+
+# ---------------------------------------------------------------------------
+# 3. @task.agent decorator with dynamic prompt
+# ---------------------------------------------------------------------------
+
+
+# [START howto_decorator_agent]
+@dag
+def example_agent_decorator():
+    @task.agent(
+        llm_conn_id="pydantic_ai_default",
+        system_prompt="You are a data analyst. Use tools to answer questions.",
+        toolsets=[
+            SQLToolset(
+                db_conn_id="postgres_default",
+                allowed_tables=["orders"],
+            )
+        ],
+    )
+    def analyze(question: str):
+        return f"Answer this question about our orders data: {question}"
+
+    analyze("What was our total revenue last month?")
+
+
+# [END howto_decorator_agent]
+
+example_agent_decorator()
+
+
+# ---------------------------------------------------------------------------
+# 4. Structured output — agent returns a Pydantic model
+# ---------------------------------------------------------------------------
+
+
+# [START howto_decorator_agent_structured]
+@dag
+def example_agent_structured_output():
+    from pydantic import BaseModel
+
+    class Analysis(BaseModel):
+        summary: str
+        top_items: list[str]
+        row_count: int
+
+    @task.agent(
+        llm_conn_id="pydantic_ai_default",
+        system_prompt="You are a data analyst. Return structured results.",
+        output_type=Analysis,
+        toolsets=[SQLToolset(db_conn_id="postgres_default")],
+    )
+    def analyze(question: str):
+        return f"Analyze: {question}"
+
+    analyze("What are the trending products this week?")
+
+
+# [END howto_decorator_agent_structured]
+
+example_agent_structured_output()
+
+
+# ---------------------------------------------------------------------------
+# 5. Chaining: agent output feeds into downstream tasks via XCom
+# ---------------------------------------------------------------------------
+
+
+# [START howto_agent_chain]
+@dag
+def example_agent_chain():
+    @task.agent(
+        llm_conn_id="pydantic_ai_default",
+        system_prompt="You are a SQL analyst.",
+        toolsets=[SQLToolset(db_conn_id="postgres_default", 
allowed_tables=["orders"])],
+    )
+    def investigate(question: str):
+        return f"Investigate: {question}"
+
+    @task
+    def send_report(analysis: str):
+        """Send the agent's analysis to a downstream system."""
+        print(f"Report: {analysis}")
+        return analysis
+
+    result = investigate("Summarize order trends for last quarter")
+    send_report(result)
+
+
+# [END howto_agent_chain]
+
+example_agent_chain()
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py 
b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
index 77a3c1b86c0..e5113d7fb3d 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
@@ -31,6 +31,7 @@ def get_provider_info():
                 "integration-name": "Common AI",
                 "external-doc-url": 
"https://airflow.apache.org/docs/apache-airflow-providers-common-ai/";,
                 "how-to-guide": [
+                    
"/docs/apache-airflow-providers-common-ai/operators/agent.rst",
                     
"/docs/apache-airflow-providers-common-ai/operators/llm.rst",
                     
"/docs/apache-airflow-providers-common-ai/operators/llm_branch.rst",
                     
"/docs/apache-airflow-providers-common-ai/operators/llm_sql.rst",
@@ -72,6 +73,7 @@ def get_provider_info():
             {
                 "integration-name": "Common AI",
                 "python-modules": [
+                    "airflow.providers.common.ai.operators.agent",
                     "airflow.providers.common.ai.operators.llm",
                     "airflow.providers.common.ai.operators.llm_branch",
                     "airflow.providers.common.ai.operators.llm_sql",
@@ -80,6 +82,7 @@ def get_provider_info():
             }
         ],
         "task-decorators": [
+            {"class-name": 
"airflow.providers.common.ai.decorators.agent.agent_task", "name": "agent"},
             {"class-name": 
"airflow.providers.common.ai.decorators.llm.llm_task", "name": "llm"},
             {
                 "class-name": 
"airflow.providers.common.ai.decorators.llm_branch.llm_branch_task",
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py 
b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
new file mode 100644
index 00000000000..ca4d61c86ec
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Operator for running pydantic-ai agents with tools and multi-turn 
reasoning."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from pydantic import BaseModel
+
+from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+from airflow.providers.common.compat.sdk import BaseOperator
+
+if TYPE_CHECKING:
+    from pydantic_ai import Agent
+    from pydantic_ai.toolsets.abstract import AbstractToolset
+
+    from airflow.sdk import Context
+
+
+class AgentOperator(BaseOperator):
+    """
+    Run a pydantic-ai Agent with tools and multi-turn reasoning.
+
+    Provide ``llm_conn_id`` and optional ``toolsets`` to let the operator build
+    and run the agent. The agent reasons about the prompt, calls tools in a
+    multi-turn loop, and returns a final answer.
+
+    :param prompt: The prompt to send to the agent.
+    :param llm_conn_id: Connection ID for the LLM provider.
+    :param model_id: Model identifier (e.g. ``"openai:gpt-5"``).
+        Overrides the model stored in the connection's extra field.
+    :param system_prompt: System-level instructions for the agent.
+    :param output_type: Expected output type. Default ``str``. Set to a 
Pydantic
+        ``BaseModel`` subclass for structured output.
+    :param toolsets: List of pydantic-ai toolsets the agent can use
+        (e.g. ``SQLToolset``, ``HookToolset``).
+    :param agent_params: Additional keyword arguments passed to the pydantic-ai
+        ``Agent`` constructor (e.g. ``retries``, ``model_settings``).
+    """
+
+    template_fields: Sequence[str] = (
+        "prompt",
+        "llm_conn_id",
+        "model_id",
+        "system_prompt",
+        "agent_params",
+    )
+
+    def __init__(
+        self,
+        *,
+        prompt: str,
+        llm_conn_id: str,
+        model_id: str | None = None,
+        system_prompt: str = "",
+        output_type: type = str,
+        toolsets: list[AbstractToolset] | None = None,
+        agent_params: dict[str, Any] | None = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(**kwargs)
+
+        self.prompt = prompt
+        self.llm_conn_id = llm_conn_id
+        self.model_id = model_id
+        self.system_prompt = system_prompt
+        self.output_type = output_type
+        self.toolsets = toolsets
+        self.agent_params = agent_params or {}
+
+    @cached_property
+    def llm_hook(self) -> PydanticAIHook:
+        """Return PydanticAIHook for the configured LLM connection."""
+        return PydanticAIHook(llm_conn_id=self.llm_conn_id, 
model_id=self.model_id)
+
+    def execute(self, context: Context) -> Any:
+        extra_kwargs = dict(self.agent_params)
+        if self.toolsets:
+            extra_kwargs["toolsets"] = self.toolsets
+        agent: Agent[None, Any] = self.llm_hook.create_agent(
+            output_type=self.output_type,
+            instructions=self.system_prompt,
+            **extra_kwargs,
+        )
+
+        result = agent.run_sync(self.prompt)
+        output = result.output
+
+        if isinstance(output, BaseModel):
+            return output.model_dump()
+        return output
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py 
b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
new file mode 100644
index 00000000000..99e2e35aafc
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.decorators.agent import 
_AgentDecoratedOperator
+
+
+class TestAgentDecoratedOperator:
+    def test_custom_operator_name(self):
+        assert _AgentDecoratedOperator.custom_operator_name == "@task.agent"
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_execute_calls_callable_and_returns_output(self, mock_hook_cls):
+        """The callable's return value becomes the agent prompt."""
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = "The top customer is Acme Corp."
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        def my_prompt():
+            return "Who is our top customer?"
+
+        op = _AgentDecoratedOperator(task_id="test", 
python_callable=my_prompt, llm_conn_id="my_llm")
+        result = op.execute(context={})
+
+        assert result == "The top customer is Acme Corp."
+        assert op.prompt == "Who is our top customer?"
+        mock_agent.run_sync.assert_called_once_with("Who is our top customer?")
+
+    @pytest.mark.parametrize(
+        "return_value",
+        [42, "", "   ", None],
+        ids=["non-string", "empty", "whitespace", "none"],
+    )
+    def test_execute_raises_on_invalid_prompt(self, return_value):
+        """TypeError when the callable returns a non-string or blank string."""
+        op = _AgentDecoratedOperator(
+            task_id="test",
+            python_callable=lambda: return_value,
+            llm_conn_id="my_llm",
+        )
+        with pytest.raises(TypeError, match="non-empty string"):
+            op.execute(context={})
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls):
+        """op_kwargs are resolved by the callable to build the prompt."""
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = "done"
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        def my_prompt(topic):
+            return f"Analyze {topic}"
+
+        op = _AgentDecoratedOperator(
+            task_id="test",
+            python_callable=my_prompt,
+            llm_conn_id="my_llm",
+            op_kwargs={"topic": "revenue trends"},
+        )
+        op.execute(context={"task_instance": MagicMock()})
+
+        assert op.prompt == "Analyze revenue trends"
+        mock_agent.run_sync.assert_called_once_with("Analyze revenue trends")
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_execute_passes_toolsets_through(self, mock_hook_cls):
+        """Toolsets passed to the decorator are forwarded to the agent."""
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = "result"
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        mock_toolset = MagicMock()
+
+        op = _AgentDecoratedOperator(
+            task_id="test",
+            python_callable=lambda: "Do something",
+            llm_conn_id="my_llm",
+            toolsets=[mock_toolset],
+        )
+        op.execute(context={})
+
+        create_call = mock_hook_cls.return_value.create_agent.call_args
+        assert create_call[1]["toolsets"] == [mock_toolset]
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_execute_structured_output(self, mock_hook_cls):
+        """BaseModel output is serialized with model_dump."""
+        from pydantic import BaseModel
+
+        class Summary(BaseModel):
+            text: str
+
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = Summary(text="Great results")
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        op = _AgentDecoratedOperator(
+            task_id="test",
+            python_callable=lambda: "Summarize",
+            llm_conn_id="my_llm",
+            output_type=Summary,
+        )
+        result = op.execute(context={})
+
+        assert result == {"text": "Great results"}
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
new file mode 100644
index 00000000000..3d949854189
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
@@ -0,0 +1,138 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+from pydantic import BaseModel
+
+from airflow.providers.common.ai.operators.agent import AgentOperator
+
+
+def _make_mock_agent(output):
+    """Create a mock agent that returns the given output."""
+    mock_result = MagicMock(spec=["output"])
+    mock_result.output = output
+    mock_agent = MagicMock(spec=["run_sync"])
+    mock_agent.run_sync.return_value = mock_result
+    return mock_agent
+
+
+class TestAgentOperatorValidation:
+    def test_requires_llm_conn_id(self):
+        with pytest.raises(TypeError):
+            AgentOperator(task_id="test", prompt="hello")
+
+
+class TestAgentOperatorTemplateFields:
+    def test_template_fields(self):
+        expected = {"prompt", "llm_conn_id", "model_id", "system_prompt", 
"agent_params"}
+        assert set(AgentOperator.template_fields) == expected
+
+
+class TestAgentOperatorExecute:
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_execute_creates_agent_from_hook(self, mock_hook_cls):
+        mock_agent = _make_mock_agent("The answer is 42.")
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        op = AgentOperator(
+            task_id="test",
+            prompt="What is the answer?",
+            llm_conn_id="my_llm",
+            system_prompt="You are helpful.",
+        )
+        result = op.execute(context=MagicMock())
+
+        assert result == "The answer is 42."
+        mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm", 
model_id=None)
+        mock_hook_cls.return_value.create_agent.assert_called_once_with(
+            output_type=str, instructions="You are helpful."
+        )
+        mock_agent.run_sync.assert_called_once_with("What is the answer?")
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_execute_passes_toolsets_in_agent_kwargs(self, mock_hook_cls):
+        """Toolsets are passed through to the agent constructor."""
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("done")
+
+        mock_toolset = MagicMock()
+        op = AgentOperator(
+            task_id="test",
+            prompt="Do something",
+            llm_conn_id="my_llm",
+            toolsets=[mock_toolset],
+        )
+        op.execute(context=MagicMock())
+
+        create_call = mock_hook_cls.return_value.create_agent.call_args
+        assert create_call[1]["toolsets"] == [mock_toolset]
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_execute_passes_agent_params(self, mock_hook_cls):
+        """agent_params are unpacked into create_agent."""
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("ok")
+
+        op = AgentOperator(
+            task_id="test",
+            prompt="test",
+            llm_conn_id="my_llm",
+            agent_params={"retries": 3, "model_settings": {"temperature": 0}},
+        )
+        op.execute(context=MagicMock())
+
+        create_call = mock_hook_cls.return_value.create_agent.call_args
+        assert create_call[1]["retries"] == 3
+        assert create_call[1]["model_settings"] == {"temperature": 0}
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_execute_structured_output(self, mock_hook_cls):
+        """Structured output via BaseModel is serialized with model_dump."""
+
+        class Summary(BaseModel):
+            text: str
+            score: float
+
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent(
+            Summary(text="Great", score=0.95)
+        )
+
+        op = AgentOperator(
+            task_id="test",
+            prompt="Analyze this",
+            llm_conn_id="my_llm",
+            output_type=Summary,
+        )
+        result = op.execute(context=MagicMock())
+
+        assert result == {"text": "Great", "score": 0.95}
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_execute_with_model_id(self, mock_hook_cls):
+        """model_id is passed to PydanticAIHook."""
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("ok")
+
+        op = AgentOperator(
+            task_id="test",
+            prompt="test",
+            llm_conn_id="my_llm",
+            model_id="openai:gpt-5",
+        )
+        op.execute(context=MagicMock())
+
+        mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm", 
model_id="openai:gpt-5")


Reply via email to