This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch sqlglot-ctas-cvas
in repository https://gitbox.apache.org/repos/asf/superset.git

commit bb1a7dc412e315cce1893102ca5bcc6d44393677
Author: Beto Dealmeida <[email protected]>
AuthorDate: Mon May 19 18:58:23 2025 -0400

    feat: implement CVAS/CTAS in sqlglot
---
 superset/sql/parse.py               |  78 +++++++++++++++++++++
 tests/unit_tests/sql/parse_tests.py | 133 ++++++++++++++++++++++++++++++++++++
 2 files changed, 211 insertions(+)

diff --git a/superset/sql/parse.py b/superset/sql/parse.py
index 1ca100975f..0457befce3 100644
--- a/superset/sql/parse.py
+++ b/superset/sql/parse.py
@@ -108,6 +108,11 @@ class LimitMethod(enum.Enum):
     FETCH_MANY = enum.auto()
 
 
+class CTASMethod(enum.Enum):
+    TABLE = enum.auto()
+    VIEW = enum.auto()
+
+
 class RLSMethod(enum.Enum):
     """
     Methods for enforcing RLS.
@@ -381,6 +386,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
         """
         raise NotImplementedError()
 
+    def is_select(self) -> bool:
+        """
+        Check if the statement is a `SELECT` statement.
+        """
+        raise NotImplementedError()
+
     def is_mutating(self) -> bool:
         """
         Check if the statement mutates data (DDL/DML).
@@ -437,6 +448,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
         """
         raise NotImplementedError()
 
+    def as_create_table(self, table: Table, method: CTASMethod) -> 
SQLStatement:
+        """
+        Rewrite the statement as a `CREATE TABLE AS` statement.
+
+        :param table: The table to create.
+        :param method: The method to use for creating the table.
+        :return: A new SQLStatement with the CTE.
+        """
+        raise NotImplementedError()
+
     def apply_rls(
         self,
         catalog: str | None,
@@ -572,6 +593,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         dialect = SQLGLOT_DIALECTS.get(engine)
         return extract_tables_from_statement(parsed, dialect)
 
+    def is_select(self) -> bool:
+        """
+        Check if the statement is a `SELECT` statement.
+        """
+        return isinstance(self._parsed, exp.Select)
+
     def is_mutating(self) -> bool:
         """
         Check if the statement mutates data (DDL/DML).
@@ -733,6 +760,22 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
             engine=self.engine,
         )
 
+    def as_create_table(self, table: Table, method: CTASMethod) -> 
SQLStatement:
+        """
+        Rewrite the statement as a `CREATE TABLE AS` statement.
+
+        :param table: The table to create.
+        :param method: The method to use for creating the table.
+        :return: A new SQLStatement with the CTE.
+        """
+        create_table = exp.Create(
+            this=sqlglot.parse_one(str(table), into=exp.Table),
+            kind=method.name,
+            expression=self._parsed.copy(),
+        )
+
+        return SQLStatement(ast=create_table, engine=self.engine)
+
     def apply_rls(
         self,
         catalog: str | None,
@@ -988,6 +1031,12 @@ class KustoKQLStatement(BaseSQLStatement[str]):
 
         return {}
 
+    def is_select(self) -> bool:
+        """
+        Check if the statement is a `SELECT` statement.
+        """
+        return not self._parsed.startswith(".")
+
     def is_mutating(self) -> bool:
         """
         Check if the statement mutates data (DDL/DML).
@@ -1142,6 +1191,35 @@ class SQLScript:
             for statement in self.statements
         )
 
+    def is_valid_ctas(self) -> bool:
+        """
+        Check if the script contains a valid CTAS statement.
+
+        CTAS (`CREATE TABLE AS SELECT`) can only be run with scripts where the 
last
+        statement is a `SELECT`.
+        """
+        # `sqlglot` parses comments after a semicolon into their own statement
+        valid_statements = [
+            statement
+            for statement in self.statements
+            if not isinstance(statement._parsed, exp.Semicolon)
+        ]
+        return valid_statements[-1].is_select()
+
+    def is_valid_cvas(self) -> bool:
+        """
+        Check if the script contains a valid CVAS statement.
+
+        CVAS (`CREATE VIEW AS SELECT`) can only be run with scripts with a 
single
+        `SELECT` statement.
+        """
+        valid_statements = [
+            statement
+            for statement in self.statements
+            if not isinstance(statement._parsed, exp.Semicolon)
+        ]
+        return len(valid_statements) == 1 and valid_statements[0].is_select()
+
 
 def extract_tables_from_statement(
     statement: exp.Expression,
diff --git a/tests/unit_tests/sql/parse_tests.py 
b/tests/unit_tests/sql/parse_tests.py
index 72907f92c7..f5cf81b9b9 100644
--- a/tests/unit_tests/sql/parse_tests.py
+++ b/tests/unit_tests/sql/parse_tests.py
@@ -22,6 +22,7 @@ from sqlglot import Dialects, parse_one
 
 from superset.exceptions import SupersetParseError
 from superset.sql.parse import (
+    CTASMethod,
     extract_tables_from_statement,
     KustoKQLStatement,
     LimitMethod,
@@ -2247,3 +2248,135 @@ def test_rls_predicate_transformer(
         RLSMethod.AS_PREDICATE,
     )
     assert statement.format() == expected
+
+
[email protected](
+    "sql, table, expected",
+    [
+        (
+            "SELECT * FROM some_table",
+            Table("some_table"),
+            """
+CREATE TABLE some_table AS
+SELECT
+  *
+FROM some_table
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM some_table",
+            Table("some_table", "schema1", "catalog1"),
+            """
+CREATE TABLE catalog1.schema1.some_table AS
+SELECT
+  *
+FROM some_table
+            """.strip(),
+        ),
+    ],
+)
+def test_as_create_table(sql: str, table: Table, expected: str) -> None:
+    """
+    Test the `as_create_table` method.
+    """
+    statement = SQLStatement(sql)
+    create_table = statement.as_create_table(table, CTASMethod.TABLE)
+    assert create_table.format() == expected
+
+
[email protected](
+    "sql, engine, expected",
+    [
+        ("SELECT * FROM table", "postgresql", True),
+        (
+            """
+-- comment
+SELECT * FROM table
+-- comment 2
+            """,
+            "mysql",
+            True,
+        ),
+        (
+            """
+-- comment
+SET @value = 42;
+SELECT @value as foo;
+-- comment 2
+            """,
+            "mysql",
+            True,
+        ),
+        (
+            """
+-- comment
+EXPLAIN SELECT * FROM table
+-- comment 2
+            """,
+            "mysql",
+            False,
+        ),
+        (
+            """
+SELECT * FROM table;
+INSERT INTO TABLE (foo) VALUES (42);
+            """,
+            "mysql",
+            False,
+        ),
+    ],
+)
+def test_is_valid_ctas(sql: str, engine: str, expected: bool) -> None:
+    """
+    Test the `is_valid_ctas` method.
+    """
+    assert SQLScript(sql, engine).is_valid_ctas() == expected
+
+
[email protected](
+    "sql, engine, expected",
+    [
+        ("SELECT * FROM table", "postgresql", True),
+        (
+            """
+-- comment
+SELECT * FROM table
+-- comment 2
+            """,
+            "mysql",
+            True,
+        ),
+        (
+            """
+-- comment
+SET @value = 42;
+SELECT @value as foo;
+-- comment 2
+            """,
+            "mysql",
+            False,
+        ),
+        (
+            """
+-- comment
+SELECT value as foo;
+-- comment 2
+            """,
+            "mysql",
+            True,
+        ),
+        (
+            """
+SELECT * FROM table;
+INSERT INTO TABLE (foo) VALUES (42);
+            """,
+            "mysql",
+            False,
+        ),
+    ],
+)
+def test_is_valid_cvas(sql: str, engine: str, expected: bool) -> None:
+    """
+    Test the `is_valid_cvas` method.
+    """
+    assert SQLScript(sql, engine).is_valid_cvas() == expected

Reply via email to