richardfogaca commented on code in PR #40344: URL: https://github.com/apache/superset/pull/40344#discussion_r3320155795
########## superset/mcp_service/action_log/schemas.py: ########## @@ -0,0 +1,305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pydantic schemas for action-log MCP tools.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Annotated, Any, Literal + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_serializer, + model_validator, + PositiveInt, +) + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE +from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils import sanitize_for_llm_context +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_model_list, +) + +DEFAULT_LOG_COLUMNS: list[str] = ["id", "action", "user_id", "dttm"] +ALL_LOG_COLUMNS: list[str] = [ + "id", + "action", + "user_id", + "dttm", + "dashboard_id", + "slice_id", + "json", +] +LOG_SORTABLE_COLUMNS: list[str] = ["id", "dttm"] + + +class ActionLogFilter(ColumnOperator): + """Filter object for action-log listing. + + col: Column to filter on. + opr: Operator to use. + value: Value to filter by. + """ + + col: Literal["action", "user_id", "dashboard_id", "slice_id", "dttm"] = Field( + ..., + description="Column to filter on.", + ) + opr: ColumnOperatorEnum = Field(..., description="Operator to use.") + value: ( + str | int | float | bool | datetime | list[str | int | float | bool | datetime] + ) = Field(..., description="Value to filter by") + + @model_validator(mode="after") + def normalize_dttm_value(self) -> "ActionLogFilter": + """Normalize string dttm values to datetime to avoid VARCHAR bind mismatch. + + Pydantic's left-to-right union matching keeps ISO strings as str when + str appears before datetime in the union. This validator parses them so + the DAO always receives a typed datetime for TIMESTAMP column comparisons. + Both scalar and list values are normalized so dttm IN (...) is also safe. + + Replaces a trailing 'Z' with '+00:00' before parsing because + datetime.fromisoformat does not accept the 'Z' suffix on Python < 3.11. + """ + + def _parse(val: str) -> datetime | str: + try: + s = val[:-1] + "+00:00" if val.endswith("Z") else val + parsed = datetime.fromisoformat(s) + return parsed if parsed.tzinfo else parsed.replace(tzinfo=timezone.utc) + except ValueError: + return val + + if self.col == "dttm": + if isinstance(self.value, str): + self.value = _parse(self.value) + elif isinstance(self.value, list): + self.value = [ + _parse(v) if isinstance(v, str) else v for v in self.value + ] + return self + + +class ActionLogInfo(BaseModel): + id: int | None = Field(None, description="Log entry ID") + action: str | None = Field(None, description="Action name") + user_id: int | None = Field( + None, description="ID of the user who performed the action" + ) + dttm: str | datetime | None = Field(None, description="Timestamp of the action") + dashboard_id: int | None = Field(None, description="Associated dashboard ID") + slice_id: int | None = Field(None, description="Associated chart/slice ID") + json: Any = Field( + None, description="JSON payload of the action (user-controlled, sanitized)" + ) + + model_config = ConfigDict( + from_attributes=True, + ser_json_timedelta="iso8601", + populate_by_name=True, + ) + + def model_post_init(self, __context: Any) -> None: + if isinstance(self.dttm, datetime) and self.dttm.tzinfo is None: + from datetime import timezone + + object.__setattr__(self, "dttm", self.dttm.replace(tzinfo=timezone.utc)) + + @model_serializer(mode="wrap") + def _filter_fields_by_context(self, serializer: Any, info: Any) -> dict[str, Any]: + data = serializer(self) + if info.context and isinstance(info.context, dict): + select_columns = info.context.get("select_columns") + if select_columns: + requested_fields = set(select_columns) + return {k: v for k, v in data.items() if k in requested_fields} + return data + + +class ActionLogList(BaseModel): + action_logs: list[ActionLogInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: list[str] = Field(default_factory=list) + columns_loaded: list[str] = Field(default_factory=list) + columns_available: list[str] = Field(default_factory=list) + sortable_columns: list[str] = Field(default_factory=list) + filters_applied: list[ActionLogFilter] = Field(default_factory=list) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListActionLogsRequest(BaseModel): + """Request schema for list_action_logs.""" + + filters: Annotated[ + list[ActionLogFilter], + Field( + default_factory=list, + description=( + "List of filter objects (col, opr, value). " + "Filter columns: action, user_id, dashboard_id, slice_id, dttm. " + "Cannot be used with 'search'." + ), + ), + ] + select_columns: Annotated[ + list[str], + Field( + default_factory=list, + description="Columns to return. Defaults to common columns.", + ), + ] + search: Annotated[ + str | None, + Field( + default=None, + description=( + "Text search string matched against action. " + "Cannot be used together with 'filters'." + ), + ), + ] + order_column: Annotated[ + str | None, + Field(default=None, description="Column to sort by (default: dttm)"), + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field(default="desc", description="Sort direction ('asc' or 'desc')"), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number (1-based)"), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Items per page (max {MAX_PAGE_SIZE})", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> list[ActionLogFilter]: + return parse_json_or_model_list(v, ActionLogFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> list[str]: + return parse_json_or_list(v, "select_columns") + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListActionLogsRequest": + if self.search and self.filters: + raise ValueError( + "Cannot use both 'search' and 'filters' simultaneously. " + "Use 'search' for text matching on action, or 'filters' for " + "column-based filtering, but not both." + ) + return self + + +class ActionLogError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Error type") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "ActionLogError": + from datetime import timezone + + return cls( + error=error, + error_type=error_type, + timestamp=datetime.now(timezone.utc), + ) + + +class GetActionLogInfoRequest(BaseModel): + """Request schema for get_action_log_info (ID-only lookup).""" + + identifier: Annotated[ + int, + Field(description="Log entry ID (integer)"), + ] + + +def _sanitize_log_json(raw: Any) -> Any: + """Parse the stored log JSON string and sanitize string leaves. + + Preserves the JSON shape so callers can inspect individual fields; wraps + every string leaf in UNTRUSTED-CONTENT delimiters so the payload cannot + inject instructions into the LLM context. Falls back to sanitizing the + raw string when it is not valid JSON. + + Passes excluded_field_names=frozenset() so that no field name is exempted + from wrapping — the entire blob is user-controlled and must be treated as + untrusted, including fields like 'url', 'schema', and 'uuid' that the + default exclusion list would otherwise only escape rather than wrap. + """ + if raw is None: + return None + if isinstance(raw, str): + try: + from superset.utils import json as json_utils # noqa: PLC0415 + + parsed = json_utils.loads(raw) + except (ValueError, TypeError): + parsed = raw + else: + parsed = raw + return sanitize_for_llm_context( Review Comment: `_sanitize_log_json()` parses the stored log JSON and passes the parsed object to `sanitize_for_llm_context()`, but that helper only wraps string values. Dict keys are only delimiter-escaped, so a payload like `{"ignore previous instructions": "..."}` still leaves the key as raw user-controlled text in the MCP response. Since this code explicitly treats the whole log JSON blob as untrusted, including field names, this looks like an LLM-context injection gap. WDYT — could we also wrap/sanitize JSON object keys for this action-log path, plus add a regression test with a malicious JSON key? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
