This is an automated email from the ASF dual-hosted git repository.

arivero pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 5cd829f13c8 fix(mcp): handle more chart types in get_chart_data 
fallback query construction (#37969)
5cd829f13c8 is described below

commit 5cd829f13c8f7f71858fcbe761e636fa84d16f21
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Tue Feb 17 07:02:42 2026 -0500

    fix(mcp): handle more chart types in get_chart_data fallback query 
construction (#37969)
    
    Co-authored-by: Claude Opus 4.6 <[email protected]>
---
 superset/mcp_service/chart/tool/get_chart_data.py  | 107 ++++-
 .../mcp_service/chart/tool/test_get_chart_data.py  | 437 ++++++++++++++++++---
 2 files changed, 486 insertions(+), 58 deletions(-)

diff --git a/superset/mcp_service/chart/tool/get_chart_data.py 
b/superset/mcp_service/chart/tool/get_chart_data.py
index a99df31dd94..327052b423f 100644
--- a/superset/mcp_service/chart/tool/get_chart_data.py
+++ b/superset/mcp_service/chart/tool/get_chart_data.py
@@ -165,21 +165,90 @@ async def get_chart_data(  # noqa: C901
                     or current_app.config["ROW_LIMIT"]
                 )
 
-                # Handle different chart types that have different form_data 
structures
-                # Some charts use "metric" (singular), not "metrics" (plural):
-                # - big_number, big_number_total
-                # - pop_kpi (BigNumberPeriodOverPeriod)
-                # These charts also don't have groupby columns
+                # Handle different chart types that have different form_data
+                # structures.  Chart types that exclusively use "metric"
+                # (singular) with no groupby:
+                #   big_number, big_number_total, pop_kpi
+                # Chart types that use "metric" (singular) but may have
+                # groupby-like fields (entity, series, columns):
+                #   world_map, treemap_v2, sunburst_v2, gauge_chart
+                # Bubble charts use x/y/size as separate metric fields.
                 viz_type = chart.viz_type or ""
-                if viz_type in ("big_number", "big_number_total", "pop_kpi"):
+
+                singular_metric_no_groupby = (
+                    "big_number",
+                    "big_number_total",
+                    "pop_kpi",
+                )
+                singular_metric_types = (
+                    *singular_metric_no_groupby,
+                    "world_map",
+                    "treemap_v2",
+                    "sunburst_v2",
+                    "gauge_chart",
+                )
+
+                if viz_type == "bubble":
+                    # Bubble charts store metrics in x, y, size fields
+                    bubble_metrics = []
+                    for field in ("x", "y", "size"):
+                        m = form_data.get(field)
+                        if m:
+                            bubble_metrics.append(m)
+                    metrics = bubble_metrics
+                    groupby_columns: list[str] = list(
+                        form_data.get("entity", None) and 
[form_data["entity"]] or []
+                    )
+                    series_field = form_data.get("series")
+                    if series_field and series_field not in groupby_columns:
+                        groupby_columns.append(series_field)
+                elif viz_type in singular_metric_types:
                     # These chart types use "metric" (singular)
                     metric = form_data.get("metric")
                     metrics = [metric] if metric else []
-                    groupby_columns: list[str] = []  # These charts don't 
group by
+                    if viz_type in singular_metric_no_groupby:
+                        groupby_columns = []
+                    else:
+                        # Some singular-metric charts use groupby, entity,
+                        # series, or columns for dimensional breakdown
+                        groupby_columns = list(form_data.get("groupby") or [])
+                        entity = form_data.get("entity")
+                        if entity and entity not in groupby_columns:
+                            groupby_columns.append(entity)
+                        series = form_data.get("series")
+                        if series and series not in groupby_columns:
+                            groupby_columns.append(series)
+                        form_columns = form_data.get("columns")
+                        if form_columns and isinstance(form_columns, list):
+                            for col in form_columns:
+                                if isinstance(col, str) and col not in 
groupby_columns:
+                                    groupby_columns.append(col)
                 else:
                     # Standard charts use "metrics" (plural) and "groupby"
                     metrics = form_data.get("metrics", [])
-                    groupby_columns = form_data.get("groupby") or []
+                    groupby_columns = list(form_data.get("groupby") or [])
+                    # Some chart types use "columns" instead of "groupby"
+                    if not groupby_columns:
+                        form_columns = form_data.get("columns")
+                        if form_columns and isinstance(form_columns, list):
+                            for col in form_columns:
+                                if isinstance(col, str):
+                                    groupby_columns.append(col)
+
+                # Fallback: if metrics is still empty, try singular "metric"
+                if not metrics:
+                    fallback_metric = form_data.get("metric")
+                    if fallback_metric:
+                        metrics = [fallback_metric]
+
+                # Fallback: try entity/series if groupby is still empty
+                if not groupby_columns:
+                    entity = form_data.get("entity")
+                    if entity:
+                        groupby_columns.append(entity)
+                    series = form_data.get("series")
+                    if series and series not in groupby_columns:
+                        groupby_columns.append(series)
 
                 # Build query columns list: include both x_axis and groupby
                 x_axis_config = form_data.get("x_axis")
@@ -192,6 +261,28 @@ async def get_chart_data(  # noqa: C901
                     if col_name and col_name not in query_columns:
                         query_columns.insert(0, col_name)
 
+                # Safety net: if we could not extract any metrics or
+                # columns, return a clear error instead of the cryptic
+                # "Empty query?" that comes from deeper in the stack.
+                if not metrics and not query_columns:
+                    await ctx.error(
+                        "Cannot construct fallback query for chart %s "
+                        "(viz_type=%s): no metrics, columns, or groupby "
+                        "could be extracted from form_data. "
+                        "Re-save the chart to populate query_context."
+                        % (chart.id, viz_type)
+                    )
+                    return ChartError(
+                        error=(
+                            f"Chart {chart.id} (type: {viz_type}) has no "
+                            f"saved query_context and its form_data does "
+                            f"not contain recognizable metrics or columns. "
+                            f"Please open this chart in Superset and "
+                            f"re-save it to generate a query_context."
+                        ),
+                        error_type="MissingQueryContext",
+                    )
+
                 query_context = factory.create(
                     datasource={
                         "id": chart.datasource_id,
diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py 
b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
index d276691425f..7850ef82a9a 100644
--- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
+++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
@@ -16,28 +16,114 @@
 # under the License.
 
 """
-Tests for the get_chart_data request schema and big_number chart handling.
+Tests for the get_chart_data request schema and chart type fallback handling.
 """
 
+from typing import Any
+
 import pytest
 
 from superset.mcp_service.chart.schemas import GetChartDataRequest
 
 
+def _collect_groupby_extras(
+    form_data: dict[str, Any],
+    groupby_columns: list[str],
+) -> None:
+    """Append entity/series/columns from form_data into groupby_columns."""
+    entity = form_data.get("entity")
+    if entity and entity not in groupby_columns:
+        groupby_columns.append(entity)
+    series = form_data.get("series")
+    if series and series not in groupby_columns:
+        groupby_columns.append(series)
+    form_columns = form_data.get("columns")
+    if form_columns and isinstance(form_columns, list):
+        for col in form_columns:
+            if isinstance(col, str) and col not in groupby_columns:
+                groupby_columns.append(col)
+
+
+def _extract_bubble(
+    form_data: dict[str, Any],
+) -> tuple[list[Any], list[str]]:
+    """Extract metrics and groupby for bubble charts."""
+    metrics: list[Any] = []
+    for field in ("x", "y", "size"):
+        m = form_data.get(field)
+        if m:
+            metrics.append(m)
+    entity = form_data.get("entity")
+    groupby: list[str] = [entity] if entity else []
+    series_field = form_data.get("series")
+    if series_field and series_field not in groupby:
+        groupby.append(series_field)
+    return metrics, groupby
+
+
+_SINGULAR_METRIC_NO_GROUPBY = (
+    "big_number",
+    "big_number_total",
+    "pop_kpi",
+)
+_SINGULAR_METRIC_TYPES = (
+    *_SINGULAR_METRIC_NO_GROUPBY,
+    "world_map",
+    "treemap_v2",
+    "sunburst_v2",
+    "gauge_chart",
+)
+
+
+def _extract_metrics_and_groupby(
+    form_data: dict[str, Any],
+) -> tuple[list[Any], list[str]]:
+    """Mirror the fallback metric/groupby extraction logic from 
get_chart_data.py."""
+    viz_type = form_data.get("viz_type", "")
+
+    groupby_columns: list[str]
+    if viz_type == "bubble":
+        metrics, groupby_columns = _extract_bubble(form_data)
+    elif viz_type in _SINGULAR_METRIC_TYPES:
+        metric = form_data.get("metric")
+        metrics = [metric] if metric else []
+        if viz_type in _SINGULAR_METRIC_NO_GROUPBY:
+            groupby_columns = []
+        else:
+            groupby_columns = list(form_data.get("groupby") or [])
+            _collect_groupby_extras(form_data, groupby_columns)
+    else:
+        metrics = form_data.get("metrics", [])
+        groupby_columns = list(form_data.get("groupby") or [])
+        if not groupby_columns:
+            form_columns = form_data.get("columns")
+            if form_columns and isinstance(form_columns, list):
+                groupby_columns = [c for c in form_columns if isinstance(c, 
str)]
+
+    # Fallback: try singular metric if metrics still empty
+    if not metrics:
+        fallback_metric = form_data.get("metric")
+        if fallback_metric:
+            metrics = [fallback_metric]
+
+    # Fallback: try entity/series if groupby still empty
+    if not groupby_columns:
+        _collect_groupby_extras(form_data, groupby_columns)
+
+    return metrics, groupby_columns
+
+
 class TestBigNumberChartFallback:
     """Tests for big_number chart fallback query construction."""
 
     def test_big_number_uses_singular_metric(self):
         """Test that big_number charts use 'metric' (singular) from 
form_data."""
-        # Mock form_data for big_number chart
         form_data = {
             "metric": {"label": "Count", "expressionType": "SIMPLE", "column": 
None},
             "viz_type": "big_number",
         }
 
-        # Verify the metric extraction logic
-        metric = form_data.get("metric")
-        metrics = [metric] if metric else []
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
 
         assert len(metrics) == 1
         assert metrics[0]["label"] == "Count"
@@ -49,8 +135,7 @@ class TestBigNumberChartFallback:
             "viz_type": "big_number_total",
         }
 
-        metric = form_data.get("metric")
-        metrics = [metric] if metric else []
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
 
         assert len(metrics) == 1
         assert metrics[0]["label"] == "Total Sales"
@@ -62,8 +147,7 @@ class TestBigNumberChartFallback:
             "viz_type": "big_number",
         }
 
-        metric = form_data.get("metric")
-        metrics = [metric] if metric else []
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
 
         assert len(metrics) == 0
 
@@ -75,13 +159,9 @@ class TestBigNumberChartFallback:
             "groupby": ["should_be_ignored"],  # This should be ignored
         }
 
-        viz_type = form_data.get("viz_type", "")
-        if viz_type.startswith("big_number"):
-            groupby_columns: list[str] = []  # big_number charts don't group by
-        else:
-            groupby_columns = form_data.get("groupby", [])
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
 
-        assert groupby_columns == []
+        assert groupby == []
 
     def test_standard_chart_uses_plural_metrics(self):
         """Test that non-big_number charts use 'metrics' (plural)."""
@@ -94,41 +174,43 @@ class TestBigNumberChartFallback:
             "viz_type": "table",
         }
 
-        viz_type = form_data.get("viz_type", "")
-        if viz_type.startswith("big_number"):
-            metric = form_data.get("metric")
-            metrics = [metric] if metric else []
-            groupby_columns: list[str] = []
-        else:
-            metrics = form_data.get("metrics", [])
-            groupby_columns = form_data.get("groupby", [])
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
 
         assert len(metrics) == 2
-        assert len(groupby_columns) == 2
+        assert len(groupby) == 2
 
     def test_viz_type_detection_for_single_metric_charts(self):
         """Test viz_type detection handles all single-metric chart types."""
-        # Chart types that use "metric" (singular) instead of "metrics" 
(plural)
-        single_metric_types = ("big_number", "pop_kpi")
-
-        # big_number variants match via startswith
-        big_number_types = ["big_number", "big_number_total"]
-        for viz_type in big_number_types:
-            is_single_metric = (
-                viz_type.startswith("big_number") or viz_type in 
single_metric_types
-            )
-            assert is_single_metric is True
+        singular_metric_types = (
+            "big_number",
+            "big_number_total",
+            "pop_kpi",
+            "world_map",
+            "treemap_v2",
+            "sunburst_v2",
+            "gauge_chart",
+        )
 
-        # pop_kpi (BigNumberPeriodOverPeriod) matches via exact match
-        assert "pop_kpi" in single_metric_types
+        for viz_type in singular_metric_types:
+            form_data = {
+                "metric": {"label": "test_metric"},
+                "viz_type": viz_type,
+            }
+            metrics, _ = _extract_metrics_and_groupby(form_data)
+            assert len(metrics) == 1, f"{viz_type} should extract singular 
metric"
 
-        # Verify standard chart types don't match
+        # Verify standard chart types don't use singular metric path
         other_types = ["table", "line", "bar", "pie", "echarts_timeseries"]
         for viz_type in other_types:
-            is_single_metric = (
-                viz_type.startswith("big_number") or viz_type in 
single_metric_types
+            form_data = {
+                "metric": {"label": "should_be_ignored"},
+                "metrics": [{"label": "plural_metric"}],
+                "viz_type": viz_type,
+            }
+            metrics, _ = _extract_metrics_and_groupby(form_data)
+            assert metrics == [{"label": "plural_metric"}], (
+                f"{viz_type} should use plural metrics"
             )
-            assert is_single_metric is False
 
     def test_pop_kpi_uses_singular_metric(self):
         """Test that pop_kpi (BigNumberPeriodOverPeriod) uses singular 
metric."""
@@ -137,19 +219,274 @@ class TestBigNumberChartFallback:
             "viz_type": "pop_kpi",
         }
 
-        viz_type = form_data.get("viz_type", "")
-        single_metric_types = ("big_number", "pop_kpi")
-        if viz_type.startswith("big_number") or viz_type in 
single_metric_types:
-            metric = form_data.get("metric")
-            metrics = [metric] if metric else []
-            groupby_columns: list[str] = []
-        else:
-            metrics = form_data.get("metrics", [])
-            groupby_columns = form_data.get("groupby", [])
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
 
         assert len(metrics) == 1
         assert metrics[0]["label"] == "Period Comparison"
-        assert groupby_columns == []
+        assert groupby == []
+
+
+class TestWorldMapChartFallback:
+    """Tests for world_map chart fallback query construction."""
+
+    def test_world_map_uses_singular_metric(self):
+        """Test that world_map charts use 'metric' (singular)."""
+        form_data = {
+            "metric": {"label": "Population"},
+            "entity": "country_code",
+            "viz_type": "world_map",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 1
+        assert metrics[0]["label"] == "Population"
+
+    def test_world_map_extracts_entity_as_groupby(self):
+        """Test that world_map entity field becomes groupby."""
+        form_data = {
+            "metric": {"label": "Population"},
+            "entity": "country_code",
+            "viz_type": "world_map",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert "country_code" in groupby
+
+    def test_world_map_extracts_series(self):
+        """Test that world_map series field is added to groupby."""
+        form_data = {
+            "metric": {"label": "Population"},
+            "entity": "country_code",
+            "series": "region",
+            "viz_type": "world_map",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert "country_code" in groupby
+        assert "region" in groupby
+
+
+class TestTreemapAndSunburstFallback:
+    """Tests for treemap_v2 and sunburst_v2 chart fallback query 
construction."""
+
+    def test_treemap_v2_uses_singular_metric(self):
+        """Test that treemap_v2 charts use 'metric' (singular)."""
+        form_data = {
+            "metric": {"label": "Revenue"},
+            "groupby": ["category", "sub_category"],
+            "viz_type": "treemap_v2",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 1
+        assert metrics[0]["label"] == "Revenue"
+        assert groupby == ["category", "sub_category"]
+
+    def test_sunburst_v2_uses_singular_metric(self):
+        """Test that sunburst_v2 charts use 'metric' (singular)."""
+        form_data = {
+            "metric": {"label": "Count"},
+            "columns": ["level1", "level2", "level3"],
+            "viz_type": "sunburst_v2",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 1
+        assert metrics[0]["label"] == "Count"
+        # columns should be picked up as groupby alternatives
+        assert "level1" in groupby
+        assert "level2" in groupby
+        assert "level3" in groupby
+
+    def test_treemap_with_columns_field(self):
+        """Test that treemap_v2 uses columns field when groupby is missing."""
+        form_data = {
+            "metric": {"label": "Revenue"},
+            "columns": ["region", "product"],
+            "viz_type": "treemap_v2",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 1
+        assert "region" in groupby
+        assert "product" in groupby
+
+
+class TestGaugeChartFallback:
+    """Tests for gauge_chart fallback query construction."""
+
+    def test_gauge_chart_uses_singular_metric(self):
+        """Test that gauge_chart uses 'metric' (singular)."""
+        form_data = {
+            "metric": {"label": "Completion %"},
+            "viz_type": "gauge_chart",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 1
+        assert metrics[0]["label"] == "Completion %"
+
+    def test_gauge_chart_with_groupby(self):
+        """Test that gauge_chart respects groupby if present."""
+        form_data = {
+            "metric": {"label": "Completion %"},
+            "groupby": ["department"],
+            "viz_type": "gauge_chart",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 1
+        assert groupby == ["department"]
+
+
+class TestBubbleChartFallback:
+    """Tests for bubble chart fallback query construction."""
+
+    def test_bubble_extracts_x_y_size_as_metrics(self):
+        """Test that bubble charts extract x, y, size as separate metrics."""
+        form_data = {
+            "x": {"label": "GDP"},
+            "y": {"label": "Life Expectancy"},
+            "size": {"label": "Population"},
+            "entity": "country",
+            "viz_type": "bubble",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 3
+        assert metrics[0]["label"] == "GDP"
+        assert metrics[1]["label"] == "Life Expectancy"
+        assert metrics[2]["label"] == "Population"
+
+    def test_bubble_extracts_entity_as_groupby(self):
+        """Test that bubble charts use entity as groupby."""
+        form_data = {
+            "x": {"label": "GDP"},
+            "y": {"label": "Life Expectancy"},
+            "size": {"label": "Population"},
+            "entity": "country",
+            "viz_type": "bubble",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert "country" in groupby
+
+    def test_bubble_extracts_series(self):
+        """Test that bubble charts include series in groupby."""
+        form_data = {
+            "x": {"label": "GDP"},
+            "y": {"label": "Life Expectancy"},
+            "size": {"label": "Population"},
+            "entity": "country",
+            "series": "continent",
+            "viz_type": "bubble",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert "country" in groupby
+        assert "continent" in groupby
+
+    def test_bubble_partial_metrics(self):
+        """Test bubble chart with only some metric fields set."""
+        form_data = {
+            "x": {"label": "GDP"},
+            "y": None,
+            "size": {"label": "Population"},
+            "entity": "country",
+            "viz_type": "bubble",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 2
+        labels = [m["label"] for m in metrics]
+        assert "GDP" in labels
+        assert "Population" in labels
+
+
+class TestFallbackMetricExtraction:
+    """Tests for the fallback singular metric extraction."""
+
+    def test_standard_chart_falls_back_to_singular_metric(self):
+        """Test that standard charts try singular metric if plural is empty."""
+        form_data = {
+            "metric": {"label": "Fallback Metric"},
+            "metrics": [],
+            "groupby": ["region"],
+            "viz_type": "bar",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 1
+        assert metrics[0]["label"] == "Fallback Metric"
+
+    def test_standard_chart_no_metrics_at_all(self):
+        """Test standard chart with neither metrics nor metric."""
+        form_data = {
+            "groupby": ["region"],
+            "viz_type": "bar",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 0
+        assert groupby == ["region"]
+
+    def test_standard_chart_uses_columns_as_groupby_fallback(self):
+        """Test that standard charts use columns field when groupby is 
empty."""
+        form_data = {
+            "metrics": [{"label": "Count"}],
+            "columns": ["col_a", "col_b"],
+            "viz_type": "table",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 1
+        assert "col_a" in groupby
+        assert "col_b" in groupby
+
+    def test_entity_series_fallback_for_unknown_chart(self):
+        """Test that entity/series are used as groupby fallback."""
+        form_data = {
+            "metric": {"label": "Some Metric"},
+            "entity": "name_col",
+            "series": "type_col",
+            "viz_type": "some_unknown_type",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert len(metrics) == 1
+        assert "name_col" in groupby
+        assert "type_col" in groupby
+
+
+class TestSafetyNetEmptyQuery:
+    """Tests for the safety net when no metrics/columns can be extracted."""
+
+    def test_completely_empty_form_data_yields_empty(self):
+        """Test that form_data with nothing extractable returns empty."""
+        form_data = {
+            "viz_type": "mystery_chart",
+        }
+
+        metrics, groupby = _extract_metrics_and_groupby(form_data)
+
+        assert metrics == []
+        assert groupby == []
 
 
 class TestXAxisInQueryContext:

Reply via email to