gopidesupavan commented on code in PR #62963: URL: https://github.com/apache/airflow/pull/62963#discussion_r3105702926
########## providers/common/ai/src/airflow/providers/common/ai/utils/db_schema.py: ########## @@ -0,0 +1,205 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Shared database hook and schema introspection utilities. + +These helpers are used by both :class:`~airflow.providers.common.ai.operators.llm_sql.LLMSQLQueryOperator` +and :class:`~airflow.providers.common.ai.operators.llm_data_quality.LLMDataQualityOperator` to +avoid code duplication while keeping both operators decoupled from each other. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from airflow.providers.common.compat.sdk import BaseHook + +if TYPE_CHECKING: + from airflow.providers.common.sql.config import DataSourceConfig + from airflow.providers.common.sql.hooks.sql import DbApiHook + +log = logging.getLogger(__name__) + +# SQLAlchemy dialect_name → sqlglot dialect mapping for names that differ. +SQLALCHEMY_TO_SQLGLOT_DIALECT: dict[str, str] = { + "postgresql": "postgres", + "mssql": "tsql", +} + + +def get_db_hook(db_conn_id: str) -> DbApiHook: + """ + Return a :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook` for *db_conn_id*. + + :param db_conn_id: Airflow connection ID that resolves to a ``DbApiHook``. + :raises ValueError: If the connection does not resolve to a ``DbApiHook``. + """ + # Lazy load to avoid hard dependency on common.sql + from airflow.providers.common.sql.hooks.sql import DbApiHook + + connection = BaseHook.get_connection(db_conn_id) + hook = connection.get_hook() + if not isinstance(hook, DbApiHook): + raise ValueError( + f"Connection {db_conn_id!r} does not provide a DbApiHook. Got {type(hook).__name__}." + ) + return hook + + +def resolve_dialect(db_hook: DbApiHook | None, explicit_dialect: str | None) -> str | None: + """ + Resolve the SQL dialect from an explicit parameter or a database hook. + + Normalises SQLAlchemy dialect names to sqlglot equivalents + (e.g. ``postgresql`` → ``postgres``). + + :param db_hook: Database hook to read ``dialect_name`` from when *explicit_dialect* is absent. + :param explicit_dialect: Caller-supplied dialect string; takes priority over the hook. + :return: Resolved dialect string, or ``None`` when neither source provides one. + """ + raw = explicit_dialect + if not raw and db_hook and hasattr(db_hook, "dialect_name"): + candidate = db_hook.dialect_name + raw = candidate if isinstance(candidate, str) else None + if raw: + return SQLALCHEMY_TO_SQLGLOT_DIALECT.get(raw, raw) + return None + + +def build_schema_context( + *, + db_hook: DbApiHook | None, + table_names: list[str] | None, + schema_context: str | None, + datasource_config: DataSourceConfig | None, +) -> str: + """ + Return a schema description string suitable for inclusion in an LLM prompt. + + Resolution order: + 1. *schema_context* — returned as-is when provided (manual override). + 2. DB introspection via *db_hook* + *table_names*. + 3. Object-storage introspection via *datasource_config*. + 4. Empty string when none of the above are available. + + :param db_hook: Hook used for relational-database schema introspection. + :param table_names: Table names to introspect via *db_hook*. + :param schema_context: Manual schema description; bypasses introspection when set. + :param datasource_config: DataFusion datasource config for object-storage schema. + :raises ValueError: If *table_names* are provided but none yield schema information. + """ + if schema_context: + return schema_context + + if (db_hook and table_names) or datasource_config: + return _introspect_schemas( + db_hook=db_hook, + table_names=table_names, + datasource_config=datasource_config, + ) + + return "" + + +def _introspect_schemas( + *, + db_hook: DbApiHook | None, + table_names: list[str] | None, + datasource_config: DataSourceConfig | None, +) -> str: + """Build schema context by introspecting tables and/or object-storage sources.""" + parts: list[str] = [] + table_to_columns: dict[str, list[dict[str, str]]] = {} + + if table_names and db_hook is None: + raise ValueError("table_names requires db_conn_id so table schema can be introspected.") Review Comment: lets validate this early in the operator init, instead of validating after task run in execute. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
