This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 5dfbab5424 fix: adds the ability to disallow SQL functions per engine 
(#28639)
5dfbab5424 is described below

commit 5dfbab542422e6f68b020bc0bccf41caa3e1f248
Author: Daniel Vaz Gaspar <[email protected]>
AuthorDate: Wed May 29 10:51:28 2024 +0100

    fix: adds the ability to disallow SQL functions per engine (#28639)
---
 superset/config.py                             |  9 ++++++
 superset/db_engine_specs/base.py               |  7 ++++-
 superset/db_engine_specs/trino.py              | 11 ++++---
 superset/exceptions.py                         | 15 +++++++++
 superset/sql_parse.py                          | 42 ++++++++++++++++++++++++++
 tests/unit_tests/db_engine_specs/test_trino.py | 24 +++++++++------
 tests/unit_tests/sql_parse_tests.py            | 26 ++++++++++++++++
 7 files changed, 119 insertions(+), 15 deletions(-)

diff --git a/superset/config.py b/superset/config.py
index 3c92354322..aa8178d086 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1227,6 +1227,15 @@ DB_CONNECTION_MUTATOR = None
 #
 DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None
 
+# A set of disallowed SQL functions per engine. This is used to restrict the 
use of
+# unsafe SQL functions in SQL Lab and Charts. The keys of the dictionary are 
the engine
+# names, and the values are sets of disallowed functions.
+DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = {
+    "postgresql": {"version", "query_to_xml", "inet_server_addr", 
"inet_client_addr"},
+    "clickhouse": {"url"},
+    "mysql": {"version"},
+}
+
 
 # A function that intercepts the SQL to be executed and can alter it.
 # A common use case for this is around adding some sort of comment header to 
the SQL
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 6df0dc61aa..548fb390d8 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -62,7 +62,7 @@ from superset import sql_parse
 from superset.constants import TimeGrain as TimeGrainConstants
 from superset.databases.utils import get_table_metadata, make_url_safe
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from superset.exceptions import OAuth2Error, OAuth2RedirectError
+from superset.exceptions import DisallowedSQLFunction, OAuth2Error, 
OAuth2RedirectError
 from superset.sql_parse import ParsedQuery, SQLScript, Table
 from superset.superset_typing import (
     OAuth2ClientConfig,
@@ -1818,6 +1818,11 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         if not cls.allows_sql_comments:
             query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
+        disallowed_functions = 
current_app.config["DISALLOWED_SQL_FUNCTIONS"].get(
+            cls.engine, set()
+        )
+        if sql_parse.check_sql_functions_exist(query, disallowed_functions, 
cls.engine):
+            raise DisallowedSQLFunction(disallowed_functions)
 
         if cls.arraysize:
             cursor.arraysize = cls.arraysize
diff --git a/superset/db_engine_specs/trino.py 
b/superset/db_engine_specs/trino.py
index eea00877d9..600f236b48 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -22,7 +22,7 @@ import threading
 import time
 from typing import Any, TYPE_CHECKING
 
-from flask import current_app
+from flask import current_app, Flask
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
 from sqlalchemy.exc import NoSuchTableError
@@ -218,11 +218,14 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         execute_result: dict[str, Any] = {}
         execute_event = threading.Event()
 
-        def _execute(results: dict[str, Any], event: threading.Event) -> None:
+        def _execute(
+            results: dict[str, Any], event: threading.Event, app: Flask
+        ) -> None:
             logger.debug("Query %d: Running query: %s", query_id, sql)
 
             try:
-                cls.execute(cursor, sql, query.database)
+                with app.app_context():
+                    cls.execute(cursor, sql, query.database)
             except Exception as ex:  # pylint: disable=broad-except
                 results["error"] = ex
             finally:
@@ -230,7 +233,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
 
         execute_thread = threading.Thread(
             target=_execute,
-            args=(execute_result, execute_event),
+            args=(execute_result, execute_event, 
current_app._get_current_object()),  # pylint: disable=protected-access
         )
         execute_thread.start()
 
diff --git a/superset/exceptions.py b/superset/exceptions.py
index 0315ee30f4..47cd511f8f 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -358,6 +358,21 @@ class OAuth2Error(SupersetErrorException):
         )
 
 
+class DisallowedSQLFunction(SupersetErrorException):
+    """
+    Disallowed function found on SQL statement
+    """
+
+    def __init__(self, functions: set[str]):
+        super().__init__(
+            SupersetError(
+                message=f"SQL statement contains disallowed function(s): 
{functions}",
+                error_type=SupersetErrorType.SYNTAX_ERROR,
+                level=ErrorLevel.ERROR,
+            )
+        )
+
+
 class CreateKeyValueDistributedLockFailedException(Exception):
     """
     Exception to signalize failure to acquire lock.
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index f32647042b..192a998c3f 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -39,6 +39,7 @@ from sqlglot.optimizer.scope import Scope, ScopeType, 
traverse_scope
 from sqlparse import keywords
 from sqlparse.lexer import Lexer
 from sqlparse.sql import (
+    Function,
     Identifier,
     IdentifierList,
     Parenthesis,
@@ -223,6 +224,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, 
str]:
     return cte, remainder
 
 
+def check_sql_functions_exist(
+    sql: str, function_list: set[str], engine: str | None = None
+) -> bool:
+    """
+    Check if the SQL statement contains any of the specified functions.
+
+    :param sql: The SQL statement
+    :param function_list: The list of functions to search for
+    :param engine: The engine to use for parsing the SQL statement
+    """
+    return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
+
+
 def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
     """
     Strips comments from a SQL statement, does a simple test first
@@ -743,6 +757,34 @@ class ParsedQuery:
             self._tables = self._extract_tables_from_sql()
         return self._tables
 
+    def _check_functions_exist_in_token(
+        self, token: Token, functions: set[str]
+    ) -> bool:
+        if (
+            isinstance(token, Function)
+            and token.get_name() is not None
+            and token.get_name().lower() in functions
+        ):
+            return True
+        if hasattr(token, "tokens"):
+            for inner_token in token.tokens:
+                if self._check_functions_exist_in_token(inner_token, 
functions):
+                    return True
+        return False
+
+    def check_functions_exist(self, functions: set[str]) -> bool:
+        """
+        Check if the SQL statement contains any of the specified functions.
+
+        :param functions: A set of functions to search for
+        :return: True if the statement contains any of the specified functions
+        """
+        for statement in self._parsed:
+            for token in statement.tokens:
+                if self._check_functions_exist_in_token(token, functions):
+                    return True
+        return False
+
     def _extract_tables_from_sql(self) -> set[Table]:
         """
         Extract all table references in a query.
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py 
b/tests/unit_tests/db_engine_specs/test_trino.py
index 88608f1e38..3a2ac91ad6 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -401,7 +401,7 @@ def test_handle_cursor_early_cancel(
         assert cancel_query_mock.call_args is None
 
 
-def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
+def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
     """Test that `execute_with_cursor` fetches query ID from the cursor"""
     from superset.db_engine_specs.trino import TrinoEngineSpec
 
@@ -416,16 +416,20 @@ def test_execute_with_cursor_in_parallel(mocker: 
MockerFixture):
         mock_cursor.query_id = query_id
 
     mock_cursor.execute.side_effect = _mock_execute
+    with patch.dict(
+        "superset.config.DISALLOWED_SQL_FUNCTIONS",
+        {},
+        clear=True,
+    ):
+        TrinoEngineSpec.execute_with_cursor(
+            cursor=mock_cursor,
+            sql="SELECT 1 FROM foo",
+            query=mock_query,
+        )
 
-    TrinoEngineSpec.execute_with_cursor(
-        cursor=mock_cursor,
-        sql="SELECT 1 FROM foo",
-        query=mock_query,
-    )
-
-    mock_query.set_extra_json_key.assert_called_once_with(
-        key=QUERY_CANCEL_KEY, value=query_id
-    )
+        mock_query.set_extra_json_key.assert_called_once_with(
+            key=QUERY_CANCEL_KEY, value=query_id
+        )
 
 
 def test_get_columns(mocker: MockerFixture):
diff --git a/tests/unit_tests/sql_parse_tests.py 
b/tests/unit_tests/sql_parse_tests.py
index 3b80b7e01d..6259d6272d 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -32,6 +32,7 @@ from superset.exceptions import (
 )
 from superset.sql_parse import (
     add_table_name,
+    check_sql_functions_exist,
     extract_table_references,
     extract_tables_from_jinja_sql,
     get_rls_for_table,
@@ -1215,6 +1216,31 @@ def test_strip_comments_from_sql() -> None:
     )
 
 
+def test_check_sql_functions_exist() -> None:
+    """
+    Test that comments are stripped out correctly.
+    """
+    assert not (
+        check_sql_functions_exist("select a, b from version", {"version"}, 
"postgresql")
+    )
+
+    assert check_sql_functions_exist("select version()", {"version"}, 
"postgresql")
+
+    assert check_sql_functions_exist(
+        "select version from version()", {"version"}, "postgresql"
+    )
+
+    assert check_sql_functions_exist(
+        "select 1, a.version from (select version from version()) as a",
+        {"version"},
+        "postgresql",
+    )
+
+    assert check_sql_functions_exist(
+        "select 1, a.version from (select version()) as a", {"version"}, 
"postgresql"
+    )
+
+
 def test_sanitize_clause_valid():
     # regular clauses
     assert sanitize_clause("col = 1") == "col = 1"

Reply via email to