This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d88504564c5 Support multi-schema introspection in common.ai SQLToolset
(#68103)
d88504564c5 is described below
commit d88504564c506e7edfbd64cb0589f6ac9d6263c5
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Jun 6 01:17:53 2026 +0100
Support multi-schema introspection in common.ai SQLToolset (#68103)
SQLToolset's metadata tools (list_tables, get_schema) operated against a
single
schema, so an agent over a multi-schema warehouse (common on Snowflake)
could
not discover tables across schemas. With no schema set and schema-qualified
tables, list_tables introspected a literal "None" schema
(SHOW TABLES IN SCHEMA "DB"."None") and failed outright.
allowed_tables entries may now be schema-qualified ("SCHEMA.TABLE").
list_tables
introspects each referenced schema and returns the matching tables fully
qualified, and get_schema routes each qualified name to its own schema.
Unqualified entries and the allow-all case keep the previous single-schema
behaviour using the default schema. Table-name matching is case-insensitive,
because databases reflect identifiers in their own case (Snowflake reflects
unquoted names lowercased) and a byte-exact match would silently return
nothing.
Results are de-duplicated by (schema, table) so a table reachable both
qualified
and via the default schema is listed once.
---
providers/common/ai/docs/toolsets.rst | 29 ++++-
.../airflow/providers/common/ai/toolsets/sql.py | 84 +++++++++++--
.../ai/tests/unit/common/ai/toolsets/test_sql.py | 132 ++++++++++++++++++++-
3 files changed, 230 insertions(+), 15 deletions(-)
diff --git a/providers/common/ai/docs/toolsets.rst
b/providers/common/ai/docs/toolsets.rst
index a9f454fbc89..617c63520b1 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -146,14 +146,39 @@ Curated toolset wrapping
The ``DbApiHook`` is resolved lazily from ``db_conn_id`` on first tool call
via ``BaseHook.get_connection(conn_id).get_hook()``.
+Multi-schema warehouses
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+When an agent's tables live in several schemas of one database -- common on
+Snowflake -- list them with schema-qualified ``allowed_tables`` entries:
+
+.. code-block:: python
+
+ SQLToolset(
+ db_conn_id="snowflake_hq",
+ allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS",
"MODEL_CRM.SF_ASTRO_ORGS"],
+ )
+
+``list_tables`` then introspects each referenced schema and returns the
matching
+tables fully qualified (e.g. ``MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS``), and
+``get_schema`` routes each qualified name to its own schema. Without this, a
+single ``schema`` only covers one namespace, and leaving ``schema`` unset made
+introspection query a literal ``"None"`` schema and fail. Unqualified entries
+fall back to ``schema``, and table-name matching is case-insensitive (databases
+reflect identifiers in their own case). For tables in a different *database*,
use
+a separate toolset whose connection points at that database.
+
Parameters
^^^^^^^^^^
- ``db_conn_id``: Airflow connection ID for the database.
- ``allowed_tables``: Restrict which tables the agent can discover via
- ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables.
+ ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables in
+ ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span
+ multiple schemas; see above. Matching is case-insensitive.
See :ref:`allowed-tables-limitation` for an important caveat.
-- ``schema``: Database schema/namespace for table listing and introspection.
+- ``schema``: Default schema/namespace for unqualified table listing and
+ introspection. Schema-qualified ``allowed_tables`` entries override it per
table.
- ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.).
Default ``False`` — only SELECT-family statements are permitted.
- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
index 0902cff99f2..fca07177597 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
@@ -111,7 +111,13 @@ class SQLToolset(AbstractToolset[Any]):
:param db_conn_id: Airflow connection ID for the database.
:param allowed_tables: Restrict which tables the agent can discover via
- ``list_tables`` and ``get_schema``. ``None`` (default) exposes all
tables.
+ ``list_tables`` and ``get_schema``. ``None`` (default) exposes all
tables
+ in ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to
span
+ multiple schemas in one database -- common on warehouses such as
Snowflake.
+ ``list_tables`` then introspects each referenced schema and returns the
+ matching tables fully qualified, and ``get_schema`` routes to the
table's
+ own schema. Unqualified entries use ``schema``. Matching is
+ case-insensitive, since databases reflect identifiers in their own
case.
.. note::
``allowed_tables`` controls metadata visibility only. It does
**not**
@@ -120,7 +126,10 @@ class SQLToolset(AbstractToolset[Any]):
restrictions, use database-level permissions (e.g. a read-only role
with grants limited to specific tables).
- :param schema: Database schema/namespace for table listing and
introspection.
+ :param schema: Default schema/namespace for table listing and
introspection,
+ used for unqualified ``allowed_tables`` entries and unqualified
+ ``get_schema`` calls. Schema-qualified ``allowed_tables`` entries
override
+ it per table.
:param allow_writes: Allow data-modifying SQL (INSERT, UPDATE, DELETE,
etc.).
Default ``False`` — only SELECT-family statements are permitted.
:param max_rows: Maximum number of rows returned from the ``query`` tool.
@@ -138,11 +147,43 @@ class SQLToolset(AbstractToolset[Any]):
) -> None:
self._db_conn_id = db_conn_id
self._allowed_tables: frozenset[str] | None =
frozenset(allowed_tables) if allowed_tables else None
+ # Case-folded view for membership tests: databases reflect identifiers
in
+ # their own case (Snowflake stores unquoted names uppercase but
reflects
+ # them lowercased), so a byte-exact match against the user's entries
would
+ # silently miss. allowed_tables is a visibility hint, not access
control,
+ # so case-insensitive matching is safe.
+ self._allowed_tables_ci: frozenset[str] | None = (
+ frozenset(t.casefold() for t in self._allowed_tables)
+ if self._allowed_tables is not None
+ else None
+ )
self._schema = schema
self._allow_writes = allow_writes
self._max_rows = max_rows
self._hook: DbApiHook | None = None
+ # Derive which schemas to introspect from schema-qualified
allowed_tables.
+ # Qualified entries ("SCHEMA.TABLE") are listed under their own schema
and
+ # returned fully qualified; unqualified entries (and allow-all) use the
+ # default ``schema``.
+ self._qualified_schemas: frozenset[str] = frozenset()
+ self._include_default_schema: bool = True
+ if self._allowed_tables is not None:
+ qualified_schemas: set[str] = set()
+ include_default = False
+ for entry in self._allowed_tables:
+ entry_schema, sep, _ = entry.rpartition(".")
+ if sep:
+ qualified_schemas.add(entry_schema)
+ else:
+ include_default = True
+ self._qualified_schemas = frozenset(qualified_schemas)
+ self._include_default_schema = include_default
+
+ def _is_table_allowed(self, name: str) -> bool:
+ """Case-insensitive membership test against ``allowed_tables``
(allow-all when unset)."""
+ return self._allowed_tables_ci is None or name.casefold() in
self._allowed_tables_ci
+
@property
def id(self) -> str:
return f"sql-{self._db_conn_id}"
@@ -213,18 +254,47 @@ class SQLToolset(AbstractToolset[Any]):
# Tool implementations
# ------------------------------------------------------------------
+ def _split_table_identifier(self, table_name: str) -> tuple[str | None,
str]:
+ """Split ``"SCHEMA.TABLE"`` into ``(schema, table)``; unqualified uses
the default schema."""
+ schema, sep, table = table_name.rpartition(".")
+ if not sep:
+ return self._schema, table_name
+ return schema, table
+
def _list_tables(self) -> str:
hook = self._get_db_hook()
- tables: list[str] = hook.inspector.get_table_names(schema=self._schema)
- if self._allowed_tables is not None:
- tables = [t for t in tables if t in self._allowed_tables]
+ tables: list[str] = []
+ # Dedupe by (schema, table) so a table reachable both qualified and
via the
+ # default schema (e.g. "public.users" and "users" with
schema="public") is
+ # listed once. Case-folded because databases reflect identifiers in
their case.
+ seen: set[tuple[str | None, str]] = set()
+
+ def add(schema: str | None, name: str, display: str) -> None:
+ key = (schema.casefold() if schema else None, name.casefold())
+ if self._is_table_allowed(display) and key not in seen:
+ seen.add(key)
+ tables.append(display)
+
+ # Schemas referenced by qualified allowed_tables entries: introspect
each
+ # and return matching tables fully qualified so they round-trip to
get_schema.
+ for schema in sorted(self._qualified_schemas):
+ for name in hook.inspector.get_table_names(schema=schema):
+ add(schema, name, f"{schema}.{name}")
+
+ # Default schema: used for allow-all and unqualified allowed_tables
entries.
+ # Names stay bare to preserve the single-schema behaviour.
+ if self._include_default_schema:
+ for name in hook.inspector.get_table_names(schema=self._schema):
+ add(self._schema, name, name)
+
return json.dumps(tables)
def _get_schema(self, table_name: str) -> str:
- if self._allowed_tables is not None and table_name not in
self._allowed_tables:
+ if not self._is_table_allowed(table_name):
return json.dumps({"error": f"Table {table_name!r} is not in the
allowed tables list."})
hook = self._get_db_hook()
- columns = hook.get_table_schema(table_name, schema=self._schema)
+ schema, table = self._split_table_identifier(table_name)
+ columns = hook.get_table_schema(table, schema=schema)
return json.dumps(columns)
def _query(self, sql: str) -> str:
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
index 471b956385d..c1aae15aad5 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
@@ -27,6 +27,7 @@ from pydantic_ai.exceptions import ModelRetry
from airflow.providers.common.ai.toolsets.sql import SQLToolset
from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
+from airflow.providers.common.sql.hooks.sql import DbApiHook
def _make_mock_db_hook(
@@ -36,8 +37,6 @@ def _make_mock_db_hook(
last_description: list[tuple] | None = None,
):
"""Create a mock DbApiHook with sensible defaults."""
- from airflow.providers.common.sql.hooks.sql import DbApiHook
-
mock = MagicMock(spec=DbApiHook)
mock.inspector = MagicMock()
mock.inspector.get_table_names.return_value = table_names or ["users",
"orders"]
@@ -335,8 +334,6 @@ class TestSQLToolsetCheckQuery:
class TestSQLToolsetHookResolution:
@patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
def test_lazy_resolves_db_hook(self, mock_base_hook):
- from airflow.providers.common.sql.hooks.sql import DbApiHook
-
mock_hook = MagicMock(spec=DbApiHook)
mock_conn = MagicMock(spec=["get_hook"])
mock_conn.get_hook.return_value = mock_hook
@@ -361,8 +358,6 @@ class TestSQLToolsetHookResolution:
@patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
def test_caches_hook_after_first_resolution(self, mock_base_hook):
- from airflow.providers.common.sql.hooks.sql import DbApiHook
-
mock_hook = MagicMock(spec=DbApiHook)
mock_conn = MagicMock(spec=["get_hook"])
mock_conn.get_hook.return_value = mock_hook
@@ -374,3 +369,128 @@ class TestSQLToolsetHookResolution:
# Only called once because result is cached.
mock_base_hook.get_connection.assert_called_once()
+
+
+class TestSQLToolsetMultiSchema:
+ """Schema-qualified allowed_tables span multiple schemas in one
database."""
+
+ @staticmethod
+ def _schema_aware_hook(tables_by_schema: dict[str | None, list[str]]):
+ hook = MagicMock(spec=DbApiHook)
+ hook.inspector = MagicMock()
+ hook.inspector.get_table_names.side_effect = lambda schema=None:
tables_by_schema.get(schema, [])
+ hook.get_table_schema.return_value = [{"name": "id", "type":
"INTEGER"}]
+ return hook
+
+ def test_list_tables_spans_multiple_schemas(self):
+ ts = SQLToolset(
+ "sf",
+ allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS",
"MODEL_CRM.SF_ASTRO_ORGS"],
+ )
+ ts._hook = self._schema_aware_hook(
+ {
+ "MODEL_ASTRO": ["DEPLOYMENT_IMAGE_DETAILS", "OTHER_TABLE"],
+ "MODEL_CRM": ["SF_ASTRO_ORGS"],
+ }
+ )
+
+ result = json.loads(asyncio.run(ts.call_tool("list_tables", {},
ctx=MagicMock(), tool=MagicMock())))
+ assert result == ["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS",
"MODEL_CRM.SF_ASTRO_ORGS"]
+
+ def
test_list_tables_never_introspects_none_schema_when_all_qualified(self):
+ """Regression for the 'SHOW TABLES IN SCHEMA "DB"."None"' failure."""
+ ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.X", "MODEL_CRM.Y"])
+ ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["X"], "MODEL_CRM":
["Y"]})
+
+ asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(),
tool=MagicMock()))
+
+ called_schemas = {c.kwargs.get("schema") for c in
ts._hook.inspector.get_table_names.call_args_list}
+ assert called_schemas == {"MODEL_ASTRO", "MODEL_CRM"}
+ assert None not in called_schemas
+
+ def test_list_tables_mixed_qualified_and_default(self):
+ ts = SQLToolset("pg", allowed_tables=["users", "MODEL_ASTRO.X"],
schema="public")
+ ts._hook = self._schema_aware_hook({"public": ["users", "orders"],
"MODEL_ASTRO": ["X", "Z"]})
+
+ result = json.loads(asyncio.run(ts.call_tool("list_tables", {},
ctx=MagicMock(), tool=MagicMock())))
+ # Qualified schemas listed first (sorted), then the default schema.
+ assert result == ["MODEL_ASTRO.X", "users"]
+
+ def test_get_schema_routes_to_qualified_schema(self):
+ ts = SQLToolset("sf",
allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"])
+ ts._hook = self._schema_aware_hook({"MODEL_ASTRO":
["DEPLOYMENT_IMAGE_DETAILS"]})
+
+ result = json.loads(
+ asyncio.run(
+ ts.call_tool(
+ "get_schema",
+ {"table_name": "MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"},
+ ctx=MagicMock(),
+ tool=MagicMock(),
+ )
+ )
+ )
+ assert result == [{"name": "id", "type": "INTEGER"}]
+
ts._hook.get_table_schema.assert_called_once_with("DEPLOYMENT_IMAGE_DETAILS",
schema="MODEL_ASTRO")
+
+ def test_get_schema_blocks_table_outside_allowed_schema(self):
+ ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.X"])
+ ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["X"]})
+
+ result = json.loads(
+ asyncio.run(
+ ts.call_tool(
+ "get_schema", {"table_name": "SECRETS.PASSWORDS"},
ctx=MagicMock(), tool=MagicMock()
+ )
+ )
+ )
+ assert "error" in result
+ ts._hook.get_table_schema.assert_not_called()
+
+ def test_get_schema_unqualified_uses_default_schema(self):
+ ts = SQLToolset("pg", schema="public")
+ ts._hook = self._schema_aware_hook({"public": ["users"]})
+
+ asyncio.run(ts.call_tool("get_schema", {"table_name": "users"},
ctx=MagicMock(), tool=MagicMock()))
+ ts._hook.get_table_schema.assert_called_once_with("users",
schema="public")
+
+ def test_list_tables_matches_case_insensitively(self):
+ """Snowflake reflects unquoted names lowercased; uppercase
allowed_tables still match."""
+ ts = SQLToolset(
+ "sf",
+ allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS",
"MODEL_CRM.SF_ASTRO_ORGS"],
+ )
+ ts._hook = self._schema_aware_hook(
+ {
+ "MODEL_ASTRO": ["deployment_image_details", "other"],
+ "MODEL_CRM": ["sf_astro_orgs"],
+ }
+ )
+
+ result = json.loads(asyncio.run(ts.call_tool("list_tables", {},
ctx=MagicMock(), tool=MagicMock())))
+ assert result == ["MODEL_ASTRO.deployment_image_details",
"MODEL_CRM.sf_astro_orgs"]
+
+ def test_get_schema_matches_case_insensitively(self):
+ ts = SQLToolset("sf",
allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"])
+ ts._hook = self._schema_aware_hook({"MODEL_ASTRO":
["deployment_image_details"]})
+
+ result = json.loads(
+ asyncio.run(
+ ts.call_tool(
+ "get_schema",
+ {"table_name": "MODEL_ASTRO.deployment_image_details"},
+ ctx=MagicMock(),
+ tool=MagicMock(),
+ )
+ )
+ )
+ assert "error" not in result
+
ts._hook.get_table_schema.assert_called_once_with("deployment_image_details",
schema="MODEL_ASTRO")
+
+ def test_list_tables_deduplicates_same_table(self):
+ """A table listed both qualified and unqualified appears once."""
+ ts = SQLToolset("pg", allowed_tables=["public.users", "users"],
schema="public")
+ ts._hook = self._schema_aware_hook({"public": ["users"]})
+
+ result = json.loads(asyncio.run(ts.call_tool("list_tables", {},
ctx=MagicMock(), tool=MagicMock())))
+ assert result == ["public.users"]