Copilot commented on code in PR #40132: URL: https://github.com/apache/superset/pull/40132#discussion_r3245123717
########## tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py: ########## @@ -0,0 +1,433 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for create_dataset MCP tool.""" + +import logging +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastmcp import Client +from fastmcp.exceptions import ToolError + +from superset.mcp_service.app import mcp +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +# Patch at source so lazy imports inside the tool function are intercepted. +_CMD_PATH = "superset.commands.dataset.create.CreateDatasetCommand" +_DAO_PATH = "superset.mcp_service.dataset.tool.create_dataset.DatasetDAO" +_SEC_PATH = "superset.mcp_service.dataset.tool.create_dataset.security_manager" + + +def _make_mock_dataset( + dataset_id: int = 42, + table_name: str = "orders", + schema: str = "public", + database_name: str = "main_db", +) -> MagicMock: + dataset = MagicMock() + dataset.id = dataset_id + dataset.table_name = table_name + dataset.schema = schema + dataset.description = None + dataset.certified_by = None + dataset.certification_details = None + dataset.changed_by = None + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by = None + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = f"[{database_name}].[{schema}]" + dataset.database = MagicMock() + dataset.database.database_name = database_name + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = None + dataset.template_params = None + dataset.extra = None + dataset.uuid = f"dataset-uuid-{dataset_id}" + dataset.columns = [] + dataset.metrics = [] + return dataset + + [email protected] +def mcp_server(): + return mcp + + [email protected](autouse=True) +def mock_auth(): + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +class TestCreateDataset: + """Tests for the create_dataset MCP tool.""" + + @pytest.fixture(autouse=True) + def mock_dao_and_security(self): + """Default: valid database exists and access is granted. + + Patches the pre-command access check so individual tests that only care + about command behavior don't need to replicate this setup. + """ + with patch(_DAO_PATH) as mock_dao, patch(_SEC_PATH) as mock_sec: + mock_dao.get_database_by_id.return_value = MagicMock( + id=1, database_name="test_db" + ) + yield mock_dao, mock_sec + + @patch(_CMD_PATH) + @pytest.mark.asyncio + async def test_create_dataset_success(self, mock_command_class, mcp_server): + """Happy path: tool creates dataset and returns DatasetInfo.""" + mock_dataset = _make_mock_dataset() + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) + + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 42 + assert data["table_name"] == "orders" + assert data["schema_name"] == "public" Review Comment: The tool returns a `DatasetInfo` object, whose serializer normalizes the schema field to the `schema` key (see `DatasetInfo.schema_name` alias + model_serializer). These assertions should check `data["schema"]` rather than `data["schema_name"]` (and similarly elsewhere in this test file), otherwise the test will fail against the actual JSON output shape used by other dataset tools. ########## superset/mcp_service/dataset/tool/create_dataset.py: ########## @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from typing import Any + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.daos.dataset import DatasetDAO +from superset.exceptions import SupersetSecurityException +from superset.extensions import event_logger, security_manager +from superset.mcp_service.dataset.schemas import ( + CreateDatasetRequest, + DatasetError, + DatasetInfo, + serialize_dataset_object, +) +from superset.sql.parse import Table + +logger = logging.getLogger(__name__) + + +@tool( + tags=["mutate"], + class_permission_name="Dataset", + method_permission_name="write", + annotations=ToolAnnotations( + title="Register a physical table as a dataset", + readOnlyHint=False, + destructiveHint=False, + ), +) +async def create_dataset( + request: CreateDatasetRequest, ctx: Context +) -> DatasetInfo | DatasetError: + """Register an existing physical table as a Superset dataset. + + Use this tool when the user wants to make a physical database table available + for charting or exploration. The table must already exist in the target database. + + Workflow: + 1. Call list_databases to find the correct database_id + 2. Call this tool with database_id, schema, and table_name + 3. Use the returned id as dataset_id in generate_chart or generate_explore_link + + Returns DatasetInfo on success or DatasetError with error_type on failure. + """ + await ctx.info( + "Registering physical table as dataset: database_id=%s, schema=%r, table=%r" + % (request.database_id, request.schema, request.table_name) + ) + + # Verify the database exists and the caller has table-level access before + # registering. Mirrors the check in DatabaseRestApi.table_metadata(). + database = DatasetDAO.get_database_by_id(request.database_id) + if database is None: + await ctx.warning("Database %s not found" % request.database_id) + return DatasetError.create( + error=f"Database {request.database_id} not found", + error_type="DatabaseNotFoundError", + ) + + table = Table(request.table_name, request.schema, request.catalog) + try: + security_manager.raise_for_access(database=database, table=table) + except SupersetSecurityException as exc: + await ctx.warning("Access denied for table %r: %s" % (str(table), str(exc))) + return DatasetError.create(error=str(exc), error_type="AccessDeniedError") + + try: + from superset.commands.dataset.create import CreateDatasetCommand + from superset.commands.dataset.exceptions import ( + DatasetCreateFailedError, + DatasetExistsValidationError, + DatasetInvalidError, + TableNotFoundValidationError, + ) + + dataset_properties: dict[str, Any] = { + k: v + for k, v in { + "database": request.database_id, + "table_name": request.table_name, + "schema": request.schema, + "catalog": request.catalog, + "owners": request.owners, + }.items() + if v is not None + } + + with event_logger.log_context(action="mcp.create_dataset.create"): + dataset = CreateDatasetCommand(dataset_properties).run() + + result = serialize_dataset_object(dataset) + if result is None: + return DatasetError.create( + error="Dataset was created but could not be serialized", + error_type="InternalError", + ) + + await ctx.info( + "Dataset registered: id=%s, table=%r" % (dataset.id, dataset.table_name) + ) + return result + + except DatasetInvalidError as exc: + # CreateDatasetCommand.validate() aggregates individual validation errors + # into DatasetInvalidError; inspect them for specific error types. + if any(isinstance(e, DatasetExistsValidationError) for e in exc._exceptions): + await ctx.warning("Dataset already exists: %s" % str(exc)) + return DatasetError.create(error=str(exc), error_type="DatasetExistsError") + if any(isinstance(e, TableNotFoundValidationError) for e in exc._exceptions): + await ctx.warning("Table not found: %s" % str(exc)) + return DatasetError.create(error=str(exc), error_type="TableNotFoundError") Review Comment: This implementation inspects `DatasetInvalidError._exceptions` to detect specific validation errors. `_exceptions` is a private attribute of `CommandInvalidError` and isn’t part of a stable public API, so this is brittle. Prefer using the public helpers on the exception (e.g. `get_list_classnames()` and/or `normalized_messages()`) to classify the error without reaching into private state. ########## superset/mcp_service/dataset/schemas.py: ########## @@ -323,6 +323,42 @@ class GetDatasetInfoRequest(MetadataCacheControl): ] +class CreateDatasetRequest(BaseModel): + """Request schema for create_dataset to register a physical table as a dataset.""" + + database_id: Annotated[ + int, + Field( + description="ID of the database connection to register the table against" + ), + ] + schema: Annotated[ + str | None, + Field( + default=None, + description="Schema where the table lives (optional).", + ), + ] + table_name: Annotated[ + str, + Field(description="Name of the physical table to register as a dataset"), + ] + catalog: Annotated[ + str | None, + Field( + default=None, + description="Catalog where the table lives (optional).", + ), + ] + owners: Annotated[ + List[int] | None, + Field( + default=None, + description="Optional list of owner user IDs. Defaults to calling user.", + ), + ] + + Review Comment: `table_name` (and optional `schema`/`catalog`) accept whitespace-only strings today, which will pass Pydantic validation but will later fail (or behave unexpectedly) deeper in `CreateDatasetCommand`. Other request schemas in this package (e.g. `CreateVirtualDatasetRequest`) strip and validate non-empty strings. Consider adding field validators to strip inputs and reject empty `table_name`, and normalize empty `schema`/`catalog` to `None`. ########## tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py: ########## @@ -0,0 +1,433 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for create_dataset MCP tool.""" + +import logging +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastmcp import Client +from fastmcp.exceptions import ToolError + +from superset.mcp_service.app import mcp +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +# Patch at source so lazy imports inside the tool function are intercepted. +_CMD_PATH = "superset.commands.dataset.create.CreateDatasetCommand" +_DAO_PATH = "superset.mcp_service.dataset.tool.create_dataset.DatasetDAO" +_SEC_PATH = "superset.mcp_service.dataset.tool.create_dataset.security_manager" + + +def _make_mock_dataset( + dataset_id: int = 42, + table_name: str = "orders", + schema: str = "public", + database_name: str = "main_db", +) -> MagicMock: + dataset = MagicMock() + dataset.id = dataset_id + dataset.table_name = table_name + dataset.schema = schema + dataset.description = None + dataset.certified_by = None + dataset.certification_details = None + dataset.changed_by = None + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by = None + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = f"[{database_name}].[{schema}]" + dataset.database = MagicMock() + dataset.database.database_name = database_name + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = None + dataset.template_params = None + dataset.extra = None + dataset.uuid = f"dataset-uuid-{dataset_id}" + dataset.columns = [] + dataset.metrics = [] + return dataset + + [email protected] +def mcp_server(): + return mcp + + [email protected](autouse=True) +def mock_auth(): + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +class TestCreateDataset: + """Tests for the create_dataset MCP tool.""" + + @pytest.fixture(autouse=True) + def mock_dao_and_security(self): + """Default: valid database exists and access is granted. + + Patches the pre-command access check so individual tests that only care + about command behavior don't need to replicate this setup. + """ + with patch(_DAO_PATH) as mock_dao, patch(_SEC_PATH) as mock_sec: + mock_dao.get_database_by_id.return_value = MagicMock( + id=1, database_name="test_db" + ) + yield mock_dao, mock_sec + + @patch(_CMD_PATH) + @pytest.mark.asyncio + async def test_create_dataset_success(self, mock_command_class, mcp_server): + """Happy path: tool creates dataset and returns DatasetInfo.""" + mock_dataset = _make_mock_dataset() + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) + + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 42 + assert data["table_name"] == "orders" + assert data["schema_name"] == "public" + + call_kwargs = mock_command_class.call_args[0][0] + assert call_kwargs["database"] == 1 + assert call_kwargs["schema"] == "public" + assert call_kwargs["table_name"] == "orders" + assert "owners" not in call_kwargs + + @patch(_CMD_PATH) + @pytest.mark.asyncio + async def test_create_dataset_with_owners(self, mock_command_class, mcp_server): + """Owners list is forwarded to the command when supplied.""" + mock_dataset = _make_mock_dataset() + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 2, + "schema": "sales", + "table_name": "transactions", + "owners": [5, 10], + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 42 + + call_kwargs = mock_command_class.call_args[0][0] + assert call_kwargs["owners"] == [5, 10] + + @patch(_CMD_PATH) + @pytest.mark.asyncio + async def test_create_dataset_already_exists(self, mock_command_class, mcp_server): + """Returns DatasetExistsError when a dataset for the table already exists. + + CreateDatasetCommand.validate() wraps DatasetExistsValidationError inside + DatasetInvalidError, so simulate the real command shape. + """ + from superset.commands.dataset.exceptions import ( + DatasetExistsValidationError, + DatasetInvalidError, + ) + from superset.sql.parse import Table + + exc = DatasetInvalidError() + exc.append(DatasetExistsValidationError(Table("orders", "public", None))) + + mock_command = MagicMock() + mock_command.run.side_effect = exc + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "DatasetExistsError" + assert "error" in data + + @patch(_CMD_PATH) + @pytest.mark.asyncio + async def test_create_dataset_table_not_found(self, mock_command_class, mcp_server): + """Returns TableNotFoundError when the physical table does not exist in the DB. + + CreateDatasetCommand.validate() wraps TableNotFoundValidationError inside + DatasetInvalidError, so simulate the real command shape. + """ + from superset.commands.dataset.exceptions import ( + DatasetInvalidError, + TableNotFoundValidationError, + ) + from superset.sql.parse import Table + + exc = DatasetInvalidError() + exc.append(TableNotFoundValidationError(Table("missing_table", "public", None))) + + mock_command = MagicMock() + mock_command.run.side_effect = exc + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "missing_table", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "TableNotFoundError" + + @patch(_CMD_PATH) + @pytest.mark.asyncio + async def test_create_dataset_unexpected_error( + self, mock_command_class, mcp_server + ): + """Unexpected exceptions are caught and returned as InternalError.""" + mock_command = MagicMock() + mock_command.run.side_effect = RuntimeError("DB connection lost") + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "InternalError" + assert "DB connection lost" in data["error"] + + @pytest.mark.asyncio + async def test_create_dataset_missing_required_fields(self, mcp_server): + """Missing required fields raise a validation error before the tool runs.""" + async with Client(mcp_server) as client: + with pytest.raises(ToolError): + await client.call_tool( + "create_dataset", + { + "request": { + # database_id and table_name are omitted intentionally + "schema": "public", + } + }, + ) + + @patch(_CMD_PATH) + @pytest.mark.asyncio + async def test_create_dataset_returns_full_dataset_info( + self, mock_command_class, mcp_server + ): + """The returned DatasetInfo includes columns, metrics, and all core fields.""" + mock_dataset = _make_mock_dataset( + dataset_id=99, table_name="sales", schema="dw" + ) + + col = MagicMock() + col.column_name = "amount" + col.verbose_name = "Amount" + col.type = "NUMERIC" + col.is_dttm = False + col.groupby = True + col.filterable = True + col.description = "Sale amount" + mock_dataset.columns = [col] + + metric = MagicMock() + metric.metric_name = "total_sales" + metric.verbose_name = "Total Sales" + metric.expression = "SUM(amount)" + metric.description = "Sum of amounts" + metric.d3format = None + mock_dataset.metrics = [metric] + + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "dw", + "table_name": "sales", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 99 + assert data["table_name"] == "sales" + assert data["schema_name"] == "dw" Review Comment: Same as earlier: `DatasetInfo` JSON output uses the `schema` key (not `schema_name`) due to aliasing/serializer normalization. Update this assertion to check `data["schema"]` so the test matches the tool’s actual response shape. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
