aminghadersohi commented on code in PR #36933:
URL: https://github.com/apache/superset/pull/36933#discussion_r2894971316


##########
tests/unit_tests/mcp_service/embedded_chart/tool/test_get_embeddable_chart.py:
##########
@@ -0,0 +1,548 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Unit tests for MCP get_embeddable_chart tool
+"""
+
+from datetime import datetime, timezone
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from superset.mcp_service.chart.schemas import (
+    ColumnRef,
+    FilterConfig,
+    TableChartConfig,
+    XYChartConfig,
+)
+from superset.mcp_service.embedded_chart.schemas import (
+    GetEmbeddableChartRequest,
+    GetEmbeddableChartResponse,
+)
+from superset.mcp_service.embedded_chart.tool.get_embeddable_chart import (
+    _ensure_guest_role_permissions,
+)
+
+
+class TestGetEmbeddableChartSchemas:
+    """Tests for get_embeddable_chart schemas."""
+
+    def test_request_with_xy_config(self):
+        """Test request with XY chart config (same as generate_chart)."""
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="genre"),
+            y=[ColumnRef(name="sales", aggregate="SUM")],
+            kind="bar",
+        )
+        request = GetEmbeddableChartRequest(
+            datasource_id=22,
+            config=config,
+        )
+        assert request.datasource_id == 22
+        assert request.config.chart_type == "xy"
+        assert request.config.x.name == "genre"
+        assert request.config.y[0].aggregate == "SUM"
+        assert request.config.kind == "bar"
+        assert request.ttl_minutes == 60  # default
+        assert request.height == 400  # default
+
+    def test_request_with_table_config(self):
+        """Test request with table chart config."""
+        config = TableChartConfig(
+            chart_type="table",
+            columns=[
+                ColumnRef(name="genre"),
+                ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
+            ],
+        )
+        request = GetEmbeddableChartRequest(
+            datasource_id="dataset-uuid-here",
+            config=config,
+            ttl_minutes=120,
+            height=600,
+        )
+        assert request.datasource_id == "dataset-uuid-here"
+        assert request.config.chart_type == "table"
+        assert len(request.config.columns) == 2
+        assert request.ttl_minutes == 120
+        assert request.height == 600
+
+    def test_request_with_rls_rules(self):
+        """Test request with row-level security rules."""
+        config = XYChartConfig(
+            chart_type="xy",
+            x=ColumnRef(name="date"),
+            y=[ColumnRef(name="count", aggregate="COUNT")],
+            kind="line",
+        )
+        request = GetEmbeddableChartRequest(
+            datasource_id=123,
+            config=config,
+            rls_rules=[
+                {"dataset": 123, "clause": "tenant_id = 42"},
+                {"dataset": 123, "clause": "region = 'US'"},
+            ],
+        )
+        assert len(request.rls_rules) == 2
+        assert request.rls_rules[0]["clause"] == "tenant_id = 42"
+
+    def test_request_with_allowed_domains(self):
+        """Test request with allowed_domains for iframe security."""
+        config = TableChartConfig(
+            chart_type="table",
+            columns=[ColumnRef(name="col1")],
+        )
+        request = GetEmbeddableChartRequest(
+            datasource_id=1,
+            config=config,
+            allowed_domains=["https://example.com";, "https://app.example.com";],
+        )
+        assert len(request.allowed_domains) == 2
+        assert "https://example.com"; in request.allowed_domains
+
+    def test_request_ttl_validation(self):
+        """Test TTL minutes validation bounds."""
+        config = TableChartConfig(
+            chart_type="table",
+            columns=[ColumnRef(name="col1")],
+        )
+        # Valid min TTL
+        request = GetEmbeddableChartRequest(
+            datasource_id=1,
+            config=config,
+            ttl_minutes=1,
+        )
+        assert request.ttl_minutes == 1
+
+        # Valid max TTL (1 week)
+        request = GetEmbeddableChartRequest(
+            datasource_id=1,
+            config=config,
+            ttl_minutes=10080,
+        )
+        assert request.ttl_minutes == 10080
+
+        # Invalid TTL should raise
+        with pytest.raises(ValueError, match="greater than or equal to 1"):
+            GetEmbeddableChartRequest(
+                datasource_id=1,
+                config=config,
+                ttl_minutes=0,  # below min
+            )
+
+        with pytest.raises(ValueError, match="less than or equal to 10080"):
+            GetEmbeddableChartRequest(
+                datasource_id=1,
+                config=config,
+                ttl_minutes=10081,  # above max
+            )
+
+    def test_request_height_validation(self):
+        """Test height validation bounds."""
+        config = TableChartConfig(
+            chart_type="table",
+            columns=[ColumnRef(name="col1")],
+        )
+        # Valid min height
+        request = GetEmbeddableChartRequest(
+            datasource_id=1,
+            config=config,
+            height=100,
+        )
+        assert request.height == 100
+
+        # Valid max height
+        request = GetEmbeddableChartRequest(
+            datasource_id=1,
+            config=config,
+            height=2000,
+        )
+        assert request.height == 2000
+
+        # Invalid heights should raise
+        with pytest.raises(ValueError, match="greater than or equal to 100"):
+            GetEmbeddableChartRequest(
+                datasource_id=1,
+                config=config,
+                height=99,  # below min
+            )
+
+        with pytest.raises(ValueError, match="less than or equal to 2000"):
+            GetEmbeddableChartRequest(
+                datasource_id=1,
+                config=config,
+                height=2001,  # above max
+            )

Review Comment:
   Thanks for the catch — this was already fixed in a follow-up commit. The 
test expectations were updated from ValueError to ValidationError to align with 
Pydantic v2's exception handling for constraint violations.



##########
superset/security/manager.py:
##########
@@ -2834,6 +2842,17 @@ def validate_guest_token_resources(resources: 
GuestTokenResources) -> None:
                     embedded = 
EmbeddedDashboardDAO.find_by_id(str(resource["id"]))
                     if not embedded:
                         raise EmbeddedDashboardNotFoundError()
+            elif resource["type"] == 
GuestTokenResourceType.CHART_PERMALINK.value:
+                # Validate that the chart permalink exists
+                permalink_key = str(resource["id"])
+                try:
+                    permalink_value = 
GetExplorePermalinkCommand(permalink_key).run()
+                    if not permalink_value:
+                        raise EmbeddedChartPermalinkNotFoundError()
+                except EmbeddedChartPermalinkNotFoundError:
+                    raise
+                except Exception:
+                    raise EmbeddedChartPermalinkNotFoundError() from None

Review Comment:
   Thanks for confirming the fix looks good.



##########
superset/security/manager.py:
##########
@@ -2834,6 +2838,17 @@ def validate_guest_token_resources(resources: 
GuestTokenResources) -> None:
                     embedded = 
EmbeddedDashboardDAO.find_by_id(str(resource["id"]))
                     if not embedded:
                         raise EmbeddedDashboardNotFoundError()
+            elif resource["type"] == 
GuestTokenResourceType.CHART_PERMALINK.value:
+                # Validate that the chart permalink exists
+                permalink_key = str(resource["id"])
+                try:
+                    permalink_value = 
GetExplorePermalinkCommand(permalink_key).run()
+                    if not permalink_value:
+                        raise EmbeddedChartPermalinkNotFoundError()
+                except EmbeddedChartPermalinkNotFoundError:
+                    raise
+                except Exception:
+                    raise EmbeddedChartPermalinkNotFoundError()

Review Comment:
   Appreciate the thorough review. Glad the specific exception handling 
approach aligns well.



##########
superset-frontend/webpack.config.js:
##########
@@ -300,6 +300,7 @@ const config = {
     menu: addPreamble('src/views/menu.tsx'),
     spa: addPreamble('/src/views/index.tsx'),
     embedded: addPreamble('/src/embedded/index.tsx'),
+    embeddedChart: addPreamble('/src/embeddedChart/index.tsx'),
   },

Review Comment:
   Thanks for verifying — correct, no leading slash needed to match the 
existing entries.



##########
superset/embedded_chart/api.py:
##########
@@ -0,0 +1,298 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+from datetime import datetime, timedelta, timezone
+from typing import Any
+
+from flask import g, request, Response
+from flask_appbuilder.api import expose, protect, safe
+
+from superset.commands.explore.permalink.create import 
CreateExplorePermalinkCommand
+from superset.daos.key_value import KeyValueDAO
+from superset.embedded_chart.exceptions import (
+    EmbeddedChartAccessDeniedError,
+    EmbeddedChartPermalinkNotFoundError,
+)
+from superset.explore.permalink.schemas import ExplorePermalinkSchema
+from superset.extensions import event_logger, security_manager
+from superset.key_value.shared_entries import get_permalink_salt
+from superset.key_value.types import (
+    KeyValueResource,
+    MarshmallowKeyValueCodec,
+    SharedKey,
+)
+from superset.key_value.utils import decode_permalink_id
+from superset.security.guest_token import (
+    GuestTokenResource,
+    GuestTokenResourceType,
+    GuestTokenRlsRule,
+    GuestTokenUser,
+    GuestUser,
+)
+from superset.views.base_api import BaseSupersetApi, statsd_metrics
+
+logger = logging.getLogger(__name__)
+
+
+class EmbeddedChartRestApi(BaseSupersetApi):
+    """REST API for embedded chart data retrieval."""
+
+    resource_name = "embedded_chart"
+    allow_browser_login = True
+    openapi_spec_tag = "Embedded Chart"
+
+    def _validate_guest_token_access(self, permalink_key: str) -> bool:
+        """
+        Validate that the guest token grants access to this permalink.
+
+        Guest tokens contain a list of resources the user can access.
+        For embedded charts, we check that the permalink_key is in that list.
+        """
+        user = g.user
+        if not isinstance(user, GuestUser):
+            return False
+
+        for resource in user.resources:
+            if (
+                resource.get("type") == 
GuestTokenResourceType.CHART_PERMALINK.value
+                and str(resource.get("id")) == permalink_key
+            ):
+                return True
+        return False
+
+    def _get_permalink_value(self, permalink_key: str) -> dict[str, Any] | 
None:
+        """
+        Get permalink value without access checks.
+
+        For embedded charts, access is controlled via guest token validation,
+        so we skip the normal dataset/chart access checks.
+        """
+        # Use the same salt, resource, and codec as the explore permalink 
command
+        salt = get_permalink_salt(SharedKey.EXPLORE_PERMALINK_SALT)
+        codec = MarshmallowKeyValueCodec(ExplorePermalinkSchema())
+        key = decode_permalink_id(permalink_key, salt=salt)
+        return KeyValueDAO.get_value(
+            KeyValueResource.EXPLORE_PERMALINK,
+            key,
+            codec,
+        )
+
+    @expose("/<permalink_key>", methods=("GET",))
+    @safe
+    @statsd_metrics
+    @event_logger.log_this_with_context(
+        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get",
+        log_to_statsd=False,
+    )
+    def get(self, permalink_key: str) -> Response:
+        """Get chart form_data from permalink key.
+        ---
+        get:
+          summary: Get embedded chart configuration
+          description: >-
+            Retrieves the form_data for rendering an embedded chart.
+            This endpoint is used by the embedded chart iframe to load
+            the chart configuration.
+          parameters:
+          - in: path
+            schema:
+              type: string
+            name: permalink_key
+            description: The chart permalink key
+            required: true
+          responses:
+            200:
+              description: Chart permalink state
+              content:
+                application/json:
+                  schema:
+                    type: object
+                    properties:
+                      state:
+                        type: object
+                        properties:
+                          formData:
+                            type: object
+                            description: The chart configuration formData
+                          allowedDomains:
+                            type: array
+                            items:
+                              type: string
+                            description: Domains allowed to embed this chart
+            401:
+              $ref: '#/components/responses/401'
+            404:
+              $ref: '#/components/responses/404'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        try:
+            # Validate guest token grants access to this permalink
+            if not self._validate_guest_token_access(permalink_key):
+                raise EmbeddedChartAccessDeniedError()
+
+            # Get permalink value without access checks (guest token already 
validated)
+            permalink_value = self._get_permalink_value(permalink_key)
+            if not permalink_value:
+                raise EmbeddedChartPermalinkNotFoundError()
+
+            # Return state in the format expected by the frontend:
+            # { state: { formData: {...}, allowedDomains: [...] } }
+            state = permalink_value.get("state", {})
+
+            return self.response(
+                200,
+                state=state,
+            )
+        except EmbeddedChartAccessDeniedError:
+            return self.response_401()
+        except EmbeddedChartPermalinkNotFoundError:
+            return self.response_404()
+        except Exception:
+            logger.exception("Error fetching embedded chart")
+            return self.response_500()
+
+    @expose("/", methods=("POST",))
+    @protect()
+    @safe
+    @statsd_metrics
+    @event_logger.log_this_with_context(
+        action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post",
+        log_to_statsd=False,
+    )
+    def post(self) -> Response:
+        """Create an embeddable chart with guest token.
+        ---
+        post:
+          summary: Create embeddable chart
+          description: >-
+            Creates an embeddable chart configuration with a guest token.
+            The returned iframe_url and guest_token can be used to embed
+            the chart in external applications.
+          requestBody:
+            required: true
+            content:
+              application/json:
+                schema:
+                  type: object
+                  required:
+                    - form_data
+                  properties:
+                    form_data:
+                      type: object
+                      description: Chart form_data configuration
+                    allowed_domains:
+                      type: array
+                      items:
+                        type: string
+                      description: Domains allowed to embed this chart
+                    ttl_minutes:
+                      type: integer
+                      default: 60
+                      description: Time-to-live for the embed in minutes
+          responses:
+            200:
+              description: Embeddable chart created
+              content:
+                application/json:
+                  schema:
+                    type: object
+                    properties:
+                      iframe_url:
+                        type: string
+                        description: URL to use in iframe src
+                      guest_token:
+                        type: string
+                        description: Guest token for authentication
+                      permalink_key:
+                        type: string
+                        description: Permalink key for the chart
+                      expires_at:
+                        type: string
+                        format: date-time
+                        description: When the embed expires
+            400:
+              $ref: '#/components/responses/400'
+            401:
+              $ref: '#/components/responses/401'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        try:
+            body = request.json or {}
+            form_data = body.get("form_data", {})
+            allowed_domains: list[str] = body.get("allowed_domains", [])
+            ttl_minutes: int = body.get("ttl_minutes", 60)
+
+            if not form_data:

Review Comment:
   Thanks for confirming the validation logic looks correct.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to