github-advanced-security[bot] commented on code in PR #35841: URL: https://github.com/apache/superset/pull/35841#discussion_r2462291698
########## superset/mcp_service/chart/schemas.py: ########## @@ -0,0 +1,1075 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pydantic schemas for chart-related responses +""" + +from __future__ import annotations + +import html +import re +from datetime import datetime, timezone +from typing import Annotated, Any, Dict, List, Literal, Protocol + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, + PositiveInt, +) + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.common.cache_schemas import ( + CacheStatus, + FormDataCacheControl, + MetadataCacheControl, + QueryCacheControl, +) +from superset.mcp_service.system.schemas import ( + PaginationInfo, + TagInfo, + UserInfo, +) + + +class ChartLike(Protocol): + """Protocol for chart-like objects with expected attributes.""" + + id: int + slice_name: str | None + viz_type: str | None + datasource_name: str | None + datasource_type: str | None + url: str | None + description: str | None + cache_timeout: int | None + form_data: Dict[str, Any] | None + query_context: Any | None + changed_by: Any | None # User object + changed_by_name: str | None + changed_on: str | datetime | None + changed_on_humanized: str | None + created_by: Any | None # User object + created_by_name: str | None + created_on: str | datetime | None + created_on_humanized: str | None + uuid: str | None + tags: List[Any] | None + owners: List[Any] | None + + +class ChartInfo(BaseModel): + """Full chart model with all possible attributes.""" + + id: int = Field(..., description="Chart ID") + slice_name: str = Field(..., description="Chart name") + viz_type: str | None = Field(None, description="Visualization type") + datasource_name: str | None = Field(None, description="Datasource name") + datasource_type: str | None = Field(None, description="Datasource type") + url: str | None = Field(None, description="Chart URL") + description: str | None = Field(None, description="Chart description") + cache_timeout: int | None = Field(None, description="Cache timeout") + form_data: Dict[str, Any] | None = Field(None, description="Chart form data") + query_context: Any | None = Field(None, description="Query context") + changed_by: str | None = Field(None, description="Last modifier (username)") + changed_by_name: str | None = Field( + None, description="Last modifier (display name)" + ) + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + changed_on_humanized: str | None = Field( + None, description="Humanized modification time" + ) + created_by: str | None = Field(None, description="Chart creator (username)") + created_on: str | datetime | None = Field(None, description="Creation timestamp") + created_on_humanized: str | None = Field( + None, description="Humanized creation time" + ) + uuid: str | None = Field(None, description="Chart UUID") + tags: List[TagInfo] = Field(default_factory=list, description="Chart tags") + owners: List[UserInfo] = Field(default_factory=list, description="Chart owners") + model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") + + +class GetChartAvailableFiltersRequest(BaseModel): + """ + Request schema for get_chart_available_filters tool. + + Currently has no parameters but provides consistent API for future extensibility. + """ + + model_config = ConfigDict( + extra="forbid", + str_strip_whitespace=True, + ) + + +class ChartAvailableFiltersResponse(BaseModel): + column_operators: Dict[str, Any] = Field( + ..., description="Available filter operators and metadata for each column" + ) + + +class ChartError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="Error timestamp", + ) + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ChartCapabilities(BaseModel): + """Describes what the chart can do for LLM understanding.""" + + supports_interaction: bool = Field(description="Chart supports user interaction") + supports_real_time: bool = Field(description="Chart supports live data updates") + supports_drill_down: bool = Field( + description="Chart supports drill-down navigation" + ) + supports_export: bool = Field(description="Chart can be exported to other formats") + optimal_formats: List[str] = Field(description="Recommended preview formats") + data_types: List[str] = Field( + description="Types of data shown (time_series, categorical, etc)" + ) + + +class ChartSemantics(BaseModel): + """Semantic information for LLM reasoning.""" + + primary_insight: str = Field( + description="Main insight or pattern the chart reveals" + ) + data_story: str = Field(description="Narrative description of what the data shows") + recommended_actions: List[str] = Field( + description="Suggested next steps based on data" + ) + anomalies: List[str] = Field(description="Notable outliers or unusual patterns") + statistical_summary: Dict[str, Any] = Field( + description="Key statistics (mean, median, trends)" + ) + + +class PerformanceMetadata(BaseModel): + """Performance information for LLM cost understanding.""" + + query_duration_ms: int = Field(description="Query execution time") + estimated_cost: str | None = Field(None, description="Resource cost estimate") + cache_status: str = Field(description="Cache hit/miss status") + optimization_suggestions: List[str] = Field( + default_factory=list, description="Performance improvement tips" + ) + + +class AccessibilityMetadata(BaseModel): + """Accessibility information for inclusive visualization.""" + + color_blind_safe: bool = Field(description="Uses colorblind-safe palette") + alt_text: str = Field(description="Screen reader description") + high_contrast_available: bool = Field(description="High contrast version available") + + +class VersionedResponse(BaseModel): + """Base class for versioned API responses.""" + + schema_version: str = Field("2.0", description="Response schema version") + api_version: str = Field("v1", description="MCP API version") + + +class GetChartInfoRequest(BaseModel): + """Request schema for get_chart_info with support for ID or UUID.""" + + identifier: Annotated[ + int | str, + Field(description="Chart identifier - can be numeric ID or UUID string"), + ] + + +def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None: + if not chart: + return None + + # Generate MCP service screenshot URL instead of chart's native URL + from superset.mcp_service.utils.url_utils import get_chart_screenshot_url + + chart_id = getattr(chart, "id", None) + screenshot_url = None + if chart_id: + screenshot_url = get_chart_screenshot_url(chart_id) + + return ChartInfo( + id=chart_id, + slice_name=getattr(chart, "slice_name", None), + viz_type=getattr(chart, "viz_type", None), + datasource_name=getattr(chart, "datasource_name", None), + datasource_type=getattr(chart, "datasource_type", None), + url=screenshot_url, + description=getattr(chart, "description", None), + cache_timeout=getattr(chart, "cache_timeout", None), + form_data=getattr(chart, "form_data", None), + query_context=getattr(chart, "query_context", None), + changed_by=getattr(chart, "changed_by_name", None) + or (str(chart.changed_by) if getattr(chart, "changed_by", None) else None), + changed_by_name=getattr(chart, "changed_by_name", None), + changed_on=getattr(chart, "changed_on", None), + changed_on_humanized=getattr(chart, "changed_on_humanized", None), + created_by=getattr(chart, "created_by_name", None) + or (str(chart.created_by) if getattr(chart, "created_by", None) else None), + created_on=getattr(chart, "created_on", None), + created_on_humanized=getattr(chart, "created_on_humanized", None), + uuid=str(getattr(chart, "uuid", "")) if getattr(chart, "uuid", None) else None, + tags=[ + TagInfo.model_validate(tag, from_attributes=True) + for tag in getattr(chart, "tags", []) + ] + if getattr(chart, "tags", None) + else [], + owners=[ + UserInfo.model_validate(owner, from_attributes=True) + for owner in getattr(chart, "owners", []) + ] + if getattr(chart, "owners", None) + else [], + ) + + +class ChartFilter(ColumnOperator): + """ + Filter object for chart listing. + col: The column to filter on. Must be one of the allowed filter fields. + opr: The operator to use. Must be one of the supported operators. + value: The value to filter by (type depends on col and opr). + """ + + col: Literal[ + "slice_name", + "viz_type", + "datasource_name", + ] = Field( + ..., + description="Column to filter on. See get_chart_available_filters for " + "allowed values.", + ) + opr: ColumnOperatorEnum = Field( + ..., + description="Operator to use. See get_chart_available_filters for " + "allowed values.", + ) + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by (type depends on col and opr)" + ) + + +class ChartList(BaseModel): + charts: List[ChartInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] | None = None + columns_loaded: List[str] | None = None + filters_applied: List[ChartFilter] = Field( + default_factory=list, + description="List of advanced filter dicts applied to the query.", + ) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +# --- Simplified schemas for generate_chart tool --- + + +# Common pieces +class ColumnRef(BaseModel): + name: str = Field( + ..., + description="Column name", + min_length=1, + max_length=255, + pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", + ) + label: str | None = Field( + None, description="Display label for the column", max_length=500 + ) + dtype: str | None = Field(None, description="Data type hint") + aggregate: ( + Literal[ + "SUM", + "COUNT", + "AVG", + "MIN", + "MAX", + "COUNT_DISTINCT", + "STDDEV", + "VAR", + "MEDIAN", + "PERCENTILE", + ] + | None + ) = Field( + None, + description="SQL aggregation function. Only these validated functions are " + "supported to prevent SQL errors.", + ) + + @field_validator("name") + @classmethod + def sanitize_name(cls, v: str) -> str: + """Sanitize column name to prevent XSS and SQL injection.""" + if not v or not v.strip(): + raise ValueError("Column name cannot be empty") + + # Remove HTML tags and decode entities + sanitized = html.escape(v.strip()) + + # Check for script content + if re.search(r"<script[^>]*>.*?</script>", v, re.IGNORECASE | re.DOTALL): + raise ValueError( + "Column name contains potentially malicious script content" + ) + + # Basic SQL injection patterns (basic protection) + dangerous_patterns = [ + r"(;|\||&|\$|`)", + r"\b(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b", + r"--", + r"/\*.*\*/", + ] + + for pattern in dangerous_patterns: + if re.search(pattern, v, re.IGNORECASE): + raise ValueError( + "Column name contains potentially unsafe characters or SQL keywords" + ) + + return sanitized + + @field_validator("label") + @classmethod + def sanitize_label(cls, v: str | None) -> str | None: + """Sanitize display label to prevent XSS attacks.""" + if v is None: + return v + + # Strip whitespace + v = v.strip() + if not v: + return None + + # Check for dangerous HTML tags and JavaScript protocols BEFORE escaping + dangerous_patterns = [ + r"<script[^>]*>.*?</script>", # Script tags + r"<iframe[^>]*>.*?</iframe>", # Iframe tags + r"<object[^>]*>.*?</object>", # Object tags + r"<embed[^>]*>.*?</embed>", # Embed tags + r"<link[^>]*>", # Link tags + r"<meta[^>]*>", # Meta tags + r"javascript:", # JavaScript protocol + r"vbscript:", # VBScript protocol + r"data:text/html", # Data URL HTML + r"on\w+\s*=", # Event handlers (onclick, onload, etc) + ] + + for pattern in dangerous_patterns: + if re.search(pattern, v, re.IGNORECASE | re.DOTALL): + raise ValueError( + "Label contains potentially malicious content. " Review Comment: ## Bad HTML filtering regexp This regular expression does not match script end tags like </script >. [Show more details](https://github.com/apache/superset/security/code-scanning/2044) ########## superset/mcp_service/chart/schemas.py: ########## @@ -0,0 +1,1075 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pydantic schemas for chart-related responses +""" + +from __future__ import annotations + +import html +import re +from datetime import datetime, timezone +from typing import Annotated, Any, Dict, List, Literal, Protocol + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, + PositiveInt, +) + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.common.cache_schemas import ( + CacheStatus, + FormDataCacheControl, + MetadataCacheControl, + QueryCacheControl, +) +from superset.mcp_service.system.schemas import ( + PaginationInfo, + TagInfo, + UserInfo, +) + + +class ChartLike(Protocol): + """Protocol for chart-like objects with expected attributes.""" + + id: int + slice_name: str | None + viz_type: str | None + datasource_name: str | None + datasource_type: str | None + url: str | None + description: str | None + cache_timeout: int | None + form_data: Dict[str, Any] | None + query_context: Any | None + changed_by: Any | None # User object + changed_by_name: str | None + changed_on: str | datetime | None + changed_on_humanized: str | None + created_by: Any | None # User object + created_by_name: str | None + created_on: str | datetime | None + created_on_humanized: str | None + uuid: str | None + tags: List[Any] | None + owners: List[Any] | None + + +class ChartInfo(BaseModel): + """Full chart model with all possible attributes.""" + + id: int = Field(..., description="Chart ID") + slice_name: str = Field(..., description="Chart name") + viz_type: str | None = Field(None, description="Visualization type") + datasource_name: str | None = Field(None, description="Datasource name") + datasource_type: str | None = Field(None, description="Datasource type") + url: str | None = Field(None, description="Chart URL") + description: str | None = Field(None, description="Chart description") + cache_timeout: int | None = Field(None, description="Cache timeout") + form_data: Dict[str, Any] | None = Field(None, description="Chart form data") + query_context: Any | None = Field(None, description="Query context") + changed_by: str | None = Field(None, description="Last modifier (username)") + changed_by_name: str | None = Field( + None, description="Last modifier (display name)" + ) + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + changed_on_humanized: str | None = Field( + None, description="Humanized modification time" + ) + created_by: str | None = Field(None, description="Chart creator (username)") + created_on: str | datetime | None = Field(None, description="Creation timestamp") + created_on_humanized: str | None = Field( + None, description="Humanized creation time" + ) + uuid: str | None = Field(None, description="Chart UUID") + tags: List[TagInfo] = Field(default_factory=list, description="Chart tags") + owners: List[UserInfo] = Field(default_factory=list, description="Chart owners") + model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") + + +class GetChartAvailableFiltersRequest(BaseModel): + """ + Request schema for get_chart_available_filters tool. + + Currently has no parameters but provides consistent API for future extensibility. + """ + + model_config = ConfigDict( + extra="forbid", + str_strip_whitespace=True, + ) + + +class ChartAvailableFiltersResponse(BaseModel): + column_operators: Dict[str, Any] = Field( + ..., description="Available filter operators and metadata for each column" + ) + + +class ChartError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="Error timestamp", + ) + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ChartCapabilities(BaseModel): + """Describes what the chart can do for LLM understanding.""" + + supports_interaction: bool = Field(description="Chart supports user interaction") + supports_real_time: bool = Field(description="Chart supports live data updates") + supports_drill_down: bool = Field( + description="Chart supports drill-down navigation" + ) + supports_export: bool = Field(description="Chart can be exported to other formats") + optimal_formats: List[str] = Field(description="Recommended preview formats") + data_types: List[str] = Field( + description="Types of data shown (time_series, categorical, etc)" + ) + + +class ChartSemantics(BaseModel): + """Semantic information for LLM reasoning.""" + + primary_insight: str = Field( + description="Main insight or pattern the chart reveals" + ) + data_story: str = Field(description="Narrative description of what the data shows") + recommended_actions: List[str] = Field( + description="Suggested next steps based on data" + ) + anomalies: List[str] = Field(description="Notable outliers or unusual patterns") + statistical_summary: Dict[str, Any] = Field( + description="Key statistics (mean, median, trends)" + ) + + +class PerformanceMetadata(BaseModel): + """Performance information for LLM cost understanding.""" + + query_duration_ms: int = Field(description="Query execution time") + estimated_cost: str | None = Field(None, description="Resource cost estimate") + cache_status: str = Field(description="Cache hit/miss status") + optimization_suggestions: List[str] = Field( + default_factory=list, description="Performance improvement tips" + ) + + +class AccessibilityMetadata(BaseModel): + """Accessibility information for inclusive visualization.""" + + color_blind_safe: bool = Field(description="Uses colorblind-safe palette") + alt_text: str = Field(description="Screen reader description") + high_contrast_available: bool = Field(description="High contrast version available") + + +class VersionedResponse(BaseModel): + """Base class for versioned API responses.""" + + schema_version: str = Field("2.0", description="Response schema version") + api_version: str = Field("v1", description="MCP API version") + + +class GetChartInfoRequest(BaseModel): + """Request schema for get_chart_info with support for ID or UUID.""" + + identifier: Annotated[ + int | str, + Field(description="Chart identifier - can be numeric ID or UUID string"), + ] + + +def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None: + if not chart: + return None + + # Generate MCP service screenshot URL instead of chart's native URL + from superset.mcp_service.utils.url_utils import get_chart_screenshot_url + + chart_id = getattr(chart, "id", None) + screenshot_url = None + if chart_id: + screenshot_url = get_chart_screenshot_url(chart_id) + + return ChartInfo( + id=chart_id, + slice_name=getattr(chart, "slice_name", None), + viz_type=getattr(chart, "viz_type", None), + datasource_name=getattr(chart, "datasource_name", None), + datasource_type=getattr(chart, "datasource_type", None), + url=screenshot_url, + description=getattr(chart, "description", None), + cache_timeout=getattr(chart, "cache_timeout", None), + form_data=getattr(chart, "form_data", None), + query_context=getattr(chart, "query_context", None), + changed_by=getattr(chart, "changed_by_name", None) + or (str(chart.changed_by) if getattr(chart, "changed_by", None) else None), + changed_by_name=getattr(chart, "changed_by_name", None), + changed_on=getattr(chart, "changed_on", None), + changed_on_humanized=getattr(chart, "changed_on_humanized", None), + created_by=getattr(chart, "created_by_name", None) + or (str(chart.created_by) if getattr(chart, "created_by", None) else None), + created_on=getattr(chart, "created_on", None), + created_on_humanized=getattr(chart, "created_on_humanized", None), + uuid=str(getattr(chart, "uuid", "")) if getattr(chart, "uuid", None) else None, + tags=[ + TagInfo.model_validate(tag, from_attributes=True) + for tag in getattr(chart, "tags", []) + ] + if getattr(chart, "tags", None) + else [], + owners=[ + UserInfo.model_validate(owner, from_attributes=True) + for owner in getattr(chart, "owners", []) + ] + if getattr(chart, "owners", None) + else [], + ) + + +class ChartFilter(ColumnOperator): + """ + Filter object for chart listing. + col: The column to filter on. Must be one of the allowed filter fields. + opr: The operator to use. Must be one of the supported operators. + value: The value to filter by (type depends on col and opr). + """ + + col: Literal[ + "slice_name", + "viz_type", + "datasource_name", + ] = Field( + ..., + description="Column to filter on. See get_chart_available_filters for " + "allowed values.", + ) + opr: ColumnOperatorEnum = Field( + ..., + description="Operator to use. See get_chart_available_filters for " + "allowed values.", + ) + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by (type depends on col and opr)" + ) + + +class ChartList(BaseModel): + charts: List[ChartInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] | None = None + columns_loaded: List[str] | None = None + filters_applied: List[ChartFilter] = Field( + default_factory=list, + description="List of advanced filter dicts applied to the query.", + ) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +# --- Simplified schemas for generate_chart tool --- + + +# Common pieces +class ColumnRef(BaseModel): + name: str = Field( + ..., + description="Column name", + min_length=1, + max_length=255, + pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", + ) + label: str | None = Field( + None, description="Display label for the column", max_length=500 + ) + dtype: str | None = Field(None, description="Data type hint") + aggregate: ( + Literal[ + "SUM", + "COUNT", + "AVG", + "MIN", + "MAX", + "COUNT_DISTINCT", + "STDDEV", + "VAR", + "MEDIAN", + "PERCENTILE", + ] + | None + ) = Field( + None, + description="SQL aggregation function. Only these validated functions are " + "supported to prevent SQL errors.", + ) + + @field_validator("name") + @classmethod + def sanitize_name(cls, v: str) -> str: + """Sanitize column name to prevent XSS and SQL injection.""" + if not v or not v.strip(): + raise ValueError("Column name cannot be empty") + + # Remove HTML tags and decode entities + sanitized = html.escape(v.strip()) + + # Check for script content + if re.search(r"<script[^>]*>.*?</script>", v, re.IGNORECASE | re.DOTALL): + raise ValueError( + "Column name contains potentially malicious script content" + ) + + # Basic SQL injection patterns (basic protection) + dangerous_patterns = [ + r"(;|\||&|\$|`)", + r"\b(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b", + r"--", + r"/\*.*\*/", + ] + + for pattern in dangerous_patterns: + if re.search(pattern, v, re.IGNORECASE): + raise ValueError( + "Column name contains potentially unsafe characters or SQL keywords" + ) + + return sanitized + + @field_validator("label") + @classmethod + def sanitize_label(cls, v: str | None) -> str | None: + """Sanitize display label to prevent XSS attacks.""" + if v is None: + return v + + # Strip whitespace + v = v.strip() + if not v: + return None + + # Check for dangerous HTML tags and JavaScript protocols BEFORE escaping + dangerous_patterns = [ + r"<script[^>]*>.*?</script>", # Script tags + r"<iframe[^>]*>.*?</iframe>", # Iframe tags + r"<object[^>]*>.*?</object>", # Object tags + r"<embed[^>]*>.*?</embed>", # Embed tags + r"<link[^>]*>", # Link tags + r"<meta[^>]*>", # Meta tags + r"javascript:", # JavaScript protocol + r"vbscript:", # VBScript protocol + r"data:text/html", # Data URL HTML + r"on\w+\s*=", # Event handlers (onclick, onload, etc) + ] + + for pattern in dangerous_patterns: + if re.search(pattern, v, re.IGNORECASE | re.DOTALL): + raise ValueError( + "Label contains potentially malicious content. " + "HTML tags, JavaScript, and event handlers are not allowed in " + "labels." + ) + + # Filter dangerous Unicode characters + v = re.sub( + r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v + ) + + # HTML escape the cleaned content + sanitized = html.escape(v) + + return sanitized if sanitized else None + + +class AxisConfig(BaseModel): + title: str | None = Field(None, description="Axis title", max_length=200) + scale: Literal["linear", "log"] | None = Field( + "linear", description="Axis scale type" + ) + format: str | None = Field( + None, description="Format string (e.g. '$,.2f')", max_length=50 + ) + + +class LegendConfig(BaseModel): + show: bool = Field(True, description="Whether to show legend") + position: Literal["top", "bottom", "left", "right"] | None = Field( + "right", description="Legend position" + ) + + +class FilterConfig(BaseModel): + column: str = Field( + ..., description="Column to filter on", min_length=1, max_length=255 + ) + op: Literal["=", ">", "<", ">=", "<=", "!="] = Field( + ..., description="Filter operator" + ) + value: str | int | float | bool = Field(..., description="Filter value") + + @field_validator("column") + @classmethod + def sanitize_column(cls, v: str) -> str: + """Sanitize filter column name to prevent injection attacks.""" + if not v or not v.strip(): + raise ValueError("Filter column name cannot be empty") + + # Remove HTML tags and decode entities + sanitized = html.escape(v.strip()) + + # Check for dangerous patterns + if re.search(r"<script[^>]*>.*?</script>", v, re.IGNORECASE | re.DOTALL): + raise ValueError( + "Filter column contains potentially malicious script content" + ) + + return sanitized + + @field_validator("value") + @classmethod + def sanitize_value(cls, v: str | int | float | bool) -> str | int | float | bool: + """Sanitize filter value to prevent XSS and SQL injection attacks.""" + if isinstance(v, str): + # Strip whitespace + v = v.strip() + + # Check for dangerous patterns (SQL injection, XSS, script injection) + dangerous_patterns = [ + # SQL injection patterns + r";\s*(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b", + r"'\s*OR\s*'", + r"'\s*AND\s*'", + r"--\s*", + r"/\*.*?\*/", + r"UNION\s+SELECT", + r"xp_cmdshell", + r"sp_executesql", + # XSS patterns + r"<script[^>]*>.*?</script>", + r"<iframe[^>]*>.*?</iframe>", + r"<object[^>]*>.*?</object>", + r"<embed[^>]*>.*?</embed>", + r"javascript:", + r"vbscript:", + r"data:text/html", + r"on\w+\s*=", + # Command injection patterns + r"[;&|`$()]", + r"\\x[0-9a-fA-F]{2}", # Hex encoding + ] + + for pattern in dangerous_patterns: + if re.search(pattern, v, re.IGNORECASE | re.DOTALL): + raise ValueError( + "Filter value contains potentially malicious content. " + "SQL injection attempts, HTML tags, JavaScript, and command " + "injection " + "are not allowed in filter values." + ) + + # Additional length check for filter values + if len(v) > 1000: Review Comment: ## Bad HTML filtering regexp This regular expression does not match script end tags like </script >. [Show more details](https://github.com/apache/superset/security/code-scanning/2045) ########## tests/unit_tests/mcp_service/chart/test_chart_utils.py: ########## @@ -0,0 +1,460 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for chart utilities module""" + +from unittest.mock import patch + +import pytest + +from superset.mcp_service.chart.chart_utils import ( + create_metric_object, + generate_chart_name, + generate_explore_link, + map_config_to_form_data, + map_filter_operator, + map_table_config, + map_xy_config, +) +from superset.mcp_service.chart.schemas import ( + AxisConfig, + ColumnRef, + FilterConfig, + LegendConfig, + TableChartConfig, + XYChartConfig, +) + + +class TestCreateMetricObject: + """Test create_metric_object function""" + + def test_create_metric_object_with_aggregate(self) -> None: + """Test creating metric object with specified aggregate""" + col = ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue") + result = create_metric_object(col) + + assert result["aggregate"] == "SUM" + assert result["column"]["column_name"] == "revenue" + assert result["label"] == "Total Revenue" + assert result["optionName"] == "metric_revenue" + assert result["expressionType"] == "SIMPLE" + + def test_create_metric_object_default_aggregate(self) -> None: + """Test creating metric object with default aggregate""" + col = ColumnRef(name="orders") + result = create_metric_object(col) + + assert result["aggregate"] == "SUM" + assert result["column"]["column_name"] == "orders" + assert result["label"] == "SUM(orders)" + assert result["optionName"] == "metric_orders" + + +class TestMapFilterOperator: + """Test map_filter_operator function""" + + def test_map_filter_operators(self) -> None: + """Test mapping of various filter operators""" + assert map_filter_operator("=") == "==" + assert map_filter_operator(">") == ">" + assert map_filter_operator("<") == "<" + assert map_filter_operator(">=") == ">=" + assert map_filter_operator("<=") == "<=" + assert map_filter_operator("!=") == "!=" + + def test_map_filter_operator_unknown(self) -> None: + """Test mapping of unknown operator returns original""" + assert map_filter_operator("UNKNOWN") == "UNKNOWN" + + +class TestMapTableConfig: + """Test map_table_config function""" + + def test_map_table_config_basic(self) -> None: + """Test basic table config mapping with aggregated columns""" + config = TableChartConfig( + columns=[ + ColumnRef(name="product", aggregate="COUNT"), + ColumnRef(name="revenue", aggregate="SUM"), + ] + ) + + result = map_table_config(config) + + assert result["viz_type"] == "table" + assert result["query_mode"] == "aggregate" + # Aggregated columns should be in metrics, not all_columns + assert "all_columns" not in result + assert len(result["metrics"]) == 2 + assert result["metrics"][0]["aggregate"] == "COUNT" + assert result["metrics"][1]["aggregate"] == "SUM" + + def test_map_table_config_raw_columns(self) -> None: + """Test table config mapping with raw columns (no aggregates)""" + config = TableChartConfig( + columns=[ + ColumnRef(name="product"), + ColumnRef(name="category"), + ] + ) + + result = map_table_config(config) + + assert result["viz_type"] == "table" + assert result["query_mode"] == "raw" + # Raw columns should be in all_columns + assert result["all_columns"] == ["product", "category"] + assert "metrics" not in result + + def test_map_table_config_with_filters(self) -> None: + """Test table config mapping with filters""" + config = TableChartConfig( + columns=[ColumnRef(name="product")], + filters=[FilterConfig(column="status", op="=", value="active")], + ) + + result = map_table_config(config) + + assert "adhoc_filters" in result + assert len(result["adhoc_filters"]) == 1 + filter_obj = result["adhoc_filters"][0] + assert filter_obj["subject"] == "status" + assert filter_obj["operator"] == "==" + assert filter_obj["comparator"] == "active" + assert filter_obj["expressionType"] == "SIMPLE" + + def test_map_table_config_with_sort(self) -> None: + """Test table config mapping with sort""" + config = TableChartConfig( + columns=[ColumnRef(name="product")], sort_by=["product", "revenue"] + ) + + result = map_table_config(config) + assert result["order_by_cols"] == ["product", "revenue"] + + +class TestMapXYConfig: + """Test map_xy_config function""" + + def test_map_xy_config_line_chart(self) -> None: + """Test XY config mapping for line chart""" + config = XYChartConfig( + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + kind="line", + ) + + result = map_xy_config(config) + + assert result["viz_type"] == "echarts_timeseries_line" + assert result["x_axis"] == "date" + assert len(result["metrics"]) == 1 + assert result["metrics"][0]["aggregate"] == "SUM" + + def test_map_xy_config_with_groupby(self) -> None: + """Test XY config mapping with group by""" + config = XYChartConfig( + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue")], + kind="bar", + group_by=ColumnRef(name="region"), + ) + + result = map_xy_config(config) + + assert result["viz_type"] == "echarts_timeseries_bar" + assert result["groupby"] == ["region"] + + def test_map_xy_config_with_axes(self) -> None: + """Test XY config mapping with axis configurations""" + config = XYChartConfig( + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue")], + kind="area", + x_axis=AxisConfig(title="Date", format="%Y-%m-%d"), + y_axis=AxisConfig(title="Revenue", scale="log", format="$,.2f"), + ) + + result = map_xy_config(config) + + assert result["viz_type"] == "echarts_area" + assert result["x_axis_title"] == "Date" + assert result["x_axis_format"] == "%Y-%m-%d" + assert result["y_axis_title"] == "Revenue" + assert result["y_axis_format"] == "$,.2f" + assert result["y_axis_scale"] == "log" + + def test_map_xy_config_with_legend(self) -> None: + """Test XY config mapping with legend configuration""" + config = XYChartConfig( + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue")], + kind="scatter", + legend=LegendConfig(show=False, position="top"), + ) + + result = map_xy_config(config) + + assert result["viz_type"] == "echarts_timeseries_scatter" + assert result["show_legend"] is False + assert result["legend_orientation"] == "top" + + +class TestMapConfigToFormData: + """Test map_config_to_form_data function""" + + def test_map_table_config_type(self) -> None: + """Test mapping table config type""" + config = TableChartConfig(columns=[ColumnRef(name="test")]) + result = map_config_to_form_data(config) + assert result["viz_type"] == "table" + + def test_map_xy_config_type(self) -> None: + """Test mapping XY config type""" + config = XYChartConfig( + x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="line" + ) + result = map_config_to_form_data(config) + assert result["viz_type"] == "echarts_timeseries_line" + + def test_map_unsupported_config_type(self) -> None: + """Test mapping unsupported config type raises error""" + with pytest.raises(ValueError, match="Unsupported config type"): + map_config_to_form_data("invalid_config") # type: ignore + + +class TestGenerateChartName: + """Test generate_chart_name function""" + + def test_generate_table_chart_name(self) -> None: + """Test generating name for table chart""" + config = TableChartConfig( + columns=[ + ColumnRef(name="product"), + ColumnRef(name="revenue"), + ] + ) + + result = generate_chart_name(config) + assert result == "Table Chart - product, revenue" + + def test_generate_xy_chart_name(self) -> None: + """Test generating name for XY chart""" + config = XYChartConfig( + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue"), ColumnRef(name="orders")], + kind="line", + ) + + result = generate_chart_name(config) + assert result == "Line Chart - date vs revenue, orders" + + def test_generate_chart_name_unsupported(self) -> None: + """Test generating name for unsupported config type""" + result = generate_chart_name("invalid_config") # type: ignore + assert result == "Chart" + + +class TestGenerateExploreLink: + """Test generate_explore_link function""" + + @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") + def test_generate_explore_link_uses_base_url(self, mock_get_base_url) -> None: + """Test that generate_explore_link uses the configured base URL""" + mock_get_base_url.return_value = "https://superset.example.com" + form_data = {"viz_type": "table", "metrics": ["count"]} + + result = generate_explore_link("123", form_data) + + # Should use the configured base URL + assert result.startswith("https://superset.example.com") Review Comment: ## Incomplete URL substring sanitization The string [https://superset.example.com](1) may be at an arbitrary position in the sanitized URL. [Show more details](https://github.com/apache/superset/security/code-scanning/2064) ########## superset/mcp_service/screenshot/webdriver_config.py: ########## @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +WebDriver pool configuration defaults for Superset MCP service +""" + +from typing import Any, Dict + +# Default WebDriver pool configuration +DEFAULT_WEBDRIVER_POOL_CONFIG = { + # Maximum number of WebDriver instances to keep in the pool + "MAX_POOL_SIZE": 5, + # Maximum age of a WebDriver instance (in seconds) + # After this time, the driver will be destroyed and recreated + "MAX_AGE_SECONDS": 3600, # 1 hour + # Maximum number of times a WebDriver can be reused + # After this many uses, the driver will be destroyed and recreated + "MAX_USAGE_COUNT": 50, + # How long a WebDriver can sit idle before being destroyed (in seconds) + "IDLE_TIMEOUT_SECONDS": 300, # 5 minutes + # How often to perform health checks on WebDriver instances (in seconds) + "HEALTH_CHECK_INTERVAL": 60, # 1 minute +} + + +def configure_webdriver_pool(app_config: Dict[str, Any]) -> None: + """ + Configure WebDriver pool settings in Superset app config. + + This function adds WebDriver pool configuration to the Superset app config + if it doesn't already exist, using sensible defaults. + + Args: + app_config: The Superset application configuration dictionary + """ + if "WEBDRIVER_POOL" not in app_config: + app_config["WEBDRIVER_POOL"] = DEFAULT_WEBDRIVER_POOL_CONFIG.copy() + else: + # Merge with defaults for any missing keys + for key, default_value in DEFAULT_WEBDRIVER_POOL_CONFIG.items(): + if key not in app_config["WEBDRIVER_POOL"]: + app_config["WEBDRIVER_POOL"][key] = default_value + + +def get_pool_stats_endpoint() -> Any: + """ + Create a Flask endpoint to view WebDriver pool statistics. + + This function can be called to register a debugging endpoint + that shows the current state of the WebDriver pool. + + Returns: + Flask route function for pool statistics + """ + + def pool_stats() -> Any: + try: + from flask import jsonify + + from superset.mcp_service.screenshot.webdriver_pool import ( + get_webdriver_pool, + ) + + pool = get_webdriver_pool() + stats = pool.get_stats() + + return jsonify({"webdriver_pool": stats, "status": "healthy"}) + except Exception as e: + from flask import jsonify + + return jsonify({"error": str(e), "status": "error"}), 500 Review Comment: ## Information exposure through an exception [Stack trace information](1) flows to this location and may be exposed to an external user. [Show more details](https://github.com/apache/superset/security/code-scanning/2065) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
