Copilot commented on code in PR #62963:
URL: https://github.com/apache/airflow/pull/62963#discussion_r3030608302


##########
providers/common/ai/src/airflow/providers/common/ai/utils/dq_models.py:
##########
@@ -0,0 +1,227 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Pydantic models for the LLM data quality plan and result reporting.
+
+``DQPlan`` is the structured output type requested from the LLM.  The remaining
+dataclasses hold execution results and are never serialised back to the model.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import Any
+
+from pydantic import BaseModel, Field, computed_field
+
+
+class DQCheck(BaseModel):
+    """
+    A single data-quality check produced by the LLM.
+
+    :param check_name: Matches the key supplied by the user in ``prompts``.
+    :param metric_key: Column alias used in the generated SQL (e.g. 
``null_email_count``).
+        The operator reads ``row[metric_key]`` from the query result.
+    :param group_id: Logical bucket for grouping checks into a single SQL query
+        (e.g. ``customers_null_check_1``, ``orders_validity_1``).
+    :param check_category: Semantic category assigned by the LLM based on the 
check
+        description.  Used to sub-group checks within a table so that checks of
+        different natures (e.g. null-checks vs. regex validity) land in 
separate
+        SQL queries.  Allowed values: ``null_check``, ``uniqueness``, 
``validity``,
+        ``numeric_range``, ``row_count``, ``string_format``.

Review Comment:
   The `DQCheck` class docstring’s `:param check_category:` section lists 
allowed values but omits the newly supported `row_level` category (even though 
the Field description includes it and the planner prompt uses it). Please 
update the docstring list to include `row_level` to avoid misleading API/docs 
for consumers and LLM prompt authors.
   ```suggestion
           ``numeric_range``, ``row_count``, ``string_format``, ``row_level``.
   ```



##########
providers/common/ai/src/airflow/providers/common/ai/operators/llm_data_quality.py:
##########
@@ -0,0 +1,616 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Operator for generating and executing data-quality checks from natural 
language using LLMs."""
+
+from __future__ import annotations
+
+import hashlib
+import json
+from collections.abc import Callable, Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.ai.utils.db_schema import get_db_hook
+from airflow.providers.common.ai.utils.dq_models import (
+    DQCheckGroup,
+    DQCheckResult,
+    DQPlan,
+    DQReport,
+    RowLevelResult,
+    UnexpectedResult,
+)
+from airflow.providers.common.ai.utils.dq_validation import default_registry
+from airflow.providers.common.compat.sdk import AirflowException, Variable
+
+try:
+    from airflow.providers.common.ai.utils.dq_planner import SQLDQPlanner
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+if TYPE_CHECKING:
+    from airflow.providers.common.sql.config import DataSourceConfig
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+    from airflow.sdk import Context
+
+_PLAN_VARIABLE_PREFIX = "dq_plan_"
+_PLAN_VARIABLE_KEY_MAX_LEN = 200  # stay well under Airflow Variable key 
length limit
+
+
+class LLMDataQualityOperator(LLMOperator):
+    """
+    Generate and execute data-quality checks from natural language 
descriptions.
+
+    Each entry in ``prompts`` describes **one** data-quality expectation.
+    The LLM groups related checks into optimised SQL queries, executes them
+    against the target database, and validates each metric against the
+    corresponding entry in ``validators``.  The task fails if any check
+    does not pass, gating downstream tasks on data quality.
+
+    Generated SQL plans are cached in Airflow
+    :class:`~airflow.models.variable.Variable` to avoid repeat LLM calls.
+    Set ``dry_run=True`` to preview the plan without executing it — the
+    serialised plan dict is returned without running any SQL.
+    Set ``require_approval=True`` to gate execution on human review via the
+    HITL interface: the plan is presented to the reviewer first, and SQL
+    checks run only after approval.  ``dry_run`` and ``require_approval``
+    are independent — enabling both returns the plan dict without any
+    approval prompt.
+
+    :param prompts: Mapping of ``{check_name: natural_language_description}``.
+        Each key must be unique.  Use one check per key; the operator enforces
+        a strict one-key → one-check mapping.
+    :param llm_conn_id: Connection ID for the LLM provider.
+    :param model_id: Model identifier (e.g. ``"openai:gpt-4o"``).
+        Overrides the model stored in the connection's extra field.
+    :param system_prompt: Additional instructions appended to the planning 
prompt.
+    :param agent_params: Additional keyword arguments passed to the pydantic-ai
+        ``Agent`` constructor (e.g. ``retries``, ``model_settings``).
+    :param db_conn_id: Connection ID for the database to run checks against.
+        Must resolve to a 
:class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`.
+    :param table_names: Tables to include in the LLM's schema context.
+    :param schema_context: Manual schema description; bypasses DB 
introspection.
+    :param validators: Mapping of ``{check_name: callable}`` where each 
callable
+        receives the raw metric value and returns ``True`` (pass) or ``False`` 
(fail).
+        Keys must be a subset of ``prompts.keys()``.
+        Use built-in factories from
+        :mod:`~airflow.providers.common.ai.utils.dq_validation` or plain 
lambdas::
+
+            from airflow.providers.common.ai.utils.dq_validation import 
null_pct_check
+
+            validators = {
+                "email_nulls": null_pct_check(max_pct=0.05),
+                "row_check": lambda v: v >= 1000,
+            }
+
+    :param dialect: SQL dialect override (``postgres``, ``mysql``, etc.).
+        Auto-detected from *db_conn_id* when not set.
+    :param datasource_config: DataFusion datasource for object-storage schema.
+    :param dry_run: When ``True``, generate and cache the plan but skip 
execution.
+        Returns the serialised plan dict instead of a 
:class:`~airflow.providers.common.ai.utils.dq_models.DQReport`.
+    :param prompt_version: Optional version tag included in the plan cache key.
+        Bump this to invalidate cached plans when prompts change semantically
+        without changing their text.
+    :param collect_unexpected: When ``True``, the LLM generates an
+        ``unexpected_query`` for validity / string-format checks.
+        If any of those checks fail, the unexpected query is executed and
+        the resulting sample rows are included in the report.
+    :param unexpected_sample_size: Maximum number of violating rows to return
+        per failed check.  Default ``100``.
+    :param row_level_sample_size: Maximum number of rows to fetch per row-level
+        check.  ``None`` (default) performs a full table scan — every row is
+        fetched and validated.  A positive integer is passed to the LLM as a
+        ``LIMIT`` clause on the generated SELECT, bounding execution time and
+        memory usage at the cost of sampling coverage.
+    :param require_approval: When ``True``, the operator defers after 
generating
+        and caching the DQ plan.  The plan SQL is surfaced in the HITL 
interface
+        for human review; checks run only after the reviewer approves.  
Inherited
+        from :class:`~airflow.providers.common.ai.operators.llm.LLMOperator`.
+        ``dry_run=True`` takes precedence — combining both flags returns the 
plan
+        dict immediately without requesting approval.
+    """
+
+    template_fields: Sequence[str] = (
+        *LLMOperator.template_fields,
+        "prompts",
+        "db_conn_id",
+        "table_names",
+        "schema_context",
+        "prompt_version",
+        "collect_unexpected",
+        "unexpected_sample_size",
+        "row_level_sample_size",
+    )
+
+    def __init__(
+        self,
+        *,
+        prompts: dict[str, str],
+        db_conn_id: str | None = None,
+        table_names: list[str] | None = None,
+        schema_context: str | None = None,
+        validators: dict[str, Callable[[Any], bool]] | None = None,
+        dialect: str | None = None,
+        datasource_config: DataSourceConfig | None = None,
+        prompt_version: str | None = None,
+        dry_run: bool = False,
+        collect_unexpected: bool = False,
+        unexpected_sample_size: int = 100,
+        row_level_sample_size: int | None = None,
+        **kwargs: Any,
+    ) -> None:
+        kwargs.pop("output_type", None)
+        kwargs.setdefault("prompt", "LLMDataQualityOperator")
+        super().__init__(**kwargs)
+
+        self.prompts = prompts
+        self.db_conn_id = db_conn_id
+        self.table_names = table_names
+        self.schema_context = schema_context
+        self.validators = validators or {}
+        self.dialect = dialect
+        self.datasource_config = datasource_config
+        self.prompt_version = prompt_version
+        self.dq_dry_run = dry_run
+        self.collect_unexpected = collect_unexpected
+        self.unexpected_sample_size = unexpected_sample_size
+        self.row_level_sample_size = row_level_sample_size
+
+        self._validate_prompts()
+        self._validate_validator_keys()
+
+    def execute(self, context: Context) -> dict[str, Any]:
+        """
+        Generate the DQ plan (or load from cache), then execute or defer for 
approval.
+
+        When ``dry_run=True`` the serialised plan dict is returned immediately 
—
+        no SQL is executed and no approval is requested.
+        When ``require_approval=True`` the task defers, presenting the plan to 
a
+        human reviewer; data-quality checks run only after the reviewer 
approves.
+
+        :returns: Dict with keys ``plan``, ``passed``, and ``results``.  On 
success
+            ``passed=True`` and ``results`` is a list of per-check result 
dicts.
+            For row-level checks the ``value`` entry in each result dict is 
itself
+            a dict with keys ``total``, ``invalid``, ``invalid_pct``, and
+            ``sample_violations`` rather than a raw scalar.
+            When ``dry_run=True`` ``passed=None`` and ``results=None`` — no SQL
+            is executed.  The ``plan`` key is always present in all modes.
+        :raises AirflowException: If any data-quality check fails threshold 
validation.
+        :raises TaskDeferred: When ``require_approval=True``, defers for human 
review
+            before executing the checks.
+        """
+        planner = self._build_planner()
+
+        schema_ctx = planner.build_schema_context(
+            table_names=self.table_names, schema_context=self.schema_context
+        )
+
+        self.log.info("Using schema context:\n%s", schema_ctx)
+
+        plan = self._load_or_generate_plan(planner, schema_ctx)
+
+        if self.dq_dry_run:
+            self.log.info(
+                "dry_run=True — skipping execution. Plan contains %d group(s), 
%d check(s).",
+                len(plan.groups),
+                len(plan.check_names),
+            )
+            for group in plan.groups:
+                self.log.info(
+                    "Group: %s\nChecks: %s\nSQL Query:\n%s\n",
+                    group.group_id,
+                    ", ".join(c.check_name for c in group.checks),
+                    group.query,
+                )
+            return {"plan": plan.model_dump(), "passed": None, "results": None}
+
+        if self.require_approval:
+            # Defer BEFORE execution — approval gates the SQL checks.
+            self.defer_for_approval(  # type: ignore[misc]
+                context,
+                plan.model_dump_json(),
+                body=self._build_dry_run_markdown(plan),
+            )
+            return {}  # type: ignore[return-value]  # pragma: no cover
+
+        return self._run_checks_and_report(context, planner, plan)
+
+    def _build_planner(self) -> SQLDQPlanner:
+        """Construct a 
:class:`~airflow.providers.common.ai.utils.dq_planner.SQLDQPlanner` from 
operator config."""
+        return SQLDQPlanner(
+            llm_hook=self.llm_hook,
+            db_hook=self.db_hook,
+            dialect=self.dialect,
+            datasource_config=self.datasource_config,
+            system_prompt=self.system_prompt,
+            agent_params=self.agent_params,
+            collect_unexpected=self.collect_unexpected,
+            unexpected_sample_size=self.unexpected_sample_size,
+            
validator_contexts=default_registry.build_llm_context(self.validators),
+            row_validators=self._collect_row_validators(),
+            row_level_sample_size=self.row_level_sample_size,
+        )
+
+    def _run_checks_and_report(
+        self,
+        context: Context,
+        planner: SQLDQPlanner,
+        plan: DQPlan,
+    ) -> dict[str, Any]:
+        """
+        Execute *plan* against the database, apply validators, and return the 
serialised report.
+
+        :raises AirflowException: If any data-quality check fails.
+        """
+        results_map = planner.execute_plan(plan)
+        check_results = self._validate_results(results_map, plan)
+
+        # Collect unexpected rows for failed validity/format checks.
+        if self.collect_unexpected:
+            failed_names = {r.check_name for r in check_results if not 
r.passed}
+            if failed_names:
+                unexpected_map = planner.execute_unexpected_queries(plan, 
failed_names)
+                self._attach_unexpected(check_results, unexpected_map)
+
+        report = DQReport.build(check_results)
+
+        output: dict[str, Any] = {
+            "plan": plan.model_dump(),
+            "passed": report.passed,
+            "results": [
+                {
+                    "check_name": r.check_name,
+                    "metric_key": r.metric_key,
+                    # RowLevelResult is not JSON-serialisable; convert to a 
plain dict.
+                    "value": (
+                        {
+                            "total": r.value.total,
+                            "invalid": r.value.invalid,
+                            "invalid_pct": r.value.invalid_pct,
+                            "sample_violations": r.value.sample_violations,
+                        }
+                        if isinstance(r.value, RowLevelResult)
+                        else r.value
+                    ),
+                    "passed": r.passed,
+                    "failure_reason": r.failure_reason,
+                    **(
+                        {
+                            "unexpected_records": 
r.unexpected.unexpected_records,
+                            "unexpected_sample_size": r.unexpected.sample_size,
+                        }
+                        if r.unexpected
+                        else {}
+                    ),
+                }
+                for r in report.results
+            ],
+        }
+
+        if not report.passed:
+            # Push results to XCom before failing so downstream tasks
+            # (e.g. with trigger_rule=all_done) can still inspect them.
+            context["ti"].xcom_push(key="return_value", value=output)
+            raise AirflowException(report.failure_summary)
+
+        self.log.info("All %d data-quality check(s) passed.", 
len(report.results))
+        return output
+
+    def _build_dry_run_markdown(self, plan: DQPlan) -> str:
+        """
+        Build a structured markdown summary of the DQ plan for the HITL review 
body.
+
+        Aggregate groups and row-level groups are rendered in separate 
sections so
+        reviewers can immediately distinguish SQL-aggregate checks from per-row
+        validation logic.
+        """
+        aggregate_groups = [g for g in plan.groups if not any(c.row_level for 
c in g.checks)]
+        row_level_groups = [g for g in plan.groups if any(c.row_level for c in 
g.checks)]
+
+        total_checks = len(plan.check_names)
+        agg_count = sum(len(g.checks) for g in aggregate_groups)
+        row_count = sum(len(g.checks) for g in row_level_groups)
+
+        lines: list[str] = [
+            "# LLM Data Quality Plan",
+            "",
+            "| | |",
+            "|---|---|",
+            f"| **Plan hash** | `{plan.plan_hash or 'N/A'}` |",
+            f"| **Total checks** | {total_checks} |",
+            f"| **Aggregate checks** | {agg_count} ({len(aggregate_groups)} 
group{'s' if len(aggregate_groups) != 1 else ''}) |",
+            f"| **Row-level checks** | {row_count} ({len(row_level_groups)} 
group{'s' if len(row_level_groups) != 1 else ''}) |",
+            "",
+        ]
+
+        if aggregate_groups:
+            lines += [
+                "---",
+                "",
+                "## Aggregate Checks",
+                "",
+                "> Each group runs as a **single SQL query**. "
+                "Result columns are matched to check names by metric key.",
+                "",
+            ]
+            for group in aggregate_groups:
+                lines += self._render_aggregate_group(group)
+
+        if row_level_groups:
+            lines += [
+                "---",
+                "",
+                "## Row-Level Checks",
+                "",
+                "> Row-level checks fetch **raw column values** and apply 
Python-side "
+                "validation per row. The threshold controls the maximum 
allowed fraction "
+                "of invalid rows before the check fails.",
+                "",
+            ]
+            for group in row_level_groups:
+                lines += self._render_row_level_group(group)
+
+        return "\n".join(lines).rstrip()
+
+    def _render_aggregate_group(self, group: DQCheckGroup) -> list[str]:
+        """Render one aggregate SQL group as a markdown subsection."""
+        lines: list[str] = [
+            f"### `{group.group_id}`",
+            "",
+            "| Check name | Metric key | Category |",
+            "|---|---|---|",
+        ]
+        for check in group.checks:
+            category = check.check_category or "—"
+            lines.append(f"| `{check.check_name}` | `{check.metric_key}` | 
{category} |")
+
+        lines += [
+            "",
+            "```sql",
+            group.query.strip(),
+            "```",
+            "",
+        ]
+
+        # Unexpected queries — only show when present.
+        unexpected = [(c.check_name, c.unexpected_query) for c in group.checks 
if c.unexpected_query]
+        if unexpected:
+            lines += ["<details><summary>Unexpected-row queries</summary>", ""]
+            for check_name, uq in unexpected:
+                lines += [
+                    f"**`{check_name}`**",
+                    "",
+                    "```sql",
+                    (uq or "").strip(),
+                    "```",
+                    "",
+                ]
+            lines += ["</details>", ""]
+
+        return lines
+
+    def _render_row_level_group(self, group: DQCheckGroup) -> list[str]:
+        """Render one row-level group as a markdown subsection with threshold 
info."""
+        lines: list[str] = [
+            f"### `{group.group_id}`",
+            "",
+            "| Check name | Metric key | Max invalid % |",
+            "|---|---|---|",
+        ]
+        for check in group.checks:
+            validator = self.validators.get(check.check_name)
+            max_pct = getattr(validator, "_max_invalid_pct", None)
+            threshold_str = f"{max_pct:.2%}" if max_pct is not None else "—"
+            lines.append(f"| `{check.check_name}` | `{check.metric_key}` | 
{threshold_str} |")
+
+        lines += [
+            "",
+            "```sql",
+            group.query.strip(),
+            "```",
+            "",
+        ]
+        return lines
+
+    def _load_or_generate_plan(self, planner: SQLDQPlanner, schema_ctx: str) 
-> DQPlan:
+        """Return a cached plan when available, otherwise generate and cache a 
new one."""
+        if not isinstance(self.prompts, dict):
+            raise TypeError("prompts must be a dict[str, str] before 
generating a DQ plan.")
+
+        plan_hash = _compute_plan_hash(self.prompts, self.prompt_version, 
self.collect_unexpected)
+        variable_key = f"{_PLAN_VARIABLE_PREFIX}{plan_hash}"
+
+        cached_json = Variable.get(variable_key, default=None)
+        if cached_json is not None:
+            self.log.info("DQ plan cache hit — key: %r", variable_key)
+            plan = DQPlan.model_validate_json(cached_json)
+            if not plan.plan_hash:
+                plan.plan_hash = plan_hash
+            return plan
+
+        self.log.info("DQ plan cache miss — generating via LLM (key: %r).", 
variable_key)
+        plan = planner.generate_plan(self.prompts, schema_ctx)
+        plan.plan_hash = plan_hash
+        Variable.set(variable_key, plan.model_dump_json())
+        return plan
+
+    def _validate_results(
+        self,
+        results_map: dict[str, Any],
+        plan: DQPlan,
+    ) -> list[DQCheckResult]:
+        """
+        Apply validators to each metric value and return per-check results.
+
+        For aggregate checks each validator callable receives the raw metric
+        value returned by the database.  For row-level checks, where *value* is
+        a :class:`~airflow.providers.common.ai.utils.dq_models.RowLevelResult`,
+        the pass/fail decision compares ``invalid_pct`` against the validator's
+        ``_max_invalid_pct`` attribute (defaulting to ``0.0`` when absent).
+        Checks without a registered validator are logged and marked as passed.
+
+        :param results_map: ``{check_name: metric_value_or_RowLevelResult}`` as
+            returned by
+            
:meth:`~airflow.providers.common.ai.utils.dq_planner.SQLDQPlanner.execute_plan`.
+        :param plan: The DQ plan whose groups and checks drive iteration order.
+        :returns: Per-check
+            :class:`~airflow.providers.common.ai.utils.dq_models.DQCheckResult`
+            list in plan-group order.
+        """
+        check_results: list[DQCheckResult] = []
+
+        for group in plan.groups:
+            for check in group.checks:

Review Comment:
   `_validate_results()` indexes `results_map[check.check_name]` directly. If 
the planner returns a partial results map (e.g. due to a row-level group being 
skipped or a missing metric key), this becomes a raw `KeyError` without 
context. Consider using `results_map.get(...)` and raising a 
`ValueError`/`AirflowException` that names the missing `check_name` and the 
group/query that produced it, so failures are diagnosable.
   ```suggestion
               for check in group.checks:
                   if check.check_name not in results_map:
                       group_name = getattr(group, "group_name", None)
                       query = getattr(group, "query", None)
                       group_context = f"group={group_name!r}, query={query!r}" 
if group_name is not None else f"query={query!r}"
                       raise AirflowException(
                           f"Missing result for check_name={check.check_name!r} 
while validating data quality plan "
                           f"({group_context})."
                       )
   ```



##########
providers/common/ai/src/airflow/providers/common/ai/utils/dq_planner.py:
##########
@@ -0,0 +1,836 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+SQL-based data-quality plan generation and execution.
+
+:class:`SQLDQPlanner` is the single entry-point for all SQL DQ logic.
+It is deliberately kept separate from the operator so it can be unit-tested
+without an Airflow context and later swapped for GEX/SODA planners without
+touching the operator.
+"""
+
+from __future__ import annotations
+
+import logging
+from collections.abc import Iterator, Sequence
+from contextlib import closing
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.ai.utils.sql_validation import (
+        DEFAULT_ALLOWED_TYPES,
+        SQLSafetyError,
+        validate_sql as _validate_sql,
+    )
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from airflow.providers.common.ai.utils.db_schema import build_schema_context, 
resolve_dialect
+from airflow.providers.common.ai.utils.dq_models import DQCheckGroup, DQPlan, 
RowLevelResult, UnexpectedResult
+from airflow.providers.common.ai.utils.logging import log_run_summary
+
+if TYPE_CHECKING:
+    from pydantic_ai import Agent
+    from pydantic_ai.messages import ModelMessage
+
+    from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+    from airflow.providers.common.sql.config import DataSourceConfig
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+log = logging.getLogger(__name__)
+
+_MAX_CHECKS_PER_GROUP = 5
+# Maximum rows fetched from DB per chunk during row-level processing — avoids 
loading the
+# entire result set into memory at once.
+_ROW_LEVEL_CHUNK_SIZE = 10_000
+# Hard cap on violation samples stored per check — independent of SQL LIMIT 
and chunk size.
+_MAX_VIOLATION_SAMPLES = 100
+
+_PLANNING_SYSTEM_PROMPT = """\
+You are a data-quality SQL expert.
+
+Given a set of named data-quality checks and a database schema, produce a \
+DQPlan that minimises the number of SQL queries while keeping each group \
+focused and manageable.
+
+GROUPING STRATEGY (multi-dimensional):
+  Group checks by **(target_table, check_category)**.  Checks on the same table
+  that belong to different categories MUST be in separate groups.
+
+  Allowed check_category values (assign one per check based on its 
description):
+    - null_check      — null / missing value counts or percentages
+    - uniqueness      — duplicate detection, cardinality checks
+    - validity        — regex / format / pattern matching on string columns
+    - numeric_range   — range, bounds, or statistical checks on numeric columns
+    - row_count       — total row counts or existence checks
+    - string_format   — length, encoding, whitespace, or character-set checks
+    - row_level       — per-row or anomaly checks that evaluate individual 
records
+
+  Row-level checks still follow the same grouping rule: group by 
(target_table, check_category="row_level").
+  MAX {max_checks_per_group} CHECKS PER GROUP:
+    If a (table, category) pair has more than {max_checks_per_group} checks,
+    split them into sub-groups of at most {max_checks_per_group}.
+
+  GROUP-ID NAMING:
+    Use the pattern "{{table}}_{{category}}_{{part}}".
+    Examples: customers_null_check_1, orders_validity_1, orders_validity_2
+
+  RATIONALE:
+    Keeping string-column checks (validity, string_format) apart from
+    numeric-column checks (numeric_range, null_check on numbers) produces
+    simpler SQL and makes failures easier to diagnose.
+
+  CORRECT (two groups for same table, different categories):
+    Group customers_null_check_1:
+      SELECT
+        (COUNT(CASE WHEN email IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS 
null_email_pct,
+        (COUNT(CASE WHEN name IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS 
null_name_pct
+      FROM customers
+
+    Group customers_validity_1:
+      SELECT
+        COUNT(CASE WHEN phone NOT LIKE '+___-___-____' THEN 1 END) AS 
invalid_phone_fmt
+      FROM customers
+
+  WRONG (mixing null-check and regex-validity in one group):
+    SELECT
+      (COUNT(CASE WHEN email IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS 
null_email_pct,
+      COUNT(CASE WHEN phone NOT LIKE '+___-___-____' THEN 1 END) AS 
invalid_phone_fmt
+    FROM customers
+
+OUTPUT RULES:
+  1. Each output column must be aliased to exactly the metric_key of its check.
+     Example: ... AS null_email_pct
+  2. Each check_name must exactly match the key in the prompts dict.
+  3. metric_key values must be valid SQL column aliases (snake_case, no 
spaces).
+  4. Generates only SELECT queries — no INSERT, UPDATE, DELETE, DROP, or DDL.
+  5. Use {dialect} syntax.
+  6. Each check must appear in exactly ONE group.
+  7. Each check must have a check_category from the allowed list above.
+  8. Return a valid DQPlan object. No extra commentary.
+"""
+
+_DATAFUSION_SYNTAX_SECTION = """\
+
+DATAFUSION SQL SYNTAX RULES:
+  The target engine is Apache DataFusion.  Observe these syntax differences
+  from standard PostgreSQL / ANSI SQL:
+
+  1. NO "FILTER (WHERE ...)" clause.  Use CASE expressions instead:
+       WRONG:  COUNT(*) FILTER (WHERE email IS NULL)
+       RIGHT:  COUNT(CASE WHEN email IS NULL THEN 1 END)
+
+  2. Regex matching uses the tilde operator:
+       column ~ 'pattern'    (match)
+       column !~ 'pattern'   (no match)
+     Do NOT use SIMILAR TO or POSIX-style ~* (case-insensitive).
+
+  3. CAST syntax — prefer CAST(expr AS type) over :: shorthand.
+
+  4. String functions: Use CHAR_LENGTH (not LEN), SUBSTR (not SUBSTRING with 
FROM/FOR).
+
+  5. Integer division: DataFusion performs integer division for INT/INT.
+     Use CAST(expr AS DOUBLE) to force floating-point division.
+
+  6. Boolean literals: Use TRUE / FALSE (not 1 / 0).
+
+  7. LIMIT is supported.  OFFSET is supported.  FETCH FIRST is NOT supported.
+
+  8. NULL handling: COALESCE, NULLIF, IFNULL are all supported.
+     NVL and ISNULL are NOT supported.
+"""
+
+_UNEXPECTED_QUERY_PROMPT_SECTION = """\
+
+UNEXPECTED VALUE COLLECTION:
+  For checks whose check_category is "validity" or "string_format", also
+  generate an unexpected_query field on the DQCheck.  This query must:
+    - SELECT the primary key column(s) and the column(s) being validated
+    - WHERE the row violates the check condition (the negation of the check)
+    - LIMIT {sample_size}
+    - Use {dialect} syntax
+    - Be a standalone SELECT (not a subquery of the group query)
+
+  For all other categories (null_check, uniqueness, numeric_range, row_count),
+  set unexpected_query to null — these are aggregate checks where individual
+  violating rows are not meaningful.
+
+  Example for a phone-format validity check:
+    unexpected_query: "SELECT id, phone FROM customers WHERE phone !~ 
'^\\d{{4}}-\\d{{4}}-\\d{{4}}$' LIMIT 100"
+"""
+
+_ROW_LEVEL_PROMPT_SECTION = """
+
+ROW-LEVEL CHECKS:
+  Some checks are marked as row_level.  For these:
+    - Generate a SELECT that returns the primary key column(s) and the column
+      being validated.  Do NOT aggregate.
+    - Set row_level = true on the DQCheck entry.
+    - metric_key must be the name of the column containing the value to 
validate
+      (the Python validator will read row[metric_key] for each row).
+    - {row_level_limit_clause}
+    - Place ALL row-level checks for the same table in a single group.
+
+  Row-level check names that require this treatment: {row_level_check_names}
+"""
+
+
+class SQLDQPlanner:
+    """
+    Generates and executes a SQL-based 
:class:`~airflow.providers.common.ai.utils.dq_models.DQPlan`.
+
+    :param llm_hook: Hook used to call the LLM for plan generation.
+    :param db_hook: Hook used to execute generated SQL against the database.
+    :param dialect: SQL dialect forwarded to the LLM prompt and 
``validate_sql``.
+        Auto-detected from *db_hook* when ``None``.
+    :param max_sql_retries: Maximum number of times a failing SQL group query 
is sent
+        back to the LLM for correction before the error is re-raised.  Default 
``2``.
+    :param validator_contexts: Pre-built LLM context string from
+        
:meth:`~airflow.providers.common.ai.utils.dq_validation.ValidatorRegistry.build_llm_context`.
+        Appended to the system prompt so the LLM knows what metric format each
+        custom validator expects.
+    :param row_validators: Mapping of ``{check_name: row_level_callable}`` for
+        checks that require row-by-row Python validation.  When a check's name
+        appears here, ``execute_plan`` fetches all (or sampled) rows and 
applies
+        the callable to each value instead of reading a single aggregate 
scalar.
+    :param row_level_sample_size: Maximum number of rows to fetch for row-level
+        checks.  ``None`` (default) performs a full scan.  A positive integer
+        instructs the LLM to add ``LIMIT N`` to the generated SELECT.
+    """
+
+    def __init__(
+        self,
+        *,
+        llm_hook: PydanticAIHook,
+        db_hook: DbApiHook | None,
+        dialect: str | None = None,
+        max_sql_retries: int = 2,
+        datasource_config: DataSourceConfig | None = None,
+        system_prompt: str = "",
+        agent_params: dict[str, Any] | None = None,
+        collect_unexpected: bool = False,
+        unexpected_sample_size: int = 100,
+        validator_contexts: str = "",
+        row_validators: dict[str, Any] | None = None,
+        row_level_sample_size: int | None = None,
+    ) -> None:
+        self._llm_hook = llm_hook
+        self._db_hook = db_hook
+        self._datasource_config = datasource_config
+        self._dialect = resolve_dialect(db_hook, dialect)
+        # Track whether the execution target is DataFusion so the prompt can
+        # include DataFusion-specific syntax rules.  The dialect stays None
+        # (generic SQL) for sqlglot validation — sqlglot has no DataFusion 
dialect.
+        self._is_datafusion = db_hook is None and datasource_config is not None
+        # When targeting DataFusion, use PostgreSQL dialect for sqlglot 
validation
+        # because DataFusion shares regex operators (~, !~) that the generic 
SQL
+        # parser does not recognise.
+        self._validation_dialect: str | None = "postgres" if 
self._is_datafusion else self._dialect
+        self._max_sql_retries = max_sql_retries
+        self._extra_system_prompt = system_prompt
+        self._agent_params: dict[str, Any] = agent_params or {}
+        self._collect_unexpected = collect_unexpected
+        self._unexpected_sample_size = unexpected_sample_size
+        self._validator_contexts = validator_contexts
+        self._row_validators: dict[str, Any] = row_validators or {}
+        self._row_level_sample_size = row_level_sample_size
+        # Populated by generate_plan; used by _retry_fix_group to continue the 
conversation.
+        self._plan_agent: Agent[None, DQPlan] | None = None
+        self._plan_all_messages: list[ModelMessage] | None = None
+
+    def build_schema_context(
+        self,
+        table_names: list[str] | None,
+        schema_context: str | None,
+    ) -> str:
+        """
+        Return a schema description string for inclusion in the LLM prompt.
+
+        Delegates to 
:func:`~airflow.providers.common.ai.utils.db_schema.build_schema_context`.
+        """
+        return build_schema_context(
+            db_hook=self._db_hook,
+            table_names=table_names,
+            schema_context=schema_context,
+            datasource_config=self._datasource_config,
+        )
+
+    def generate_plan(self, prompts: dict[str, str], schema_context: str) -> 
DQPlan:
+        """
+        Ask the LLM to produce a 
:class:`~airflow.providers.common.ai.utils.dq_models.DQPlan`.
+
+        The LLM receives the user prompts, schema context, and planning 
instructions
+        as a structured-output call (``output_type=DQPlan``).  After 
generation the
+        method verifies that the returned ``check_names`` exactly match
+        ``prompts.keys()``.
+
+        :param prompts: ``{check_name: natural_language_description}`` dict.
+        :param schema_context: Schema description previously built via
+            :meth:`build_schema_context`.
+        :raises ValueError: If the LLM's plan does not cover every prompt key
+            exactly once.
+        """
+        dialect_label = self._dialect or ("DataFusion-compatible SQL" if 
self._is_datafusion else "SQL")
+        system_prompt = _PLANNING_SYSTEM_PROMPT.format(
+            dialect=dialect_label, max_checks_per_group=_MAX_CHECKS_PER_GROUP
+        )
+
+        if self._is_datafusion:
+            system_prompt += _DATAFUSION_SYNTAX_SECTION
+
+        if self._collect_unexpected:
+            system_prompt += _UNEXPECTED_QUERY_PROMPT_SECTION.format(
+                dialect=dialect_label, sample_size=self._unexpected_sample_size
+            )
+
+        if schema_context:
+            system_prompt += f"\nAvailable schema:\n{schema_context}\n"
+
+        if self._validator_contexts:
+            system_prompt += self._validator_contexts
+
+        if self._row_validators:
+            row_level_check_names = ", ".join(sorted(self._row_validators))
+            if self._row_level_sample_size is not None:
+                limit_clause = f"Add LIMIT {self._row_level_sample_size} to 
the query."
+            else:
+                limit_clause = "Do NOT add a LIMIT — return all rows."
+            system_prompt += _ROW_LEVEL_PROMPT_SECTION.format(
+                row_level_check_names=row_level_check_names,
+                row_level_limit_clause=limit_clause,
+            )
+
+        if self._extra_system_prompt:
+            system_prompt += f"\nAdditional 
instructions:\n{self._extra_system_prompt}\n"
+
+        user_message = self._build_user_message(prompts)
+
+        log.info("Using system prompt:\n%s", system_prompt)
+        log.info("Using user message:\n%s", user_message)
+
+        agent = self._llm_hook.create_agent(
+            output_type=DQPlan, instructions=system_prompt, 
**self._agent_params
+        )
+        result = agent.run_sync(user_message)
+        log_run_summary(log, result)
+
+        # Persist the agent and full conversation so execute_plan can continue
+        # the same chat thread when asking for SQL corrections.
+        self._plan_agent = agent
+        self._plan_all_messages = result.all_messages()
+
+        plan: DQPlan = result.output
+
+        self._validate_plan_coverage(plan, prompts)
+        self._validate_group_sizes(plan)
+        return plan
+
+    def execute_plan(self, plan: DQPlan) -> dict[str, Any]:
+        """
+        Execute every SQL group in *plan* and return a flat ``{check_name: 
value}`` map.
+
+        Each group's query is safety-validated via
+        :func:`~airflow.providers.common.ai.utils.sql_validation.validate_sql` 
before
+        execution.  The first row of each result-set is used; each column 
corresponds
+        to the ``metric_key`` of one 
:class:`~airflow.providers.common.ai.utils.dq_models.DQCheck`.
+
+        :param plan: Plan produced by :meth:`generate_plan`.
+        :raises ValueError: If neither *db_hook* nor *datasource_config* was 
supplied.
+        :raises SQLSafetyError: If a generated query fails AST validation even 
after
+            ``max_sql_retries`` LLM correction attempts.
+        :raises ValueError: If a query result does not contain an expected 
metric column.
+        """
+        if self._db_hook is None and self._datasource_config is None:
+            raise ValueError("Either db_conn_id or datasource_config is 
required to execute the DQ plan.")
+
+        datafusion_engine: DataFusionEngine | None = None
+        if self._db_hook is None:
+            datafusion_engine = self._build_datafusion_engine()
+
+        results: dict[str, Any] = {}
+
+        for raw_group in plan.groups:
+            group = self._validate_or_fix_group(raw_group)
+            log.debug("Executing DQ group %r:\n%s", group.group_id, 
group.query)
+
+            # Row-level checks and aggregate checks are mutually exclusive 
within a group
+            # because the LLM places them in separate groups based on the 
system prompt.
+            if any(check.row_level for check in group.checks):
+                row_level_results = self._execute_row_level_group(group, 
datafusion_engine)
+                results.update(row_level_results)
+                continue
+
+            if datafusion_engine is not None:
+                row = self._run_datafusion_group(datafusion_engine, group)
+            else:
+                row = self._run_db_group(group)
+
+            for check in group.checks:
+                if check.metric_key not in row:
+                    raise ValueError(
+                        f"Query for group {group.group_id!r} did not return "
+                        f"column {check.metric_key!r} required by check 
{check.check_name!r}. "
+                        f"Available columns: {list(row.keys())}"
+                    )
+                results[check.check_name] = row[check.metric_key]
+
+        return results
+
+    def _execute_row_level_group(
+        self,
+        group: DQCheckGroup,
+        datafusion_engine: DataFusionEngine | None,
+    ) -> dict[str, RowLevelResult]:
+        """
+        Apply row-level validators to every row returned by *group.query*.
+
+        :param group: A plan group whose checks all have ``row_level=True``.
+        :param datafusion_engine: Active DataFusion engine or ``None`` for DB.
+        :returns: ``{check_name: RowLevelResult}`` for every row-level check.
+        """
+        active_checks = [
+            check for check in group.checks if check.row_level and 
check.check_name in self._row_validators
+        ]
+        for check in group.checks:
+            if check.row_level and check.check_name not in 
self._row_validators:
+                log.warning("No row-level validator found for check %r — 
skipping.", check.check_name)
+
+        if not active_checks:
+            return {}
+
+        # counters[check_name] = [total, invalid, sample_violations]
+        counters: dict[str, list[Any]] = {check.check_name: [0, 0, []] for 
check in active_checks}
+
+        chunk_iter: Iterator[list[dict[str, Any]]]
+        if datafusion_engine is not None:
+            chunk_iter = self._iter_datafusion_row_chunks(datafusion_engine, 
group.query)
+        else:
+            chunk_iter = self._iter_db_row_chunks(group.query)
+
+        total_rows_seen = 0
+        for chunk in chunk_iter:
+            total_rows_seen += len(chunk)
+            for row in chunk:
+                for check in active_checks:
+                    value = row.get(check.metric_key)
+                    c = counters[check.check_name]
+                    c[0] += 1  # total
+                    try:
+                        passed = 
bool(self._row_validators[check.check_name](value))
+                    except Exception:
+                        # Treat any exception from the validator as a 
validation failure.
+                        passed = False
+                    if not passed:
+                        c[1] += 1  # invalid
+                        if len(c[2]) < _MAX_VIOLATION_SAMPLES:
+                            c[2].append(str(value))
+
+        if total_rows_seen == 0:
+            log.warning("Row-level query for group %r returned no rows.", 
group.group_id)
+
+        results: dict[str, RowLevelResult] = {}
+        for check in active_checks:
+            total, invalid, violations = counters[check.check_name]
+            invalid_pct = invalid / total if total else 0.0
+            results[check.check_name] = RowLevelResult(
+                check_name=check.check_name,
+                total=total,
+                invalid=invalid,
+                invalid_pct=invalid_pct,
+                sample_violations=violations,
+                sample_size=_MAX_VIOLATION_SAMPLES if total else 0,
+            )
+            log.info(
+                "Row-level check %r: %d/%d invalid (%.4f%%)",
+                check.check_name,
+                invalid,
+                total,
+                invalid_pct * 100,
+            )
+
+        return results
+
+    def _iter_db_row_chunks(self, query: str) -> Iterator[list[dict[str, 
Any]]]:
+        """
+        Execute *query* via the DB hook and yield chunks of row-dicts.
+
+        Uses ``cursor.fetchmany(_ROW_LEVEL_CHUNK_SIZE)`` so at most
+        :data:`_ROW_LEVEL_CHUNK_SIZE` raw rows are held in memory at once.
+        Column names are resolved from ``cursor.description``.
+        """
+        with closing(self._db_hook.get_conn()) as conn:  # type: 
ignore[union-attr]
+            with closing(conn.cursor()) as cur:
+                cur.execute(query)
+                col_names = [str(desc[0]) for desc in cur.description] if 
cur.description else []
+                while True:
+                    chunk = cur.fetchmany(_ROW_LEVEL_CHUNK_SIZE)
+                    if not chunk:
+                        break
+                    result_chunk: list[dict[str, Any]] = []
+                    for raw_row in chunk:
+                        if isinstance(raw_row, dict):
+                            result_chunk.append(raw_row)
+                        elif isinstance(raw_row, Sequence) and not isinstance(
+                            raw_row, str | bytes | bytearray
+                        ):
+                            result_chunk.append(dict(zip(col_names, raw_row)))
+                        else:
+                            result_chunk.append({})
+                    yield result_chunk

Review Comment:
   `_iter_db_row_chunks()` silently converts any unexpected row shape into an 
empty dict (`{}`) and also uses `dict(zip(col_names, raw_row))` without 
verifying that `cursor.description` exists and matches the tuple length. This 
can silently drop columns (or produce empty rows) and cause row-level 
validators to run against `None` values instead of the intended column. Please 
raise a clear exception for unsupported row types and for tuple/sequence rows 
when `cursor.description` is missing or the column count doesn’t match the row 
length, so query/driver issues don’t turn into incorrect DQ results.



##########
providers/common/ai/src/airflow/providers/common/ai/operators/llm_data_quality.py:
##########
@@ -0,0 +1,616 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Operator for generating and executing data-quality checks from natural 
language using LLMs."""
+
+from __future__ import annotations
+
+import hashlib
+import json
+from collections.abc import Callable, Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.ai.utils.db_schema import get_db_hook
+from airflow.providers.common.ai.utils.dq_models import (
+    DQCheckGroup,
+    DQCheckResult,
+    DQPlan,
+    DQReport,
+    RowLevelResult,
+    UnexpectedResult,
+)
+from airflow.providers.common.ai.utils.dq_validation import default_registry
+from airflow.providers.common.compat.sdk import AirflowException, Variable
+
+try:
+    from airflow.providers.common.ai.utils.dq_planner import SQLDQPlanner
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+if TYPE_CHECKING:
+    from airflow.providers.common.sql.config import DataSourceConfig
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+    from airflow.sdk import Context
+
+_PLAN_VARIABLE_PREFIX = "dq_plan_"
+_PLAN_VARIABLE_KEY_MAX_LEN = 200  # stay well under Airflow Variable key 
length limit
+
+
+class LLMDataQualityOperator(LLMOperator):
+    """
+    Generate and execute data-quality checks from natural language 
descriptions.
+
+    Each entry in ``prompts`` describes **one** data-quality expectation.
+    The LLM groups related checks into optimised SQL queries, executes them
+    against the target database, and validates each metric against the
+    corresponding entry in ``validators``.  The task fails if any check
+    does not pass, gating downstream tasks on data quality.
+
+    Generated SQL plans are cached in Airflow
+    :class:`~airflow.models.variable.Variable` to avoid repeat LLM calls.
+    Set ``dry_run=True`` to preview the plan without executing it — the
+    serialised plan dict is returned without running any SQL.
+    Set ``require_approval=True`` to gate execution on human review via the
+    HITL interface: the plan is presented to the reviewer first, and SQL
+    checks run only after approval.  ``dry_run`` and ``require_approval``
+    are independent — enabling both returns the plan dict without any
+    approval prompt.
+
+    :param prompts: Mapping of ``{check_name: natural_language_description}``.
+        Each key must be unique.  Use one check per key; the operator enforces
+        a strict one-key → one-check mapping.
+    :param llm_conn_id: Connection ID for the LLM provider.
+    :param model_id: Model identifier (e.g. ``"openai:gpt-4o"``).
+        Overrides the model stored in the connection's extra field.
+    :param system_prompt: Additional instructions appended to the planning 
prompt.
+    :param agent_params: Additional keyword arguments passed to the pydantic-ai
+        ``Agent`` constructor (e.g. ``retries``, ``model_settings``).
+    :param db_conn_id: Connection ID for the database to run checks against.
+        Must resolve to a 
:class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`.
+    :param table_names: Tables to include in the LLM's schema context.
+    :param schema_context: Manual schema description; bypasses DB 
introspection.
+    :param validators: Mapping of ``{check_name: callable}`` where each 
callable
+        receives the raw metric value and returns ``True`` (pass) or ``False`` 
(fail).
+        Keys must be a subset of ``prompts.keys()``.
+        Use built-in factories from
+        :mod:`~airflow.providers.common.ai.utils.dq_validation` or plain 
lambdas::
+
+            from airflow.providers.common.ai.utils.dq_validation import 
null_pct_check
+
+            validators = {
+                "email_nulls": null_pct_check(max_pct=0.05),
+                "row_check": lambda v: v >= 1000,
+            }
+
+    :param dialect: SQL dialect override (``postgres``, ``mysql``, etc.).
+        Auto-detected from *db_conn_id* when not set.
+    :param datasource_config: DataFusion datasource for object-storage schema.
+    :param dry_run: When ``True``, generate and cache the plan but skip 
execution.
+        Returns the serialised plan dict instead of a 
:class:`~airflow.providers.common.ai.utils.dq_models.DQReport`.
+    :param prompt_version: Optional version tag included in the plan cache key.
+        Bump this to invalidate cached plans when prompts change semantically
+        without changing their text.
+    :param collect_unexpected: When ``True``, the LLM generates an
+        ``unexpected_query`` for validity / string-format checks.
+        If any of those checks fail, the unexpected query is executed and
+        the resulting sample rows are included in the report.
+    :param unexpected_sample_size: Maximum number of violating rows to return
+        per failed check.  Default ``100``.
+    :param row_level_sample_size: Maximum number of rows to fetch per row-level
+        check.  ``None`` (default) performs a full table scan — every row is
+        fetched and validated.  A positive integer is passed to the LLM as a
+        ``LIMIT`` clause on the generated SELECT, bounding execution time and
+        memory usage at the cost of sampling coverage.
+    :param require_approval: When ``True``, the operator defers after 
generating
+        and caching the DQ plan.  The plan SQL is surfaced in the HITL 
interface
+        for human review; checks run only after the reviewer approves.  
Inherited
+        from :class:`~airflow.providers.common.ai.operators.llm.LLMOperator`.
+        ``dry_run=True`` takes precedence — combining both flags returns the 
plan
+        dict immediately without requesting approval.
+    """
+
+    template_fields: Sequence[str] = (
+        *LLMOperator.template_fields,
+        "prompts",
+        "db_conn_id",
+        "table_names",
+        "schema_context",
+        "prompt_version",
+        "collect_unexpected",
+        "unexpected_sample_size",
+        "row_level_sample_size",
+    )
+
+    def __init__(
+        self,
+        *,
+        prompts: dict[str, str],
+        db_conn_id: str | None = None,
+        table_names: list[str] | None = None,
+        schema_context: str | None = None,
+        validators: dict[str, Callable[[Any], bool]] | None = None,
+        dialect: str | None = None,
+        datasource_config: DataSourceConfig | None = None,
+        prompt_version: str | None = None,
+        dry_run: bool = False,
+        collect_unexpected: bool = False,
+        unexpected_sample_size: int = 100,
+        row_level_sample_size: int | None = None,
+        **kwargs: Any,
+    ) -> None:
+        kwargs.pop("output_type", None)
+        kwargs.setdefault("prompt", "LLMDataQualityOperator")
+        super().__init__(**kwargs)
+
+        self.prompts = prompts
+        self.db_conn_id = db_conn_id
+        self.table_names = table_names
+        self.schema_context = schema_context
+        self.validators = validators or {}
+        self.dialect = dialect
+        self.datasource_config = datasource_config
+        self.prompt_version = prompt_version
+        self.dq_dry_run = dry_run
+        self.collect_unexpected = collect_unexpected
+        self.unexpected_sample_size = unexpected_sample_size
+        self.row_level_sample_size = row_level_sample_size
+
+        self._validate_prompts()
+        self._validate_validator_keys()
+
+    def execute(self, context: Context) -> dict[str, Any]:
+        """
+        Generate the DQ plan (or load from cache), then execute or defer for 
approval.
+
+        When ``dry_run=True`` the serialised plan dict is returned immediately 
—
+        no SQL is executed and no approval is requested.
+        When ``require_approval=True`` the task defers, presenting the plan to 
a
+        human reviewer; data-quality checks run only after the reviewer 
approves.
+
+        :returns: Dict with keys ``plan``, ``passed``, and ``results``.  On 
success
+            ``passed=True`` and ``results`` is a list of per-check result 
dicts.
+            For row-level checks the ``value`` entry in each result dict is 
itself
+            a dict with keys ``total``, ``invalid``, ``invalid_pct``, and
+            ``sample_violations`` rather than a raw scalar.
+            When ``dry_run=True`` ``passed=None`` and ``results=None`` — no SQL
+            is executed.  The ``plan`` key is always present in all modes.
+        :raises AirflowException: If any data-quality check fails threshold 
validation.
+        :raises TaskDeferred: When ``require_approval=True``, defers for human 
review
+            before executing the checks.
+        """
+        planner = self._build_planner()
+
+        schema_ctx = planner.build_schema_context(
+            table_names=self.table_names, schema_context=self.schema_context
+        )
+
+        self.log.info("Using schema context:\n%s", schema_ctx)
+
+        plan = self._load_or_generate_plan(planner, schema_ctx)
+
+        if self.dq_dry_run:
+            self.log.info(
+                "dry_run=True — skipping execution. Plan contains %d group(s), 
%d check(s).",
+                len(plan.groups),
+                len(plan.check_names),
+            )
+            for group in plan.groups:
+                self.log.info(
+                    "Group: %s\nChecks: %s\nSQL Query:\n%s\n",
+                    group.group_id,
+                    ", ".join(c.check_name for c in group.checks),
+                    group.query,
+                )
+            return {"plan": plan.model_dump(), "passed": None, "results": None}
+
+        if self.require_approval:
+            # Defer BEFORE execution — approval gates the SQL checks.
+            self.defer_for_approval(  # type: ignore[misc]
+                context,
+                plan.model_dump_json(),
+                body=self._build_dry_run_markdown(plan),
+            )
+            return {}  # type: ignore[return-value]  # pragma: no cover
+
+        return self._run_checks_and_report(context, planner, plan)
+
+    def _build_planner(self) -> SQLDQPlanner:
+        """Construct a 
:class:`~airflow.providers.common.ai.utils.dq_planner.SQLDQPlanner` from 
operator config."""
+        return SQLDQPlanner(
+            llm_hook=self.llm_hook,
+            db_hook=self.db_hook,
+            dialect=self.dialect,
+            datasource_config=self.datasource_config,
+            system_prompt=self.system_prompt,
+            agent_params=self.agent_params,
+            collect_unexpected=self.collect_unexpected,
+            unexpected_sample_size=self.unexpected_sample_size,
+            
validator_contexts=default_registry.build_llm_context(self.validators),
+            row_validators=self._collect_row_validators(),
+            row_level_sample_size=self.row_level_sample_size,
+        )
+
+    def _run_checks_and_report(
+        self,
+        context: Context,
+        planner: SQLDQPlanner,
+        plan: DQPlan,
+    ) -> dict[str, Any]:
+        """
+        Execute *plan* against the database, apply validators, and return the 
serialised report.
+
+        :raises AirflowException: If any data-quality check fails.
+        """
+        results_map = planner.execute_plan(plan)
+        check_results = self._validate_results(results_map, plan)
+
+        # Collect unexpected rows for failed validity/format checks.
+        if self.collect_unexpected:
+            failed_names = {r.check_name for r in check_results if not 
r.passed}
+            if failed_names:
+                unexpected_map = planner.execute_unexpected_queries(plan, 
failed_names)
+                self._attach_unexpected(check_results, unexpected_map)
+
+        report = DQReport.build(check_results)
+
+        output: dict[str, Any] = {
+            "plan": plan.model_dump(),
+            "passed": report.passed,
+            "results": [
+                {
+                    "check_name": r.check_name,
+                    "metric_key": r.metric_key,
+                    # RowLevelResult is not JSON-serialisable; convert to a 
plain dict.
+                    "value": (
+                        {
+                            "total": r.value.total,
+                            "invalid": r.value.invalid,
+                            "invalid_pct": r.value.invalid_pct,
+                            "sample_violations": r.value.sample_violations,
+                        }
+                        if isinstance(r.value, RowLevelResult)
+                        else r.value
+                    ),
+                    "passed": r.passed,
+                    "failure_reason": r.failure_reason,
+                    **(
+                        {
+                            "unexpected_records": 
r.unexpected.unexpected_records,
+                            "unexpected_sample_size": r.unexpected.sample_size,
+                        }
+                        if r.unexpected
+                        else {}
+                    ),
+                }
+                for r in report.results
+            ],
+        }
+
+        if not report.passed:
+            # Push results to XCom before failing so downstream tasks
+            # (e.g. with trigger_rule=all_done) can still inspect them.
+            context["ti"].xcom_push(key="return_value", value=output)
+            raise AirflowException(report.failure_summary)
+
+        self.log.info("All %d data-quality check(s) passed.", 
len(report.results))
+        return output
+
+    def _build_dry_run_markdown(self, plan: DQPlan) -> str:
+        """
+        Build a structured markdown summary of the DQ plan for the HITL review 
body.
+
+        Aggregate groups and row-level groups are rendered in separate 
sections so
+        reviewers can immediately distinguish SQL-aggregate checks from per-row
+        validation logic.
+        """
+        aggregate_groups = [g for g in plan.groups if not any(c.row_level for 
c in g.checks)]
+        row_level_groups = [g for g in plan.groups if any(c.row_level for c in 
g.checks)]
+
+        total_checks = len(plan.check_names)
+        agg_count = sum(len(g.checks) for g in aggregate_groups)
+        row_count = sum(len(g.checks) for g in row_level_groups)
+
+        lines: list[str] = [
+            "# LLM Data Quality Plan",
+            "",
+            "| | |",
+            "|---|---|",
+            f"| **Plan hash** | `{plan.plan_hash or 'N/A'}` |",
+            f"| **Total checks** | {total_checks} |",
+            f"| **Aggregate checks** | {agg_count} ({len(aggregate_groups)} 
group{'s' if len(aggregate_groups) != 1 else ''}) |",
+            f"| **Row-level checks** | {row_count} ({len(row_level_groups)} 
group{'s' if len(row_level_groups) != 1 else ''}) |",
+            "",
+        ]
+
+        if aggregate_groups:
+            lines += [
+                "---",
+                "",
+                "## Aggregate Checks",
+                "",
+                "> Each group runs as a **single SQL query**. "
+                "Result columns are matched to check names by metric key.",
+                "",
+            ]
+            for group in aggregate_groups:
+                lines += self._render_aggregate_group(group)
+
+        if row_level_groups:
+            lines += [
+                "---",
+                "",
+                "## Row-Level Checks",
+                "",
+                "> Row-level checks fetch **raw column values** and apply 
Python-side "
+                "validation per row. The threshold controls the maximum 
allowed fraction "
+                "of invalid rows before the check fails.",
+                "",
+            ]
+            for group in row_level_groups:
+                lines += self._render_row_level_group(group)
+
+        return "\n".join(lines).rstrip()
+
+    def _render_aggregate_group(self, group: DQCheckGroup) -> list[str]:
+        """Render one aggregate SQL group as a markdown subsection."""
+        lines: list[str] = [
+            f"### `{group.group_id}`",
+            "",
+            "| Check name | Metric key | Category |",
+            "|---|---|---|",
+        ]
+        for check in group.checks:
+            category = check.check_category or "—"
+            lines.append(f"| `{check.check_name}` | `{check.metric_key}` | 
{category} |")
+
+        lines += [
+            "",
+            "```sql",
+            group.query.strip(),
+            "```",
+            "",
+        ]
+
+        # Unexpected queries — only show when present.
+        unexpected = [(c.check_name, c.unexpected_query) for c in group.checks 
if c.unexpected_query]
+        if unexpected:
+            lines += ["<details><summary>Unexpected-row queries</summary>", ""]
+            for check_name, uq in unexpected:
+                lines += [
+                    f"**`{check_name}`**",
+                    "",
+                    "```sql",
+                    (uq or "").strip(),
+                    "```",
+                    "",
+                ]
+            lines += ["</details>", ""]
+
+        return lines
+
+    def _render_row_level_group(self, group: DQCheckGroup) -> list[str]:
+        """Render one row-level group as a markdown subsection with threshold 
info."""
+        lines: list[str] = [
+            f"### `{group.group_id}`",
+            "",
+            "| Check name | Metric key | Max invalid % |",
+            "|---|---|---|",
+        ]
+        for check in group.checks:
+            validator = self.validators.get(check.check_name)
+            max_pct = getattr(validator, "_max_invalid_pct", None)
+            threshold_str = f"{max_pct:.2%}" if max_pct is not None else "—"
+            lines.append(f"| `{check.check_name}` | `{check.metric_key}` | 
{threshold_str} |")
+
+        lines += [
+            "",
+            "```sql",
+            group.query.strip(),
+            "```",
+            "",
+        ]
+        return lines
+
+    def _load_or_generate_plan(self, planner: SQLDQPlanner, schema_ctx: str) 
-> DQPlan:
+        """Return a cached plan when available, otherwise generate and cache a 
new one."""
+        if not isinstance(self.prompts, dict):
+            raise TypeError("prompts must be a dict[str, str] before 
generating a DQ plan.")
+
+        plan_hash = _compute_plan_hash(self.prompts, self.prompt_version, 
self.collect_unexpected)
+        variable_key = f"{_PLAN_VARIABLE_PREFIX}{plan_hash}"
+
+        cached_json = Variable.get(variable_key, default=None)
+        if cached_json is not None:
+            self.log.info("DQ plan cache hit — key: %r", variable_key)
+            plan = DQPlan.model_validate_json(cached_json)
+            if not plan.plan_hash:
+                plan.plan_hash = plan_hash
+            return plan
+
+        self.log.info("DQ plan cache miss — generating via LLM (key: %r).", 
variable_key)
+        plan = planner.generate_plan(self.prompts, schema_ctx)
+        plan.plan_hash = plan_hash
+        Variable.set(variable_key, plan.model_dump_json())
+        return plan
+
+    def _validate_results(
+        self,
+        results_map: dict[str, Any],
+        plan: DQPlan,
+    ) -> list[DQCheckResult]:
+        """
+        Apply validators to each metric value and return per-check results.
+
+        For aggregate checks each validator callable receives the raw metric
+        value returned by the database.  For row-level checks, where *value* is
+        a :class:`~airflow.providers.common.ai.utils.dq_models.RowLevelResult`,
+        the pass/fail decision compares ``invalid_pct`` against the validator's
+        ``_max_invalid_pct`` attribute (defaulting to ``0.0`` when absent).
+        Checks without a registered validator are logged and marked as passed.
+
+        :param results_map: ``{check_name: metric_value_or_RowLevelResult}`` as
+            returned by
+            
:meth:`~airflow.providers.common.ai.utils.dq_planner.SQLDQPlanner.execute_plan`.
+        :param plan: The DQ plan whose groups and checks drive iteration order.
+        :returns: Per-check
+            :class:`~airflow.providers.common.ai.utils.dq_models.DQCheckResult`
+            list in plan-group order.
+        """
+        check_results: list[DQCheckResult] = []
+
+        for group in plan.groups:
+            for check in group.checks:
+                value = results_map[check.check_name]
+                validator = self.validators.get(check.check_name)
+
+                passed = True
+                failure_reason: str | None = None
+
+                if isinstance(value, RowLevelResult):
+                    # Row-level check: evaluate threshold against invalid_pct.
+                    max_pct = getattr(validator, "_max_invalid_pct", 0.0)
+                    passed = value.invalid_pct <= max_pct
+                    if not passed:
+                        failure_reason = (
+                            f"Row-level check failed: 
{value.invalid}/{value.total} rows invalid "
+                            f"({value.invalid_pct:.4%}), threshold 
{max_pct:.4%}"
+                        )

Review Comment:
   Row-level checks are always evaluated against `max_pct = getattr(validator, 
"_max_invalid_pct", 0.0)` even when `validator` is `None`. This makes “no 
validator provided ⇒ check passes by default” true for aggregate checks but 
**not** for row-level checks (a missing validator effectively enforces a 0% 
invalid threshold). Consider explicitly handling `validator is None` in the 
`RowLevelResult` branch the same way as aggregate checks (warn + mark passed), 
and/or raising earlier if a row-level result is returned for a check that has 
no corresponding validator callable.



##########
providers/common/ai/tests/unit/common/ai/utils/test_dq_planner.py:
##########
@@ -0,0 +1,1174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+from airflow.providers.common.ai.utils.dq_models import DQCheck, DQCheckGroup, 
DQPlan
+from airflow.providers.common.ai.utils.dq_planner import SQLDQPlanner
+from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
+
+
+def _make_plan(*check_names: str) -> DQPlan:
+    """Helper: build a minimal DQPlan with one group per check."""
+    groups = [
+        DQCheckGroup(
+            group_id="numeric_aggregate",
+            query=f"SELECT COUNT(*) AS {name}_count FROM t",
+            checks=[DQCheck(check_name=name, metric_key=f"{name}_count", 
group_id="numeric_aggregate")],
+        )
+        for name in check_names
+    ]
+    return DQPlan(groups=groups)
+
+
+def _make_llm_hook(plan: DQPlan) -> MagicMock:
+    """Helper: mock PydanticAIHook that returns *plan* from agent.run_sync."""
+    mock_usage = MagicMock(requests=1, tool_calls=0, input_tokens=100, 
output_tokens=50, total_tokens=150)
+    mock_result = MagicMock(spec=["output", "all_messages", "usage", 
"response"])
+    mock_result.output = plan
+    mock_result.all_messages.return_value = []
+    mock_result.usage.return_value = mock_usage
+    mock_result.response.model_name = "test-model"
+    mock_agent = MagicMock(spec=["run_sync"])
+    mock_agent.run_sync.return_value = mock_result
+    mock_hook = MagicMock(spec=PydanticAIHook)
+    mock_hook.create_agent.return_value = mock_agent
+    return mock_hook
+
+
+class TestSQLDQPlannerBuildSchema:
+    def test_returns_manual_schema_context_verbatim(self):
+        planner = SQLDQPlanner(llm_hook=MagicMock(spec=PydanticAIHook), 
db_hook=None)
+        result = planner.build_schema_context(
+            table_names=None,
+            schema_context="Table: t\nColumns: id INT",
+        )
+        assert result == "Table: t\nColumns: id INT"
+
+    def test_introspects_via_db_hook_when_no_manual_context(self):
+        mock_db_hook = MagicMock()
+        mock_db_hook.get_table_schema.return_value = [{"name": "id", "type": 
"INT"}]
+
+        planner = SQLDQPlanner(llm_hook=MagicMock(spec=PydanticAIHook), 
db_hook=mock_db_hook)
+        result = planner.build_schema_context(

Review Comment:
   These tests construct several `MagicMock()` instances without 
`spec`/`autospec` (e.g. `mock_db_hook = MagicMock()` here). Unspec’d mocks can 
hide real interface mismatches and make refactors risky. Prefer 
`MagicMock(spec=...)` (or `patch(..., autospec=True)`) for hooks/agents/cursors 
so attribute mistakes fail fast.



##########
providers/common/ai/src/airflow/providers/common/ai/utils/dq_planner.py:
##########
@@ -0,0 +1,836 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+SQL-based data-quality plan generation and execution.
+
+:class:`SQLDQPlanner` is the single entry-point for all SQL DQ logic.
+It is deliberately kept separate from the operator so it can be unit-tested
+without an Airflow context and later swapped for GEX/SODA planners without
+touching the operator.
+"""
+
+from __future__ import annotations
+
+import logging
+from collections.abc import Iterator, Sequence
+from contextlib import closing
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.ai.utils.sql_validation import (
+        DEFAULT_ALLOWED_TYPES,
+        SQLSafetyError,
+        validate_sql as _validate_sql,
+    )
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from airflow.providers.common.ai.utils.db_schema import build_schema_context, 
resolve_dialect
+from airflow.providers.common.ai.utils.dq_models import DQCheckGroup, DQPlan, 
RowLevelResult, UnexpectedResult
+from airflow.providers.common.ai.utils.logging import log_run_summary
+
+if TYPE_CHECKING:
+    from pydantic_ai import Agent
+    from pydantic_ai.messages import ModelMessage
+
+    from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+    from airflow.providers.common.sql.config import DataSourceConfig
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+log = logging.getLogger(__name__)
+
+_MAX_CHECKS_PER_GROUP = 5
+# Maximum rows fetched from DB per chunk during row-level processing — avoids 
loading the
+# entire result set into memory at once.
+_ROW_LEVEL_CHUNK_SIZE = 10_000
+# Hard cap on violation samples stored per check — independent of SQL LIMIT 
and chunk size.
+_MAX_VIOLATION_SAMPLES = 100
+
+_PLANNING_SYSTEM_PROMPT = """\
+You are a data-quality SQL expert.
+
+Given a set of named data-quality checks and a database schema, produce a \
+DQPlan that minimises the number of SQL queries while keeping each group \
+focused and manageable.
+
+GROUPING STRATEGY (multi-dimensional):
+  Group checks by **(target_table, check_category)**.  Checks on the same table
+  that belong to different categories MUST be in separate groups.
+
+  Allowed check_category values (assign one per check based on its 
description):
+    - null_check      — null / missing value counts or percentages
+    - uniqueness      — duplicate detection, cardinality checks
+    - validity        — regex / format / pattern matching on string columns
+    - numeric_range   — range, bounds, or statistical checks on numeric columns
+    - row_count       — total row counts or existence checks
+    - string_format   — length, encoding, whitespace, or character-set checks
+    - row_level       — per-row or anomaly checks that evaluate individual 
records
+
+  Row-level checks still follow the same grouping rule: group by 
(target_table, check_category="row_level").
+  MAX {max_checks_per_group} CHECKS PER GROUP:
+    If a (table, category) pair has more than {max_checks_per_group} checks,
+    split them into sub-groups of at most {max_checks_per_group}.
+
+  GROUP-ID NAMING:
+    Use the pattern "{{table}}_{{category}}_{{part}}".
+    Examples: customers_null_check_1, orders_validity_1, orders_validity_2
+
+  RATIONALE:
+    Keeping string-column checks (validity, string_format) apart from
+    numeric-column checks (numeric_range, null_check on numbers) produces
+    simpler SQL and makes failures easier to diagnose.
+
+  CORRECT (two groups for same table, different categories):
+    Group customers_null_check_1:
+      SELECT
+        (COUNT(CASE WHEN email IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS 
null_email_pct,
+        (COUNT(CASE WHEN name IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS 
null_name_pct
+      FROM customers
+
+    Group customers_validity_1:
+      SELECT
+        COUNT(CASE WHEN phone NOT LIKE '+___-___-____' THEN 1 END) AS 
invalid_phone_fmt
+      FROM customers
+
+  WRONG (mixing null-check and regex-validity in one group):
+    SELECT
+      (COUNT(CASE WHEN email IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS 
null_email_pct,
+      COUNT(CASE WHEN phone NOT LIKE '+___-___-____' THEN 1 END) AS 
invalid_phone_fmt
+    FROM customers
+
+OUTPUT RULES:
+  1. Each output column must be aliased to exactly the metric_key of its check.
+     Example: ... AS null_email_pct
+  2. Each check_name must exactly match the key in the prompts dict.
+  3. metric_key values must be valid SQL column aliases (snake_case, no 
spaces).
+  4. Generates only SELECT queries — no INSERT, UPDATE, DELETE, DROP, or DDL.
+  5. Use {dialect} syntax.
+  6. Each check must appear in exactly ONE group.
+  7. Each check must have a check_category from the allowed list above.
+  8. Return a valid DQPlan object. No extra commentary.
+"""
+
+_DATAFUSION_SYNTAX_SECTION = """\
+
+DATAFUSION SQL SYNTAX RULES:
+  The target engine is Apache DataFusion.  Observe these syntax differences
+  from standard PostgreSQL / ANSI SQL:
+
+  1. NO "FILTER (WHERE ...)" clause.  Use CASE expressions instead:
+       WRONG:  COUNT(*) FILTER (WHERE email IS NULL)
+       RIGHT:  COUNT(CASE WHEN email IS NULL THEN 1 END)
+
+  2. Regex matching uses the tilde operator:
+       column ~ 'pattern'    (match)
+       column !~ 'pattern'   (no match)
+     Do NOT use SIMILAR TO or POSIX-style ~* (case-insensitive).
+
+  3. CAST syntax — prefer CAST(expr AS type) over :: shorthand.
+
+  4. String functions: Use CHAR_LENGTH (not LEN), SUBSTR (not SUBSTRING with 
FROM/FOR).
+
+  5. Integer division: DataFusion performs integer division for INT/INT.
+     Use CAST(expr AS DOUBLE) to force floating-point division.
+
+  6. Boolean literals: Use TRUE / FALSE (not 1 / 0).
+
+  7. LIMIT is supported.  OFFSET is supported.  FETCH FIRST is NOT supported.
+
+  8. NULL handling: COALESCE, NULLIF, IFNULL are all supported.
+     NVL and ISNULL are NOT supported.
+"""
+
+_UNEXPECTED_QUERY_PROMPT_SECTION = """\
+
+UNEXPECTED VALUE COLLECTION:
+  For checks whose check_category is "validity" or "string_format", also
+  generate an unexpected_query field on the DQCheck.  This query must:
+    - SELECT the primary key column(s) and the column(s) being validated
+    - WHERE the row violates the check condition (the negation of the check)
+    - LIMIT {sample_size}
+    - Use {dialect} syntax
+    - Be a standalone SELECT (not a subquery of the group query)
+
+  For all other categories (null_check, uniqueness, numeric_range, row_count),
+  set unexpected_query to null — these are aggregate checks where individual
+  violating rows are not meaningful.
+
+  Example for a phone-format validity check:
+    unexpected_query: "SELECT id, phone FROM customers WHERE phone !~ 
'^\\d{{4}}-\\d{{4}}-\\d{{4}}$' LIMIT 100"
+"""
+
+_ROW_LEVEL_PROMPT_SECTION = """
+
+ROW-LEVEL CHECKS:
+  Some checks are marked as row_level.  For these:
+    - Generate a SELECT that returns the primary key column(s) and the column
+      being validated.  Do NOT aggregate.
+    - Set row_level = true on the DQCheck entry.
+    - metric_key must be the name of the column containing the value to 
validate
+      (the Python validator will read row[metric_key] for each row).
+    - {row_level_limit_clause}
+    - Place ALL row-level checks for the same table in a single group.
+
+  Row-level check names that require this treatment: {row_level_check_names}
+"""
+
+
+class SQLDQPlanner:
+    """
+    Generates and executes a SQL-based 
:class:`~airflow.providers.common.ai.utils.dq_models.DQPlan`.
+
+    :param llm_hook: Hook used to call the LLM for plan generation.
+    :param db_hook: Hook used to execute generated SQL against the database.
+    :param dialect: SQL dialect forwarded to the LLM prompt and 
``validate_sql``.
+        Auto-detected from *db_hook* when ``None``.
+    :param max_sql_retries: Maximum number of times a failing SQL group query 
is sent
+        back to the LLM for correction before the error is re-raised.  Default 
``2``.
+    :param validator_contexts: Pre-built LLM context string from
+        
:meth:`~airflow.providers.common.ai.utils.dq_validation.ValidatorRegistry.build_llm_context`.
+        Appended to the system prompt so the LLM knows what metric format each
+        custom validator expects.
+    :param row_validators: Mapping of ``{check_name: row_level_callable}`` for
+        checks that require row-by-row Python validation.  When a check's name
+        appears here, ``execute_plan`` fetches all (or sampled) rows and 
applies
+        the callable to each value instead of reading a single aggregate 
scalar.
+    :param row_level_sample_size: Maximum number of rows to fetch for row-level
+        checks.  ``None`` (default) performs a full scan.  A positive integer
+        instructs the LLM to add ``LIMIT N`` to the generated SELECT.
+    """
+
+    def __init__(
+        self,
+        *,
+        llm_hook: PydanticAIHook,
+        db_hook: DbApiHook | None,
+        dialect: str | None = None,
+        max_sql_retries: int = 2,
+        datasource_config: DataSourceConfig | None = None,
+        system_prompt: str = "",
+        agent_params: dict[str, Any] | None = None,
+        collect_unexpected: bool = False,
+        unexpected_sample_size: int = 100,
+        validator_contexts: str = "",
+        row_validators: dict[str, Any] | None = None,
+        row_level_sample_size: int | None = None,
+    ) -> None:
+        self._llm_hook = llm_hook
+        self._db_hook = db_hook
+        self._datasource_config = datasource_config
+        self._dialect = resolve_dialect(db_hook, dialect)
+        # Track whether the execution target is DataFusion so the prompt can
+        # include DataFusion-specific syntax rules.  The dialect stays None
+        # (generic SQL) for sqlglot validation — sqlglot has no DataFusion 
dialect.
+        self._is_datafusion = db_hook is None and datasource_config is not None
+        # When targeting DataFusion, use PostgreSQL dialect for sqlglot 
validation
+        # because DataFusion shares regex operators (~, !~) that the generic 
SQL
+        # parser does not recognise.
+        self._validation_dialect: str | None = "postgres" if 
self._is_datafusion else self._dialect
+        self._max_sql_retries = max_sql_retries
+        self._extra_system_prompt = system_prompt
+        self._agent_params: dict[str, Any] = agent_params or {}
+        self._collect_unexpected = collect_unexpected
+        self._unexpected_sample_size = unexpected_sample_size
+        self._validator_contexts = validator_contexts
+        self._row_validators: dict[str, Any] = row_validators or {}
+        self._row_level_sample_size = row_level_sample_size
+        # Populated by generate_plan; used by _retry_fix_group to continue the 
conversation.
+        self._plan_agent: Agent[None, DQPlan] | None = None
+        self._plan_all_messages: list[ModelMessage] | None = None
+
+    def build_schema_context(
+        self,
+        table_names: list[str] | None,
+        schema_context: str | None,
+    ) -> str:
+        """
+        Return a schema description string for inclusion in the LLM prompt.
+
+        Delegates to 
:func:`~airflow.providers.common.ai.utils.db_schema.build_schema_context`.
+        """
+        return build_schema_context(
+            db_hook=self._db_hook,
+            table_names=table_names,
+            schema_context=schema_context,
+            datasource_config=self._datasource_config,
+        )
+
+    def generate_plan(self, prompts: dict[str, str], schema_context: str) -> 
DQPlan:
+        """
+        Ask the LLM to produce a 
:class:`~airflow.providers.common.ai.utils.dq_models.DQPlan`.
+
+        The LLM receives the user prompts, schema context, and planning 
instructions
+        as a structured-output call (``output_type=DQPlan``).  After 
generation the
+        method verifies that the returned ``check_names`` exactly match
+        ``prompts.keys()``.
+
+        :param prompts: ``{check_name: natural_language_description}`` dict.
+        :param schema_context: Schema description previously built via
+            :meth:`build_schema_context`.
+        :raises ValueError: If the LLM's plan does not cover every prompt key
+            exactly once.
+        """
+        dialect_label = self._dialect or ("DataFusion-compatible SQL" if 
self._is_datafusion else "SQL")
+        system_prompt = _PLANNING_SYSTEM_PROMPT.format(
+            dialect=dialect_label, max_checks_per_group=_MAX_CHECKS_PER_GROUP
+        )
+
+        if self._is_datafusion:
+            system_prompt += _DATAFUSION_SYNTAX_SECTION
+
+        if self._collect_unexpected:
+            system_prompt += _UNEXPECTED_QUERY_PROMPT_SECTION.format(
+                dialect=dialect_label, sample_size=self._unexpected_sample_size
+            )
+
+        if schema_context:
+            system_prompt += f"\nAvailable schema:\n{schema_context}\n"
+
+        if self._validator_contexts:
+            system_prompt += self._validator_contexts
+
+        if self._row_validators:
+            row_level_check_names = ", ".join(sorted(self._row_validators))
+            if self._row_level_sample_size is not None:
+                limit_clause = f"Add LIMIT {self._row_level_sample_size} to 
the query."
+            else:
+                limit_clause = "Do NOT add a LIMIT — return all rows."
+            system_prompt += _ROW_LEVEL_PROMPT_SECTION.format(
+                row_level_check_names=row_level_check_names,
+                row_level_limit_clause=limit_clause,
+            )
+
+        if self._extra_system_prompt:
+            system_prompt += f"\nAdditional 
instructions:\n{self._extra_system_prompt}\n"
+
+        user_message = self._build_user_message(prompts)
+
+        log.info("Using system prompt:\n%s", system_prompt)
+        log.info("Using user message:\n%s", user_message)
+
+        agent = self._llm_hook.create_agent(
+            output_type=DQPlan, instructions=system_prompt, 
**self._agent_params
+        )
+        result = agent.run_sync(user_message)
+        log_run_summary(log, result)
+
+        # Persist the agent and full conversation so execute_plan can continue
+        # the same chat thread when asking for SQL corrections.
+        self._plan_agent = agent
+        self._plan_all_messages = result.all_messages()
+
+        plan: DQPlan = result.output
+
+        self._validate_plan_coverage(plan, prompts)
+        self._validate_group_sizes(plan)
+        return plan
+
+    def execute_plan(self, plan: DQPlan) -> dict[str, Any]:
+        """
+        Execute every SQL group in *plan* and return a flat ``{check_name: 
value}`` map.
+
+        Each group's query is safety-validated via
+        :func:`~airflow.providers.common.ai.utils.sql_validation.validate_sql` 
before
+        execution.  The first row of each result-set is used; each column 
corresponds
+        to the ``metric_key`` of one 
:class:`~airflow.providers.common.ai.utils.dq_models.DQCheck`.
+
+        :param plan: Plan produced by :meth:`generate_plan`.
+        :raises ValueError: If neither *db_hook* nor *datasource_config* was 
supplied.
+        :raises SQLSafetyError: If a generated query fails AST validation even 
after
+            ``max_sql_retries`` LLM correction attempts.
+        :raises ValueError: If a query result does not contain an expected 
metric column.
+        """
+        if self._db_hook is None and self._datasource_config is None:
+            raise ValueError("Either db_conn_id or datasource_config is 
required to execute the DQ plan.")
+
+        datafusion_engine: DataFusionEngine | None = None
+        if self._db_hook is None:
+            datafusion_engine = self._build_datafusion_engine()
+
+        results: dict[str, Any] = {}
+
+        for raw_group in plan.groups:
+            group = self._validate_or_fix_group(raw_group)
+            log.debug("Executing DQ group %r:\n%s", group.group_id, 
group.query)
+
+            # Row-level checks and aggregate checks are mutually exclusive 
within a group
+            # because the LLM places them in separate groups based on the 
system prompt.
+            if any(check.row_level for check in group.checks):
+                row_level_results = self._execute_row_level_group(group, 
datafusion_engine)
+                results.update(row_level_results)
+                continue

Review Comment:
   `execute_plan()` treats a group as row-level if **any** check has 
`row_level=True`, then skips executing the group SQL entirely. If the LLM ever 
returns a mixed group (some `row_level`, some aggregate) or marks a check 
`row_level=True` without a corresponding entry in `self._row_validators`, the 
aggregate checks will be silently dropped and the final `results` map may miss 
prompt keys (leading to downstream `KeyError`s). Consider validating that 
groups are homogeneous (all row-level vs none), and raising a clear 
`ValueError` when a row-level group has zero runnable row validators rather 
than continuing.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to