gopidesupavan commented on code in PR #68487:
URL: https://github.com/apache/airflow/pull/68487#discussion_r3409195194
##########
providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py:
##########
@@ -105,6 +107,239 @@ class SQLSafetyError(Exception):
"""Generated SQL failed safety validation."""
+def parse_sql(
+ sql: str,
+ *,
+ dialect: str | None = None,
+ allow_multiple_statements: bool = False,
+) -> list[exp.Expr]:
+ """
+ Parse SQL into statements, enforcing the empty- and multi-statement guards
only.
+
+ Shared by :func:`validate_sql` (which then applies statement-type checks)
and by
+ callers that need the parsed AST for their own analysis -- e.g.
table-reference
+ extraction for ``allowed_tables`` enforcement -- without the read-only
allow-list.
+
+ :param sql: SQL string to parse.
+ :param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.).
+ :param allow_multiple_statements: Whether to allow multiple
semicolon-separated
+ statements. Default ``False`` -- multi-statement input can hide a
dangerous
+ operation after a benign one.
+ :return: List of parsed sqlglot Expression objects (never empty).
+ :raises SQLSafetyError: If the SQL is empty, cannot be parsed, or contains
multiple
+ statements when not permitted.
+ """
+ if not sql or not sql.strip():
+ raise SQLSafetyError("Empty SQL input.")
+
+ try:
+ statements = sqlglot.parse(sql, dialect=dialect,
error_level=ErrorLevel.RAISE)
+ except sqlglot.errors.ParseError as e:
+ raise SQLSafetyError(f"SQL parse error: {e}") from e
+
+ # sqlglot.parse can return [None] for empty input
+ parsed = [s for s in statements if s is not None]
+ if not parsed:
+ raise SQLSafetyError("Empty SQL input.")
+
+ if not allow_multiple_statements and len(parsed) > 1:
+ raise SQLSafetyError(
+ f"Multiple statements detected ({len(parsed)}). Only single
statements are allowed by default."
+ )
+ return parsed
+
+
+class TableScan(NamedTuple):
+ """Result of :func:`collect_table_references`."""
+
+ #: ``(catalog, schema, table)`` for every real base table referenced
anywhere in
+ #: the AST. ``catalog`` and ``schema`` are ``""`` when the reference omits
them.
+ #: In-scope CTE references are excluded. Catalog is reported so the caller
can
+ #: reject cross-database references (``otherdb.public.orders``) that a
+ #: ``schema.table`` allow-list cannot describe.
+ tables: list[tuple[str, str, str]]
+ #: Human-readable descriptions of constructs that cannot be checked
against an
+ #: allow-list and so must be rejected while one is active: table-valued
functions
+ #: (``dblink``), ``TABLE('name')`` row sources, ``SHOW``, dynamic SQL
+ #: (``EXEC``/``Command``), inline comments (a parser-vs-engine
differential), and
+ #: the ``TABLE <name>`` shorthand. Empty when every construct is
verifiable.
+ unverifiable_sources: list[str]
+
+
+_DML_TYPES: tuple[type[exp.Expr], ...] = (exp.Insert, exp.Update, exp.Delete,
exp.Merge)
+
+
+def _same_identifier(a: exp.Identifier, b: exp.Identifier) -> bool:
+ """
+ Compare two identifiers under standard identifier-folding rules.
+
+ Unquoted names fold (case-insensitive); quoted names are case-preserving
and
+ distinct from unquoted ones. Used to decide whether a table reference
names a CTE:
+ being *stricter* here is safe -- a near-miss falls through to the
allow-list check.
+ """
+ aq, bq = bool(a.args.get("quoted")), bool(b.args.get("quoted"))
+ if not aq and not bq:
+ return str(a.this).casefold() == str(b.this).casefold()
+ if aq and bq:
+ return str(a.this) == str(b.this)
+ return False
+
+
+def _enclosing_cte(table: exp.Expr, with_: exp.With) -> exp.CTE | None:
+ """Return the CTE of ``with_`` whose *definition* contains ``table`` (else
``None``)."""
+ node = table.parent
+ while node is not None and node is not with_.parent:
+ if isinstance(node, exp.CTE) and node.parent is with_:
+ return node
+ node = node.parent
+ return None
+
+
+def _is_in_scope_cte(table: exp.Table) -> bool:
+ """
+ Report whether ``table`` is a bare reference resolved by a CTE visible at
its scope.
+
+ Walks the ancestor chain (lexical scope) collecting CTE names from each
enclosing
+ ``WITH``. A CTE defined in a *sibling* or *inner* subquery is not an
ancestor, so a
+ real top-level table is never excluded by an unrelated same-named CTE
+ (``SELECT * FROM secret WHERE id IN (WITH secret AS (...) SELECT ...)``). A
+ non-recursive CTE is not visible inside its own definition, so
+ ``WITH secret AS (SELECT * FROM secret) ...`` still reports the real
``secret``.
+ CTE order matters too: inside one CTE's body only *earlier* siblings are
in scope
+ (forward references need ``RECURSIVE``), so ``WITH a AS (SELECT * FROM
secret),
+ secret AS (...) SELECT * FROM a`` still reports the real ``secret`` read
by ``a``.
+ """
+ ref = table.this
+ if not isinstance(ref, exp.Identifier):
+ return False
+ node: exp.Expr | None = table.parent
+ while node is not None:
+ # A WITH attaches to its owning query (Select/Union/DML) as a sibling
of the
+ # body, so the query -- an ancestor of the table -- holds it. Find it
by type
+ # rather than a fixed arg key (sqlglot has used both ``with`` and
``with_``).
+ with_ = (
+ next((v for v in node.args.values() if isinstance(v, exp.With)),
None)
+ if isinstance(node, exp.Expression)
+ else None
+ )
+ if isinstance(with_, exp.With):
+ recursive = bool(with_.args.get("recursive"))
+ ctes = list(with_.expressions)
+ enclosing = _enclosing_cte(table, with_)
+ # If the reference sits inside CTE E's own body, only CTEs defined
*before*
+ # E are visible there (plus E itself when RECURSIVE); a CTE
defined after E
+ # is not yet in scope. In the main query body every CTE is visible.
+ enclosing_idx = next((i for i, c in enumerate(ctes) if c is
enclosing), None)
+ for idx, cte in enumerate(ctes):
+ if enclosing_idx is not None:
+ if idx > enclosing_idx:
+ continue
+ if idx == enclosing_idx and not recursive:
+ continue
+ alias = cte.args.get("alias")
+ cte_ident = alias.this if isinstance(alias, exp.TableAlias)
else None
+ if isinstance(cte_ident, exp.Identifier) and
_same_identifier(cte_ident, ref):
+ return True
+ node = node.parent
+ return False
+
+
+def collect_table_references(statements: list[exp.Expr]) -> TableScan:
+ """
+ Walk parsed statements and report every real table they reach,
scope-correctly.
+
+ This is the AST half of ``allowed_tables`` enforcement: it returns the
concrete
+ base tables a query reaches (including those nested in subqueries, CTEs,
JOINs, set
+ operations, ``DESCRIBE``, and DML) as ``(catalog, schema, table)`` so the
caller can
+ check each against its allow-list, plus a list of constructs that cannot
be checked
+ and must therefore be rejected while an allow-list is active.
+
+ Handled carefully (each was a confirmed bypass before it was closed):
+
+ - **CTE references are excluded by lexical scope, not by name.** A table
is treated
+ as a CTE only when a ``WITH`` *enclosing that reference* defines the
name (see
+ :func:`_is_in_scope_cte`); a same-named CTE in a sibling/inner query no
longer
+ hides a real top-level table. CTE handling is skipped entirely for DML
statements,
+ where the target is always a real table.
+ - **Catalog-qualified references are reported with their catalog**, so the
caller
+ rejects ``otherdb.public.orders`` instead of matching it to
``public.orders``.
+ - **Unverifiable constructs are listed, not silently dropped:** nameless
+ table-valued functions (``dblink``), ``TABLE('name')`` row sources
+ (``exp.TableFromRows``), ``SHOW``, dynamic SQL (``EXEC``/``Command``),
the
+ ``TABLE <name>`` shorthand (which sqlglot parses incorrectly, leaking the
+ ``TABLE`` keyword as a column), a **quoted identifier** (case-sensitive
on the engine but
+ matched case-insensitively here, so ``"Orders"`` could otherwise reach a
table
+ distinct from the allow-listed ``orders``), and **any inline comment** --
+ comments are where parser-vs-engine differentials hide (MySQL executable
+ ``/*! ... */``, ``--`` not followed by whitespace, ``#``).
+
+ :param statements: Parsed sqlglot statements (from :func:`parse_sql`).
+ :return: A :class:`TableScan` of real table references and unverifiable
constructs.
+ """
+ tables: list[tuple[str, str, str]] = []
+ unverifiable: list[str] = []
+ for stmt in statements:
+ # SHOW enumerates objects / leaks a table's columns outside any single
table.
+ if isinstance(stmt, exp.Show):
+ unverifiable.append("a SHOW statement")
+ continue
+ # Dynamic SQL and anything sqlglot can only represent as a raw Command
reach
+ # data through text the parser cannot inspect.
+ if isinstance(stmt, (exp.Command, exp.Execute)):
+ unverifiable.append(f"a {type(stmt).__name__.lower()} statement")
+ continue
+
+ # A comment is a parser-vs-engine differential vector: sqlglot drops
it, but the
+ # engine may execute it (MySQL `/*! ... */`) or tokenize it
differently (`--`
+ # without a trailing space, `#`). sqlglot tokenizes string literals
correctly,
+ # so a `--` inside a quoted string is not flagged here.
+ if any(node.comments for node in stmt.walk()):
+ unverifiable.append("an inline comment")
+ continue
+
+ # `TABLE('name')` / `TABLE($$name$$)` name a table through a string
the parser
+ # cannot resolve; sqlglot models them as TableFromRows, not exp.Table.
+ if any(True for _ in stmt.find_all(exp.TableFromRows)):
+ unverifiable.append("a TABLE(...) row source")
+ continue
+
+ # `TABLE <name>` (Postgres/MySQL shorthand for SELECT * FROM <name>)
is not
+ # modelled by sqlglot; it parses incorrectly, leaking the reserved
word TABLE as an
+ # unquoted column identifier. No real query has an unquoted column
named TABLE.
+ if any(
+ isinstance(col.this, exp.Identifier)
+ and not col.this.args.get("quoted")
+ and str(col.this.this).upper() == "TABLE"
+ for col in stmt.find_all(exp.Column)
+ ):
+ unverifiable.append("a TABLE <name> shorthand")
+ continue
+
+ is_dml = isinstance(stmt, _DML_TYPES)
+ for table in stmt.find_all(exp.Table):
+ name = table.name
+ if not name:
+ unverifiable.append(f"table-valued function ({table.sql()})")
+ continue
+ # Bare references may be CTEs; qualified ones never are. DML
targets/sources
+ # are always real tables (you cannot read or write a CTE as a base
table).
+ if not is_dml and not table.db and not table.catalog and
_is_in_scope_cte(table):
Review Comment:
This can false-reject DML when the source is a CTE rather than a real table.
For example, `WITH src AS (SELECT * FROM orders) INSERT INTO orders SELECT *
FROM src` only touches the real table `orders`, but `src` is still checked
against `allowed_tables` and the query fails.
Could we handle this scenario ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]