bito-code-review[bot] commented on code in PR #40346: URL: https://github.com/apache/superset/pull/40346#discussion_r3306603513
########## superset/mcp_service/saved_query/tool/get_saved_query_info.py: ########## @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Get saved query info FastMCP tool + +This module contains the FastMCP tool for getting detailed information +about a specific saved SQL query. +""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.saved_query.schemas import ( + GetSavedQueryInfoRequest, + SavedQueryError, + SavedQueryInfo, + serialize_saved_query_object, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="SavedQuery", + annotations=ToolAnnotations( + title="Get saved query info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_saved_query_info( + request: GetSavedQueryInfoRequest, ctx: Context +) -> SavedQueryInfo | SavedQueryError: + """Get saved query details by ID or UUID. + + Returns the full saved query including SQL text, label, database, + schema, and timestamps. + + IMPORTANT FOR LLM CLIENTS: + - Use numeric ID (e.g., 42) or UUID string (e.g., "a1b2c3d4-...") + - To find a saved query ID, use the list_saved_queries tool first + + Example usage: + ```json + { + "identifier": 42 + } + ``` + + Or with UUID: + ```json + { + "identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab" + } + ``` + """ + await ctx.info( + "Retrieving saved query information: identifier=%s" % (request.identifier,) + ) + + try: + from superset.daos.query import SavedQueryDAO + + with event_logger.log_context(action="mcp.get_saved_query_info.lookup"): + get_tool = ModelGetInfoCore( + dao_class=SavedQueryDAO, + output_schema=SavedQueryInfo, + error_schema=SavedQueryError, + serializer=serialize_saved_query_object, + supports_slug=False, + logger=logger, + ) + + result = get_tool.run_tool(request.identifier) + + if isinstance(result, SavedQueryInfo): + await ctx.info( + "Saved query information retrieved successfully: " + "saved_query_id=%s, label=%s, db_id=%s" + % ( + result.id, + result.label, + result.db_id, + ) + ) + else: + await ctx.warning( + "Saved query retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + + except Exception as e: Review Comment: <!-- Bito Reply --> The code in `get_saved_query_info.py` currently returns raw exception text (`str(e)`) in error responses, which could expose sensitive internal details to clients. This is inconsistent with the `list_saved_queries` tool, which raises exceptions and lets the `GlobalErrorHandlerMiddleware` handle them. To maintain consistency and improve security, the error handling should be unified so that both tools either return structured error responses or raise exceptions that are handled by the middleware. **superset/mcp_service/saved_query/tool/get_saved_query_info.py** ``` except Exception as e: await ctx.error( f"Failed to get saved query info: {str(e)}" ) return SavedQueryError( error_type="get_saved_query_info_failed", error=f"Failed to get saved query info: {str(e)}" ) ``` ########## tests/unit_tests/mcp_service/query/tool/test_query_tools.py: ########## @@ -0,0 +1,302 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import logging +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp import Client +from pydantic import ValidationError + +from superset.mcp_service.app import mcp +from superset.mcp_service.query.schemas import ( + ListQueriesRequest, + QueryFilter, +) +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +class TestQueryFilterSchema: + """Tests for QueryFilter schema — filterable columns.""" + + def test_invalid_filter_column_rejected(self): + """Columns not in the Literal set must be rejected.""" + with pytest.raises(ValidationError): + QueryFilter(col="not_a_real_column", opr="eq", value="test") + + def test_user_id_is_rejected_as_filter_column(self): + """user_id is an internal field and should not be a filter column.""" + with pytest.raises(ValidationError): + QueryFilter(col="user_id", opr="eq", value=1) + + def test_valid_status_filter_accepted(self): + """status is a valid filter column.""" + f = QueryFilter(col="status", opr="eq", value="success") + assert f.col == "status" + + def test_valid_database_id_filter_accepted(self): + """database_id is a valid filter column.""" + f = QueryFilter(col="database_id", opr="eq", value=1) + assert f.col == "database_id" + + def test_valid_schema_filter_accepted(self): + """schema is a valid filter column.""" + f = QueryFilter(col="schema", opr="eq", value="public") + assert f.col == "schema" + + +def create_mock_query( + query_id: int = 1, + sql: str = "SELECT * FROM table", + status: str = "success", + start_time: float = 1700000000.0, + end_time: float = 1700000001.0, + rows: int = 100, + database_id: int = 1, + schema: str = "public", + tab_name: str = "SQL Lab 1", + error_message: str | None = None, + client_id: str = "abc123", +) -> MagicMock: + """Factory function to create mock query objects with sensible defaults.""" + query = MagicMock() + query.id = query_id + query.sql = sql + query.status = status + query.start_time = start_time + query.end_time = end_time + query.rows = rows + query.database_id = database_id + query.schema = schema + query.tab_name = tab_name + query.error_message = error_message + query.client_id = client_id + query.limit = 1000 + query.progress = 100 + query.changed_on = None + return query + + [email protected] +def mcp_server(): + return mcp + + [email protected](autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + from unittest.mock import Mock, patch + + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +@patch("superset.daos.query.QueryDAO.list") [email protected] +async def test_list_queries_basic(mock_list, mcp_server): + """Test basic query listing functionality.""" + query = create_mock_query() + query._mapping = { + "id": query.id, + "sql": query.sql, + "status": query.status, + "start_time": query.start_time, + "database_id": query.database_id, + "schema": query.schema, + } + mock_list.return_value = ([query], 1) + async with Client(mcp_server) as client: + request = ListQueriesRequest(page=1, page_size=10) + result = await client.call_tool( + "list_queries", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["queries"] is not None + assert len(data["queries"]) == 1 + assert data["queries"][0]["id"] == 1 + assert data["queries"][0]["status"] == "success" + + +@patch("superset.daos.query.QueryDAO.list") [email protected] +async def test_list_queries_with_status_filter(mock_list, mcp_server): + """Test query listing with status filter.""" + query = create_mock_query(status="failed", error_message="Syntax error") + query._mapping = { + "id": query.id, + "sql": query.sql, + "status": query.status, + "error_message": query.error_message, + } + mock_list.return_value = ([query], 1) + async with Client(mcp_server) as client: + request = ListQueriesRequest( + page=1, + page_size=10, + filters=[ + {"col": "status", "opr": "eq", "value": "failed"}, + ], + ) + result = await client.call_tool( + "list_queries", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["queries"] is not None + assert len(data["queries"]) == 1 + assert data["queries"][0]["status"] == "failed" + + +@patch("superset.daos.query.QueryDAO.list") [email protected] +async def test_list_queries_default_page_size(mock_list, mcp_server): + """Test that default page size is 25 for query history.""" + mock_list.return_value = ([], 0) + async with Client(mcp_server) as client: + result = await client.call_tool("list_queries", {}) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["page_size"] == 25 + + +def test_list_queries_request_rejects_both_search_and_filters(): + """Cannot use search and filters simultaneously.""" + with pytest.raises(ValidationError): + ListQueriesRequest( + search="test", + filters=[{"col": "status", "opr": "eq", "value": "success"}], + ) + + +@patch("superset.daos.query.QueryDAO.find_by_id") [email protected] +async def test_get_query_info_basic(mock_find, mcp_server): + """Test basic get query info functionality.""" + query = create_mock_query() + mock_find.return_value = query + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_query_info", {"request": {"identifier": 1}} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 1 + assert data["status"] == "success" + assert data["database_id"] == 1 + + +@patch("superset.daos.query.QueryDAO.find_by_id") [email protected] +async def test_get_query_info_not_found(mock_find, mcp_server): + """Test get query info when query does not exist.""" + mock_find.return_value = None + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_query_info", {"request": {"identifier": 999}} + ) + assert result.data["error_type"] == "not_found" + + +@patch("superset.daos.query.QueryDAO.list") [email protected] +async def test_list_queries_empty(mock_list, mcp_server): + """Test query listing returns empty list when no results.""" + mock_list.return_value = ([], 0) + async with Client(mcp_server) as client: + request = ListQueriesRequest(page=1, page_size=10) + result = await client.call_tool( + "list_queries", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["queries"] == [] + assert data["count"] == 0 + assert data["total_count"] == 0 + + +@patch("superset.daos.query.QueryDAO.list") [email protected] +async def test_list_queries_pagination_info(mock_list, mcp_server): + """Test that pagination info is correctly returned.""" + queries = [create_mock_query(query_id=i) for i in range(1, 4)] + for q in queries: + q._mapping = {"id": q.id, "sql": q.sql, "status": q.status} + mock_list.return_value = (queries, 100) + async with Client(mcp_server) as client: + request = ListQueriesRequest(page=1, page_size=3) + result = await client.call_tool( + "list_queries", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + assert data["total_count"] == 100 + assert data["page_size"] == 3 + assert data["has_next"] is True + assert data["has_previous"] is False + + +@patch("superset.daos.query.QueryDAO.list") [email protected] +async def test_list_queries_default_order_is_start_time_desc(mock_list, mcp_server): + """Test that default ordering is start_time descending.""" + mock_list.return_value = ([], 0) + async with Client(mcp_server) as client: + result = await client.call_tool("list_queries", {}) + assert result.content is not None + mock_list.assert_called_once() + call_kwargs = mock_list.call_args + assert call_kwargs.kwargs.get("order_column") == "start_time" + assert call_kwargs.kwargs.get("order_direction") == "desc" + + +@patch("superset.daos.query.QueryDAO.list") [email protected] +async def test_list_queries_select_columns_projects_fields(mock_list, mcp_server): + """select_columns limits which fields appear in each query result.""" + query = create_mock_query() + query._mapping = {"id": query.id, "status": query.status} + mock_list.return_value = ([query], 1) + async with Client(mcp_server) as client: + request = ListQueriesRequest( + page=1, page_size=10, select_columns=["id", "status"] + ) + result = await client.call_tool( + "list_queries", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + assert data["queries"] is not None + q = data["queries"][0] + assert set(q.keys()) == {"id", "status"} + assert q["id"] == 1 + assert q["status"] == "success" + + [email protected] +async def test_list_queries_invalid_order_column_raises(mcp_server): + """order_column not in SORTABLE_QUERY_COLUMNS must be rejected.""" + request = ListQueriesRequest(page=1, page_size=10, order_column="tab_name") + async with Client(mcp_server) as client: + with pytest.raises(Exception, match="Invalid order_column"): Review Comment: <!-- Bito Reply --> The comment suggests that the error response in the code includes raw exception text (`str(e)`), which could expose internal implementation details to clients. The recommended fix is to return a generic user-facing message instead and log the full exception details on the server. This change would help prevent sensitive information from being disclosed to external clients. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
