mistercrunch commented on code in PR #26767: URL: https://github.com/apache/superset/pull/26767#discussion_r1482060518
########## tests/integration_tests/core_tests.py: ########## @@ -537,7 +537,7 @@ def test_mssql_engine_spec_pymssql(self): ) def test_comments_in_sqlatable_query(self): - clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl" + clean_query = "SELECT\n '/* val 1 */' AS c1,\n '-- val 2' AS c2\nFROM tbl" Review Comment: [suggesting] it could be good to have some sort of reusable `compare_sql(strict=False, case_sensitive=False, disregard_schema_prefix=True)` function that could be reused in unit tests. Maybe it's a method of `ParseQuery` (`is_identical` or `is_similar`) ########## superset/sql_parse.py: ########## @@ -252,6 +253,182 @@ def __eq__(self, __o: object) -> bool: return str(self) == str(__o) +def extract_tables_from_statement( + statement: exp.Expression, + dialect: Optional[Dialects], +) -> set[Table]: + """ + Extract all table references in a single statement. + + Please not that this is not trivial; consider the following queries: + + DESCRIBE some_table; + SHOW PARTITIONS FROM some_table; + WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name; + + See the unit tests for other tricky cases. + """ + sources: Iterable[exp.Table] + + if isinstance(statement, exp.Describe): + # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly + # query for all tables. + sources = statement.find_all(exp.Table) + elif isinstance(statement, exp.Command): + # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a + # `SELECT` statetement in order to extract tables. + literal = statement.find(exp.Literal) + if not literal: + return set() + + pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect) + sources = pseudo_query.find_all(exp.Table) + else: + sources = [ + source + for scope in traverse_scope(statement) + for source in scope.sources.values() + if isinstance(source, exp.Table) and not is_cte(source, scope) + ] + + return { + Table( + source.name, + source.db if source.db != "" else None, + source.catalog if source.catalog != "" else None, + ) + for source in sources + } + + +def is_cte(source: exp.Table, scope: Scope) -> bool: + """ + Is the source a CTE? + + CTEs in the parent scope look like tables (and are represented by + exp.Table objects), but should not be considered as such; + otherwise a user with access to table `foo` could access any table + with a query like this: + + WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo + + """ + parent_sources = scope.parent.sources if scope.parent else {} + ctes_in_scope = { + name + for name, parent_scope in parent_sources.items() + if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE + } + + return source.name in ctes_in_scope + + +class SQLQuery: + """ + A SQL query, with 0+ statements. + """ + + def __init__( + self, + query: str, + engine: Optional[str] = None, + ): + dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + + self.statements = [ + SQLStatement(statement, engine=engine) + for statement in parse(query, dialect=dialect) + if statement + ] + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL query. + """ + return ";\n".join(statement.format(comments) for statement in self.statements) + + def get_settings(self) -> dict[str, str]: + """ + Return the settings for the SQL query. + + >>> statement = SQLQuery("SET foo = 'bar'; SET foo = 'baz'") + >>> statement.get_settings() + {"foo": "'baz'"} + + """ + settings: dict[str, str] = {} + for statement in self.statements: + settings.update(statement.get_settings()) + + return settings + + +class SQLStatement: + """ + A SQL statement. + + This class provides helper methods to manipulate and introspect SQL. + """ + + def __init__( + self, + statement: Union[str, exp.Expression], + engine: Optional[str] = None, + ): + dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + + if isinstance(statement, str): + try: + self._parsed = self._parse_statement(statement, dialect) + except ParseError as ex: + raise SupersetParseError(statement, engine) from ex + else: + self._parsed = statement + + self._dialect = dialect + self.tables = extract_tables_from_statement(self._parsed, dialect) + + @staticmethod + def _parse_statement( + sql_statement: str, + dialect: Optional[Dialects], + ) -> exp.Expression: + """ + Parse a single SQL statement. + """ + statements = [ + statement + for statement in sqlglot.parse(sql_statement, dialect=dialect) + if statement + ] + if len(statements) != 1: + raise ValueError("SQLStatement should have exactly one statement") + + return statements[0] + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL statement. + """ + write = Dialect.get_or_raise(self._dialect) + return write.generate(self._parsed, copy=False, comments=comments, pretty=True) + + def get_settings(self) -> dict[str, str]: Review Comment: `get_settings` could have a better name, maybe `get_set_schema_setting`. I'm not sure how cross-dialect this is. This may need a `db_engine_spec` method... Related convo with GPT -> <img width="849" alt="Screenshot 2024-02-07 at 12 30 38 PM" src="https://github.com/apache/superset/assets/487433/a6392896-8230-44bc-a26e-d229eb2cd580"> ########## superset/sql_parse.py: ########## @@ -22,12 +22,13 @@ import urllib.parse from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, cast, Optional +from typing import Any, cast, Optional, Union Review Comment: It could be good to rename this module to a better name, the current name can be confused with having a relationship with `sqlparse`, and overall just isn't a good name. It could be `utils/sql.py`, or more directly `superset/sql/*` if we need to grow this into a package with multiple modules. `superset/sql_parser.py` (?) ########## superset/sql_parse.py: ########## @@ -252,6 +253,182 @@ def __eq__(self, __o: object) -> bool: return str(self) == str(__o) +def extract_tables_from_statement( + statement: exp.Expression, + dialect: Optional[Dialects], +) -> set[Table]: + """ + Extract all table references in a single statement. + + Please not that this is not trivial; consider the following queries: + + DESCRIBE some_table; + SHOW PARTITIONS FROM some_table; + WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name; + + See the unit tests for other tricky cases. + """ + sources: Iterable[exp.Table] + + if isinstance(statement, exp.Describe): + # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly + # query for all tables. + sources = statement.find_all(exp.Table) + elif isinstance(statement, exp.Command): + # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a + # `SELECT` statetement in order to extract tables. + literal = statement.find(exp.Literal) + if not literal: + return set() + + pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect) + sources = pseudo_query.find_all(exp.Table) + else: + sources = [ + source + for scope in traverse_scope(statement) + for source in scope.sources.values() + if isinstance(source, exp.Table) and not is_cte(source, scope) + ] + + return { + Table( + source.name, + source.db if source.db != "" else None, + source.catalog if source.catalog != "" else None, + ) + for source in sources + } + + +def is_cte(source: exp.Table, scope: Scope) -> bool: + """ + Is the source a CTE? + + CTEs in the parent scope look like tables (and are represented by + exp.Table objects), but should not be considered as such; + otherwise a user with access to table `foo` could access any table + with a query like this: + + WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo + + """ + parent_sources = scope.parent.sources if scope.parent else {} + ctes_in_scope = { + name + for name, parent_scope in parent_sources.items() + if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE + } + + return source.name in ctes_in_scope + + +class SQLQuery: + """ + A SQL query, with 0+ statements. + """ + + def __init__( + self, + query: str, + engine: Optional[str] = None, + ): + dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + + self.statements = [ + SQLStatement(statement, engine=engine) + for statement in parse(query, dialect=dialect) + if statement + ] + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL query. + """ + return ";\n".join(statement.format(comments) for statement in self.statements) + + def get_settings(self) -> dict[str, str]: + """ + Return the settings for the SQL query. + + >>> statement = SQLQuery("SET foo = 'bar'; SET foo = 'baz'") + >>> statement.get_settings() + {"foo": "'baz'"} + + """ + settings: dict[str, str] = {} + for statement in self.statements: + settings.update(statement.get_settings()) + + return settings + + +class SQLStatement: Review Comment: Looking at the class design in this module, this looks like it really extract the first statement from a set of statement. It may not be for this PR, but ideally the classes become a better representation of the real word. Maybe there's a notion of a `SqlScript` or `SqlStatementCollection`, that's an array of other SQL objects like `SqlQuery`, `SqlDMLCommand`, ... Maybe there's benefit in having inheritance defined there where a SqlQuery and DmlCommand are both derivative or `SqlStatement` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: notifications-unsubscr...@superset.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: notifications-unsubscr...@superset.apache.org For additional commands, e-mail: notifications-h...@superset.apache.org