This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a77dcb6cc96 Allow DESCRIBE/SHOW in common.ai SQLToolset read-only
queries (#68102)
a77dcb6cc96 is described below
commit a77dcb6cc96a7cefba70a568f65720a87cc40e35
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Jun 6 02:50:35 2026 +0100
Allow DESCRIBE/SHOW in common.ai SQLToolset read-only queries (#68102)
In read-only mode the SQLToolset query tool only accepted SELECT-family
statements, so an agent that opened with DESCRIBE (a common first move to
learn a table's columns) hard-failed with SQLSafetyError. That made agent
runs nondeterministic: a run composing SELECTs directly succeeded while one
starting with DESCRIBE failed outright.
The query and check_query tools now also accept read-only metadata
statements
(DESCRIBE/DESC and SHOW) via an opt-in allow_read_only_metadata flag on
validate_sql(). The toolset passes the connection's dialect through, so
SHOW is
recognized on databases that support it (Snowflake, MySQL); without a
supporting
dialect it falls back to a blocked statement. Data-modifying statements stay
blocked, including ones wrapped behind DESCRIBE/EXPLAIN (e.g. EXPLAIN
DELETE,
DESCRIBE DROP TABLE): the deep scan now also rejects DDL nodes that became
reachable through the metadata allowlist.
The SQLAlchemy-to-sqlglot dialect mapping is consolidated into a shared
resolve_sqlglot_dialect() helper (reused by LLMSQLQueryOperator) that
returns
None for unknown dialects so a misdetected dialect never breaks validation.
---
providers/common/ai/docs/toolsets.rst | 21 ++++-
.../providers/common/ai/operators/llm_sql.py | 11 +--
.../airflow/providers/common/ai/toolsets/sql.py | 28 +++++-
.../providers/common/ai/utils/sql_validation.py | 75 +++++++++++++--
.../ai/tests/unit/common/ai/toolsets/test_sql.py | 62 +++++++++++++
.../unit/common/ai/utils/test_sql_validation.py | 101 ++++++++++++++++++++-
6 files changed, 276 insertions(+), 22 deletions(-)
diff --git a/providers/common/ai/docs/toolsets.rst
b/providers/common/ai/docs/toolsets.rst
index 617c63520b1..b5e868abea2 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -146,6 +146,20 @@ Curated toolset wrapping
The ``DbApiHook`` is resolved lazily from ``db_conn_id`` on first tool call
via ``BaseHook.get_connection(conn_id).get_hook()``.
+In read-only mode (``allow_writes=False``, the default) the ``query`` tool also
+accepts read-only metadata statements -- ``DESCRIBE``/``DESC`` and ``SHOW`` --
+in addition to SELECT-family queries. Agents commonly open with ``DESCRIBE`` to
+learn a table's columns, so permitting it keeps runs deterministic instead of
+hard-failing on schema discovery. The toolset passes the connection's dialect
to
+the validator, so ``SHOW`` is recognized on databases that support it
(Snowflake,
+MySQL, etc.); on databases without ``SHOW`` it stays rejected. Data-modifying
+statements remain blocked -- including ones hidden behind
``DESCRIBE``/``EXPLAIN``
+(e.g. ``EXPLAIN DELETE ...``, ``DESCRIBE DROP TABLE ...``), which the validator
+rejects by scanning the parsed statement for write operations. Like ``SELECT``,
+metadata statements are not scoped by ``allowed_tables`` (see
+:ref:`allowed-tables-limitation`) -- an agent can ``DESCRIBE`` a table outside
the
+list, so rely on database permissions to restrict access.
+
Multi-schema warehouses
^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -180,7 +194,8 @@ Parameters
- ``schema``: Default schema/namespace for unqualified table listing and
introspection. Schema-qualified ``allowed_tables`` entries override it per
table.
- ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.).
- Default ``False`` — only SELECT-family statements are permitted.
+ Default ``False`` -- only SELECT-family and read-only metadata
+ (``DESCRIBE``/``SHOW``) statements are permitted.
- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
``DataFusionToolset``
@@ -534,7 +549,9 @@ No single layer is sufficient — they work together.
- Does not restrict what arguments the agent passes to allowed methods.
* - **SQLToolset: read-only by default**
- ``allow_writes=False`` (default) validates every SQL query through
- ``validate_sql()`` and rejects INSERT, UPDATE, DELETE, DROP, etc.
+ ``validate_sql()``: SELECT-family and read-only metadata
+ (``DESCRIBE``/``SHOW``) statements pass; INSERT, UPDATE, DELETE, DROP,
+ and writes hidden behind ``EXPLAIN`` are rejected.
- Does not prevent the agent from reading sensitive data that the
database user has SELECT access to.
* - **DataFusionToolset: read-only by default**
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
index 370c819fa83..7342be2b7e3 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any
try:
from airflow.providers.common.ai.utils.sql_validation import (
DEFAULT_ALLOWED_TYPES,
+ resolve_sqlglot_dialect,
validate_sql as _validate_sql,
)
from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
@@ -44,12 +45,6 @@ if TYPE_CHECKING:
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.sdk import Context
-# SQLAlchemy dialect_name → sqlglot dialect mapping for names that differ.
-_SQLALCHEMY_TO_SQLGLOT_DIALECT: dict[str, str] = {
- "postgresql": "postgres",
- "mssql": "tsql",
-}
-
class LLMSQLQueryOperator(LLMOperator):
"""
@@ -257,6 +252,4 @@ class LLMSQLQueryOperator(LLMOperator):
raw = self.dialect
if not raw and self.db_hook and hasattr(self.db_hook, "dialect_name"):
raw = self.db_hook.dialect_name
- if raw:
- return _SQLALCHEMY_TO_SQLGLOT_DIALECT.get(raw, raw)
- return None
+ return resolve_sqlglot_dialect(raw)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
index fca07177597..ee3128705a1 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
@@ -24,7 +24,10 @@ from contextlib import suppress
from typing import TYPE_CHECKING, Any
try:
- from airflow.providers.common.ai.utils.sql_validation import validate_sql
as _validate_sql
+ from airflow.providers.common.ai.utils.sql_validation import (
+ resolve_sqlglot_dialect,
+ validate_sql as _validate_sql,
+ )
from airflow.providers.common.sql.hooks.sql import DbApiHook
except ImportError as e:
from airflow.providers.common.compat.sdk import
AirflowOptionalProviderFeatureException
@@ -297,11 +300,23 @@ class SQLToolset(AbstractToolset[Any]):
columns = hook.get_table_schema(table, schema=schema)
return json.dumps(columns)
+ def _dialect_for_validation(self) -> str | None:
+ """Resolve the hook's sqlglot dialect so DESCRIBE/SHOW validate
correctly."""
+ hook = self._get_db_hook()
+ return resolve_sqlglot_dialect(getattr(hook, "dialect_name", None))
+
def _query(self, sql: str) -> str:
+ hook = self._get_db_hook()
if not self._allow_writes:
- _validate_sql(sql)
+ # allow_read_only_metadata lets agents inspect schemas with
DESCRIBE/SHOW
+ # (a common first move) instead of hard-failing; the deep scan
still
+ # rejects any data-modifying statement, including EXPLAIN <write>.
+ _validate_sql(
+ sql,
+ dialect=self._dialect_for_validation(),
+ allow_read_only_metadata=True,
+ )
- hook = self._get_db_hook()
try:
rows = hook.get_records(sql)
except Exception as e:
@@ -347,8 +362,13 @@ class SQLToolset(AbstractToolset[Any]):
return False
def _check_query(self, sql: str) -> str:
+ # Resolve the dialect best-effort: if the connection can't be reached
we
+ # still syntax-check dialect-agnostically rather than reporting
invalid.
+ dialect: str | None = None
+ with suppress(Exception):
+ dialect = self._dialect_for_validation()
try:
- _validate_sql(sql)
+ _validate_sql(sql, dialect=dialect, allow_read_only_metadata=True)
return json.dumps({"valid": True})
except Exception as e:
return json.dumps({"valid": False, "error": str(e)})
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
index 3ab87516ff6..a00b4dc11e6 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
@@ -26,8 +26,36 @@ from __future__ import annotations
import sqlglot
from sqlglot import exp
+from sqlglot.dialects import Dialects
from sqlglot.errors import ErrorLevel
+# Dialect names sqlglot recognizes. Used to drop unknown dialect names so a bad
+# value never breaks parsing (sqlglot raises on an unknown dialect).
+_KNOWN_SQLGLOT_DIALECTS: frozenset[str] = frozenset(d.value for d in Dialects)
+
+# SQLAlchemy ``dialect_name`` → sqlglot dialect mapping for names that differ.
+_SQLALCHEMY_TO_SQLGLOT_DIALECT: dict[str, str] = {
+ "postgresql": "postgres",
+ "mssql": "tsql",
+}
+
+
+def resolve_sqlglot_dialect(dialect_name: str | None) -> str | None:
+ """
+ Normalize a SQLAlchemy dialect name to a sqlglot dialect.
+
+ Returns ``None`` (dialect-agnostic parsing) for empty, non-string, or
+ unknown inputs, so a bad dialect value never breaks SQL validation.
+
+ :param dialect_name: A SQLAlchemy ``dialect_name`` (e.g. ``"postgresql"``).
+ :return: The matching sqlglot dialect (e.g. ``"postgres"``), or ``None``.
+ """
+ if not isinstance(dialect_name, str) or not dialect_name:
+ return None
+ mapped = _SQLALCHEMY_TO_SQLGLOT_DIALECT.get(dialect_name, dialect_name)
+ return mapped if mapped in _KNOWN_SQLGLOT_DIALECTS else None
+
+
# Allowlist: only these top-level statement types pass validation by default.
# - Select: plain queries and CTE-wrapped queries (WITH ... AS ... SELECT is
parsed
# as Select with a `with` clause property — still a Select node at the top
level)
@@ -39,9 +67,21 @@ DEFAULT_ALLOWED_TYPES: tuple[type[exp.Expr], ...] = (
exp.Except,
)
+# Read-only metadata statements that introspect the schema without touching
data:
+# - Describe: DESCRIBE / DESC <table> (and EXPLAIN on some dialects)
+# - Show: SHOW TABLES / SHOW COLUMNS / SHOW DATABASES, etc.
+# Opt-in via ``allow_read_only_metadata=True``. SHOW only parses to
``exp.Show``
+# when a dialect that supports it is passed (e.g. snowflake, mysql); without a
+# dialect sqlglot falls back to ``exp.Command``, which stays blocked.
+READ_ONLY_METADATA_TYPES: tuple[type[exp.Expr], ...] = (
+ exp.Describe,
+ exp.Show,
+)
+
# Denylist: expression types that mutate data or schema when found anywhere in
the AST.
# This catches data-modifying CTEs (e.g. WITH del AS (DELETE …) SELECT …),
-# SELECT INTO, and other constructs that bypass top-level type checks.
+# SELECT INTO, DDL or DML wrapped behind DESCRIBE/EXPLAIN (e.g. DESCRIBE DROP
TABLE …),
+# and other constructs that bypass top-level type checks.
# Note: exp.Command is sqlglot's fallback for any syntax it doesn't recognize.
# Including it makes the denylist fail-closed (safer), but may block legitimate
# vendor-specific SQL that sqlglot can't parse. Callers who need such syntax
can
@@ -53,6 +93,11 @@ _DATA_MODIFYING_NODES: tuple[type[exp.Expr], ...] = (
exp.Merge,
exp.Into,
exp.Command,
+ # DDL — newly reachable through the DESCRIBE/SHOW allowlist, so deny it
here too.
+ exp.Create,
+ exp.Drop,
+ exp.Alter,
+ exp.TruncateTable,
)
@@ -66,6 +111,7 @@ def validate_sql(
allowed_types: tuple[type[exp.Expr], ...] | None = None,
dialect: str | None = None,
allow_multiple_statements: bool = False,
+ allow_read_only_metadata: bool = False,
) -> list[exp.Expr]:
"""
Parse SQL and verify all statements are in the allowed types list.
@@ -78,10 +124,16 @@ def validate_sql(
:param sql: SQL string to validate.
:param allowed_types: Tuple of sqlglot expression types to permit.
- Defaults to ``(Select, Union, Intersect, Except)``.
+ Defaults to ``(Select, Union, Intersect, Except)``. When supplied, the
+ caller takes full control of the allow-list and
``allow_read_only_metadata``
+ is ignored.
:param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.).
:param allow_multiple_statements: Whether to allow multiple
semicolon-separated
statements. Default ``False``.
+ :param allow_read_only_metadata: Also permit read-only metadata statements
+ (``DESCRIBE``/``SHOW``) on top of the default read-only allow-list.
Ignored
+ when ``allowed_types`` is supplied. Note ``SHOW`` only parses to a
metadata
+ statement when a ``dialect`` that supports it is given. Default
``False``.
:return: List of parsed sqlglot Expression objects.
:raises SQLSafetyError: If the SQL is empty, contains disallowed statement
types,
or has multiple statements when not permitted.
@@ -89,7 +141,18 @@ def validate_sql(
if not sql or not sql.strip():
raise SQLSafetyError("Empty SQL input.")
- types = allowed_types or DEFAULT_ALLOWED_TYPES
+ # A caller-supplied ``allowed_types`` is an explicit opt-out of the curated
+ # read-only defaults (and the data-modifying deep scan). Otherwise we use
the
+ # read-only defaults, optionally widened with metadata statements, and keep
+ # the deep scan on.
+ if allowed_types is None:
+ types: tuple[type[exp.Expr], ...] = DEFAULT_ALLOWED_TYPES
+ if allow_read_only_metadata:
+ types = types + READ_ONLY_METADATA_TYPES
+ run_data_modifying_scan = True
+ else:
+ types = allowed_types
+ run_data_modifying_scan = types == DEFAULT_ALLOWED_TYPES
try:
statements = sqlglot.parse(sql, dialect=dialect,
error_level=ErrorLevel.RAISE)
@@ -114,10 +177,10 @@ def validate_sql(
)
# Deep scan: reject data-modifying nodes hidden inside otherwise-allowed
statements
- # (e.g. data-modifying CTEs, SELECT INTO). Only applies when using the
default
- # read-only allowlist — callers who provide custom allowed_types have
explicitly
+ # (e.g. data-modifying CTEs, SELECT INTO, EXPLAIN <write>). Runs for the
curated
+ # read-only allow-list — callers who provide custom allowed_types have
explicitly
# opted into non-read-only operations.
- if types is DEFAULT_ALLOWED_TYPES:
+ if run_data_modifying_scan:
_check_for_data_modifying_nodes(parsed)
return parsed
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
index c1aae15aad5..5e425597a32 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
@@ -494,3 +494,65 @@ class TestSQLToolsetMultiSchema:
result = json.loads(asyncio.run(ts.call_tool("list_tables", {},
ctx=MagicMock(), tool=MagicMock())))
assert result == ["public.users"]
+
+
+class TestSQLToolsetMetadataStatements:
+ """Read-only metadata statements (DESCRIBE/SHOW) flow through the query
tool."""
+
+ def test_describe_allowed_through_query(self):
+ """DESCRIBE is read-only metadata and should not be rejected as
unsafe."""
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook(
+ records=[("id", "INTEGER"), ("name", "VARCHAR")],
+ last_description=[("column_name",), ("data_type",)],
+ )
+
+ result = asyncio.run(
+ ts.call_tool("query", {"sql": "DESCRIBE TABLE users"},
ctx=MagicMock(), tool=MagicMock())
+ )
+ data = json.loads(result)
+ assert "rows" in data
+ ts._hook.get_records.assert_called_once_with("DESCRIBE TABLE users")
+
+ def test_show_allowed_with_snowflake_dialect(self):
+ """SHOW parses to a metadata statement once the hook's dialect is
passed through."""
+ ts = SQLToolset("sf_default")
+ ts._hook = _make_mock_db_hook(records=[("USERS",)],
last_description=[("name",)])
+ ts._hook.dialect_name = "snowflake"
+
+ result = asyncio.run(ts.call_tool("query", {"sql": "SHOW TABLES"},
ctx=MagicMock(), tool=MagicMock()))
+ data = json.loads(result)
+ assert "rows" in data
+ ts._hook.get_records.assert_called_once_with("SHOW TABLES")
+
+ @pytest.mark.parametrize(
+ "sql",
+ # SHOW falls back to Command on Postgres (no SHOW support); DELETE is
a write.
+ ["SHOW TABLES", "DELETE FROM users"],
+ ids=["show_without_dialect_support", "write"],
+ )
+ def test_query_blocks_disallowed_statements(self, sql):
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+ ts._hook.dialect_name = "postgresql"
+
+ with pytest.raises(SQLSafetyError, match="not allowed"):
+ asyncio.run(ts.call_tool("query", {"sql": sql}, ctx=MagicMock(),
tool=MagicMock()))
+
+ def test_check_query_accepts_describe(self):
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+
+ result = asyncio.run(
+ ts.call_tool("check_query", {"sql": "DESCRIBE TABLE users"},
ctx=MagicMock(), tool=MagicMock())
+ )
+ assert json.loads(result)["valid"] is True
+
+ def test_check_query_handles_unresolvable_connection(self):
+ """check_query stays usable (dialect-agnostic) when the connection
can't be resolved."""
+ ts = SQLToolset("missing_conn")
+ with patch.object(ts, "_get_db_hook", side_effect=RuntimeError("no
such connection")):
+ result = asyncio.run(
+ ts.call_tool("check_query", {"sql": "SELECT 1"},
ctx=MagicMock(), tool=MagicMock())
+ )
+ assert json.loads(result)["valid"] is True
diff --git
a/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
index fe27379aebe..9ca6604ba5e 100644
--- a/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
+++ b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
@@ -19,7 +19,11 @@ from __future__ import annotations
import pytest
from sqlglot import exp
-from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError,
validate_sql
+from airflow.providers.common.ai.utils.sql_validation import (
+ SQLSafetyError,
+ resolve_sqlglot_dialect,
+ validate_sql,
+)
class TestValidateSQLAllowed:
@@ -213,3 +217,98 @@ class TestDataModifyingNodeDetection:
allowed_types=DEFAULT_ALLOWED_TYPES,
dialect="postgres",
)
+
+
+class TestReadOnlyMetadata:
+ """Read-only metadata statements (DESCRIBE/SHOW) with
``allow_read_only_metadata``."""
+
+ @pytest.mark.parametrize(
+ ("sql", "kwargs"),
+ [
+ ("DESCRIBE TABLE users", {}),
+ ("SHOW TABLES", {"dialect": "snowflake"}),
+ ],
+ ids=["describe", "show"],
+ )
+ def test_metadata_blocked_without_flag(self, sql, kwargs):
+ with pytest.raises(SQLSafetyError, match="not allowed"):
+ validate_sql(sql, **kwargs)
+
+ @pytest.mark.parametrize(
+ ("sql", "dialect", "expected_type"),
+ [
+ # DESCRIBE/DESC parse to exp.Describe in every dialect
(dialect-agnostic).
+ ("DESCRIBE TABLE users", None, exp.Describe),
+ ("DESC users", None, exp.Describe),
+ # SHOW only parses to exp.Show when a supporting dialect is passed.
+ ("SHOW TABLES", "snowflake", exp.Show),
+ ("SHOW COLUMNS IN users", "snowflake", exp.Show),
+ ],
+ )
+ def test_metadata_allowed_with_flag(self, sql, dialect, expected_type):
+ result = validate_sql(sql, dialect=dialect,
allow_read_only_metadata=True)
+ assert len(result) == 1
+ assert isinstance(result[0], expected_type)
+
+ def test_show_blocked_without_supporting_dialect(self):
+ """Without a dialect that supports SHOW, sqlglot falls back to
exp.Command, still blocked."""
+ with pytest.raises(SQLSafetyError, match="Command.*not allowed"):
+ validate_sql("SHOW TABLES", allow_read_only_metadata=True)
+
+ def test_explain_wrapped_write_still_blocked(self):
+ """EXPLAIN <write> parses to exp.Describe but the deep scan rejects
the inner write."""
+ with pytest.raises(SQLSafetyError, match="Data-modifying operation
'Delete'"):
+ validate_sql("EXPLAIN DELETE FROM users", dialect="mysql",
allow_read_only_metadata=True)
+
+ @pytest.mark.parametrize(
+ ("sql", "node"),
+ [
+ ("DESCRIBE CREATE TABLE t (a int)", "Create"),
+ ("DESCRIBE DROP TABLE users", "Drop"),
+ ("DESCRIBE TRUNCATE TABLE users", "TruncateTable"),
+ ("DESCRIBE DELETE FROM users", "Delete"),
+ ],
+ )
+ def test_describe_wrapped_ddl_or_dml_blocked(self, sql, node):
+ """DESCRIBE <DDL/DML> parses to exp.Describe; the deep scan rejects
the inner write."""
+ with pytest.raises(SQLSafetyError, match=f"Data-modifying operation
'{node}'"):
+ validate_sql(sql, allow_read_only_metadata=True)
+
+ def test_metadata_flag_ignored_when_custom_types_supplied(self):
+ """When the caller supplies allowed_types it controls the allow-list;
the flag is ignored."""
+ with pytest.raises(SQLSafetyError, match="Describe.*not allowed"):
+ validate_sql(
+ "DESCRIBE TABLE users",
+ allowed_types=(exp.Select,),
+ allow_read_only_metadata=True,
+ )
+
+ def test_select_still_allowed_with_flag(self):
+ result = validate_sql("SELECT 1", allow_read_only_metadata=True)
+ assert isinstance(result[0], exp.Select)
+
+ def test_writes_still_blocked_with_flag(self):
+ with pytest.raises(SQLSafetyError, match="Delete.*not allowed"):
+ validate_sql("DELETE FROM users WHERE id = 1",
allow_read_only_metadata=True)
+
+
+class TestResolveSqlglotDialect:
+ """``resolve_sqlglot_dialect`` normalizes/validates SQLAlchemy dialect
names."""
+
+ @pytest.mark.parametrize(
+ ("dialect_name", "expected"),
+ [
+ ("postgresql", "postgres"),
+ ("mssql", "tsql"),
+ ("mysql", "mysql"),
+ ("snowflake", "snowflake"),
+ ("sqlite", "sqlite"),
+ (None, None),
+ ("", None),
+ ("default", None),
+ ("not_a_real_dialect", None),
+ (123, None),
+ ],
+ )
+ def test_resolution(self, dialect_name, expected):
+ assert resolve_sqlglot_dialect(dialect_name) == expected