aminghadersohi commented on code in PR #40344:
URL: https://github.com/apache/superset/pull/40344#discussion_r3311886790


##########
tests/unit_tests/mcp_service/task/tool/test_task_tools.py:
##########
@@ -0,0 +1,249 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Unit tests for list_tasks and get_task_info MCP tools."""
+
+import uuid
+from datetime import datetime, timezone
+from unittest.mock import MagicMock, patch
+
+import pytest
+from fastmcp import Client
+from pydantic import ValidationError
+
+from superset.mcp_service.app import mcp
+from superset.mcp_service.task.schemas import ListTasksRequest, 
TaskColumnFilter
+from superset.utils import json
+
+SAMPLE_UUID = str(uuid.uuid4())
+
+
+def create_mock_task(
+    task_id: int = 1,
+    task_uuid: str | None = None,
+    task_type: str = "sql_execution",
+    task_key: str = "default-key",
+    task_name: str | None = None,
+    status: str = "success",
+    scope: str = "private",
+    changed_on: datetime | None = None,
+    created_on: datetime | None = None,
+) -> MagicMock:
+    task = MagicMock()
+    task.id = task_id
+    task.uuid = task_uuid or SAMPLE_UUID
+    task.task_type = task_type
+    task.task_key = task_key
+    task.task_name = task_name
+    task.status = status
+    task.scope = scope
+    task.changed_on = changed_on or datetime(2024, 1, 2, 10, 0, 0, 
tzinfo=timezone.utc)
+    task.created_on = created_on or datetime(2024, 1, 1, 9, 0, 0, 
tzinfo=timezone.utc)
+    return task
+
+
[email protected]
+def mcp_server():
+    return mcp
+
+
[email protected](autouse=True)
+def mock_auth():
+    from unittest.mock import Mock
+
+    with patch("superset.mcp_service.auth.get_user_from_request") as 
mock_get_user:
+        mock_user = Mock()
+        mock_user.id = 1
+        mock_user.username = "testuser"
+        mock_get_user.return_value = mock_user
+        yield mock_get_user
+
+
+class TestTaskColumnFilterSchema:
+    def test_valid_filter_columns_accepted(self):
+        for col in ("task_type", "status", "scope"):
+            f = TaskColumnFilter(col=col, opr="eq", value="test")
+            assert f.col == col
+
+    def test_invalid_filter_column_rejected(self):
+        with pytest.raises(ValidationError):
+            TaskColumnFilter(col="user_id", opr="eq", value=1)
+
+    def test_uuid_column_rejected(self):
+        with pytest.raises(ValidationError):
+            TaskColumnFilter(col="uuid", opr="eq", value="some-uuid")
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_basic(mock_list, mcp_server):
+    """Basic task listing returns structured response."""
+    task = create_mock_task()
+    mock_list.return_value = ([task], 1)
+
+    async with Client(mcp_server) as client:
+        result = await client.call_tool("list_tasks", {})
+
+    data = json.loads(result.content[0].text)
+    assert data["tasks"] is not None
+    assert len(data["tasks"]) == 1
+    assert data["tasks"][0]["id"] == 1
+    assert data["tasks"][0]["task_type"] == "sql_execution"
+    assert data["tasks"][0]["status"] == "success"
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_with_status_filter(mock_list, mcp_server):
+    """Status filter is passed through to the DAO correctly."""
+    task = create_mock_task(status="pending")
+    mock_list.return_value = ([task], 1)
+
+    async with Client(mcp_server) as client:
+        request = ListTasksRequest(
+            filters=[{"col": "status", "opr": "eq", "value": "pending"}]
+        )
+        result = await client.call_tool("list_tasks", {"request": 
request.model_dump()})
+
+    data = json.loads(result.content[0].text)
+    assert len(data["tasks"]) == 1
+    assert data["tasks"][0]["status"] == "pending"
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_taskfilter_scoping_is_applied(mock_list, mcp_server):
+    """TaskDAO.list is called with base_filter (TaskFilter) applied via DAO 
class."""
+    mock_list.return_value = ([], 0)
+
+    async with Client(mcp_server) as client:
+        await client.call_tool("list_tasks", {})
+
+    # Verify the DAO's list() is called — the TaskFilter scoping is enforced
+    # by TaskDAO.base_filter = TaskFilter which BaseDAO applies automatically.
+    assert mock_list.called

Review Comment:
   Acknowledged — the unit test intentionally mocks the DAO. The `TaskFilter` 
base filter is an integration-level concern enforced by 
`BaseDAO._apply_base_filter` inside the real DAO. Verifying it at the unit 
level would require either (a) not mocking the DAO at all and spinning up a 
real DB, or (b) testing `TaskDAO` directly in a DAO-level test. The unit test 
here confirms the MCP layer delegates to `TaskDAO.list()` — the DAO itself is 
integration-tested separately.



##########
tests/unit_tests/mcp_service/task/tool/test_task_tools.py:
##########
@@ -0,0 +1,249 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Unit tests for list_tasks and get_task_info MCP tools."""
+
+import uuid
+from datetime import datetime, timezone
+from unittest.mock import MagicMock, patch
+
+import pytest
+from fastmcp import Client
+from pydantic import ValidationError
+
+from superset.mcp_service.app import mcp
+from superset.mcp_service.task.schemas import ListTasksRequest, 
TaskColumnFilter
+from superset.utils import json
+
+SAMPLE_UUID = str(uuid.uuid4())
+
+
+def create_mock_task(
+    task_id: int = 1,
+    task_uuid: str | None = None,
+    task_type: str = "sql_execution",
+    task_key: str = "default-key",
+    task_name: str | None = None,
+    status: str = "success",
+    scope: str = "private",
+    changed_on: datetime | None = None,
+    created_on: datetime | None = None,
+) -> MagicMock:
+    task = MagicMock()
+    task.id = task_id
+    task.uuid = task_uuid or SAMPLE_UUID
+    task.task_type = task_type
+    task.task_key = task_key
+    task.task_name = task_name
+    task.status = status
+    task.scope = scope
+    task.changed_on = changed_on or datetime(2024, 1, 2, 10, 0, 0, 
tzinfo=timezone.utc)
+    task.created_on = created_on or datetime(2024, 1, 1, 9, 0, 0, 
tzinfo=timezone.utc)
+    return task
+
+
[email protected]
+def mcp_server():
+    return mcp
+
+
[email protected](autouse=True)
+def mock_auth():
+    from unittest.mock import Mock
+
+    with patch("superset.mcp_service.auth.get_user_from_request") as 
mock_get_user:
+        mock_user = Mock()
+        mock_user.id = 1
+        mock_user.username = "testuser"
+        mock_get_user.return_value = mock_user
+        yield mock_get_user
+
+
+class TestTaskColumnFilterSchema:
+    def test_valid_filter_columns_accepted(self):
+        for col in ("task_type", "status", "scope"):
+            f = TaskColumnFilter(col=col, opr="eq", value="test")
+            assert f.col == col
+
+    def test_invalid_filter_column_rejected(self):
+        with pytest.raises(ValidationError):
+            TaskColumnFilter(col="user_id", opr="eq", value=1)
+
+    def test_uuid_column_rejected(self):
+        with pytest.raises(ValidationError):
+            TaskColumnFilter(col="uuid", opr="eq", value="some-uuid")
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_basic(mock_list, mcp_server):
+    """Basic task listing returns structured response."""
+    task = create_mock_task()
+    mock_list.return_value = ([task], 1)
+
+    async with Client(mcp_server) as client:
+        result = await client.call_tool("list_tasks", {})
+
+    data = json.loads(result.content[0].text)
+    assert data["tasks"] is not None
+    assert len(data["tasks"]) == 1
+    assert data["tasks"][0]["id"] == 1
+    assert data["tasks"][0]["task_type"] == "sql_execution"
+    assert data["tasks"][0]["status"] == "success"
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_with_status_filter(mock_list, mcp_server):
+    """Status filter is passed through to the DAO correctly."""
+    task = create_mock_task(status="pending")
+    mock_list.return_value = ([task], 1)
+
+    async with Client(mcp_server) as client:
+        request = ListTasksRequest(
+            filters=[{"col": "status", "opr": "eq", "value": "pending"}]
+        )
+        result = await client.call_tool("list_tasks", {"request": 
request.model_dump()})
+
+    data = json.loads(result.content[0].text)
+    assert len(data["tasks"]) == 1
+    assert data["tasks"][0]["status"] == "pending"
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_taskfilter_scoping_is_applied(mock_list, mcp_server):
+    """TaskDAO.list is called with base_filter (TaskFilter) applied via DAO 
class."""
+    mock_list.return_value = ([], 0)
+
+    async with Client(mcp_server) as client:
+        await client.call_tool("list_tasks", {})
+
+    # Verify the DAO's list() is called — the TaskFilter scoping is enforced
+    # by TaskDAO.base_filter = TaskFilter which BaseDAO applies automatically.
+    assert mock_list.called
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_pagination(mock_list, mcp_server):
+    """Pagination metadata is correct."""
+    tasks = [create_mock_task(task_id=i) for i in range(1, 4)]
+    mock_list.return_value = (tasks, 30)
+
+    async with Client(mcp_server) as client:
+        request = ListTasksRequest(page=2, page_size=3)
+        result = await client.call_tool("list_tasks", {"request": 
request.model_dump()})
+
+    data = json.loads(result.content[0].text)
+    assert data["count"] == 3
+    assert data["total_count"] == 30
+    assert data["page"] == 2
+    assert data["page_size"] == 3
+    assert data["has_previous"] is True
+    assert data["has_next"] is True
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_uuid_in_response(mock_list, mcp_server):
+    """Task UUID is returned as a string in the response."""
+    task_uuid = str(uuid.uuid4())
+    task = create_mock_task(task_uuid=task_uuid)
+    mock_list.return_value = ([task], 1)
+
+    async with Client(mcp_server) as client:
+        result = await client.call_tool("list_tasks", {})
+
+    data = json.loads(result.content[0].text)
+    assert data["tasks"][0]["uuid"] == task_uuid
+
+
+@patch("superset.daos.tasks.TaskDAO.find_by_id")
[email protected]
+async def test_get_task_info_by_integer_id(mock_find, mcp_server):
+    """get_task_info resolves a task by integer ID."""
+    task = create_mock_task(task_id=5, task_type="thumbnail", 
status="in_progress")
+    mock_find.return_value = task
+
+    async with Client(mcp_server) as client:
+        result = await client.call_tool("get_task_info", {"request": 
{"identifier": 5}})
+
+    data = json.loads(result.content[0].text)
+    assert data["id"] == 5
+    assert data["task_type"] == "thumbnail"
+    assert data["status"] == "in_progress"
+
+
+@patch("superset.daos.tasks.TaskDAO.find_by_id")
[email protected]
+async def test_get_task_info_by_uuid(mock_find, mcp_server):
+    """get_task_info resolves a task by UUID string."""
+    task_uuid = str(uuid.uuid4())
+    task = create_mock_task(task_id=10, task_uuid=task_uuid, status="success")
+    mock_find.return_value = task
+
+    async with Client(mcp_server) as client:
+        result = await client.call_tool(
+            "get_task_info", {"request": {"identifier": task_uuid}}
+        )
+
+    data = json.loads(result.content[0].text)
+    assert data["id"] == 10
+    assert data["status"] == "success"

Review Comment:
   Fixed in 28461c7591 — added `mock_find.assert_called_once_with(task_uuid, 
id_column="uuid", query_options=None)` to explicitly verify the UUID lookup 
path through `ModelGetInfoCore`.



##########
tests/unit_tests/mcp_service/task/tool/test_task_tools.py:
##########
@@ -0,0 +1,249 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Unit tests for list_tasks and get_task_info MCP tools."""
+
+import uuid
+from datetime import datetime, timezone
+from unittest.mock import MagicMock, patch
+
+import pytest
+from fastmcp import Client
+from pydantic import ValidationError
+
+from superset.mcp_service.app import mcp
+from superset.mcp_service.task.schemas import ListTasksRequest, 
TaskColumnFilter
+from superset.utils import json
+
+SAMPLE_UUID = str(uuid.uuid4())
+
+
+def create_mock_task(
+    task_id: int = 1,
+    task_uuid: str | None = None,
+    task_type: str = "sql_execution",
+    task_key: str = "default-key",
+    task_name: str | None = None,
+    status: str = "success",
+    scope: str = "private",
+    changed_on: datetime | None = None,
+    created_on: datetime | None = None,
+) -> MagicMock:
+    task = MagicMock()
+    task.id = task_id
+    task.uuid = task_uuid or SAMPLE_UUID
+    task.task_type = task_type
+    task.task_key = task_key
+    task.task_name = task_name
+    task.status = status
+    task.scope = scope
+    task.changed_on = changed_on or datetime(2024, 1, 2, 10, 0, 0, 
tzinfo=timezone.utc)
+    task.created_on = created_on or datetime(2024, 1, 1, 9, 0, 0, 
tzinfo=timezone.utc)
+    return task
+
+
[email protected]
+def mcp_server():
+    return mcp
+
+
[email protected](autouse=True)
+def mock_auth():
+    from unittest.mock import Mock
+
+    with patch("superset.mcp_service.auth.get_user_from_request") as 
mock_get_user:
+        mock_user = Mock()
+        mock_user.id = 1
+        mock_user.username = "testuser"
+        mock_get_user.return_value = mock_user
+        yield mock_get_user
+
+
+class TestTaskColumnFilterSchema:
+    def test_valid_filter_columns_accepted(self):
+        for col in ("task_type", "status", "scope"):
+            f = TaskColumnFilter(col=col, opr="eq", value="test")
+            assert f.col == col
+
+    def test_invalid_filter_column_rejected(self):
+        with pytest.raises(ValidationError):
+            TaskColumnFilter(col="user_id", opr="eq", value=1)
+
+    def test_uuid_column_rejected(self):
+        with pytest.raises(ValidationError):
+            TaskColumnFilter(col="uuid", opr="eq", value="some-uuid")
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_basic(mock_list, mcp_server):
+    """Basic task listing returns structured response."""
+    task = create_mock_task()
+    mock_list.return_value = ([task], 1)
+
+    async with Client(mcp_server) as client:
+        result = await client.call_tool("list_tasks", {})
+
+    data = json.loads(result.content[0].text)
+    assert data["tasks"] is not None
+    assert len(data["tasks"]) == 1
+    assert data["tasks"][0]["id"] == 1
+    assert data["tasks"][0]["task_type"] == "sql_execution"
+    assert data["tasks"][0]["status"] == "success"
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_with_status_filter(mock_list, mcp_server):
+    """Status filter is passed through to the DAO correctly."""
+    task = create_mock_task(status="pending")
+    mock_list.return_value = ([task], 1)
+
+    async with Client(mcp_server) as client:
+        request = ListTasksRequest(
+            filters=[{"col": "status", "opr": "eq", "value": "pending"}]
+        )
+        result = await client.call_tool("list_tasks", {"request": 
request.model_dump()})
+
+    data = json.loads(result.content[0].text)
+    assert len(data["tasks"]) == 1
+    assert data["tasks"][0]["status"] == "pending"
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_taskfilter_scoping_is_applied(mock_list, mcp_server):
+    """TaskDAO.list is called with base_filter (TaskFilter) applied via DAO 
class."""
+    mock_list.return_value = ([], 0)
+
+    async with Client(mcp_server) as client:
+        await client.call_tool("list_tasks", {})
+
+    # Verify the DAO's list() is called — the TaskFilter scoping is enforced
+    # by TaskDAO.base_filter = TaskFilter which BaseDAO applies automatically.
+    assert mock_list.called
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_pagination(mock_list, mcp_server):
+    """Pagination metadata is correct."""
+    tasks = [create_mock_task(task_id=i) for i in range(1, 4)]
+    mock_list.return_value = (tasks, 30)
+
+    async with Client(mcp_server) as client:
+        request = ListTasksRequest(page=2, page_size=3)
+        result = await client.call_tool("list_tasks", {"request": 
request.model_dump()})
+
+    data = json.loads(result.content[0].text)
+    assert data["count"] == 3
+    assert data["total_count"] == 30
+    assert data["page"] == 2
+    assert data["page_size"] == 3
+    assert data["has_previous"] is True
+    assert data["has_next"] is True
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_uuid_in_response(mock_list, mcp_server):
+    """Task UUID is returned as a string in the response."""
+    task_uuid = str(uuid.uuid4())
+    task = create_mock_task(task_uuid=task_uuid)
+    mock_list.return_value = ([task], 1)
+
+    async with Client(mcp_server) as client:
+        result = await client.call_tool("list_tasks", {})
+
+    data = json.loads(result.content[0].text)
+    assert data["tasks"][0]["uuid"] == task_uuid
+
+
+@patch("superset.daos.tasks.TaskDAO.find_by_id")
[email protected]
+async def test_get_task_info_by_integer_id(mock_find, mcp_server):
+    """get_task_info resolves a task by integer ID."""
+    task = create_mock_task(task_id=5, task_type="thumbnail", 
status="in_progress")
+    mock_find.return_value = task
+
+    async with Client(mcp_server) as client:
+        result = await client.call_tool("get_task_info", {"request": 
{"identifier": 5}})
+
+    data = json.loads(result.content[0].text)
+    assert data["id"] == 5
+    assert data["task_type"] == "thumbnail"
+    assert data["status"] == "in_progress"
+
+
+@patch("superset.daos.tasks.TaskDAO.find_by_id")
[email protected]
+async def test_get_task_info_by_uuid(mock_find, mcp_server):
+    """get_task_info resolves a task by UUID string."""
+    task_uuid = str(uuid.uuid4())
+    task = create_mock_task(task_id=10, task_uuid=task_uuid, status="success")
+    mock_find.return_value = task
+
+    async with Client(mcp_server) as client:
+        result = await client.call_tool(
+            "get_task_info", {"request": {"identifier": task_uuid}}
+        )
+
+    data = json.loads(result.content[0].text)
+    assert data["id"] == 10
+    assert data["status"] == "success"
+
+
+@patch("superset.daos.tasks.TaskDAO.find_by_id")
[email protected]
+async def test_get_task_info_not_found(mock_find, mcp_server):
+    """get_task_info returns error when task does not exist."""
+    mock_find.return_value = None
+
+    async with Client(mcp_server) as client:
+        result = await client.call_tool(
+            "get_task_info", {"request": {"identifier": 9999}}
+        )
+
+    data = json.loads(result.content[0].text)
+    assert data["error_type"] == "not_found"
+
+
+@patch("superset.daos.tasks.TaskDAO.list")
[email protected]
+async def test_list_tasks_non_admin_sees_only_subscribed(mock_list, 
mcp_server):
+    """Non-admin users only receive tasks their subscriptions allow.
+
+    The subscription scoping is enforced by TaskDAO.base_filter = TaskFilter,
+    which BaseDAO._apply_base_filter injects before each query.  The MCP tool
+    itself adds no extra filtering — it just delegates to TaskDAO.list(), which
+    carries the filter class.  This test confirms that:
+
+    1. list_tasks calls TaskDAO.list() (so the base_filter runs)
+    2. Only items returned by that call appear in the response
+    """
+    # Simulate DAO returning only the subscribed task
+    subscribed_task = create_mock_task(task_id=42, status="pending")
+    mock_list.return_value = ([subscribed_task], 1)
+
+    async with Client(mcp_server) as client:
+        result = await client.call_tool("list_tasks", {})
+
+    data = json.loads(result.content[0].text)
+    assert len(data["tasks"]) == 1
+    assert data["tasks"][0]["id"] == 42
+    # TaskDAO.list was called exactly once — base_filter is applied inside
+    assert mock_list.call_count == 1

Review Comment:
   Acknowledged — same rationale as the `TaskFilter` scoping comment above. The 
non-admin visibility test demonstrates that `list_tasks` surfaces exactly what 
the DAO returns (no additional filtering or leaking added by the tool layer), 
which is the correct unit-level contract. The actual subscription scoping that 
limits what the DAO returns is enforced by `TaskFilter` inside the real DAO and 
belongs in an integration test.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to