This is an automated email from the ASF dual-hosted git repository.
dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 5dfbab5424 fix: adds the ability to disallow SQL functions per engine
(#28639)
5dfbab5424 is described below
commit 5dfbab542422e6f68b020bc0bccf41caa3e1f248
Author: Daniel Vaz Gaspar <[email protected]>
AuthorDate: Wed May 29 10:51:28 2024 +0100
fix: adds the ability to disallow SQL functions per engine (#28639)
---
superset/config.py | 9 ++++++
superset/db_engine_specs/base.py | 7 ++++-
superset/db_engine_specs/trino.py | 11 ++++---
superset/exceptions.py | 15 +++++++++
superset/sql_parse.py | 42 ++++++++++++++++++++++++++
tests/unit_tests/db_engine_specs/test_trino.py | 24 +++++++++------
tests/unit_tests/sql_parse_tests.py | 26 ++++++++++++++++
7 files changed, 119 insertions(+), 15 deletions(-)
diff --git a/superset/config.py b/superset/config.py
index 3c92354322..aa8178d086 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1227,6 +1227,15 @@ DB_CONNECTION_MUTATOR = None
#
DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None
+# A set of disallowed SQL functions per engine. This is used to restrict the
use of
+# unsafe SQL functions in SQL Lab and Charts. The keys of the dictionary are
the engine
+# names, and the values are sets of disallowed functions.
+DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = {
+ "postgresql": {"version", "query_to_xml", "inet_server_addr",
"inet_client_addr"},
+ "clickhouse": {"url"},
+ "mysql": {"version"},
+}
+
# A function that intercepts the SQL to be executed and can alter it.
# A common use case for this is around adding some sort of comment header to
the SQL
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 6df0dc61aa..548fb390d8 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -62,7 +62,7 @@ from superset import sql_parse
from superset.constants import TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from superset.exceptions import OAuth2Error, OAuth2RedirectError
+from superset.exceptions import DisallowedSQLFunction, OAuth2Error,
OAuth2RedirectError
from superset.sql_parse import ParsedQuery, SQLScript, Table
from superset.superset_typing import (
OAuth2ClientConfig,
@@ -1818,6 +1818,11 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
"""
if not cls.allows_sql_comments:
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
+ disallowed_functions =
current_app.config["DISALLOWED_SQL_FUNCTIONS"].get(
+ cls.engine, set()
+ )
+ if sql_parse.check_sql_functions_exist(query, disallowed_functions,
cls.engine):
+ raise DisallowedSQLFunction(disallowed_functions)
if cls.arraysize:
cursor.arraysize = cls.arraysize
diff --git a/superset/db_engine_specs/trino.py
b/superset/db_engine_specs/trino.py
index eea00877d9..600f236b48 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -22,7 +22,7 @@ import threading
import time
from typing import Any, TYPE_CHECKING
-from flask import current_app
+from flask import current_app, Flask
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
@@ -218,11 +218,14 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
execute_result: dict[str, Any] = {}
execute_event = threading.Event()
- def _execute(results: dict[str, Any], event: threading.Event) -> None:
+ def _execute(
+ results: dict[str, Any], event: threading.Event, app: Flask
+ ) -> None:
logger.debug("Query %d: Running query: %s", query_id, sql)
try:
- cls.execute(cursor, sql, query.database)
+ with app.app_context():
+ cls.execute(cursor, sql, query.database)
except Exception as ex: # pylint: disable=broad-except
results["error"] = ex
finally:
@@ -230,7 +233,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
execute_thread = threading.Thread(
target=_execute,
- args=(execute_result, execute_event),
+ args=(execute_result, execute_event,
current_app._get_current_object()), # pylint: disable=protected-access
)
execute_thread.start()
diff --git a/superset/exceptions.py b/superset/exceptions.py
index 0315ee30f4..47cd511f8f 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -358,6 +358,21 @@ class OAuth2Error(SupersetErrorException):
)
+class DisallowedSQLFunction(SupersetErrorException):
+ """
+ Disallowed function found on SQL statement
+ """
+
+ def __init__(self, functions: set[str]):
+ super().__init__(
+ SupersetError(
+ message=f"SQL statement contains disallowed function(s):
{functions}",
+ error_type=SupersetErrorType.SYNTAX_ERROR,
+ level=ErrorLevel.ERROR,
+ )
+ )
+
+
class CreateKeyValueDistributedLockFailedException(Exception):
"""
Exception to signalize failure to acquire lock.
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index f32647042b..192a998c3f 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -39,6 +39,7 @@ from sqlglot.optimizer.scope import Scope, ScopeType,
traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
+ Function,
Identifier,
IdentifierList,
Parenthesis,
@@ -223,6 +224,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None,
str]:
return cte, remainder
+def check_sql_functions_exist(
+ sql: str, function_list: set[str], engine: str | None = None
+) -> bool:
+ """
+ Check if the SQL statement contains any of the specified functions.
+
+ :param sql: The SQL statement
+ :param function_list: The list of functions to search for
+ :param engine: The engine to use for parsing the SQL statement
+ """
+ return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
+
+
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
@@ -743,6 +757,34 @@ class ParsedQuery:
self._tables = self._extract_tables_from_sql()
return self._tables
+ def _check_functions_exist_in_token(
+ self, token: Token, functions: set[str]
+ ) -> bool:
+ if (
+ isinstance(token, Function)
+ and token.get_name() is not None
+ and token.get_name().lower() in functions
+ ):
+ return True
+ if hasattr(token, "tokens"):
+ for inner_token in token.tokens:
+ if self._check_functions_exist_in_token(inner_token,
functions):
+ return True
+ return False
+
+ def check_functions_exist(self, functions: set[str]) -> bool:
+ """
+ Check if the SQL statement contains any of the specified functions.
+
+ :param functions: A set of functions to search for
+ :return: True if the statement contains any of the specified functions
+ """
+ for statement in self._parsed:
+ for token in statement.tokens:
+ if self._check_functions_exist_in_token(token, functions):
+ return True
+ return False
+
def _extract_tables_from_sql(self) -> set[Table]:
"""
Extract all table references in a query.
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py
b/tests/unit_tests/db_engine_specs/test_trino.py
index 88608f1e38..3a2ac91ad6 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -401,7 +401,7 @@ def test_handle_cursor_early_cancel(
assert cancel_query_mock.call_args is None
-def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
+def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
from superset.db_engine_specs.trino import TrinoEngineSpec
@@ -416,16 +416,20 @@ def test_execute_with_cursor_in_parallel(mocker:
MockerFixture):
mock_cursor.query_id = query_id
mock_cursor.execute.side_effect = _mock_execute
+ with patch.dict(
+ "superset.config.DISALLOWED_SQL_FUNCTIONS",
+ {},
+ clear=True,
+ ):
+ TrinoEngineSpec.execute_with_cursor(
+ cursor=mock_cursor,
+ sql="SELECT 1 FROM foo",
+ query=mock_query,
+ )
- TrinoEngineSpec.execute_with_cursor(
- cursor=mock_cursor,
- sql="SELECT 1 FROM foo",
- query=mock_query,
- )
-
- mock_query.set_extra_json_key.assert_called_once_with(
- key=QUERY_CANCEL_KEY, value=query_id
- )
+ mock_query.set_extra_json_key.assert_called_once_with(
+ key=QUERY_CANCEL_KEY, value=query_id
+ )
def test_get_columns(mocker: MockerFixture):
diff --git a/tests/unit_tests/sql_parse_tests.py
b/tests/unit_tests/sql_parse_tests.py
index 3b80b7e01d..6259d6272d 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -32,6 +32,7 @@ from superset.exceptions import (
)
from superset.sql_parse import (
add_table_name,
+ check_sql_functions_exist,
extract_table_references,
extract_tables_from_jinja_sql,
get_rls_for_table,
@@ -1215,6 +1216,31 @@ def test_strip_comments_from_sql() -> None:
)
+def test_check_sql_functions_exist() -> None:
+ """
+ Test that comments are stripped out correctly.
+ """
+ assert not (
+ check_sql_functions_exist("select a, b from version", {"version"},
"postgresql")
+ )
+
+ assert check_sql_functions_exist("select version()", {"version"},
"postgresql")
+
+ assert check_sql_functions_exist(
+ "select version from version()", {"version"}, "postgresql"
+ )
+
+ assert check_sql_functions_exist(
+ "select 1, a.version from (select version from version()) as a",
+ {"version"},
+ "postgresql",
+ )
+
+ assert check_sql_functions_exist(
+ "select 1, a.version from (select version()) as a", {"version"},
"postgresql"
+ )
+
+
def test_sanitize_clause_valid():
# regular clauses
assert sanitize_clause("col = 1") == "col = 1"