This is an automated email from the ASF dual-hosted git repository.

gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d349918fa4b SQLToolset: Retry model on query errors (#63358)
d349918fa4b is described below

commit d349918fa4b24d178b77761224260638ab59fe2e
Author: GPK <[email protected]>
AuthorDate: Thu Mar 12 07:54:40 2026 +0000

    SQLToolset: Retry model on query errors (#63358)
    
    * Add ModelRetry mechanism for sqltoolset to retry using RETRYABLE_ERRORS
    
    * Move SQL retry classification into SQLToolset and narrow retryable errors
    
    * Resolve comments
    
    * fixup tests
---
 .../airflow/providers/common/ai/toolsets/sql.py    |  55 +++++++-
 .../ai/tests/unit/common/ai/toolsets/test_sql.py   | 143 +++++++++++++++++++++
 2 files changed, 197 insertions(+), 1 deletion(-)

diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
index f60f4b621c3..0902cff99f2 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
@@ -19,6 +19,8 @@
 from __future__ import annotations
 
 import json
+import sqlite3
+from contextlib import suppress
 from typing import TYPE_CHECKING, Any
 
 try:
@@ -29,6 +31,7 @@ except ImportError as e:
 
     raise AirflowOptionalProviderFeatureException(e)
 
+from pydantic_ai.exceptions import ModelRetry
 from pydantic_ai.tools import ToolDefinition
 from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
 from pydantic_core import SchemaValidator, core_schema
@@ -70,6 +73,31 @@ _CHECK_QUERY_SCHEMA: dict[str, Any] = {
     "required": ["sql"],
 }
 
+_POSTGRES_RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = ()
+with suppress(ImportError):
+    import psycopg2.errors as _psycopg2_errors
+
+    _POSTGRES_RETRYABLE_EXCEPTIONS += (
+        _psycopg2_errors.UndefinedColumn,
+        _psycopg2_errors.UndefinedTable,
+    )
+
+with suppress(ImportError):
+    from psycopg import errors as _psycopg3_errors
+
+    _POSTGRES_RETRYABLE_EXCEPTIONS += (
+        _psycopg3_errors.UndefinedColumn,
+        _psycopg3_errors.UndefinedTable,
+    )
+
+_SQLALCHEMY_RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = ()
+with suppress(ImportError):
+    from sqlalchemy.exc import (
+        ProgrammingError as _SQLAlchemyProgrammingError,
+    )
+
+    _SQLALCHEMY_RETRYABLE_EXCEPTIONS = (_SQLAlchemyProgrammingError,)
+
 
 class SQLToolset(AbstractToolset[Any]):
     """
@@ -204,7 +232,14 @@ class SQLToolset(AbstractToolset[Any]):
             _validate_sql(sql)
 
         hook = self._get_db_hook()
-        rows = hook.get_records(sql)
+        try:
+            rows = hook.get_records(sql)
+        except Exception as e:
+            if self._is_retryable_query_error(hook, e):
+                raise ModelRetry(
+                    f"error: {e!s}, Use get_schema and list_tables tools for 
more details."
+                ) from e
+            raise
         # Fetch column names from cursor description.
         col_names: list[str] | None = None
         if hook.last_description:
@@ -223,6 +258,24 @@ class SQLToolset(AbstractToolset[Any]):
             output["max_rows"] = self._max_rows
         return json.dumps(output, default=str)
 
+    @staticmethod
+    def _is_retryable_query_error(hook: DbApiHook, error: Exception) -> bool:
+        check_error = getattr(error, "orig", error)
+        conn_type = getattr(hook, "conn_type", None)
+        if conn_type == "postgres":
+            return bool(_POSTGRES_RETRYABLE_EXCEPTIONS) and isinstance(
+                check_error, _POSTGRES_RETRYABLE_EXCEPTIONS
+            )
+        if conn_type == "sqlite":
+            if isinstance(check_error, sqlite3.OperationalError):
+                message = str(check_error).lower()
+                return "no such column" in message or "no such table" in 
message
+            return False
+        if _SQLALCHEMY_RETRYABLE_EXCEPTIONS and isinstance(error, 
_SQLALCHEMY_RETRYABLE_EXCEPTIONS):
+            return True
+        # TODO: Add support for other databases.
+        return False
+
     def _check_query(self, sql: str) -> str:
         try:
             _validate_sql(sql)
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py 
b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
index 0573acd2a77..471b956385d 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
@@ -17,10 +17,13 @@
 from __future__ import annotations
 
 import asyncio
+import importlib.util
 import json
+import sqlite3
 from unittest.mock import MagicMock, PropertyMock, patch
 
 import pytest
+from pydantic_ai.exceptions import ModelRetry
 
 from airflow.providers.common.ai.toolsets.sql import SQLToolset
 from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
@@ -165,6 +168,146 @@ class TestSQLToolsetQuery:
         data = json.loads(result)
         assert "rows" in data
 
+    def test_raises_model_retry_when_query_fails_with_retryable_error(self):
+        """When the query fails with a retryable error, raise ModelRetry so 
the model retries."""
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+        ts._hook.conn_type = "sqlite"
+        ts._hook.get_records.side_effect = sqlite3.OperationalError("no such 
column: nonexistent")
+
+        with pytest.raises(ModelRetry) as exc_info:
+            asyncio.run(
+                ts.call_tool(
+                    "query",
+                    {"sql": "SELECT id, nonexistent FROM users"},
+                    ctx=MagicMock(),
+                    tool=MagicMock(),
+                )
+            )
+        assert "nonexistent" in exc_info.value.message
+        assert "get_schema" in exc_info.value.message
+        assert "list_tables" in exc_info.value.message
+
+    def test_model_retry_message_includes_schema_hint(self):
+        """ModelRetry message tells the model to use get_schema and 
list_tables for more details."""
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+        ts._hook.conn_type = "sqlite"
+        ts._hook.get_records.side_effect = sqlite3.OperationalError("no such 
table: missing_table")
+
+        with pytest.raises(ModelRetry) as exc_info:
+            asyncio.run(
+                ts.call_tool("query", {"sql": "SELECT foo FROM x"}, 
ctx=MagicMock(), tool=MagicMock())
+            )
+        assert "get_schema" in exc_info.value.message
+        assert "list_tables" in exc_info.value.message
+
+    def test_non_retryable_error_is_propagated(self):
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+        ts._hook.conn_type = "sqlite"
+        ts._hook.get_records.side_effect = sqlite3.OperationalError("database 
is locked")
+
+        with pytest.raises(sqlite3.OperationalError, match="database is 
locked"):
+            asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"}, 
ctx=MagicMock(), tool=MagicMock()))
+
+    def test_error_propagates_when_hook_conn_type_not_supported(self):
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+        ts._hook.conn_type = "mysql"
+        ts._hook.get_records.side_effect = RuntimeError("unexpected db error")
+
+        with pytest.raises(RuntimeError, match="unexpected db error"):
+            asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"}, 
ctx=MagicMock(), tool=MagicMock()))
+
+    def test_error_propagates_when_hook_has_no_conn_type(self):
+        ts = SQLToolset("pg_default")
+        mock_hook = MagicMock(spec=["get_records", "last_description"])
+        mock_hook.get_records.side_effect = RuntimeError("hook error")
+        type(mock_hook).last_description = PropertyMock(return_value=[])
+        ts._hook = mock_hook
+
+        with pytest.raises(RuntimeError, match="hook error"):
+            asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"}, 
ctx=MagicMock(), tool=MagicMock()))
+
+    @pytest.mark.skipif(
+        importlib.util.find_spec("psycopg2") is None,
+        reason="psycopg2 is not available for lowest dependency tests",
+    )
+    def 
test_sqlalchemy_programming_error_with_psycopg2_undefined_column_orig_raises_model_retry_for_postgres(
+        self,
+    ):
+        from psycopg2 import errors as psycopg2_errors
+        from sqlalchemy.exc import ProgrammingError
+
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+        ts._hook.conn_type = "postgres"
+        ts._hook.get_records.side_effect = ProgrammingError(
+            statement="SELECT id, missing FROM users",
+            params=None,
+            orig=psycopg2_errors.UndefinedColumn('column "missing" does not 
exist'),
+        )
+
+        with (
+            patch(
+                
"airflow.providers.common.ai.toolsets.sql._POSTGRES_RETRYABLE_EXCEPTIONS",
+                (psycopg2_errors.UndefinedColumn,),
+            ),
+            patch(
+                
"airflow.providers.common.ai.toolsets.sql._SQLALCHEMY_RETRYABLE_EXCEPTIONS",
+                (ProgrammingError,),
+            ),
+            pytest.raises(ModelRetry),
+        ):
+            asyncio.run(
+                ts.call_tool(
+                    "query",
+                    {"sql": "SELECT id, missing FROM users"},
+                    ctx=MagicMock(),
+                    tool=MagicMock(),
+                )
+            )
+
+    @pytest.mark.skipif(
+        importlib.util.find_spec("psycopg2") is None,
+        reason="psycopg2 is not available for lowest dependency tests",
+    )
+    def 
test_sqlalchemy_programming_error_with_psycopg2_insufficient_privilege_orig_is_not_retried_for_postgres(
+        self,
+    ):
+        from psycopg2 import errors as psycopg2_errors
+        from sqlalchemy.exc import ProgrammingError
+
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+        ts._hook.conn_type = "postgres"
+        ts._hook.get_records.side_effect = ProgrammingError(
+            statement="SELECT id FROM users",
+            params=None,
+            orig=psycopg2_errors.InsufficientPrivilege("permission denied for 
table users"),
+        )
+
+        with (
+            patch(
+                
"airflow.providers.common.ai.toolsets.sql._POSTGRES_RETRYABLE_EXCEPTIONS",
+                (psycopg2_errors.UndefinedColumn, 
psycopg2_errors.UndefinedTable),
+            ),
+            patch(
+                
"airflow.providers.common.ai.toolsets.sql._SQLALCHEMY_RETRYABLE_EXCEPTIONS",
+                (ProgrammingError,),
+            ),
+            pytest.raises(ProgrammingError),
+        ):
+            asyncio.run(
+                ts.call_tool(
+                    "query",
+                    {"sql": "SELECT id FROM users"},
+                    ctx=MagicMock(),
+                    tool=MagicMock(),
+                )
+            )
+
 
 class TestSQLToolsetCheckQuery:
     def test_valid_select(self):

Reply via email to