kaxil commented on code in PR #63358:
URL: https://github.com/apache/airflow/pull/63358#discussion_r2921568006


##########
providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py:
##########
@@ -165,6 +167,68 @@ def test_allows_writes_when_enabled(self):
         data = json.loads(result)
         assert "rows" in data
 
+    def test_raises_model_retry_when_query_fails_with_retryable_error(self):
+        """When the query fails with a retryable error, raise ModelRetry so 
the model retries."""
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+        ts._hook.conn_type = "sqlite"
+        ts._hook.get_records.side_effect = sqlite3.OperationalError("no such 
column: nonexistent")
+
+        with pytest.raises(ModelRetry) as exc_info:
+            asyncio.run(
+                ts.call_tool(
+                    "query",
+                    {"sql": "SELECT id, nonexistent FROM users"},
+                    ctx=MagicMock(),
+                    tool=MagicMock(),
+                )
+            )
+        assert "nonexistent" in exc_info.value.message
+        assert "get_schema" in exc_info.value.message
+        assert "list_tables" in exc_info.value.message
+
+    def test_model_retry_message_includes_schema_hint(self):
+        """ModelRetry message tells the model to use get_schema and 
list_tables for more details."""
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+        ts._hook.conn_type = "sqlite"
+        ts._hook.get_records.side_effect = sqlite3.OperationalError("no such 
table: missing_table")
+
+        with pytest.raises(ModelRetry) as exc_info:
+            asyncio.run(
+                ts.call_tool("query", {"sql": "SELECT foo FROM x"}, 
ctx=MagicMock(), tool=MagicMock())
+            )
+        assert "get_schema" in exc_info.value.message
+        assert "list_tables" in exc_info.value.message
+
+    def test_non_retryable_error_is_propagated(self):
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+        ts._hook.conn_type = "sqlite"
+        ts._hook.get_records.side_effect = sqlite3.OperationalError("database 
is locked")
+
+        with pytest.raises(sqlite3.OperationalError, match="database is 
locked"):
+            asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"}, 
ctx=MagicMock(), tool=MagicMock()))
+
+    def test_error_propagates_when_hook_conn_type_not_supported(self):
+        ts = SQLToolset("pg_default")
+        ts._hook = _make_mock_db_hook()
+        ts._hook.conn_type = "mysql"
+        ts._hook.get_records.side_effect = RuntimeError("unexpected db error")
+
+        with pytest.raises(RuntimeError, match="unexpected db error"):
+            asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"}, 
ctx=MagicMock(), tool=MagicMock()))
+
+    def test_error_propagates_when_hook_has_no_conn_type(self):
+        ts = SQLToolset("pg_default")
+        mock_hook = MagicMock(spec=["get_records", "last_description"])
+        mock_hook.get_records.side_effect = RuntimeError("hook error")
+        type(mock_hook).last_description = PropertyMock(return_value=[])
+        ts._hook = mock_hook
+
+        with pytest.raises(RuntimeError, match="hook error"):
+            asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"}, 
ctx=MagicMock(), tool=MagicMock()))
+
 
 class TestSQLToolsetCheckQuery:
     def test_valid_select(self):

Review Comment:
   All five tests use either `sqlite3.OperationalError` or `RuntimeError`. The 
SQLAlchemy `ProgrammingError` path and the psycopg2 
`UndefinedColumn`/`UndefinedTable` path are untested — and those are the most 
common paths in production (Airflow hooks use SQLAlchemy by default).
   
   Worth adding at least:
   1. A test with a mock `sqlalchemy.exc.ProgrammingError` (with `.orig` set to 
a `psycopg2.errors.UndefinedColumn`) to exercise the SQLAlchemy wrapping path.
   2. A test with a non-retryable `sqlalchemy.exc.ProgrammingError` (e.g., 
`.orig` is `InsufficientPrivilege`) to confirm it's NOT retried after the 
`orig` fix above.



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py:
##########
@@ -223,6 +257,23 @@ def _query(self, sql: str) -> str:
             output["max_rows"] = self._max_rows
         return json.dumps(output, default=str)
 

Review Comment:
   The SQLAlchemy `ProgrammingError` check runs first and short-circuits before 
the per-DB narrowing. In practice, `DbApiHook.get_records()` uses SQLAlchemy, 
so all database errors arrive wrapped — a `psycopg2.errors.UndefinedColumn` 
becomes `sqlalchemy.exc.ProgrammingError`, which is caught here. But so does 
`psycopg2.errors.InsufficientPrivilege` (also class 42 → `ProgrammingError`), 
which can't be fixed by rewriting SQL.
   
   The careful psycopg2 `UndefinedColumn`/`UndefinedTable` narrowing never 
fires when SQLAlchemy is in play, which is the common case.
   
   Fix: unwrap via `orig` first, then check DB-specific lists:
   
   ```python
   @staticmethod
   def _is_retryable_query_error(hook: DbApiHook, error: Exception) -> bool:
       # Unwrap SQLAlchemy wrapper to check the original DB-specific error
       check_error = getattr(error, "orig", error)
   
       conn_type = getattr(hook, "conn_type", None)
       if conn_type == "postgres":
           return bool(_POSTGRES_RETRYABLE_EXCEPTIONS) and isinstance(
               check_error, _POSTGRES_RETRYABLE_EXCEPTIONS
           )
       if conn_type == "sqlite":
           import sqlite3
   
           if isinstance(check_error, sqlite3.OperationalError):
               message = str(check_error).lower()
               return "no such column" in message or "no such table" in message
           return False
   
       # Fallback for unsupported DBs: trust SQLAlchemy's classification
       if _SQLALCHEMY_RETRYABLE_EXCEPTIONS and isinstance(error, 
_SQLALCHEMY_RETRYABLE_EXCEPTIONS):
           return True
       return False
   ```
   
   This way postgres/sqlite get the narrow checks even through SQLAlchemy, and 
unknown DBs still get the broad fallback.



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py:
##########
@@ -223,6 +257,23 @@ def _query(self, sql: str) -> str:
             output["max_rows"] = self._max_rows
         return json.dumps(output, default=str)
 
+    @staticmethod
+    def _is_retryable_query_error(hook: DbApiHook, error: Exception) -> bool:
+        if _SQLALCHEMY_RETRYABLE_EXCEPTIONS and isinstance(error, 
_SQLALCHEMY_RETRYABLE_EXCEPTIONS):
+            return True
+        conn_type = getattr(hook, "conn_type", None)
+        if conn_type == "postgres":
+            return bool(_POSTGRES_RETRYABLE_EXCEPTIONS) and isinstance(error, 
_POSTGRES_RETRYABLE_EXCEPTIONS)
+        if conn_type == "sqlite":
+            with suppress(ImportError):

Review Comment:
   Nit: `sqlite3` is a stdlib module — it's always available, so the 
`suppress(ImportError)` wrapper is unnecessary. Fine to just `import sqlite3` 
at module level alongside the psycopg2/psycopg3 imports, or inline without the 
guard.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to