This is an automated email from the ASF dual-hosted git repository. msyavuz pushed a commit to branch msyavuz/feat/confidence-check in repository https://gitbox.apache.org/repos/asf/superset.git
commit 7af529969424c79664c82c4d5a9b8025d05a9eb3 Author: Mehmet Salih Yavuz <[email protected]> AuthorDate: Thu Dec 18 16:01:29 2025 +0300 feat: confidence check --- superset/commands/database_analyzer/analyze.py | 139 ++++++++++-- superset/commands/database_analyzer/llm_service.py | 242 +++++++++++++++++---- superset/databases/analyzer_api.py | 9 + ...e4e65062d_add_confidence_fields_to_database_.py | 56 +++++ superset/models/database_analyzer.py | 4 + superset/tasks/database_analyzer.py | 2 + 6 files changed, 400 insertions(+), 52 deletions(-) diff --git a/superset/commands/database_analyzer/analyze.py b/superset/commands/database_analyzer/analyze.py index 67f99eed84..f533f50224 100644 --- a/superset/commands/database_analyzer/analyze.py +++ b/superset/commands/database_analyzer/analyze.py @@ -30,10 +30,8 @@ from superset.models.core import Database from superset.models.database_analyzer import ( AnalyzedColumn, AnalyzedTable, - Cardinality, DatabaseSchemaReport, InferredJoin, - JoinType, TableType, ) from superset.utils import json @@ -70,9 +68,13 @@ class AnalyzeDatabaseSchemaCommand(BaseCommand): # Infer joins using AI self._infer_joins_with_ai() + # Validate the analysis confidence + self._validate_analysis_confidence() + return { "tables_count": len(self.report.tables), "joins_count": len(self.report.joins), + "confidence_score": self.report.confidence_score, } def validate(self) -> None: @@ -178,17 +180,25 @@ class AnalyzeDatabaseSchemaCommand(BaseCommand): result = engine.execute(text(sample_sql)) for row in result: sample_rows.append(dict(row)) - logger.debug("Fetched %d sample rows from %s", len(sample_rows), table_name) - except Exception as e: + logger.debug( + "Fetched %d sample rows from %s", len(sample_rows), table_name + ) + except Exception: # Fallback to regular LIMIT if RANDOM() not supported try: fallback_sql = f'SELECT * FROM "{schema}"."{table_name}" LIMIT 3' # noqa: S608, E501 result = engine.execute(text(fallback_sql)) for row in result: sample_rows.append(dict(row)) - logger.debug("Fetched %d sample rows from %s (fallback)", len(sample_rows), table_name) + logger.debug( + "Fetched %d sample rows from %s (fallback)", + len(sample_rows), + table_name, + ) except Exception as e2: - logger.warning("Could not fetch sample data for %s: %s", table_name, str(e2)) + logger.warning( + "Could not fetch sample data for %s: %s", table_name, str(e2) + ) # Get row count (try reltuples first, fallback to actual count) row_count = None @@ -201,7 +211,7 @@ class AnalyzeDatabaseSchemaCommand(BaseCommand): result = engine.execute(text(count_sql)) row = result.fetchone() row_count = row[0] if row and row[0] >= 0 else None - + # If reltuples is -1 or None, get actual count for small tables if row_count is None or row_count < 0: actual_count_sql = f'SELECT COUNT(*) FROM "{schema}"."{table_name}"' # noqa: S608 @@ -296,7 +306,7 @@ class AnalyzeDatabaseSchemaCommand(BaseCommand): return max_workers = min(10, len(tables)) - + # Capture the current Flask app context app = current_app._get_current_object() @@ -317,7 +327,7 @@ class AnalyzeDatabaseSchemaCommand(BaseCommand): str(e), ) - def _augment_table_with_ai_context(self, app, table: AnalyzedTable) -> None: + def _augment_table_with_ai_context(self, app: Any, table: AnalyzedTable) -> None: """Wrapper to provide Flask context to the AI description thread""" with app.app_context(): self._augment_table_with_ai(table) @@ -458,10 +468,10 @@ class AnalyzeDatabaseSchemaCommand(BaseCommand): # Debug logging to see actual data being stored for i, join_data in enumerate(inferred_joins): logger.debug( - "Join %d data: join_type=%s, cardinality=%s", - i, - join_data.get("join_type"), - join_data.get("cardinality") + "Join %d data: join_type=%s, cardinality=%s", + i, + join_data.get("join_type"), + join_data.get("cardinality"), ) for join_data in inferred_joins: @@ -495,3 +505,106 @@ class AnalyzeDatabaseSchemaCommand(BaseCommand): db.session.add(join) db.session.commit() # pylint: disable=consider-using-transaction + + def _validate_analysis_confidence(self) -> None: + """Use another LLM to validate the confidence of the analysis""" + assert self.report is not None + + logger.info("Validating analysis confidence using LLM") + + if not self.llm_service.is_available(): + logger.warning("LLM service not available, skipping confidence validation") + return + + try: + # Prepare data for validation + tables_data = [] + for table in self.report.tables: + extra_json = json.loads(table.extra_json or "{}") + tables_data.append( + { + "name": table.table_name, + "type": table.table_type.value if table.table_type else "table", + "columns_count": len(table.columns), + "row_count": extra_json.get("row_count_estimate"), + "ai_description": table.ai_description, + "has_description": bool( + table.ai_description or table.db_comment + ), + } + ) + + joins_data = [] + for join in self.report.joins: + extra_json = json.loads(join.extra_json or "{}") + joins_data.append( + { + "source_table": join.source_table.table_name, + "source_columns": json.loads(join.source_columns), + "target_table": join.target_table.table_name, + "target_columns": json.loads(join.target_columns), + "join_type": join.join_type.value + if join.join_type + else "inner", + "cardinality": join.cardinality.value + if join.cardinality + else "N:1", + "confidence_score": extra_json.get("confidence_score", 0.5), + "semantic_context": join.semantic_context, + } + ) + + # Collect all AI descriptions + ai_descriptions = { + "tables": { + table.table_name: table.ai_description + for table in self.report.tables + if table.ai_description + }, + "columns": {}, + } + for table in self.report.tables: + for col in table.columns: + if col.ai_description: + key = f"{table.table_name}.{col.column_name}" + ai_descriptions["columns"][key] = col.ai_description + + # Call validation service + validation_result = self.llm_service.validate_analysis_confidence( + schema_name=self.report.schema_name, + tables=tables_data, + joins=joins_data, + ai_descriptions=ai_descriptions, + ) + + # Store validation results + self.report.confidence_score = validation_result.get( + "overall_confidence", 0.5 + ) + self.report.confidence_breakdown = json.dumps( + validation_result.get("confidence_breakdown", {}) + ) + + # Combine recommendations and potential issues + all_recommendations = [] + all_recommendations.extend(validation_result.get("recommendations", [])) + all_recommendations.extend(validation_result.get("potential_issues", [])) + + self.report.confidence_recommendations = json.dumps(all_recommendations) + self.report.confidence_validation_notes = validation_result.get( + "validation_notes", "" + ) + + db.session.commit() # pylint: disable=consider-using-transaction + + logger.info( + "Confidence validation complete. Score: %.2f", + self.report.confidence_score or 0.5, + ) + + except Exception as e: + logger.error("Error validating analysis confidence: %s", str(e)) + # Set default confidence if validation fails + self.report.confidence_score = 0.5 + self.report.confidence_validation_notes = f"Validation error: {str(e)}" + db.session.commit() # pylint: disable=consider-using-transaction diff --git a/superset/commands/database_analyzer/llm_service.py b/superset/commands/database_analyzer/llm_service.py index a491928763..182b6328a5 100644 --- a/superset/commands/database_analyzer/llm_service.py +++ b/superset/commands/database_analyzer/llm_service.py @@ -31,14 +31,29 @@ class LLMService: def __init__(self) -> None: import os + # Try environment variables first, then fall back to config - self.api_key = os.environ.get("SUPERSET_LLM_API_KEY") or current_app.config.get("LLM_API_KEY") - self.model = os.environ.get("SUPERSET_LLM_MODEL") or current_app.config.get("LLM_MODEL", "gpt-4o") - self.temperature = float(os.environ.get("SUPERSET_LLM_TEMPERATURE", current_app.config.get("LLM_TEMPERATURE", 0.3))) - self.max_tokens = int(os.environ.get("SUPERSET_LLM_MAX_TOKENS", current_app.config.get("LLM_MAX_TOKENS", 4096))) - self.base_url = os.environ.get("SUPERSET_LLM_BASE_URL") or current_app.config.get( - "LLM_BASE_URL", "https://api.openai.com/v1" + self.api_key = os.environ.get("SUPERSET_LLM_API_KEY") or current_app.config.get( + "LLM_API_KEY" + ) + self.model = os.environ.get("SUPERSET_LLM_MODEL") or current_app.config.get( + "LLM_MODEL", "gpt-4o" + ) + self.temperature = float( + os.environ.get( + "SUPERSET_LLM_TEMPERATURE", + current_app.config.get("LLM_TEMPERATURE", 0.3), + ) ) + self.max_tokens = int( + os.environ.get( + "SUPERSET_LLM_MAX_TOKENS", + current_app.config.get("LLM_MAX_TOKENS", 4096), + ) + ) + self.base_url = os.environ.get( + "SUPERSET_LLM_BASE_URL" + ) or current_app.config.get("LLM_BASE_URL", "https://api.openai.com/v1") def is_available(self) -> bool: """Check if LLM service is configured and available""" @@ -218,52 +233,59 @@ Return the response as JSON array: def _call_llm(self, prompt: str) -> str: """Call the LLM API with the given prompt""" import requests - + if not self.api_key: logger.warning("No API key configured for LLM service") return json.dumps({}) - + headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } - + # For OpenRouter, we might need additional headers if "openrouter" in self.base_url.lower(): - headers["HTTP-Referer"] = "http://localhost:8088" # Optional but recommended + headers["HTTP-Referer"] = ( + "http://localhost:8088" # Optional but recommended + ) headers["X-Title"] = "Superset Database Analyzer" # Optional - + data = { "model": self.model, "messages": [ { "role": "system", - "content": "You are a database expert helping to document and understand database schemas. Respond only with valid JSON as requested." + "content": ( + "You are a database expert helping to document and " + "understand database schemas. Respond only with valid " + "JSON as requested." + ), }, - { - "role": "user", - "content": prompt - } + {"role": "user", "content": prompt}, ], "temperature": self.temperature, "max_tokens": self.max_tokens, } - + try: response = requests.post( f"{self.base_url}/chat/completions", headers=headers, json=data, - timeout=60 + timeout=60, ) - + if response.status_code != 200: - logger.error(f"LLM API error: {response.status_code} - {response.text}") + logger.error( + "LLM API error: %s - %s", response.status_code, response.text + ) return json.dumps({}) - + result = response.json() - content = result.get("choices", [{}])[0].get("message", {}).get("content", "") - + content = ( + result.get("choices", [{}])[0].get("message", {}).get("content", "") + ) + # Try to extract JSON from the response if content: # Clean up the response - sometimes LLMs add markdown formatting @@ -275,14 +297,14 @@ Return the response as JSON array: if content.endswith("```"): content = content[:-3] content = content.strip() - + return content - + except requests.exceptions.RequestException as e: - logger.error(f"Error calling LLM API: {e}") + logger.error("Error calling LLM API: %s", e) return json.dumps({}) except Exception as e: - logger.error(f"Unexpected error in LLM call: {e}") + logger.error("Unexpected error in LLM call: %s", e) return json.dumps({}) def _parse_table_description_response(self, response: str) -> dict[str, Any]: @@ -309,17 +331,17 @@ Return the response as JSON array: valid_joins = [] for i, join in enumerate(joins): logger.debug( - "Raw join %d: join_type=%s, cardinality=%s", - i, - join.get("join_type"), - join.get("cardinality") + "Raw join %d: join_type=%s, cardinality=%s", + i, + join.get("join_type"), + join.get("cardinality"), ) if self._validate_join(join): logger.debug( - "Validated join %d: join_type=%s, cardinality=%s", - i, - join.get("join_type"), - join.get("cardinality") + "Validated join %d: join_type=%s, cardinality=%s", + i, + join.get("join_type"), + join.get("cardinality"), ) valid_joins.append(join) else: @@ -353,18 +375,18 @@ Return the response as JSON array: # Normalize join_type to lowercase if "join_type" in join: join["join_type"] = str(join["join_type"]).lower() - + # Normalize cardinality to use enum values if "cardinality" in join: cardinality_map = { "ONE_TO_ONE": "1:1", "1:1": "1:1", - "ONE_TO_MANY": "1:N", + "ONE_TO_MANY": "1:N", "1:N": "1:N", "MANY_TO_ONE": "N:1", - "N:1": "N:1", + "N:1": "N:1", "MANY_TO_MANY": "N:M", - "N:M": "N:M" + "N:M": "N:M", } raw_cardinality = str(join["cardinality"]).upper() join["cardinality"] = cardinality_map.get(raw_cardinality, "N:1") @@ -376,3 +398,145 @@ Return the response as JSON array: join.setdefault("suggested_by", "ai_inference") return True + + def validate_analysis_confidence( + self, + schema_name: str, + tables: list[dict[str, Any]], + joins: list[dict[str, Any]], + ai_descriptions: dict[str, Any], + ) -> dict[str, Any]: + """ + Validate the confidence of the analysis using another LLM. + + :param schema_name: Name of the schema analyzed + :param tables: List of analyzed tables with metadata + :param joins: List of inferred joins + :param ai_descriptions: AI-generated descriptions + :return: Dict with confidence scores and recommendations + """ + if not self.is_available(): + return { + "overall_confidence": 0.5, + "confidence_breakdown": {}, + "recommendations": [], + "validation_notes": "LLM validation not available", + } + + prompt = self._build_confidence_validation_prompt( + schema_name, tables, joins, ai_descriptions + ) + + try: + response = self._call_llm(prompt) + return self._parse_confidence_validation_response(response) + except Exception as e: + logger.error("Error calling LLM for confidence validation: %s", str(e)) + return { + "overall_confidence": 0.5, + "confidence_breakdown": {}, + "recommendations": [], + "validation_notes": f"Validation failed: {str(e)}", + } + + def _build_confidence_validation_prompt( + self, + schema_name: str, + tables: list[dict[str, Any]], + joins: list[dict[str, Any]], + ai_descriptions: dict[str, Any], + ) -> str: + """Build prompt for validating analysis confidence""" + prompt = ( + "You are a database analysis quality auditor. Review the following " + "database schema analysis and provide confidence scores.\n\n" + f"Schema: {schema_name}\n" + f"Number of tables: {len(tables)}\n" + f"Number of inferred joins: {len(joins)}\n\n" + "Analysis Summary:\n" + ) + + # Add table information + prompt += "\nTables analyzed:\n" + for table in tables[:10]: # Limit to first 10 tables for context + prompt += f"- {table.get('name', 'Unknown')}: " + prompt += f"{table.get('columns_count', 0)} columns, " + prompt += f"{table.get('row_count', 'unknown')} rows\n" + if table.get("ai_description"): + prompt += f" Description: {table['ai_description'][:100]}...\n" + + # Add join information + if joins: + prompt += f"\nInferred joins ({len(joins)} total):\n" + for join in joins[:5]: # Show first 5 joins + prompt += ( + f"- {join.get('source_table')}.{join.get('source_columns')} -> " + ) + prompt += f"{join.get('target_table')}.{join.get('target_columns')}\n" + prompt += f" Type: {join.get('join_type', 'unknown')}, " + prompt += f"Cardinality: {join.get('cardinality', 'unknown')}, " + prompt += f"Confidence: {join.get('confidence_score', 0)}\n" + + prompt += """ +Please evaluate the quality and completeness of this analysis and provide: + +1. overall_confidence: A score from 0.0 to 1.0 indicating overall confidence +2. confidence_breakdown: Individual confidence scores for different aspects: + - table_descriptions: How accurate/complete are the table descriptions + - column_descriptions: How accurate/complete are the column descriptions + - join_inference: How accurate are the inferred joins + - schema_coverage: How complete is the schema analysis +3. recommendations: List of specific recommendations to improve the analysis +4. potential_issues: Any red flags or concerns noticed +5. validation_notes: Additional context about the validation + +Consider factors like: +- Consistency of descriptions +- Plausibility of inferred relationships +- Completeness of metadata +- Semantic accuracy + +Return the response as JSON: +{ + "overall_confidence": 0.75, + "confidence_breakdown": { + "table_descriptions": 0.8, + "column_descriptions": 0.7, + "join_inference": 0.65, + "schema_coverage": 0.9 + }, + "recommendations": [ + "Review join between X and Y - seems unlikely", + "Table Z description needs more detail" + ], + "potential_issues": [ + "Missing relationships between related tables", + "Some descriptions are too generic" + ], + "validation_notes": "Analysis appears mostly complete with minor gaps" +} +""" + return prompt + + def _parse_confidence_validation_response(self, response: str) -> dict[str, Any]: + """Parse the LLM response for confidence validation""" + try: + result = json.loads(response) + + # Ensure all required fields exist with defaults + return { + "overall_confidence": float(result.get("overall_confidence", 0.5)), + "confidence_breakdown": result.get("confidence_breakdown", {}), + "recommendations": result.get("recommendations", []), + "potential_issues": result.get("potential_issues", []), + "validation_notes": result.get("validation_notes", ""), + } + except (json.JSONDecodeError, ValueError) as e: + logger.error("Failed to parse confidence validation response: %s", str(e)) + return { + "overall_confidence": 0.5, + "confidence_breakdown": {}, + "recommendations": [], + "potential_issues": [], + "validation_notes": "Failed to parse validation response", + } diff --git a/superset/databases/analyzer_api.py b/superset/databases/analyzer_api.py index 75e04d2fd2..3b939e431f 100644 --- a/superset/databases/analyzer_api.py +++ b/superset/databases/analyzer_api.py @@ -30,6 +30,7 @@ from superset.tasks.database_analyzer import ( check_analysis_status, kickstart_analysis, ) +from superset.utils import json from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics logger = logging.getLogger(__name__) @@ -76,6 +77,8 @@ class CheckStatusResponseSchema(Schema): error_message = fields.String(allow_none=True) tables_count = fields.Integer(allow_none=True) joins_count = fields.Integer(allow_none=True) + confidence_score = fields.Float(allow_none=True) + confidence_validation_notes = fields.String(allow_none=True) class DatasourceAnalyzerRestApi(BaseSupersetApi): @@ -257,6 +260,12 @@ class DatasourceAnalyzerRestApi(BaseSupersetApi): "created_at": report.created_on.isoformat() if report.created_on else None, + "confidence_score": report.confidence_score, + "confidence_breakdown": json.loads(report.confidence_breakdown or "{}"), + "confidence_recommendations": json.loads( + report.confidence_recommendations or "[]" + ), + "confidence_validation_notes": report.confidence_validation_notes, "tables": [], "joins": [], } diff --git a/superset/migrations/versions/2025-12-18_14-27_45be4e65062d_add_confidence_fields_to_database_.py b/superset/migrations/versions/2025-12-18_14-27_45be4e65062d_add_confidence_fields_to_database_.py new file mode 100644 index 0000000000..fa619de243 --- /dev/null +++ b/superset/migrations/versions/2025-12-18_14-27_45be4e65062d_add_confidence_fields_to_database_.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add_confidence_fields_to_database_schema_report + +Revision ID: 45be4e65062d +Revises: 4a032c8dbc11 +Create Date: 2025-12-18 14:27:18.506951 + +""" + +# revision identifiers, used by Alembic. +revision = "45be4e65062d" +down_revision = "4a032c8dbc11" + +import sqlalchemy as sa +from alembic import op + + +def upgrade(): + op.add_column( + "database_schema_report", + sa.Column("confidence_score", sa.Float(), nullable=True), + ) + op.add_column( + "database_schema_report", + sa.Column("confidence_breakdown", sa.Text(), nullable=True), + ) + op.add_column( + "database_schema_report", + sa.Column("confidence_recommendations", sa.Text(), nullable=True), + ) + op.add_column( + "database_schema_report", + sa.Column("confidence_validation_notes", sa.Text(), nullable=True), + ) + + +def downgrade(): + op.drop_column("database_schema_report", "confidence_validation_notes") + op.drop_column("database_schema_report", "confidence_recommendations") + op.drop_column("database_schema_report", "confidence_breakdown") + op.drop_column("database_schema_report", "confidence_score") diff --git a/superset/models/database_analyzer.py b/superset/models/database_analyzer.py index a3a1df6724..365aeaa080 100644 --- a/superset/models/database_analyzer.py +++ b/superset/models/database_analyzer.py @@ -57,6 +57,10 @@ class DatabaseSchemaReport(Model, AuditMixinNullable, UUIDMixin): start_dttm = sa.Column(sa.DateTime, nullable=True) end_dttm = sa.Column(sa.DateTime, nullable=True) error_message = sa.Column(sa.Text, nullable=True) + confidence_score = sa.Column(sa.Float, nullable=True) + confidence_breakdown = sa.Column(sa.Text, nullable=True) # JSON dict + confidence_recommendations = sa.Column(sa.Text, nullable=True) # JSON array + confidence_validation_notes = sa.Column(sa.Text, nullable=True) extra_json = sa.Column(sa.Text, nullable=True) # Relationships diff --git a/superset/tasks/database_analyzer.py b/superset/tasks/database_analyzer.py index a4abab620f..1964d37e83 100644 --- a/superset/tasks/database_analyzer.py +++ b/superset/tasks/database_analyzer.py @@ -227,6 +227,8 @@ def check_analysis_status(run_id: str) -> dict[str, Any]: ) result["tables_count"] = len(report.tables) result["joins_count"] = len(report.joins) + result["confidence_score"] = report.confidence_score + result["confidence_validation_notes"] = report.confidence_validation_notes elif report.status == AnalysisStatus.FAILED: result["error_message"] = report.error_message
