This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c0911ae7571 Add `LLMRetryPolicy` to common-ai provider (#65451)
c0911ae7571 is described below

commit c0911ae7571297e45ea7f36f827a4ad5fbccb2d8
Author: Kaxil Naik <[email protected]>
AuthorDate: Mon May 18 18:27:10 2026 +0100

    Add `LLMRetryPolicy` to common-ai provider (#65451)
    
    Uses PydanticAIHook to call any LLM for error classification with
    structured output. Timeout via pydantic-ai ModelSettings (default 30s).
    Falls back to declarative fallback_rules when LLM call fails.
    Gated for Airflow 3.3+. RST docs, example DAG, 12 tests.
    
    Add a Custom Instructions section showing how to override the default
    classifier prompt with domain-specific guidance. Uses Snowflake as the
    example since its transient errors (Warehouse suspended, JWT expired,
    Statement queued) need backend-specific knowledge to classify correctly.
    
    * Fix LLMRetryPolicy CI: mypy, ruff, spellcheck, compat tests
    
    - Type-annotate llm_policy as Optional in the example DAG so the
      fallback branch satisfies mypy on Airflow versions without RetryPolicy.
    - Wrap pydanticai* and claude-haiku* connection examples in
      double-backticks so RST autoapi treats them as code (skips spellcheck).
    - Gate test_retry.py with pytest.importorskip so compat-3.2.1 CI on
      older Airflow versions skips the module instead of erroring on import.
    - Apply ruff and ruff-format auto-fixes.
    
    * Fix LLMRetryPolicy CI: license header, sphinx duplicate llm_policy
    
    - Add Apache license header to the empty policies test __init__.py
      (insert-license + end-of-file-fixer were failing on it).
    - Move the @dag definition inside the try block in the example DAG so
      llm_policy is only assigned in one branch. Sphinx autoapi was
      treating the upfront None declaration plus the in-try LLMRetryPolicy
      assignment as two separate object descriptions, failing the docs
      build with a duplicate-object warning (treated as error). Without
      the upfront declaration, mypy is happy because the variable only
      exists when the import succeeds.
---
 providers/common/ai/docs/index.rst                 |   1 +
 providers/common/ai/docs/retry_policies.rst        | 170 ++++++++++++++++++
 .../ai/example_dags/example_llm_retry_policy.py    |  72 ++++++++
 .../providers/common/ai/policies/__init__.py       |  16 ++
 .../airflow/providers/common/ai/policies/retry.py  | 183 +++++++++++++++++++
 .../ai/tests/unit/common/ai/policies/__init__.py   |  16 ++
 .../ai/tests/unit/common/ai/policies/test_retry.py | 197 +++++++++++++++++++++
 7 files changed, 655 insertions(+)

diff --git a/providers/common/ai/docs/index.rst 
b/providers/common/ai/docs/index.rst
index e96ba4cfd27..a5cd4196f7a 100644
--- a/providers/common/ai/docs/index.rst
+++ b/providers/common/ai/docs/index.rst
@@ -39,6 +39,7 @@
     Hooks <hooks/pydantic_ai>
     Toolsets <toolsets>
     Operators <operators/index>
+    Retry Policies <retry_policies>
     HITL Review <hitl_review>
 
 .. toctree::
diff --git a/providers/common/ai/docs/retry_policies.rst 
b/providers/common/ai/docs/retry_policies.rst
new file mode 100644
index 00000000000..036bc8e5908
--- /dev/null
+++ b/providers/common/ai/docs/retry_policies.rst
@@ -0,0 +1,170 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+LLM Retry Policies
+===================
+
+.. versionadded:: 3.3.0
+
+The ``LLMRetryPolicy`` uses an LLM to classify task errors and make intelligent
+retry decisions. It works with any LLM provider supported by pydantic-ai
+(OpenAI, Anthropic, Bedrock, Vertex, Ollama, etc.).
+
+For the core retry policy concepts, see 
:doc:`apache-airflow:core-concepts/tasks`.
+
+Setup
+-----
+
+1. Install the provider with the LLM backend you need:
+
+   .. code-block:: bash
+
+       pip install 'apache-airflow-providers-common-ai[anthropic]'
+
+2. Create a connection (``Admin > Connections``):
+
+   - **Connection Id**: ``pydanticai_default``
+   - **Connection Type**: ``Pydantic AI``
+   - **Password**: Your API key
+   - **Extra**: ``{"model": "anthropic:claude-haiku-4-5-20251001"}``
+
+Usage
+-----
+
+.. code-block:: python
+
+    from airflow.providers.common.ai.policies.retry import LLMRetryPolicy
+    from airflow.sdk.definitions.retry_policy import RetryAction, RetryRule
+    from datetime import timedelta
+
+    llm_policy = LLMRetryPolicy(
+        llm_conn_id="pydanticai_default",
+        timeout=30.0,  # max seconds to wait for LLM response
+        fallback_rules=[  # used when LLM call fails
+            RetryRule(exception=ConnectionError, action=RetryAction.RETRY, 
retry_delay=timedelta(seconds=10)),
+            RetryRule(exception=PermissionError, action=RetryAction.FAIL),
+        ],
+    )
+
+
+    @task(retries=5, retry_policy=llm_policy)
+    def call_external_api(): ...
+
+How it works
+------------
+
+When a task fails, ``LLMRetryPolicy``:
+
+1. Sends the exception message to the configured LLM
+2. The LLM classifies the error into a category (``rate_limit``, ``auth``,
+   ``network``, ``data``, ``transient``, ``permanent``)
+3. Based on the classification, returns RETRY (with a suggested delay) or FAIL
+4. The classification reason is logged in the task logs
+
+If the LLM call fails (provider down, timeout, bad credentials), the policy
+falls back to ``fallback_rules`` if configured, or to the task's standard
+retry behaviour.
+
+Custom instructions
+-------------------
+
+The default classifier handles generic categories. For domain-specific
+behaviour, override ``instructions`` to inject your own taxonomy. The LLM still
+returns an 
:class:`~airflow.providers.common.ai.policies.retry.ErrorClassification`
+(``category``, ``should_retry``, ``suggested_delay_seconds``, ``reasoning``)
+-- only the prompt changes.
+
+.. code-block:: python
+
+    SNOWFLAKE_INSTRUCTIONS = (
+        "You are an error classifier for Snowflake-backed data pipelines. "
+        "Classify the error into one of: rate_limit, auth, network, data, "
+        "transient, permanent.\n\n"
+        "Snowflake-specific guidance:\n"
+        "- 'Statement queued' or 'concurrency limit' -> rate_limit, retry 
after 120s\n"
+        "- 'JWT token expired' -> transient (token rotates), retry after 30s\n"
+        "- 'Authentication token has expired' AFTER multiple retries -> auth, 
do NOT retry\n"
+        "- 'Column does not exist' -> data, do NOT retry (schema drift needs 
human fix)\n"
+        "- 'Warehouse suspended' -> transient, retry after 30s 
(auto-resume)\n\n"
+        "Set suggested_delay_seconds based on the error type. "
+        "Set 0 for errors that should not retry."
+    )
+
+    snowflake_policy = LLMRetryPolicy(
+        llm_conn_id="pydanticai_default",
+        instructions=SNOWFLAKE_INSTRUCTIONS,
+        fallback_rules=[
+            RetryRule(
+                exception=ConnectionError,
+                action=RetryAction.RETRY,
+                retry_delay=timedelta(seconds=30),
+            ),
+        ],
+    )
+
+
+    @task(retries=5, retry_policy=snowflake_policy)
+    def query_snowflake(): ...
+
+When writing custom instructions:
+
+- The LLM must return the same ``ErrorClassification`` schema (``category``,
+  ``should_retry``, ``suggested_delay_seconds``, ``reasoning``). Mention the
+  fields explicitly so the model fills them.
+- Be concrete with examples (``"'Warehouse suspended' -> transient"``) rather
+  than vague rules ("treat warehouse issues as recoverable").
+- ``retry_reason`` is truncated to 500 chars in the audit log -- keep
+  ``reasoning`` outputs concise.
+
+Parameters
+----------
+
+.. list-table::
+   :header-rows: 1
+   :widths: 20 15 65
+
+   * - Parameter
+     - Default
+     - Description
+   * - ``llm_conn_id``
+     - (required)
+     - Airflow connection ID for the LLM provider.
+   * - ``model_id``
+     - None
+     - Override the model from the connection (e.g., ``"openai:gpt-4o-mini"``).
+   * - ``instructions``
+     - (built-in)
+     - Custom system prompt for error classification.
+   * - ``fallback_rules``
+     - None
+     - List of ``RetryRule`` objects used when the LLM call fails.
+   * - ``timeout``
+     - 30.0
+     - Max seconds to wait for the LLM response before falling back.
+
+Local LLM support
+-----------------
+
+For environments where exception data must not leave the infrastructure, point
+to a local model via Ollama or vLLM:
+
+.. code-block:: python
+
+    LLMRetryPolicy(
+        llm_conn_id="ollama_local",  # host=http://localhost:11434
+        model_id="ollama:llama3.2",
+    )
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_retry_policy.py
 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_retry_policy.py
new file mode 100644
index 00000000000..bdd2528dd28
--- /dev/null
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_retry_policy.py
@@ -0,0 +1,72 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example DAG demonstrating LLM-powered retry policies.
+
+Uses an LLM (via PydanticAIHook) to classify errors and decide whether
+to retry, fail immediately, or retry with a custom delay.
+
+Prerequisites:
+  - Connection ``pydanticai_default`` with ``conn_type='pydanticai'``,
+    ``password=<API key>``, ``extra='{"model": 
"anthropic:claude-haiku-4-5-20251001"}'``
+  - ``pip install apache-airflow-providers-common-ai[anthropic]``
+"""
+
+from __future__ import annotations
+
+from datetime import timedelta
+
+from airflow.providers.common.compat.sdk import dag, task
+
+try:
+    from airflow.providers.common.ai.policies.retry import LLMRetryPolicy
+    from airflow.sdk.definitions.retry_policy import RetryAction, RetryRule
+
+    llm_policy = LLMRetryPolicy(
+        llm_conn_id="pydanticai_default",
+        timeout=30.0,
+        fallback_rules=[
+            RetryRule(exception=ConnectionError, action=RetryAction.RETRY, 
retry_delay=timedelta(seconds=10)),
+            RetryRule(exception=PermissionError, action=RetryAction.FAIL),
+        ],
+    )
+
+    @dag(catchup=False, tags=["example", "retry_policy", "llm"])
+    def example_llm_retry_policy():
+        @task(retries=3, retry_delay=timedelta(minutes=1), 
retry_policy=llm_policy)
+        def task_auth_error():
+            """LLM should classify as auth -> FAIL immediately."""
+            raise PermissionError("403 Forbidden: API key expired for service 
account [email protected]")
+
+        @task(retries=3, retry_delay=timedelta(minutes=1), 
retry_policy=llm_policy)
+        def task_rate_limit():
+            """LLM should classify as rate_limit -> RETRY with ~60s delay."""
+            raise RuntimeError("429 Too Many Requests: Rate limit exceeded. 
Retry after 60 seconds.")
+
+        @task(retries=3, retry_delay=timedelta(minutes=1), 
retry_policy=llm_policy)
+        def task_data_error():
+            """LLM should classify as data -> FAIL immediately."""
+            raise ValueError("Column 'user_id' expected type INT but got 
STRING in row 42.")
+
+        task_auth_error()
+        task_rate_limit()
+        task_data_error()
+
+    example_llm_retry_policy()
+except ImportError:
+    # RetryPolicy requires Airflow 3.3+; example DAG is skipped on older 
versions.
+    pass
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/policies/__init__.py 
b/providers/common/ai/src/airflow/providers/common/ai/policies/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/policies/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/policies/retry.py 
b/providers/common/ai/src/airflow/providers/common/ai/policies/retry.py
new file mode 100644
index 00000000000..f92e4e0d64f
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/policies/retry.py
@@ -0,0 +1,183 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+LLM-powered retry policy using pydantic-ai for error classification.
+
+Requires Airflow 3.3+ (RetryPolicy was added in AIP-105).
+"""
+
+from __future__ import annotations
+
+import logging
+from datetime import timedelta
+from typing import TYPE_CHECKING
+
+from pydantic import BaseModel
+
+try:
+    from airflow.sdk.definitions.retry_policy import (
+        ExceptionRetryPolicy,
+        RetryDecision,
+        RetryPolicy,
+    )
+except ImportError:
+    raise ImportError(
+        "LLMRetryPolicy requires Airflow 3.3+ which includes RetryPolicy 
support. "
+        "Please upgrade apache-airflow-core."
+    ) from None
+
+if TYPE_CHECKING:
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.definitions.retry_policy import RetryRule
+
+log = logging.getLogger(__name__)
+
+__all__ = ["ErrorClassification", "LLMRetryPolicy"]
+
+DEFAULT_INSTRUCTIONS = (
+    "You are an error classifier for a data pipeline system. "
+    "Given an error message from a failed task, classify it into one of these 
categories:\n\n"
+    "- rate_limit: API throttling or quota exceeded. Should retry after a 
delay.\n"
+    "- auth: Credentials invalid, expired, or missing permissions. Should NOT 
retry.\n"
+    "- network: Transient connectivity issue. Should retry quickly.\n"
+    "- data: Schema validation, type mismatch, or bad input data. Should NOT 
retry.\n"
+    "- resource: Resource not found or unavailable (e.g., missing table, 
bucket). Should NOT retry.\n"
+    "- transient: Temporary issue likely to resolve on its own. Should 
retry.\n"
+    "- permanent: Problem that won't resolve without code or config changes. 
Should NOT retry.\n\n"
+    "Set suggested_delay_seconds based on the error type: "
+    "60 for rate limits, 10 for network, 30 for transient. "
+    "Set 0 for errors that should not retry."
+)
+
+
+class ErrorClassification(BaseModel):
+    """Structured LLM output for error classification."""
+
+    category: str
+    """One of: rate_limit, auth, network, data, resource, transient, 
permanent."""
+    should_retry: bool
+    """Whether the operation should be retried."""
+    suggested_delay_seconds: int = 0
+    """How long to wait before retrying (0 if should_retry is False)."""
+    reasoning: str
+    """Brief explanation of the classification decision."""
+
+
+class LLMRetryPolicy(RetryPolicy):
+    """
+    Retry policy that uses an LLM to classify errors and decide retry 
behaviour.
+
+    Uses :class:`~airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook`
+    to call any configured LLM provider (OpenAI, Anthropic, Bedrock, Vertex,
+    Ollama, etc.) for error classification with structured output.
+
+    When the LLM call itself fails, the policy falls back to ``fallback_rules``
+    (if provided) or returns DEFAULT to use the task's standard retry logic.
+
+    :param llm_conn_id: Airflow connection ID for the LLM provider.
+    :param model_id: Model identifier override (e.g. ``"openai:gpt-4o-mini"``
+        for cost efficiency). If not set, uses the model from the connection.
+    :param instructions: Custom system prompt for classification.
+        Defaults to a general-purpose error classifier.
+    :param fallback_rules: Optional list of
+        :class:`~airflow.sdk.definitions.retry_policy.RetryRule` applied when 
the
+        LLM call fails. Provides a deterministic safety net.
+    :param timeout: Maximum seconds to wait for the LLM response before
+        falling back.  Defaults to 30s.  The LLM provider's own timeout
+        (e.g. 600s for Anthropic) is much longer; this keeps the retry
+        decision path fast even when the provider is degraded.
+    """
+
+    def __init__(
+        self,
+        llm_conn_id: str,
+        model_id: str | None = None,
+        instructions: str | None = None,
+        fallback_rules: list[RetryRule] | None = None,
+        timeout: float = 30.0,
+    ) -> None:
+        self.llm_conn_id = llm_conn_id
+        self.model_id = model_id
+        self.instructions = instructions or DEFAULT_INSTRUCTIONS
+        self.fallback_rules = fallback_rules
+        self.timeout = timeout
+
+    def evaluate(
+        self,
+        exception: BaseException,
+        try_number: int,
+        max_tries: int,
+        context: Context | None = None,
+    ) -> RetryDecision:
+        try:
+            return self._classify(exception, try_number, max_tries)
+        except Exception:
+            log.exception("LLM retry classification failed, using fallback")
+            if self.fallback_rules:
+                return 
ExceptionRetryPolicy(rules=self.fallback_rules).evaluate(
+                    exception, try_number, max_tries, context
+                )
+            return RetryDecision.default()
+
+    def _classify(
+        self,
+        exception: BaseException,
+        try_number: int,
+        max_tries: int,
+    ) -> RetryDecision:
+        from airflow.providers.common.ai.hooks.pydantic_ai import 
PydanticAIHook
+
+        hook = PydanticAIHook(llm_conn_id=self.llm_conn_id, 
model_id=self.model_id)
+        agent = hook.create_agent(
+            output_type=ErrorClassification,
+            instructions=self.instructions,
+        )
+
+        prompt = (
+            f"Classify this error from a data pipeline task "
+            f"(attempt {try_number} of {max_tries}):\n\n"
+            f"{type(exception).__name__}: {exception}"
+        )
+
+        from pydantic_ai.settings import ModelSettings
+
+        result = agent.run_sync(
+            prompt,
+            model_settings=ModelSettings(timeout=self.timeout),
+        )
+        classification = result.output
+
+        log.info(
+            "LLM error classification: category=%s, should_retry=%s, 
delay=%ds, reasoning=%s",
+            classification.category,
+            classification.should_retry,
+            classification.suggested_delay_seconds,
+            classification.reasoning,
+        )
+
+        if not classification.should_retry:
+            return RetryDecision.fail(reason=f"{classification.category}: 
{classification.reasoning}")
+
+        delay = (
+            timedelta(seconds=classification.suggested_delay_seconds)
+            if classification.suggested_delay_seconds > 0
+            else None
+        )
+        return RetryDecision.retry(
+            delay=delay,
+            reason=f"{classification.category}: {classification.reasoning}",
+        )
diff --git a/providers/common/ai/tests/unit/common/ai/policies/__init__.py 
b/providers/common/ai/tests/unit/common/ai/policies/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/policies/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/common/ai/tests/unit/common/ai/policies/test_retry.py 
b/providers/common/ai/tests/unit/common/ai/policies/test_retry.py
new file mode 100644
index 00000000000..6f9d976d6f1
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/policies/test_retry.py
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import timedelta
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+# LLMRetryPolicy depends on the RetryPolicy ABC introduced in Airflow 3.3 
(AIP-105).
+# Skip the entire test module on older Airflow versions tested in compat CI.
+pytest.importorskip("airflow.sdk.definitions.retry_policy", 
reason="RetryPolicy requires Airflow 3.3+")
+
+from airflow.providers.common.ai.policies.retry import (
+    ErrorClassification,
+    LLMRetryPolicy,
+)
+from airflow.sdk.definitions.retry_policy import RetryAction, RetryRule
+
+
+def _make_mock_agent(category, should_retry, delay=0, reasoning="test"):
+    """Create a mock agent that returns a canned ErrorClassification."""
+    mock_result = MagicMock()
+    mock_result.output = ErrorClassification(
+        category=category,
+        should_retry=should_retry,
+        suggested_delay_seconds=delay,
+        reasoning=reasoning,
+    )
+    mock_agent = MagicMock()
+    mock_agent.run_sync.return_value = mock_result
+    return mock_agent
+
+
+class TestLLMClassifyDecisions:
+    """Test that _classify maps LLM classification to correct 
RetryDecisions."""
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook", 
autospec=True)
+    def test_auth_error_returns_fail(self, mock_hook_cls):
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent(
+            "auth", should_retry=False, reasoning="API key expired"
+        )
+        policy = LLMRetryPolicy(llm_conn_id="test")
+        decision = policy.evaluate(PermissionError("403"), try_number=1, 
max_tries=3)
+
+        assert decision.action == RetryAction.FAIL
+        assert "auth" in decision.reason
+        assert "API key expired" in decision.reason
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook", 
autospec=True)
+    def test_rate_limit_returns_retry_with_delay(self, mock_hook_cls):
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent(
+            "rate_limit", should_retry=True, delay=60, reasoning="429"
+        )
+        policy = LLMRetryPolicy(llm_conn_id="test")
+        decision = policy.evaluate(RuntimeError("429"), try_number=1, 
max_tries=3)
+
+        assert decision.action == RetryAction.RETRY
+        assert decision.retry_delay == timedelta(seconds=60)
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook", 
autospec=True)
+    def test_transient_retry_with_zero_delay_uses_default(self, mock_hook_cls):
+        """suggested_delay_seconds=0 means use the task's default delay, not 
override."""
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent(
+            "transient", should_retry=True, delay=0
+        )
+        policy = LLMRetryPolicy(llm_conn_id="test")
+        decision = policy.evaluate(RuntimeError("glitch"), try_number=1, 
max_tries=3)
+
+        assert decision.action == RetryAction.RETRY
+        assert decision.retry_delay is None  # None = use task's default
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook", 
autospec=True)
+    def test_negative_delay_treated_as_no_override(self, mock_hook_cls):
+        """Negative delay from LLM should not produce a negative timedelta."""
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent(
+            "transient", should_retry=True, delay=-5
+        )
+        policy = LLMRetryPolicy(llm_conn_id="test")
+        decision = policy.evaluate(RuntimeError("x"), try_number=1, 
max_tries=3)
+
+        assert decision.action == RetryAction.RETRY
+        assert decision.retry_delay is None
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook", 
autospec=True)
+    def test_prompt_includes_exception_type_and_message(self, mock_hook_cls):
+        mock_agent = _make_mock_agent("data", should_retry=False)
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        policy = LLMRetryPolicy(llm_conn_id="test")
+        policy.evaluate(ValueError("bad column type"), try_number=2, 
max_tries=5)
+
+        prompt = mock_agent.run_sync.call_args[0][0]
+        assert "ValueError: bad column type" in prompt
+        assert "attempt 2 of 5" in prompt
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook", 
autospec=True)
+    def test_custom_instructions_forwarded_to_agent(self, mock_hook_cls):
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("x", False)
+
+        policy = LLMRetryPolicy(llm_conn_id="test", instructions="My custom 
prompt")
+        policy.evaluate(ValueError("x"), try_number=1, max_tries=3)
+
+        mock_hook_cls.return_value.create_agent.assert_called_once_with(
+            output_type=ErrorClassification,
+            instructions="My custom prompt",
+        )
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook", 
autospec=True)
+    def test_timeout_passed_via_model_settings(self, mock_hook_cls):
+        mock_agent = _make_mock_agent("auth", False)
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        policy = LLMRetryPolicy(llm_conn_id="test", timeout=15.0)
+        policy.evaluate(ValueError("x"), try_number=1, max_tries=3)
+
+        model_settings = mock_agent.run_sync.call_args.kwargs["model_settings"]
+        assert model_settings["timeout"] == 15.0
+
+
+class TestLLMFallbackBehaviour:
+    """Test fallback when the LLM call itself fails."""
+
+    def test_falls_back_to_rules_when_connection_missing(self):
+        policy = LLMRetryPolicy(
+            llm_conn_id="nonexistent",
+            fallback_rules=[
+                RetryRule(
+                    exception=ConnectionError, action=RetryAction.RETRY, 
retry_delay=timedelta(seconds=10)
+                ),
+                RetryRule(exception=PermissionError, action=RetryAction.FAIL, 
reason="auth fallback"),
+            ],
+        )
+        d = policy.evaluate(ConnectionError("refused"), try_number=1, 
max_tries=3)
+        assert d.action == RetryAction.RETRY
+        assert d.retry_delay == timedelta(seconds=10)
+
+        d = policy.evaluate(PermissionError("denied"), try_number=1, 
max_tries=3)
+        assert d.action == RetryAction.FAIL
+
+    def test_falls_back_to_default_when_no_rules(self):
+        policy = LLMRetryPolicy(llm_conn_id="nonexistent")
+        d = policy.evaluate(ValueError("bad"), try_number=1, max_tries=3)
+        assert d.action == RetryAction.DEFAULT
+
+    def test_fallback_rules_no_match_returns_default(self):
+        """When fallback rules exist but none match, DEFAULT is returned."""
+        policy = LLMRetryPolicy(
+            llm_conn_id="nonexistent",
+            fallback_rules=[
+                RetryRule(exception=PermissionError, action=RetryAction.FAIL),
+            ],
+        )
+        # ValueError doesn't match the PermissionError rule
+        d = policy.evaluate(ValueError("bad"), try_number=1, max_tries=3)
+        assert d.action == RetryAction.DEFAULT
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook", 
autospec=True)
+    def test_agent_run_sync_failure_triggers_fallback(self, mock_hook_cls):
+        """Failure during run_sync (not hook creation) still triggers 
fallback."""
+        mock_agent = MagicMock()
+        mock_agent.run_sync.side_effect = RuntimeError("network error 
mid-call")
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        policy = LLMRetryPolicy(
+            llm_conn_id="test",
+            fallback_rules=[RetryRule(exception=ValueError, 
action=RetryAction.FAIL, reason="fallback")],
+        )
+        d = policy.evaluate(ValueError("x"), try_number=1, max_tries=3)
+        assert d.action == RetryAction.FAIL
+        assert d.reason == "fallback"
+
+    @patch("airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook", 
autospec=True)
+    def test_hook_creation_failure_triggers_fallback(self, mock_hook_cls):
+        """Failure during hook.create_agent still triggers fallback."""
+        mock_hook_cls.return_value.create_agent.side_effect = 
RuntimeError("unexpected")
+
+        policy = LLMRetryPolicy(
+            llm_conn_id="test",
+            fallback_rules=[RetryRule(exception=ValueError, 
action=RetryAction.FAIL, reason="caught")],
+        )
+        d = policy.evaluate(ValueError("x"), try_number=1, max_tries=3)
+        assert d.action == RetryAction.FAIL

Reply via email to