mistercrunch commented on code in PR #26767:
URL: https://github.com/apache/superset/pull/26767#discussion_r1482060518


##########
tests/integration_tests/core_tests.py:
##########
@@ -537,7 +537,7 @@ def test_mssql_engine_spec_pymssql(self):
         )
 
     def test_comments_in_sqlatable_query(self):
-        clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl"
+        clean_query = "SELECT\n  '/* val 1 */' AS c1,\n  '-- val 2' AS 
c2\nFROM tbl"

Review Comment:
   [suggesting] it could be good to have some sort of reusable 
`compare_sql(strict=False, case_sensitive=False, disregard_schema_prefix=True)` 
function that could be reused in unit tests. Maybe it's a method of 
`ParseQuery` (`is_identical` or `is_similar`)



##########
superset/sql_parse.py:
##########
@@ -252,6 +253,182 @@ def __eq__(self, __o: object) -> bool:
         return str(self) == str(__o)
 
 
+def extract_tables_from_statement(
+    statement: exp.Expression,
+    dialect: Optional[Dialects],
+) -> set[Table]:
+    """
+    Extract all table references in a single statement.
+
+    Please not that this is not trivial; consider the following queries:
+
+        DESCRIBE some_table;
+        SHOW PARTITIONS FROM some_table;
+        WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM 
masked_name;
+
+    See the unit tests for other tricky cases.
+    """
+    sources: Iterable[exp.Table]
+
+    if isinstance(statement, exp.Describe):
+        # A `DESCRIBE` query has no sources in sqlglot, so we need to 
explicitly
+        # query for all tables.
+        sources = statement.find_all(exp.Table)
+    elif isinstance(statement, exp.Command):
+        # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
+        # `SELECT` statetement in order to extract tables.
+        literal = statement.find(exp.Literal)
+        if not literal:
+            return set()
+
+        pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
+        sources = pseudo_query.find_all(exp.Table)
+    else:
+        sources = [
+            source
+            for scope in traverse_scope(statement)
+            for source in scope.sources.values()
+            if isinstance(source, exp.Table) and not is_cte(source, scope)
+        ]
+
+    return {
+        Table(
+            source.name,
+            source.db if source.db != "" else None,
+            source.catalog if source.catalog != "" else None,
+        )
+        for source in sources
+    }
+
+
+def is_cte(source: exp.Table, scope: Scope) -> bool:
+    """
+    Is the source a CTE?
+
+    CTEs in the parent scope look like tables (and are represented by
+    exp.Table objects), but should not be considered as such;
+    otherwise a user with access to table `foo` could access any table
+    with a query like this:
+
+        WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
+
+    """
+    parent_sources = scope.parent.sources if scope.parent else {}
+    ctes_in_scope = {
+        name
+        for name, parent_scope in parent_sources.items()
+        if isinstance(parent_scope, Scope) and parent_scope.scope_type == 
ScopeType.CTE
+    }
+
+    return source.name in ctes_in_scope
+
+
+class SQLQuery:
+    """
+    A SQL query, with 0+ statements.
+    """
+
+    def __init__(
+        self,
+        query: str,
+        engine: Optional[str] = None,
+    ):
+        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+
+        self.statements = [
+            SQLStatement(statement, engine=engine)
+            for statement in parse(query, dialect=dialect)
+            if statement
+        ]
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL query.
+        """
+        return ";\n".join(statement.format(comments) for statement in 
self.statements)
+
+    def get_settings(self) -> dict[str, str]:
+        """
+        Return the settings for the SQL query.
+
+            >>> statement = SQLQuery("SET foo = 'bar'; SET foo = 'baz'")
+            >>> statement.get_settings()
+            {"foo": "'baz'"}
+
+        """
+        settings: dict[str, str] = {}
+        for statement in self.statements:
+            settings.update(statement.get_settings())
+
+        return settings
+
+
+class SQLStatement:
+    """
+    A SQL statement.
+
+    This class provides helper methods to manipulate and introspect SQL.
+    """
+
+    def __init__(
+        self,
+        statement: Union[str, exp.Expression],
+        engine: Optional[str] = None,
+    ):
+        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+
+        if isinstance(statement, str):
+            try:
+                self._parsed = self._parse_statement(statement, dialect)
+            except ParseError as ex:
+                raise SupersetParseError(statement, engine) from ex
+        else:
+            self._parsed = statement
+
+        self._dialect = dialect
+        self.tables = extract_tables_from_statement(self._parsed, dialect)
+
+    @staticmethod
+    def _parse_statement(
+        sql_statement: str,
+        dialect: Optional[Dialects],
+    ) -> exp.Expression:
+        """
+        Parse a single SQL statement.
+        """
+        statements = [
+            statement
+            for statement in sqlglot.parse(sql_statement, dialect=dialect)
+            if statement
+        ]
+        if len(statements) != 1:
+            raise ValueError("SQLStatement should have exactly one statement")
+
+        return statements[0]
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL statement.
+        """
+        write = Dialect.get_or_raise(self._dialect)
+        return write.generate(self._parsed, copy=False, comments=comments, 
pretty=True)
+
+    def get_settings(self) -> dict[str, str]:

Review Comment:
   `get_settings` could have a better name, maybe `get_set_schema_setting`. I'm 
not sure how cross-dialect this is. This may need a `db_engine_spec` method... 
Related convo with GPT ->
   
   <img width="849" alt="Screenshot 2024-02-07 at 12 30 38 PM" 
src="https://github.com/apache/superset/assets/487433/a6392896-8230-44bc-a26e-d229eb2cd580";>
   



##########
superset/sql_parse.py:
##########
@@ -22,12 +22,13 @@
 import urllib.parse
 from collections.abc import Iterable, Iterator
 from dataclasses import dataclass
-from typing import Any, cast, Optional
+from typing import Any, cast, Optional, Union

Review Comment:
   It could be good to rename this module to a better name, the current name 
can be confused with having a relationship with `sqlparse`, and overall just 
isn't a good name. It could be `utils/sql.py`, or more directly 
`superset/sql/*` if we need to grow this into a package with multiple modules. 
`superset/sql_parser.py` (?)



##########
superset/sql_parse.py:
##########
@@ -252,6 +253,182 @@ def __eq__(self, __o: object) -> bool:
         return str(self) == str(__o)
 
 
+def extract_tables_from_statement(
+    statement: exp.Expression,
+    dialect: Optional[Dialects],
+) -> set[Table]:
+    """
+    Extract all table references in a single statement.
+
+    Please not that this is not trivial; consider the following queries:
+
+        DESCRIBE some_table;
+        SHOW PARTITIONS FROM some_table;
+        WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM 
masked_name;
+
+    See the unit tests for other tricky cases.
+    """
+    sources: Iterable[exp.Table]
+
+    if isinstance(statement, exp.Describe):
+        # A `DESCRIBE` query has no sources in sqlglot, so we need to 
explicitly
+        # query for all tables.
+        sources = statement.find_all(exp.Table)
+    elif isinstance(statement, exp.Command):
+        # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
+        # `SELECT` statetement in order to extract tables.
+        literal = statement.find(exp.Literal)
+        if not literal:
+            return set()
+
+        pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
+        sources = pseudo_query.find_all(exp.Table)
+    else:
+        sources = [
+            source
+            for scope in traverse_scope(statement)
+            for source in scope.sources.values()
+            if isinstance(source, exp.Table) and not is_cte(source, scope)
+        ]
+
+    return {
+        Table(
+            source.name,
+            source.db if source.db != "" else None,
+            source.catalog if source.catalog != "" else None,
+        )
+        for source in sources
+    }
+
+
+def is_cte(source: exp.Table, scope: Scope) -> bool:
+    """
+    Is the source a CTE?
+
+    CTEs in the parent scope look like tables (and are represented by
+    exp.Table objects), but should not be considered as such;
+    otherwise a user with access to table `foo` could access any table
+    with a query like this:
+
+        WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
+
+    """
+    parent_sources = scope.parent.sources if scope.parent else {}
+    ctes_in_scope = {
+        name
+        for name, parent_scope in parent_sources.items()
+        if isinstance(parent_scope, Scope) and parent_scope.scope_type == 
ScopeType.CTE
+    }
+
+    return source.name in ctes_in_scope
+
+
+class SQLQuery:
+    """
+    A SQL query, with 0+ statements.
+    """
+
+    def __init__(
+        self,
+        query: str,
+        engine: Optional[str] = None,
+    ):
+        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+
+        self.statements = [
+            SQLStatement(statement, engine=engine)
+            for statement in parse(query, dialect=dialect)
+            if statement
+        ]
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL query.
+        """
+        return ";\n".join(statement.format(comments) for statement in 
self.statements)
+
+    def get_settings(self) -> dict[str, str]:
+        """
+        Return the settings for the SQL query.
+
+            >>> statement = SQLQuery("SET foo = 'bar'; SET foo = 'baz'")
+            >>> statement.get_settings()
+            {"foo": "'baz'"}
+
+        """
+        settings: dict[str, str] = {}
+        for statement in self.statements:
+            settings.update(statement.get_settings())
+
+        return settings
+
+
+class SQLStatement:

Review Comment:
   Looking at the class design in this module, this looks like it really 
extract the first statement from a set of statement. It may not be for this PR, 
but ideally the classes become a better representation of the real word. Maybe 
there's a notion of a `SqlScript` or `SqlStatementCollection`, that's an array 
of other SQL objects like `SqlQuery`, `SqlDMLCommand`, ...
   
   Maybe there's benefit in having inheritance defined there where a SqlQuery 
and DmlCommand are both derivative or `SqlStatement`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: notifications-unsubscr...@superset.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscr...@superset.apache.org
For additional commands, e-mail: notifications-h...@superset.apache.org

Reply via email to