This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch sqlglot-ctas-cvas in repository https://gitbox.apache.org/repos/asf/superset.git
commit bb1a7dc412e315cce1893102ca5bcc6d44393677 Author: Beto Dealmeida <[email protected]> AuthorDate: Mon May 19 18:58:23 2025 -0400 feat: implement CVAS/CTAS in sqlglot --- superset/sql/parse.py | 78 +++++++++++++++++++++ tests/unit_tests/sql/parse_tests.py | 133 ++++++++++++++++++++++++++++++++++++ 2 files changed, 211 insertions(+) diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 1ca100975f..0457befce3 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -108,6 +108,11 @@ class LimitMethod(enum.Enum): FETCH_MANY = enum.auto() +class CTASMethod(enum.Enum): + TABLE = enum.auto() + VIEW = enum.auto() + + class RLSMethod(enum.Enum): """ Methods for enforcing RLS. @@ -381,6 +386,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def is_select(self) -> bool: + """ + Check if the statement is a `SELECT` statement. + """ + raise NotImplementedError() + def is_mutating(self) -> bool: """ Check if the statement mutates data (DDL/DML). @@ -437,6 +448,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def as_create_table(self, table: Table, method: CTASMethod) -> SQLStatement: + """ + Rewrite the statement as a `CREATE TABLE AS` statement. + + :param table: The table to create. + :param method: The method to use for creating the table. + :return: A new SQLStatement with the CTE. + """ + raise NotImplementedError() + def apply_rls( self, catalog: str | None, @@ -572,6 +593,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): dialect = SQLGLOT_DIALECTS.get(engine) return extract_tables_from_statement(parsed, dialect) + def is_select(self) -> bool: + """ + Check if the statement is a `SELECT` statement. + """ + return isinstance(self._parsed, exp.Select) + def is_mutating(self) -> bool: """ Check if the statement mutates data (DDL/DML). @@ -733,6 +760,22 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): engine=self.engine, ) + def as_create_table(self, table: Table, method: CTASMethod) -> SQLStatement: + """ + Rewrite the statement as a `CREATE TABLE AS` statement. + + :param table: The table to create. + :param method: The method to use for creating the table. + :return: A new SQLStatement with the CTE. + """ + create_table = exp.Create( + this=sqlglot.parse_one(str(table), into=exp.Table), + kind=method.name, + expression=self._parsed.copy(), + ) + + return SQLStatement(ast=create_table, engine=self.engine) + def apply_rls( self, catalog: str | None, @@ -988,6 +1031,12 @@ class KustoKQLStatement(BaseSQLStatement[str]): return {} + def is_select(self) -> bool: + """ + Check if the statement is a `SELECT` statement. + """ + return not self._parsed.startswith(".") + def is_mutating(self) -> bool: """ Check if the statement mutates data (DDL/DML). @@ -1142,6 +1191,35 @@ class SQLScript: for statement in self.statements ) + def is_valid_ctas(self) -> bool: + """ + Check if the script contains a valid CTAS statement. + + CTAS (`CREATE TABLE AS SELECT`) can only be run with scripts where the last + statement is a `SELECT`. + """ + # `sqlglot` parses comments after a semicolon into their own statement + valid_statements = [ + statement + for statement in self.statements + if not isinstance(statement._parsed, exp.Semicolon) + ] + return valid_statements[-1].is_select() + + def is_valid_cvas(self) -> bool: + """ + Check if the script contains a valid CVAS statement. + + CVAS (`CREATE VIEW AS SELECT`) can only be run with scripts with a single + `SELECT` statement. + """ + valid_statements = [ + statement + for statement in self.statements + if not isinstance(statement._parsed, exp.Semicolon) + ] + return len(valid_statements) == 1 and valid_statements[0].is_select() + def extract_tables_from_statement( statement: exp.Expression, diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 72907f92c7..f5cf81b9b9 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -22,6 +22,7 @@ from sqlglot import Dialects, parse_one from superset.exceptions import SupersetParseError from superset.sql.parse import ( + CTASMethod, extract_tables_from_statement, KustoKQLStatement, LimitMethod, @@ -2247,3 +2248,135 @@ def test_rls_predicate_transformer( RLSMethod.AS_PREDICATE, ) assert statement.format() == expected + + [email protected]( + "sql, table, expected", + [ + ( + "SELECT * FROM some_table", + Table("some_table"), + """ +CREATE TABLE some_table AS +SELECT + * +FROM some_table + """.strip(), + ), + ( + "SELECT * FROM some_table", + Table("some_table", "schema1", "catalog1"), + """ +CREATE TABLE catalog1.schema1.some_table AS +SELECT + * +FROM some_table + """.strip(), + ), + ], +) +def test_as_create_table(sql: str, table: Table, expected: str) -> None: + """ + Test the `as_create_table` method. + """ + statement = SQLStatement(sql) + create_table = statement.as_create_table(table, CTASMethod.TABLE) + assert create_table.format() == expected + + [email protected]( + "sql, engine, expected", + [ + ("SELECT * FROM table", "postgresql", True), + ( + """ +-- comment +SELECT * FROM table +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +-- comment +SET @value = 42; +SELECT @value as foo; +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +-- comment +EXPLAIN SELECT * FROM table +-- comment 2 + """, + "mysql", + False, + ), + ( + """ +SELECT * FROM table; +INSERT INTO TABLE (foo) VALUES (42); + """, + "mysql", + False, + ), + ], +) +def test_is_valid_ctas(sql: str, engine: str, expected: bool) -> None: + """ + Test the `is_valid_ctas` method. + """ + assert SQLScript(sql, engine).is_valid_ctas() == expected + + [email protected]( + "sql, engine, expected", + [ + ("SELECT * FROM table", "postgresql", True), + ( + """ +-- comment +SELECT * FROM table +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +-- comment +SET @value = 42; +SELECT @value as foo; +-- comment 2 + """, + "mysql", + False, + ), + ( + """ +-- comment +SELECT value as foo; +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +SELECT * FROM table; +INSERT INTO TABLE (foo) VALUES (42); + """, + "mysql", + False, + ), + ], +) +def test_is_valid_cvas(sql: str, engine: str, expected: bool) -> None: + """ + Test the `is_valid_cvas` method. + """ + assert SQLScript(sql, engine).is_valid_cvas() == expected
