kaxil commented on code in PR #63358:
URL: https://github.com/apache/airflow/pull/63358#discussion_r2921568006
##########
providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py:
##########
@@ -165,6 +167,68 @@ def test_allows_writes_when_enabled(self):
data = json.loads(result)
assert "rows" in data
+ def test_raises_model_retry_when_query_fails_with_retryable_error(self):
+ """When the query fails with a retryable error, raise ModelRetry so
the model retries."""
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.conn_type = "sqlite"
+ ts._hook.get_records.side_effect = sqlite3.OperationalError("no such
column: nonexistent")
+
+ with pytest.raises(ModelRetry) as exc_info:
+ asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "SELECT id, nonexistent FROM users"},
+ ctx=MagicMock(),
+ tool=MagicMock(),
+ )
+ )
+ assert "nonexistent" in exc_info.value.message
+ assert "get_schema" in exc_info.value.message
+ assert "list_tables" in exc_info.value.message
+
+ def test_model_retry_message_includes_schema_hint(self):
+ """ModelRetry message tells the model to use get_schema and
list_tables for more details."""
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.conn_type = "sqlite"
+ ts._hook.get_records.side_effect = sqlite3.OperationalError("no such
table: missing_table")
+
+ with pytest.raises(ModelRetry) as exc_info:
+ asyncio.run(
+ ts.call_tool("query", {"sql": "SELECT foo FROM x"},
ctx=MagicMock(), tool=MagicMock())
+ )
+ assert "get_schema" in exc_info.value.message
+ assert "list_tables" in exc_info.value.message
+
+ def test_non_retryable_error_is_propagated(self):
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.conn_type = "sqlite"
+ ts._hook.get_records.side_effect = sqlite3.OperationalError("database
is locked")
+
+ with pytest.raises(sqlite3.OperationalError, match="database is
locked"):
+ asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"},
ctx=MagicMock(), tool=MagicMock()))
+
+ def test_error_propagates_when_hook_conn_type_not_supported(self):
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.conn_type = "mysql"
+ ts._hook.get_records.side_effect = RuntimeError("unexpected db error")
+
+ with pytest.raises(RuntimeError, match="unexpected db error"):
+ asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"},
ctx=MagicMock(), tool=MagicMock()))
+
+ def test_error_propagates_when_hook_has_no_conn_type(self):
+ ts = SQLToolset("pg_default")
+ mock_hook = MagicMock(spec=["get_records", "last_description"])
+ mock_hook.get_records.side_effect = RuntimeError("hook error")
+ type(mock_hook).last_description = PropertyMock(return_value=[])
+ ts._hook = mock_hook
+
+ with pytest.raises(RuntimeError, match="hook error"):
+ asyncio.run(ts.call_tool("query", {"sql": "SELECT 1"},
ctx=MagicMock(), tool=MagicMock()))
+
class TestSQLToolsetCheckQuery:
def test_valid_select(self):
Review Comment:
All five tests use either `sqlite3.OperationalError` or `RuntimeError`. The
SQLAlchemy `ProgrammingError` path and the psycopg2
`UndefinedColumn`/`UndefinedTable` path are untested — and those are the most
common paths in production (Airflow hooks use SQLAlchemy by default).
Worth adding at least:
1. A test with a mock `sqlalchemy.exc.ProgrammingError` (with `.orig` set to
a `psycopg2.errors.UndefinedColumn`) to exercise the SQLAlchemy wrapping path.
2. A test with a non-retryable `sqlalchemy.exc.ProgrammingError` (e.g.,
`.orig` is `InsufficientPrivilege`) to confirm it's NOT retried after the
`orig` fix above.
##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py:
##########
@@ -223,6 +257,23 @@ def _query(self, sql: str) -> str:
output["max_rows"] = self._max_rows
return json.dumps(output, default=str)
Review Comment:
The SQLAlchemy `ProgrammingError` check runs first and short-circuits before
the per-DB narrowing. In practice, `DbApiHook.get_records()` uses SQLAlchemy,
so all database errors arrive wrapped — a `psycopg2.errors.UndefinedColumn`
becomes `sqlalchemy.exc.ProgrammingError`, which is caught here. But so does
`psycopg2.errors.InsufficientPrivilege` (also class 42 → `ProgrammingError`),
which can't be fixed by rewriting SQL.
The careful psycopg2 `UndefinedColumn`/`UndefinedTable` narrowing never
fires when SQLAlchemy is in play, which is the common case.
Fix: unwrap via `orig` first, then check DB-specific lists:
```python
@staticmethod
def _is_retryable_query_error(hook: DbApiHook, error: Exception) -> bool:
# Unwrap SQLAlchemy wrapper to check the original DB-specific error
check_error = getattr(error, "orig", error)
conn_type = getattr(hook, "conn_type", None)
if conn_type == "postgres":
return bool(_POSTGRES_RETRYABLE_EXCEPTIONS) and isinstance(
check_error, _POSTGRES_RETRYABLE_EXCEPTIONS
)
if conn_type == "sqlite":
import sqlite3
if isinstance(check_error, sqlite3.OperationalError):
message = str(check_error).lower()
return "no such column" in message or "no such table" in message
return False
# Fallback for unsupported DBs: trust SQLAlchemy's classification
if _SQLALCHEMY_RETRYABLE_EXCEPTIONS and isinstance(error,
_SQLALCHEMY_RETRYABLE_EXCEPTIONS):
return True
return False
```
This way postgres/sqlite get the narrow checks even through SQLAlchemy, and
unknown DBs still get the broad fallback.
##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py:
##########
@@ -223,6 +257,23 @@ def _query(self, sql: str) -> str:
output["max_rows"] = self._max_rows
return json.dumps(output, default=str)
+ @staticmethod
+ def _is_retryable_query_error(hook: DbApiHook, error: Exception) -> bool:
+ if _SQLALCHEMY_RETRYABLE_EXCEPTIONS and isinstance(error,
_SQLALCHEMY_RETRYABLE_EXCEPTIONS):
+ return True
+ conn_type = getattr(hook, "conn_type", None)
+ if conn_type == "postgres":
+ return bool(_POSTGRES_RETRYABLE_EXCEPTIONS) and isinstance(error,
_POSTGRES_RETRYABLE_EXCEPTIONS)
+ if conn_type == "sqlite":
+ with suppress(ImportError):
Review Comment:
Nit: `sqlite3` is a stdlib module — it's always available, so the
`suppress(ImportError)` wrapper is unnecessary. Fine to just `import sqlite3`
at module level alongside the psycopg2/psycopg3 imports, or inline without the
guard.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]