aminghadersohi commented on code in PR #38414: URL: https://github.com/apache/superset/pull/38414#discussion_r2894760602
########## tests/unit_tests/mcp_service/sql_lab/tool/test_save_sql_query.py: ########## @@ -0,0 +1,443 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for save_sql_query MCP tool schemas and logic. +""" + +import importlib +import sys +import types +from unittest.mock import MagicMock, Mock, patch + +import pytest +from pydantic import ValidationError + +from superset.mcp_service.sql_lab.schemas import ( + SaveSqlQueryRequest, + SaveSqlQueryResponse, +) + + +class TestSaveSqlQueryRequest: + """Test SaveSqlQueryRequest schema validation.""" + + def test_valid_request(self) -> None: + req = SaveSqlQueryRequest( + database_id=1, + label="Revenue Query", + sql="SELECT SUM(revenue) FROM sales", + ) + assert req.database_id == 1 + assert req.label == "Revenue Query" + assert req.sql == "SELECT SUM(revenue) FROM sales" + + def test_with_optional_fields(self) -> None: + req = SaveSqlQueryRequest( + database_id=1, + label="Revenue Query", + sql="SELECT 1", + schema="public", + catalog="main", + description="Sums revenue", + ) + assert req.schema_name == "public" + assert req.catalog == "main" + assert req.description == "Sums revenue" + + def test_empty_sql_fails(self) -> None: + with pytest.raises(ValidationError, match="SQL query cannot be empty"): + SaveSqlQueryRequest(database_id=1, label="test", sql=" ") + + def test_empty_label_fails(self) -> None: + with pytest.raises(ValidationError, match="Label cannot be empty"): + SaveSqlQueryRequest(database_id=1, label=" ", sql="SELECT 1") + + def test_sql_is_stripped(self) -> None: + req = SaveSqlQueryRequest(database_id=1, label="test", sql=" SELECT 1 ") + assert req.sql == "SELECT 1" + + def test_label_is_stripped(self) -> None: + req = SaveSqlQueryRequest(database_id=1, label=" My Query ", sql="SELECT 1") + assert req.label == "My Query" + + def test_label_max_length(self) -> None: + with pytest.raises(ValidationError, match="String should have at most 256"): + SaveSqlQueryRequest(database_id=1, label="a" * 257, sql="SELECT 1") + + def test_schema_alias(self) -> None: + """The field accepts 'schema' as alias for 'schema_name'.""" + req = SaveSqlQueryRequest( + database_id=1, + label="test", + sql="SELECT 1", + schema="public", + ) + assert req.schema_name == "public" + + +class TestSaveSqlQueryResponse: + """Test SaveSqlQueryResponse schema.""" + + def test_response_fields(self) -> None: + resp = SaveSqlQueryResponse( + id=42, + label="Revenue", + sql="SELECT 1", + database_id=1, + url="/sqllab?savedQueryId=42", + ) + assert resp.id == 42 + assert resp.label == "Revenue" + assert resp.url == "/sqllab?savedQueryId=42" + + def test_response_with_optional_fields(self) -> None: + resp = SaveSqlQueryResponse( + id=42, + label="Revenue", + sql="SELECT 1", + database_id=1, + schema="public", + description="A query", + url="/sqllab?savedQueryId=42", + ) + assert resp.schema_name == "public" + assert resp.description == "A query" + + +def _force_passthrough_decorators(): + """Force superset_core.api.mcp.tool to be a passthrough decorator. + + In CI, superset_core is fully installed and the real @tool decorator + includes authentication middleware. For unit tests we want to bypass + auth and test the tool logic directly, so we always replace the + decorator with a passthrough regardless of installation state. + + Returns a dict of original sys.modules entries so they can be restored. + """ + + def _passthrough_tool(func=None, **kwargs): + if func is not None: + return func + return lambda f: f + + mock_mcp = MagicMock() + mock_mcp.tool = _passthrough_tool + + mock_api = MagicMock() + mock_api.mcp = mock_mcp + + # Save original modules so we can restore them later + saved_modules: dict[str, types.ModuleType] = {} + superset_core_keys = [k for k in sys.modules if k.startswith("superset_core")] + for key in superset_core_keys: + saved_modules[key] = sys.modules.pop(key) + + # Mock all possible import paths for superset_core + sys.modules["superset_core"] = MagicMock() + sys.modules["superset_core.api"] = mock_api + sys.modules["superset_core.api.mcp"] = mock_mcp + sys.modules["superset_core.mcp"] = mock_mcp + sys.modules.setdefault("superset_core.api.types", MagicMock()) + + return saved_modules + + +def _restore_modules(saved_modules: dict[str, types.ModuleType]) -> None: + """Restore original sys.modules entries after passthrough mocking.""" + # Remove mock entries + for key in list(sys.modules.keys()): + if key.startswith("superset_core"): + del sys.modules[key] + # Restore originals Review Comment: Great catch\! You're right — `_restore_modules` wasn't cleaning up the tool modules imported under the patched decorators. I've updated it to also remove `superset.mcp_service.sql_lab.tool.*` entries from `sys.modules` before restoring originals. Fixed in the latest commit. Thank you\! ########## superset/mcp_service/sql_lab/schemas.py: ########## @@ -115,6 +115,64 @@ class ExecuteSqlResponse(BaseModel): ) +class SaveSqlQueryRequest(BaseModel): + """Request schema for saving a SQL query.""" + + database_id: int = Field( + ..., description="Database connection ID the query runs against" + ) + label: str = Field( + ..., + description="Name for the saved query (shown in Saved Queries list)", + min_length=1, + max_length=256, + ) + sql: str = Field( + ..., + description="SQL query text to save", + ) + schema_name: str | None = Field( + None, + description="Schema the query targets", + alias="schema", + ) + catalog: str | None = Field(None, description="Catalog name (if applicable)") + description: str | None = Field( + None, description="Optional description of the query" + ) + + @field_validator("sql") + @classmethod + def sql_not_empty(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("SQL query cannot be empty") + return v.strip() + + @field_validator("label") + @classmethod + def label_not_empty(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("Label cannot be empty") + return v.strip() + + +class SaveSqlQueryResponse(BaseModel): + """Response schema for a saved SQL query.""" + + id: int = Field(..., description="Saved query ID") + label: str = Field(..., description="Query name") + sql: str = Field(..., description="SQL query text") + database_id: int = Field(..., description="Database ID") + schema_name: str | None = Field(None, description="Schema name", alias="schema") + description: str | None = Field(None, description="Query description") Review Comment: Good suggestion — the `SavedQuery` model does include a `catalog` field, and it makes sense to expose it in the response. I've added `catalog` to `SaveSqlQueryResponse` and also included it in the tool's response construction. Fixed in the latest commit. Thanks\! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
