Copilot commented on code in PR #62963:
URL: https://github.com/apache/airflow/pull/62963#discussion_r3114032278
##########
providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py:
##########
@@ -119,7 +120,25 @@ def execute_query(self, query: str, max_rows: int | None =
None) -> dict[str, li
return result
return df.to_pydict()
except Exception as e:
- raise QueryExecutionException(f"Error while executing query: {e}")
+ raise QueryExecutionException(f"Error while executing query: {e}")
from e
+
+ def iter_query_row_chunks(self, query: str) -> Iterator[dict[str,
list[Any]]]:
+ """
+ Execute *query* and yield one column-dict per RecordBatch, streaming
results.
+
+ :param query: SQL SELECT query to execute.
+ :raises QueryExecutionException: On SQL execution errors.
+ """
+ try:
+ df = self.session_context.sql(query)
+ if hasattr(df, "execute_stream"):
+ for batch in df.execute_stream():
+ yield batch.to_pyarrow().to_pydict()
+ else:
+ for batch in df.collect():
+ yield batch.to_pyarrow().to_pydict()
+ except Exception as e:
+ raise QueryExecutionException(f"Error while executing query: {e}")
from e
Review Comment:
`iter_query_row_chunks()` is new public behavior on `DataFusionEngine` and
is now used for row-level DQ execution, but there are no unit tests covering
its streaming/collect fallback behavior or the yielded payload shape. Since
this module already has comprehensive unit tests, please add tests for
`iter_query_row_chunks()` (including the `execute_stream` path and the
`collect` fallback).
##########
providers/common/ai/docs/operators/llm_data_quality.rst:
##########
@@ -0,0 +1,488 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/operator:llm_data_quality:
+
+``LLMDataQualityOperator``
+==========================
+
+Use
:class:`~airflow.providers.common.ai.operators.llm_data_quality.LLMDataQualityOperator`
+to generate and execute data-quality checks from natural language using an LLM.
+
+Each entry in ``checks`` describes **one** data-quality expectation as a
+:class:`~airflow.providers.common.ai.utils.dq_models.DQCheckInput` object.
+The LLM groups related checks into optimised SQL queries, executes them
against the
+target database, and validates each metric. The task fails if any check does
not
+pass, gating downstream tasks on data quality.
+
+.. seealso::
+ :ref:`Connection configuration <howto/connection:pydanticai>`
+
+Basic Usage
+-----------
+
+Provide a ``checks`` list and a target ``db_conn_id``. The operator
introspects the
+schema automatically when ``table_names`` is provided:
+
+.. code-block:: python
+
+ from airflow.providers.common.ai.operators.llm_data_quality import
LLMDataQualityOperator
+ from airflow.providers.common.ai.utils.dq_models import DQCheckInput
+ from airflow.providers.common.ai.utils.dq_validation import
null_pct_check, row_count_check
+
+ LLMDataQualityOperator(
+ task_id="validate_orders",
+ llm_conn_id="pydanticai_default",
+ db_conn_id="postgres_default",
+ table_names=["orders", "customers"],
+ checks=[
+ DQCheckInput(
+ name="row_count",
+ description="The orders table must contain at least 1000
rows.",
+ validator=row_count_check(min_count=1000),
+ ),
+ DQCheckInput(
+ name="email_nulls",
+ description="No more than 5% of customer email addresses
should be null.",
+ validator=null_pct_check(max_pct=0.05),
+ ),
+ ],
+ )
+
+Validators
+----------
+
+Each :class:`~airflow.providers.common.ai.utils.dq_models.DQCheckInput` can
carry
+an optional ``validator`` — a callable that receives the raw metric value
returned
+by the generated SQL and returns ``True`` (pass) or ``False`` (fail).
+
+Built-in Factories
+~~~~~~~~~~~~~~~~~~
+
+:mod:`~airflow.providers.common.ai.utils.dq_validation` ships ready-made
factories
+for the most common thresholds:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 30 50 20
+
+ * - Factory
+ - Passes when …
+ - Example
+ * - ``null_pct_check(max_pct=…)``
+ - metric ≤ ``max_pct``
+ - ``null_pct_check(max_pct=0.05)``
+ * - ``row_count_check(min_count=…)``
+ - metric ≥ ``min_count``
+ - ``row_count_check(min_count=1000)``
+ * - ``duplicate_pct_check(max_pct=…)``
+ - metric ≤ ``max_pct``
+ - ``duplicate_pct_check(max_pct=0.0)``
+ * - ``between_check(min_val=…, max_val=…)``
+ - ``min_val`` ≤ metric ≤ ``max_val``
+ - ``between_check(min_val=0.0, max_val=1.0)``
+ * - ``exact_check(expected=…)``
+ - metric == ``expected``
+ - ``exact_check(expected=0)``
+
+You can also use plain lambdas for one-off conditions::
+
+ DQCheckInput(
+ name="stale_rows",
+ description="Count rows older than 30 days",
+ validator=lambda v: int(v) < 1000,
+ )
+
+Aggregate checks without a validator are marked as **passed** — metrics are
+still collected and included in the report, but no threshold is enforced.
+Row-level checks are stricter: each row-level check must have a corresponding
+row-level validator, otherwise execution fails fast.
+
+Custom Validators with ``register_validator``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Use the
:func:`~airflow.providers.common.ai.utils.dq_validation.register_validator`
+decorator to attach an ``llm_context`` hint to your validator factory. The
+operator injects the hint into the LLM system prompt so the model generates SQL
+that returns the metric format your validator expects:
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py
+ :language: python
+ :start-after: [START howto_operator_llm_dq_custom_validator]
+ :end-before: [END howto_operator_llm_dq_custom_validator]
+
Review Comment:
The `exampleinclude` directives reference snippet markers (e.g.
`howto_operator_llm_dq_custom_validator`, `howto_operator_llm_dq_s3_parquet`,
`howto_operator_llm_dq_custom_row_level_validator`, etc.) that do not exist in
`example_llm_data_quality.py` (which defines different `[START ...]` tags).
This will break the docs build; update the `:start-after:`/`:end-before:` tags
(or the example DAG markers) so they match.
```suggestion
.. code-block:: python
@register_validator(llm_context="SQL must return a single percentage
between 0 and 1.")
def max_pct(max_pct: float):
return lambda value: float(value) <= max_pct
checks = [
DQCheckInput(
name="null_pct_check",
description="Percent of rows with null values",
validator=max_pct(0.05),
),
]
```
##########
providers/common/ai/src/airflow/providers/common/ai/utils/dq_planner.py:
##########
@@ -0,0 +1,1122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+SQL-based data-quality plan generation and execution.
+
+:class:`SQLDQPlanner` is the single entry-point for all SQL DQ logic.
+It is deliberately kept separate from the operator so it can be unit-tested
+without an Airflow context and later swapped for GEX/SODA planners without
+touching the operator.
+"""
+
+from __future__ import annotations
+
+import hashlib
+import logging
+import re
+from collections import Counter
+from collections.abc import Iterator, Sequence
+from contextlib import closing
+from typing import TYPE_CHECKING, Any
+
+try:
+ from airflow.providers.common.ai.utils.sql_validation import (
+ DEFAULT_ALLOWED_TYPES,
+ SQLSafetyError,
+ validate_sql as _validate_sql,
+ )
+except ImportError as e:
+ from airflow.providers.common.compat.sdk import
AirflowOptionalProviderFeatureException
+
+ raise AirflowOptionalProviderFeatureException(e)
+
+from airflow.providers.common.ai.utils.db_schema import build_schema_context,
resolve_dialect
+from airflow.providers.common.ai.utils.dq_models import (
+ DQCheckGroup,
+ DQCheckInput,
+ DQPlan,
+ RowLevelResult,
+ UnexpectedResult,
+)
+from airflow.providers.common.ai.utils.dq_validation import DQValidationToolset
+from airflow.providers.common.ai.utils.logging import log_run_summary
+
+if TYPE_CHECKING:
+ from pydantic_ai import Agent
+ from pydantic_ai.messages import ModelMessage
+
+ from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+ from airflow.providers.common.sql.config import DataSourceConfig
+ from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+log = logging.getLogger(__name__)
+
+_MAX_CHECKS_PER_GROUP = 5
+# Maximum rows fetched from DB per chunk during row-level processing — avoids
loading the
+# entire result set into memory at once.
+_ROW_LEVEL_CHUNK_SIZE = 10_000
+# Hard cap on violation samples stored per check — independent of SQL LIMIT
and chunk size.
+_MAX_VIOLATION_SAMPLES = 100
+
+_PLANNING_SYSTEM_PROMPT = """\
+You are a data-quality SQL expert.
+
+Given a set of named data-quality checks and a database schema, produce a \
+DQPlan that minimises the number of SQL queries while keeping each group \
+focused and manageable.
+
+GROUPING STRATEGY (multi-dimensional):
+ Group checks by **(target_table, check_category)**. Checks on the same table
+ that belong to different categories MUST be in separate groups.
+
+ Allowed check_category values (assign one per check based on its
description):
+ - null_check — null / missing value counts or percentages
+ - uniqueness — duplicate detection, cardinality checks
+ - validity — regex / format / pattern matching on string columns
+ - numeric_range — range, bounds, or statistical checks on numeric columns
+ - row_count — total row counts or existence checks
+ - string_format — length, encoding, whitespace, or character-set checks
+ - row_level — per-row or anomaly checks that evaluate individual
records
+
+ Row-level checks still follow the same grouping rule: group by
(target_table, check_category="row_level").
+ MAX {max_checks_per_group} CHECKS PER GROUP:
+ If a (table, category) pair has more than {max_checks_per_group} checks,
+ split them into sub-groups of at most {max_checks_per_group}.
+
+ GROUP-ID NAMING:
+ Use the pattern "{{table}}_{{category}}_{{part}}".
+ Examples: customers_null_check_1, orders_validity_1, orders_validity_2
+
+ RATIONALE:
+ Keeping string-column checks (validity, string_format) apart from
+ numeric-column checks (numeric_range, null_check on numbers) produces
+ simpler SQL and makes failures easier to diagnose.
+
+ CORRECT (two groups for same table, different categories):
+ Group customers_null_check_1:
+ SELECT
+ (COUNT(CASE WHEN email IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS
null_email_pct,
+ (COUNT(CASE WHEN name IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS
null_name_pct
+ FROM customers
+
+ Group customers_validity_1:
+ SELECT
+ COUNT(CASE WHEN phone NOT LIKE '+___-___-____' THEN 1 END) AS
invalid_phone_fmt
+ FROM customers
+
+ WRONG (mixing null-check and regex-validity in one group):
+ SELECT
+ (COUNT(CASE WHEN email IS NULL THEN 1 END) * 100.0 / COUNT(*)) AS
null_email_pct,
+ COUNT(CASE WHEN phone NOT LIKE '+___-___-____' THEN 1 END) AS
invalid_phone_fmt
+ FROM customers
+
+OUTPUT RULES:
+ 1. Each output column must be aliased to exactly the metric_key of its check.
+ Example: ... AS null_email_pct
+ 2. Each check_name must exactly match the name in the provided checks list.
+ 3. metric_key values must be valid SQL column aliases (snake_case, no
spaces).
+ 4. Generates only SELECT queries — no INSERT, UPDATE, DELETE, DROP, or DDL.
+ 5. Use {dialect} syntax.
+ 6. Each check must appear in exactly ONE group.
+ 7. Each check must have a check_category from the allowed list above.
+ 8. Return a valid DQPlan object. No extra commentary.
+"""
+
+_DATAFUSION_SYNTAX_SECTION = """\
+
+DATAFUSION SQL SYNTAX RULES:
+ The target engine is Apache DataFusion. Observe these syntax differences
+ from standard PostgreSQL / ANSI SQL:
+
+ 1. NO "FILTER (WHERE ...)" clause. Use CASE expressions instead:
+ WRONG: COUNT(*) FILTER (WHERE email IS NULL)
+ RIGHT: COUNT(CASE WHEN email IS NULL THEN 1 END)
+
+ 2. Regex matching uses the tilde operator:
+ column ~ 'pattern' (match)
+ column !~ 'pattern' (no match)
+ Do NOT use SIMILAR TO or POSIX-style ~* (case-insensitive).
+
+ 3. CAST syntax — prefer CAST(expr AS type) over :: shorthand.
+
+ 4. String functions: Use CHAR_LENGTH (not LEN), SUBSTR (not SUBSTRING with
FROM/FOR).
+
+ 5. Integer division: DataFusion performs integer division for INT/INT.
+ Use CAST(expr AS DOUBLE) to force floating-point division.
+
+ 6. Boolean literals: Use TRUE / FALSE (not 1 / 0).
+
+ 7. LIMIT is supported. OFFSET is supported. FETCH FIRST is NOT supported.
+
+ 8. NULL handling: COALESCE, NULLIF, IFNULL are all supported.
+ NVL and ISNULL are NOT supported.
+"""
+
+_UNEXPECTED_QUERY_PROMPT_SECTION = """\
+
+UNEXPECTED VALUE COLLECTION:
+ For checks whose check_category is "validity" or "string_format", also
+ generate an unexpected_query field on the DQCheck. This query must:
+ - SELECT the primary key column(s) and the column(s) being validated
+ - WHERE the row violates the check condition (the negation of the check)
+ - LIMIT {sample_size}
+ - Use {dialect} syntax
+ - Be a standalone SELECT (not a subquery of the group query)
+
+ For all other categories (null_check, uniqueness, numeric_range, row_count),
+ set unexpected_query to null — these are aggregate checks where individual
+ violating rows are not meaningful.
+
+ Example for a phone-format validity check:
+ unexpected_query: "SELECT id, phone FROM customers WHERE phone !~
'^\\d{{4}}-\\d{{4}}-\\d{{4}}$' LIMIT 100"
+"""
+
+_ROW_LEVEL_PROMPT_SECTION = """
+
+ROW-LEVEL CHECKS:
+ Some checks are marked as row_level. For these:
+ - Generate a SELECT that returns the primary key column(s) and the column
+ being validated. Do NOT aggregate.
+ - Set row_level = true on the DQCheck entry.
+ - metric_key must be the name of the column containing the value to
validate
+ (the Python validator will read row[metric_key] for each row).
+ - {row_level_limit_clause}
+ - Place ALL row-level checks for the same table in a single group.
+
+ Row-level check names that require this treatment: {row_level_check_names}
+"""
+
+_TABLE_NAME_CONSTRAINT_SECTION = """
+TABLE NAME CONSTRAINT:
+ The ONLY table(s) you may reference in FROM clauses are: {table_names}.
+ You MUST use these exact table names in every SQL query you generate.
+ Do NOT rename, abbreviate, alias the table, or invent new table names.
+ Using any table name not listed above will cause a runtime error.
+"""
+
+
+def _extract_table_names(schema_context: str) -> list[str]:
+ """Extract table names from a schema context string produced by
:func:`build_schema_context`."""
+ return re.findall(r"^Table:\s+(\S+)", schema_context, re.MULTILINE)
+
+
+class SQLDQPlanner:
+ """
+ Generates and executes a SQL-based
:class:`~airflow.providers.common.ai.utils.dq_models.DQPlan`.
+
+ :param llm_hook: Hook used to call the LLM for plan generation.
+ :param db_hook: Hook used to execute generated SQL against the database.
+ :param dialect: SQL dialect forwarded to the LLM prompt and
``validate_sql``.
+ Auto-detected from *db_hook* when ``None``.
+ :param max_sql_retries: Maximum number of times a failing SQL group query
is sent
+ back to the LLM for correction before the error is re-raised. Default
``2``.
+ :param validator_contexts: Pre-built LLM context string from
+
:meth:`~airflow.providers.common.ai.utils.dq_validation.ValidatorRegistry.build_llm_context`.
+ Appended to the system prompt so the LLM knows what metric format each
+ custom validator expects.
+ :param row_validators: Mapping of ``{check_name: row_level_callable}`` for
+ checks that require row-by-row Python validation. When a check's name
+ appears here, ``execute_plan`` fetches all (or sampled) rows and
applies
+ the callable to each value instead of reading a single aggregate
scalar.
+ :param row_level_sample_size: Maximum number of rows to fetch for row-level
+ checks. ``None`` (default) performs a full scan. A positive integer
+ instructs the LLM to add ``LIMIT N`` to the generated SELECT.
+ :param toolset:
:class:`~airflow.providers.common.ai.utils.dq_validation.DQValidationToolset`
+ that exposes the validator catalog to the LLM and validates its
selections.
+ Defaults to a toolset backed by the
:data:`~airflow.providers.common.ai.utils.dq_validation.default_registry`.
+ :param fixed_validators: Mapping of ``{check_name: callable}`` for checks
where
+ the user preselected a validator. These checks are excluded from LLM
+ validator-suggestion and the provided callables are used directly.
+ :param max_validator_retries: Maximum number of times the LLM is asked to
correct
+ invalid validator selections before the plan is rejected. Default
``3``.
+ """
+
+ def __init__(
+ self,
+ *,
+ llm_hook: PydanticAIHook,
+ db_hook: DbApiHook | None,
+ dialect: str | None = None,
+ max_sql_retries: int = 2,
+ datasource_config: DataSourceConfig | None = None,
+ system_prompt: str = "",
+ agent_params: dict[str, Any] | None = None,
+ collect_unexpected: bool = False,
+ unexpected_sample_size: int = 100,
+ validator_contexts: str = "",
+ row_validators: dict[str, Any] | None = None,
+ row_level_sample_size: int | None = None,
+ toolset: DQValidationToolset | None = None,
+ fixed_validators: dict[str, Any] | None = None,
+ max_validator_retries: int = 3,
+ ) -> None:
+ self._llm_hook = llm_hook
+ self._db_hook = db_hook
+ self._datasource_config = datasource_config
+ self._dialect = resolve_dialect(db_hook, dialect)
+ # Track whether the execution target is DataFusion so the prompt can
+ # include DataFusion-specific syntax rules. The dialect stays None
+ # (generic SQL) for sqlglot validation — sqlglot has no DataFusion
dialect.
+ self._is_datafusion = db_hook is None and datasource_config is not None
+ # When targeting DataFusion, use PostgreSQL dialect for sqlglot
validation
+ # because DataFusion shares regex operators (~, !~) that the generic
SQL
+ # parser does not recognise.
+ self._validation_dialect: str | None = "postgres" if
self._is_datafusion else self._dialect
+ self._max_sql_retries = max_sql_retries
+ self._extra_system_prompt = system_prompt
+ self._agent_params: dict[str, Any] = agent_params or {}
+ self._collect_unexpected = collect_unexpected
+ self._unexpected_sample_size = unexpected_sample_size
+ self._validator_contexts = validator_contexts
+ self._row_validators: dict[str, Any] = row_validators or {}
+ self._row_level_sample_size = row_level_sample_size
+ self._toolset: DQValidationToolset = toolset if toolset is not None
else DQValidationToolset()
+ self._fixed_validators: dict[str, Any] = fixed_validators or {}
+ self._max_validator_retries = max_validator_retries
+ self._cached_datafusion_engine: DataFusionEngine | None = None
+ self._plan_agent: Agent[None, DQPlan] | None = None
+ self._plan_all_messages: list[ModelMessage] | None = None
+
+ def build_schema_context(
+ self,
+ table_names: list[str] | None,
+ schema_context: str | None,
+ ) -> str:
+ """
+ Return a schema description string for inclusion in the LLM prompt.
+
+ Delegates to
:func:`~airflow.providers.common.ai.utils.db_schema.build_schema_context`.
+ """
+ return build_schema_context(
+ db_hook=self._db_hook,
+ table_names=table_names,
+ schema_context=schema_context,
+ datasource_config=self._datasource_config,
+ )
+
+ def build_catalog_hash(self) -> str:
+ """
+ Return a 16-char SHA-256 fingerprint of the current validator catalog.
+
+ Used by the operator to include a catalog change fingerprint in the
plan
+ cache key, ensuring that adding or removing a validator invalidates
stale
+ cached plans without requiring a new ``prompt_version``.
+ """
+ section = self._toolset.build_system_prompt_section()
+ return hashlib.sha256(section.encode()).hexdigest()[:16]
+
+ def generate_plan(self, checks: list[DQCheckInput], schema_context: str)
-> DQPlan:
+ """
+ Ask the LLM to produce a
:class:`~airflow.providers.common.ai.utils.dq_models.DQPlan`.
+
+ The LLM receives the check definitions, schema context, validator
catalog,
+ and planning instructions as a structured-output call
(``output_type=DQPlan``).
+ In a single response the model produces both the SQL plan and validator
+ selections for each check.
+
+ After generation the method:
+
+ 1. Verifies that the returned ``check_names`` exactly match the input
names.
+ 2. Validates each LLM-suggested validator via the toolset.
+ 3. On validation failures, retries up to ``max_validator_retries``
times by
+ continuing the same conversation thread.
+ 4. Instantiates validated validator suggestions and merges them with
any
+ user-fixed validators.
+
+ :param checks: A list of
:class:`~airflow.providers.common.ai.utils.dq_models.DQCheckInput`
+ objects describing the data-quality expectations.
+ :param schema_context: Schema description previously built via
+ :meth:`build_schema_context`.
+ :raises ValueError: If the LLM's plan does not cover every check name
exactly once.
+ :raises ValueError: If validator suggestions cannot be corrected
within the retry budget.
+ """
+ check_descriptions = {c.name: c.description for c in checks}
+ dialect_label = self._dialect or ("DataFusion-compatible SQL" if
self._is_datafusion else "SQL")
+ system_prompt = _PLANNING_SYSTEM_PROMPT.format(
+ dialect=dialect_label, max_checks_per_group=_MAX_CHECKS_PER_GROUP
+ )
+
+ if self._is_datafusion:
+ system_prompt += _DATAFUSION_SYNTAX_SECTION
+
+ if self._collect_unexpected:
+ system_prompt += _UNEXPECTED_QUERY_PROMPT_SECTION.format(
+ dialect=dialect_label, sample_size=self._unexpected_sample_size
+ )
+
+ if schema_context:
+ system_prompt += f"\nAvailable schema:\n{schema_context}\n"
+ table_names = _extract_table_names(schema_context)
+ if table_names:
+ system_prompt +=
_TABLE_NAME_CONSTRAINT_SECTION.format(table_names=", ".join(table_names))
+
+ # Inject validator catalog for LLM-driven selection.
+ system_prompt += self._toolset.build_system_prompt_section()
+
+ # Legacy per-validator context (still injected for fixed validators
that carry it).
+ if self._validator_contexts:
+ system_prompt += self._validator_contexts
+
+ if self._row_validators:
+ row_level_check_names = ", ".join(sorted(self._row_validators))
+ if self._row_level_sample_size is not None:
+ limit_clause = f"Add LIMIT {self._row_level_sample_size} to
the query."
+ else:
+ limit_clause = "Do NOT add a LIMIT — return all rows."
+ system_prompt += _ROW_LEVEL_PROMPT_SECTION.format(
+ row_level_check_names=row_level_check_names,
+ row_level_limit_clause=limit_clause,
+ )
+
+ if self._extra_system_prompt:
+ system_prompt += f"\nAdditional
instructions:\n{self._extra_system_prompt}\n"
+
+ user_message = self._build_user_message(checks,
fixed_validator_names=set(self._fixed_validators))
+ prompt_fingerprint =
hashlib.sha256(f"{system_prompt}\n{user_message}".encode()).hexdigest()[:12]
+
+ log.info(
+ "Generating DQ plan with %d check(s) (system_prompt_chars=%d,
user_message_chars=%d, "
+ "prompt_fingerprint=%s).",
+ len(checks),
+ len(system_prompt),
+ len(user_message),
+ prompt_fingerprint,
+ )
+ log.debug("Using system prompt:\n%s", system_prompt)
+ log.debug("Using user message:\n%s", user_message)
+
+ agent = self._llm_hook.create_agent(
+ output_type=DQPlan, instructions=system_prompt,
**self._agent_params
+ )
+ result = agent.run_sync(user_message)
+ log_run_summary(log, result)
+
+ # Persist the agent and full conversation so execute_plan and validator
+ # correction can continue the same chat thread.
+ self._plan_agent = agent
+ self._plan_all_messages = result.all_messages()
+
+ plan: DQPlan = result.output
+
+ self._validate_plan_coverage(plan, check_descriptions)
+ self._validate_group_sizes(plan)
+
+ # Validate and potentially retry LLM validator suggestions.
+ plan = self._validate_and_fix_validator_suggestions(plan)
+
+ return plan
+
+ def _validate_and_fix_validator_suggestions(self, plan: DQPlan) -> DQPlan:
+ """
+ Validate LLM-suggested validator names and arguments for each check.
+
+ Checks with a user-fixed validator (in ``self._fixed_validators``) are
+ skipped. For the rest, each suggestion is tested via
+
:meth:`~airflow.providers.common.ai.utils.dq_validation.DQValidationToolset.validate_suggestion`.
+
+ On failure the LLM is sent a correction prompt in the same conversation
+ thread (up to ``max_validator_retries`` attempts). If the LLM still
+ returns invalid selections after all retries, an informative
+ :class:`ValueError` is raised.
+
+ :param plan: Plan returned by the initial LLM call.
+ :returns: The plan with validated (and possibly corrected) validator
fields.
+ :raises ValueError: If validator suggestions cannot be corrected.
+ """
+ current_plan = plan
+ for attempt in range(1, self._max_validator_retries + 1):
+ errors = self._collect_validator_errors(current_plan)
+ if not errors:
+ break
+
+ # Log all failed checks as errors so they are visible in task logs.
+ for check_name, error_msg in errors.items():
+ log.error(
+ "Validator suggestion invalid for check %r (attempt
%d/%d): %s",
+ check_name,
+ attempt,
+ self._max_validator_retries,
+ error_msg,
+ )
+
+ if self._plan_agent is None or self._plan_all_messages is None:
+ # No conversation thread to continue — fail immediately.
+ break
+
+ correction_prompt = self._build_validator_correction_prompt(errors)
+ log.info(
+ "Sending validator correction prompt (attempt %d/%d).",
+ attempt,
+ self._max_validator_retries,
+ )
+ result = self._plan_agent.run_sync(
+ correction_prompt,
+ message_history=self._plan_all_messages,
+ )
+ self._plan_all_messages = result.all_messages()
+ current_plan = result.output
+
+ # Final check after loop.
+ errors = self._collect_validator_errors(current_plan)
+ if errors:
+ error_lines = [f" - {name}: {msg}" for name, msg in
sorted(errors.items())]
+ raise ValueError(
+ f"LLM validator suggestions could not be corrected after "
+ f"{self._max_validator_retries} attempt(s). "
+ f"Invalid suggestions:\n" + "\n".join(error_lines)
+ )
+
+ return current_plan
+
+ def _collect_validator_errors(self, plan: DQPlan) -> dict[str, str]:
+ """
+ Return a ``{check_name: error_message}`` dict for every invalid
validator suggestion.
+
+ Checks with user-fixed validators are excluded from validation.
+ """
+ errors: dict[str, str] = {}
+ for group in plan.groups:
+ for check in group.checks:
+ if check.check_name in self._fixed_validators:
+ continue
+ # For non-fixed checks, missing validator_name is invalid and
must
+ # trigger correction/failure.
+ suggested_name = check.validator_name or ""
+ ok, msg = self._toolset.validate_suggestion(
+ check.check_name,
+ suggested_name,
+ check.validator_args,
+ )
Review Comment:
`_collect_validator_errors()` flags missing/"none" validator_name as invalid
for every non-fixed check, which forces the LLM to always select a validator
even though the prompt text allows "none" and the operator treats aggregate
checks without validators as pass-through metric collection. This will either
cause repeated correction loops or hard failures for checks where "none" is
appropriate. Consider allowing aggregate checks to omit a validator (and only
requiring a validator for row-level checks).
##########
providers/common/ai/src/airflow/providers/common/ai/decorators/llm_data_quality.py:
##########
@@ -0,0 +1,164 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+TaskFlow decorator for LLM-driven data-quality checks.
+
+The user writes a function that **returns a list of
+:class:`~airflow.providers.common.ai.utils.dq_models.DQCheckInput`** objects.
+The decorator handles LLM plan generation, plan caching, SQL execution against
+the target database, and metric validation.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable, Collection, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar
+
+from airflow.providers.common.ai.operators.llm_data_quality import
LLMDataQualityOperator
+from airflow.providers.common.compat.sdk import (
+ DecoratedOperator,
+ TaskDecorator,
+ context_merge,
+ determine_kwargs,
+ task_decorator_factory,
+)
+from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class _LLMDQDecoratedOperator(DecoratedOperator, LLMDataQualityOperator):
+ """
+ Wraps a callable that returns a list of :class:`DQCheckInput` for LLM
data-quality checks.
+
+ The user function is called at execution time to produce the checks list.
+ All other parameters (``llm_conn_id``, ``db_conn_id``, ``table_names``,
+ etc.) are passed through to
+
:class:`~airflow.providers.common.ai.operators.llm_data_quality.LLMDataQualityOperator`.
+
+ :param python_callable: A callable that returns a
+ ``list[DQCheckInput]``.
+ :param op_args: Positional arguments for the callable.
+ :param op_kwargs: Keyword arguments for the callable.
+ """
+
+ template_fields: Sequence[str] = (
+ *DecoratedOperator.template_fields,
+ *LLMDataQualityOperator.template_fields,
+ )
+ template_fields_renderers: ClassVar[dict[str, str]] = {
+ **DecoratedOperator.template_fields_renderers,
+ }
+
+ custom_operator_name: str = "@task.llm_dq"
+
+ def __init__(
+ self,
+ *,
+ python_callable: Callable,
+ op_args: Collection[Any] | None = None,
+ op_kwargs: Mapping[str, Any] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ python_callable=python_callable,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ checks=SET_DURING_EXECUTION,
+ **kwargs,
+ )
+
+ def execute(self, context: Context) -> Any:
+ context_merge(context, self.op_kwargs)
+ kwargs = determine_kwargs(self.python_callable, self.op_args, context)
+
+ checks = self.python_callable(*self.op_args, **kwargs)
+
+ if not isinstance(checks, list) or not checks:
+ raise TypeError(
+ "The returned value from the @task.llm_dq callable must be a
non-empty list[DQCheckInput]."
+ )
+
+ self.checks = checks
+ self.render_template_fields(context)
+ return LLMDataQualityOperator.execute(self, context)
Review Comment:
`_LLMDQDecoratedOperator.execute()` assigns `self.checks` directly from the
user callable without coercing dicts to `DQCheckInput` (unlike
`LLMDataQualityOperator.__init__`) and without running the
duplicate-name/empty-field validation that the operator normally does. This can
lead to runtime attribute errors (e.g. when computing the plan hash) or allow
duplicate check names through. Consider coercing each item via
`DQCheckInput.coerce(...)` and calling the operator’s `_validate_checks()`
after setting `self.checks`.
##########
providers/common/ai/src/airflow/providers/common/ai/utils/db_schema.py:
##########
@@ -0,0 +1,205 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Shared database hook and schema introspection utilities.
+
+These helpers are used by both
:class:`~airflow.providers.common.ai.operators.llm_sql.LLMSQLQueryOperator`
+and
:class:`~airflow.providers.common.ai.operators.llm_data_quality.LLMDataQualityOperator`
to
+avoid code duplication while keeping both operators decoupled from each other.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING
+
+from airflow.providers.common.compat.sdk import BaseHook
+
+if TYPE_CHECKING:
+ from airflow.providers.common.sql.config import DataSourceConfig
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+log = logging.getLogger(__name__)
+
+# SQLAlchemy dialect_name → sqlglot dialect mapping for names that differ.
+SQLALCHEMY_TO_SQLGLOT_DIALECT: dict[str, str] = {
+ "postgresql": "postgres",
+ "mssql": "tsql",
+}
+
+
+def get_db_hook(db_conn_id: str) -> DbApiHook:
+ """
+ Return a :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook` for
*db_conn_id*.
+
+ :param db_conn_id: Airflow connection ID that resolves to a ``DbApiHook``.
+ :raises ValueError: If the connection does not resolve to a ``DbApiHook``.
+ """
+ # Lazy load to avoid hard dependency on common.sql
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+ connection = BaseHook.get_connection(db_conn_id)
+ hook = connection.get_hook()
+ if not isinstance(hook, DbApiHook):
+ raise ValueError(
+ f"Connection {db_conn_id!r} does not provide a DbApiHook. Got
{type(hook).__name__}."
+ )
+ return hook
+
+
+def resolve_dialect(db_hook: DbApiHook | None, explicit_dialect: str | None)
-> str | None:
+ """
+ Resolve the SQL dialect from an explicit parameter or a database hook.
+
+ Normalises SQLAlchemy dialect names to sqlglot equivalents
+ (e.g. ``postgresql`` → ``postgres``).
+
+ :param db_hook: Database hook to read ``dialect_name`` from when
*explicit_dialect* is absent.
+ :param explicit_dialect: Caller-supplied dialect string; takes priority
over the hook.
+ :return: Resolved dialect string, or ``None`` when neither source provides
one.
+ """
+ raw = explicit_dialect
+ if not raw and db_hook and hasattr(db_hook, "dialect_name"):
+ candidate = db_hook.dialect_name
+ raw = candidate if isinstance(candidate, str) else None
+ if raw:
+ return SQLALCHEMY_TO_SQLGLOT_DIALECT.get(raw, raw)
+ return None
+
+
+def build_schema_context(
+ *,
+ db_hook: DbApiHook | None,
+ table_names: list[str] | None,
+ schema_context: str | None,
+ datasource_config: DataSourceConfig | None,
+) -> str:
+ """
+ Return a schema description string suitable for inclusion in an LLM prompt.
+
+ Resolution order:
+ 1. *schema_context* — returned as-is when provided (manual override).
+ 2. DB introspection via *db_hook* + *table_names*.
+ 3. Object-storage introspection via *datasource_config*.
+ 4. Empty string when none of the above are available.
+
+ :param db_hook: Hook used for relational-database schema introspection.
+ :param table_names: Table names to introspect via *db_hook*.
+ :param schema_context: Manual schema description; bypasses introspection
when set.
+ :param datasource_config: DataFusion datasource config for object-storage
schema.
+ :raises ValueError: If *table_names* are provided but none yield schema
information.
+ """
+ if schema_context:
+ return schema_context
+
+ if (db_hook and table_names) or datasource_config:
+ return _introspect_schemas(
+ db_hook=db_hook,
+ table_names=table_names,
+ datasource_config=datasource_config,
+ )
+
+ return ""
Review Comment:
`build_schema_context()` returns an empty string when `table_names` is
provided but `db_hook` and `datasource_config` are both missing, so a
misconfiguration silently disables schema introspection. Since
`_introspect_schemas()` explicitly raises for `table_names` without a DB hook,
consider raising a `ValueError` in `build_schema_context()` whenever
`table_names` is set but `db_hook` is None (unless `schema_context` is
provided).
##########
providers/common/ai/src/airflow/providers/common/ai/utils/dq_validation.py:
##########
@@ -0,0 +1,632 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Built-in and custom validator factories for
:class:`~airflow.providers.common.ai.operators.llm_data_quality.LLMDataQualityOperator`.
+
+Each factory returns a ``Callable[[Any], bool]`` and can be passed as the
+``validator`` argument of a
:class:`~airflow.providers.common.ai.utils.dq_models.DQCheckInput`
+to force a specific validator for that check, bypassing LLM selection.
+Factories are intentionally decoupled from the operator so they can be tested
and
+composed independently.
+
+Custom validators registered with :func:`register_validator` are exposed to
the LLM
+via the validator catalog so the model can select them automatically.
+
+Usage::
+
+ from airflow.providers.common.ai.utils.dq_models import DQCheckInput
+ from airflow.providers.common.ai.utils.dq_validation import (
+ null_pct_check,
+ row_count_check,
+ register_validator,
+ )
+
+ # Fixed validators — LLM is not asked to select a validator for these
checks.
+ checks = [
+ DQCheckInput(
+ name="email_nulls",
+ description="Check for null emails",
+ validator=null_pct_check(max_pct=0.05),
+ ),
+ DQCheckInput(
+ name="min_customers",
+ description="Ensure at least 1000 rows",
+ validator=row_count_check(min_count=1000),
+ ),
+ ]
+
+
+ # Custom validator with LLM context — LLM can select this automatically.
+ @register_validator(
+ "freshness_check",
+ llm_context=(
+ "Compute hours since the most recent row. "
+ "SQL pattern: EXTRACT(EPOCH FROM (NOW() - MAX(ts_col))) / 3600.0. "
+ "Returns a DOUBLE representing hours elapsed."
+ ),
+ check_category="freshness",
+ )
+ def freshness_check(*, max_hours: float):
+ def _check(value):
+ return float(value) <= max_hours
+
+ return _check
+"""
+
+from __future__ import annotations
+
+import functools
+import inspect
+import logging
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any
+
+log = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class ValidatorEntry:
+ """
+ Metadata for a registered validator factory.
+
+ :param factory: Callable that returns a ``Callable[[Any], bool]``
validator.
+ :param llm_context: Optional hint injected into the LLM system prompt so
+ the model knows what SQL metric format this validator expects.
+ :param check_category: Optional custom check category. When set, the LLM
+ is instructed to use this category for grouping.
+ :param row_level: When ``True`` the LLM is instructed to generate a plain
+ ``SELECT pk, col FROM table`` (no aggregation). The planner fetches
+ every row and applies the validator callable to each column value,
+ then reports ``{total, invalid, invalid_pct, sample_violations}``.
+ """
+
+ factory: Callable[..., Callable[[Any], bool]]
+ llm_context: str = ""
+ check_category: str = ""
+ row_level: bool = False
+
+
+class ValidatorRegistry:
+ """
+ Registry for reusable validator factories with optional LLM context.
+
+ Validators registered here can carry an ``llm_context`` string that the
+ operator automatically injects into the LLM system prompt, guiding the
+ model to produce SQL that returns the metric format the validator expects.
+
+ A module-level :data:`default_registry` instance is available. Use the
+ convenience decorator :func:`register_validator` to register into it.
+ """
+
+ def __init__(self) -> None:
+ self._entries: dict[str, ValidatorEntry] = {}
+
+ def register(
+ self,
+ name: str,
+ *,
+ llm_context: str = "",
+ check_category: str = "",
+ row_level: bool = False,
+ ) -> Callable[[Callable[..., Callable[[Any], bool]]], Callable[...,
Callable[[Any], bool]]]:
+ """
+ Return a decorator that registers a validator factory under *name*.
+
+ :param name: Unique name for this validator.
+ :param llm_context: SQL generation hint injected into the LLM prompt.
+ :param check_category: Custom check category for LLM grouping.
+ :param row_level: When ``True``, the LLM generates a plain SELECT
+ returning raw row values instead of an aggregate query. The
+ planner applies the validator to each row and aggregates results.
+ :raises ValueError: If *name* is already registered.
+ """
+ if name in self._entries:
+ raise ValueError(
+ f"Validator {name!r} is already registered. "
+ "Use a different name or unregister the existing one first."
+ )
+
+ def _decorator(
+ factory: Callable[..., Callable[[Any], bool]],
+ ) -> Callable[..., Callable[[Any], bool]]:
+ # Wrap the factory so every closure it returns carries
introspection
+ # attributes used by the operator and planner.
+ @functools.wraps(factory)
+ def _wrapped_factory(*args: Any, **kwargs: Any) -> Callable[[Any],
bool]:
+ closure = factory(*args, **kwargs)
+ arg_parts = [repr(a) for a in args]
+ kwarg_parts = [f"{k}={v!r}" for k, v in sorted(kwargs.items())]
+ call_str = f"{name}({', '.join(arg_parts + kwarg_parts)})"
+ if not hasattr(closure, "_validator_name"):
+ closure._validator_name = name # type:
ignore[attr-defined]
+ if not hasattr(closure, "_row_level"):
+ closure._row_level = row_level # type:
ignore[attr-defined]
+ if not hasattr(closure, "_validator_display"):
+ closure._validator_display = call_str # type:
ignore[attr-defined]
+ for k, v in sorted(kwargs.items()):
+ if not hasattr(closure, f"_{k}"):
+ setattr(closure, f"_{k}", v) # e.g. _max_pct,
_min_count
+ return closure
+
+ _wrapped_factory._validator_name = name # type:
ignore[attr-defined]
+ _wrapped_factory._llm_context = llm_context # type:
ignore[attr-defined]
+ _wrapped_factory._check_category = check_category # type:
ignore[attr-defined]
+ _wrapped_factory._row_level = row_level # type:
ignore[attr-defined]
+ _wrapped_factory._validator_display = name # type:
ignore[attr-defined]
+ _wrapped_factory.__name__ = factory.__name__
+ _wrapped_factory.__qualname__ = factory.__qualname__
+ _wrapped_factory.__doc__ = factory.__doc__
+
+ self._entries[name] = ValidatorEntry(
+ factory=_wrapped_factory,
+ llm_context=llm_context,
+ check_category=check_category,
+ row_level=row_level,
+ )
+ return _wrapped_factory
+
+ return _decorator
+
+ def get(self, name: str) -> ValidatorEntry:
+ """
+ Return the :class:`ValidatorEntry` for *name*.
+
+ :raises KeyError: If *name* is not registered.
+ """
+ try:
+ return self._entries[name]
+ except KeyError:
+ raise KeyError(
+ f"Validator {name!r} is not registered. Available validators:
{sorted(self._entries)}"
+ ) from None
+
+ def list_validators(self) -> list[str]:
+ """Return sorted list of all registered validator names."""
+ return sorted(self._entries)
+
+ def is_row_level(self, validator: Callable[[Any], bool]) -> bool:
+ """
+ Return ``True`` when *validator* was produced by a row-level factory.
+
+ Checks the ``_row_level`` attribute set by the factory closure and,
+ as a fallback, the registry entry for the factory name.
+ """
+ if hasattr(validator, "_row_level"):
+ return bool(validator._row_level)
+ factory_name: str | None = getattr(validator, "_validator_name", None)
+ if factory_name and factory_name in self._entries:
+ return self._entries[factory_name].row_level
+ return False
+
+ def build_llm_context(self, validators: dict[str, Callable[[Any], bool]])
-> str:
+ """
+ Collect LLM context strings from all validators that carry one.
+
+ Aggregate and row-level validators are emitted in separate sections so
+ the LLM knows which checks require raw-row SELECTs vs. aggregate
queries.
+
+ Checks three sources in order for each validator callable:
+
+ 1. Registry entry (if the callable's factory was registered).
+ 2. ``_llm_context`` attribute on the callable itself.
+ 3. ``llm_context`` attribute on the callable itself.
+
+ :param validators: The ``{check_name: callable}`` dict from the
operator.
+ :returns: Combined context string ready for injection into the system
prompt,
+ or empty string if no validator carries context.
+ """
+ aggregate_lines: list[str] = []
+ row_level_lines: list[str] = []
+
+ for check_name, validator in validators.items():
+ context = self._resolve_llm_context(validator)
+ if not context:
+ continue
+ if self.is_row_level(validator):
+ row_level_lines.append(f" - {check_name}: {context}")
+ else:
+ aggregate_lines.append(f" - {check_name}: {context}")
+
+ if not aggregate_lines and not row_level_lines:
+ return ""
+
+ parts: list[str] = []
+ if aggregate_lines:
+ parts.append(
+ "\nCUSTOM VALIDATOR CONTEXT:\n"
+ " The following checks have specific metric requirements.\n"
+ " Generate SQL that returns values matching these
descriptions:\n"
+ + "\n".join(aggregate_lines)
+ )
+ if row_level_lines:
+ parts.append(
+ "\nROW-LEVEL CHECKS:\n"
+ " The following checks require ROW-LEVEL validation.\n"
+ " For each, generate a SELECT that returns the primary key
column(s)\n"
+ " and the column(s) to validate — do NOT aggregate.\n"
+ " Set check.row_level = true on these DQCheck entries.\n"
+ " The Python-side validator will inspect each returned
value:\n" + "\n".join(row_level_lines)
+ )
+ return "\n".join(parts) + "\n"
+
+ def _resolve_llm_context(self, validator: Callable[[Any], bool]) -> str:
+ """Resolve LLM context from registry entries or callable attributes."""
+ # Check registry by factory name attribute.
+ factory_name: str | None = getattr(validator, "_validator_name", None)
+ if factory_name and factory_name in self._entries:
+ entry = self._entries[factory_name]
+ if entry.llm_context:
+ return entry.llm_context
+
+ # Fallback: attribute on the callable itself.
+ for attr in ("_llm_context", "llm_context"):
+ context = getattr(validator, attr, None)
+ if context and isinstance(context, str):
+ return context
+
+ return ""
+
+ def unregister(self, name: str) -> None:
+ """
+ Remove a validator from the registry.
+
+ :raises KeyError: If *name* is not registered.
+ """
+ try:
+ del self._entries[name]
+ except KeyError:
+ raise KeyError(f"Validator {name!r} is not registered.") from None
+
+
+default_registry = ValidatorRegistry()
+
+
+def register_validator(
+ name: str,
+ *,
+ llm_context: str = "",
+ check_category: str = "",
+ row_level: bool = False,
+) -> Callable[[Callable[..., Callable[[Any], bool]]], Callable[...,
Callable[[Any], bool]]]:
+ """
+ Register a validator factory in the :data:`default_registry`.
+
+ Use as a decorator on a factory function::
+
+ @register_validator(
+ "freshness_check",
+ llm_context="Compute hours since MAX(updated_at). Returns DOUBLE.",
+ check_category="freshness",
+ )
+ def freshness_check(*, max_hours: float):
+ def _check(value):
+ return float(value) <= max_hours
+
+ return _check
+
+ For row-level validation (e.g. TCKN formula)::
+
+ @register_validator(
+ "tckn_check",
+ llm_context="ROW-LEVEL: SELECT pk, tckn_col FROM table. No
aggregation.",
+ check_category="row_level",
+ row_level=True,
+ )
+ def tckn_check(*, max_invalid_pct: float = 0.0):
+ def _check_row(value): ...
+
+ return _check_row
+
+ :param name: Unique name for this validator.
+ :param llm_context: SQL generation hint injected into the LLM prompt.
+ :param check_category: Custom check category for LLM grouping.
+ :param row_level: When ``True``, the LLM generates a plain SELECT returning
+ raw column values. The planner validates each row with the callable.
+ """
+ return default_registry.register(
+ name, llm_context=llm_context, check_category=check_category,
row_level=row_level
+ )
+
+
+@register_validator(
+ "null_pct_check",
+ llm_context="Returns null percentage as a float between 0.0 and 1.0. SQL
pattern: COUNT(CASE WHEN col IS NULL THEN 1 END) * 1.0 / COUNT(*).",
+ check_category="null_check",
+)
+def null_pct_check(*, max_pct: float) -> Callable[[Any], bool]:
+ """
+ Return a validator that passes when ``value <= max_pct``.
+
+ :param max_pct: Maximum allowed null percentage (0.0 – 1.0).
+ :raises TypeError: When the metric value cannot be converted to ``float``.
+ """
+
+ def _check(value: Any) -> bool:
+ try:
+ return float(value) <= max_pct
+ except (TypeError, ValueError) as exc:
+ raise TypeError(
+ f"null_pct_check(max_pct={max_pct!r}): expected a numeric
value, got {value!r}"
+ ) from exc
+
+ return _check
+
+
+@register_validator(
+ "row_count_check",
+ llm_context="Returns an integer row count. SQL pattern: COUNT(*).",
+ check_category="row_count",
+)
+def row_count_check(*, min_count: int) -> Callable[[Any], bool]:
+ """
+ Return a validator that passes when ``value >= min_count``.
+
+ :param min_count: Minimum required row count.
+ :raises TypeError: When the metric value cannot be converted to ``int``.
+ """
+
+ def _check(value: Any) -> bool:
+ try:
+ return int(value) >= min_count
+ except (TypeError, ValueError) as exc:
+ raise TypeError(
+ f"row_count_check(min_count={min_count!r}): expected an
integer value, got {value!r}"
+ ) from exc
+
+ return _check
+
+
+@register_validator(
+ "duplicate_pct_check",
+ llm_context="Returns duplicate percentage as a float between 0.0 and 1.0.
SQL pattern: (COUNT(*) - COUNT(DISTINCT col)) * 1.0 / COUNT(*).",
+ check_category="uniqueness",
+)
+def duplicate_pct_check(*, max_pct: float) -> Callable[[Any], bool]:
+ """
+ Return a validator that passes when ``value <= max_pct``.
+
+ :param max_pct: Maximum allowed duplicate percentage (0.0 – 1.0).
+ :raises TypeError: When the metric value cannot be converted to ``float``.
+ """
+
+ def _check(value: Any) -> bool:
+ try:
+ return float(value) <= max_pct
+ except (TypeError, ValueError) as exc:
+ raise TypeError(
+ f"duplicate_pct_check(max_pct={max_pct!r}): expected a numeric
value, got {value!r}"
+ ) from exc
+
+ return _check
+
+
+@register_validator(
+ "between_check",
+ llm_context="Returns a numeric value that will be compared against
inclusive bounds. SQL should return a single DOUBLE or INTEGER metric.",
+ check_category="numeric_range",
+)
+def between_check(*, min_val: float, max_val: float) -> Callable[[Any], bool]:
+ """
+ Return a validator that passes when ``min_val <= value <= max_val``.
+
+ :param min_val: Inclusive lower bound.
+ :param max_val: Inclusive upper bound.
+ :raises ValueError: When *min_val* > *max_val*.
+ :raises TypeError: When the metric value cannot be converted to ``float``.
+ """
+ if min_val > max_val:
+ raise ValueError(f"between_check: min_val ({min_val!r}) must be <=
max_val ({max_val!r})")
+
+ def _check(value: Any) -> bool:
+ try:
+ return min_val <= float(value) <= max_val
+ except (TypeError, ValueError) as exc:
+ raise TypeError(
+ f"between_check(min_val={min_val!r}, max_val={max_val!r}): "
+ f"expected a numeric value, got {value!r}"
+ ) from exc
+
+ return _check
+
+
+@register_validator(
+ "exact_check",
+ llm_context="Returns a value that must exactly equal an expected constant.
SQL should return a single scalar metric.",
+ check_category="validity",
+)
+def exact_check(*, expected: Any) -> Callable[[Any], bool]:
+ """
+ Return a validator that passes when ``value == expected``.
+
+ .. note::
+ Comparison uses Python's ``==`` operator without type coercion.
+ ``Decimal(0) == 0`` passes (Python numeric promotion), but
+ ``"0" == 0`` does not. The behaviour depends on the DB driver's
+ Python type for the returned column.
+
+ :param expected: The exact value the metric must equal.
+ """
+
+ def _check(value: Any) -> bool:
+ return value == expected
+
+ return _check
+
+
+class DQValidationToolset:
+ """
+ Validator catalog manager that exposes registered validators to the LLM.
+
+ Serves two purposes:
+
+ 1. **Prompt generation** — :meth:`build_system_prompt_section` produces a
+ text block that is injected into the LLM system prompt so the model
knows
+ which validators are available, their parameters, and what SQL metric
each
+ one expects.
+
+ 2. **Suggestion validation** — :meth:`validate_suggestion` verifies that
the
+ name and arguments proposed by the LLM can actually instantiate a
callable
+ without error, before the plan is accepted.
+
+ :param registry: Validator registry to expose. Defaults to
+ :data:`default_registry`.
+ """
+
+ def __init__(self, registry: ValidatorRegistry | None = None) -> None:
+ self._registry: ValidatorRegistry = registry if registry is not None
else default_registry
+
+ def build_system_prompt_section(self) -> str:
+ """
+ Return a system-prompt block listing all registered validators.
+
+ The block is appended to the SQL-planning prompt and tells the LLM
+ which validators exist, their parameter signatures, and what SQL
+ metric each one expects. The LLM fills ``validator_name`` and
+ ``validator_args`` on each
:class:`~airflow.providers.common.ai.utils.dq_models.DQCheck`
+ based on this catalogue.
+ """
+ lines: list[str] = [
+ "",
+ "AVAILABLE VALIDATORS:",
+ " For each DQCheck, select the most appropriate validator from
the list below.",
+ " Fill the validator_name field with the exact name and
validator_args with the",
+ " required keyword arguments as a JSON object.",
+ ' If no validator is appropriate for a check, set validator_name
to "none".',
+ " If the check has a [FIXED VALIDATOR] annotation, leave
validator_name as null.",
+ "",
+ ]
+
+ aggregate_entries: list[str] = []
+ row_level_entries: list[str] = []
+
+ for name in self._registry.list_validators():
+ entry = self._registry.get(name)
+ sig_str = self._format_signature(entry.factory)
+ category = entry.check_category or "—"
+ block = (
+ f' - name: "{name}"\n'
+ f" category: {category}\n"
+ f" row_level: {str(entry.row_level).lower()}\n"
+ f" parameters: {sig_str}\n"
+ f" description: {entry.llm_context or '(no description)'}"
+ )
+ if entry.row_level:
+ row_level_entries.append(block)
+ else:
+ aggregate_entries.append(block)
+
+ if aggregate_entries:
+ lines.append(" Aggregate validators (SQL returns a single scalar
metric):")
+ lines.extend(aggregate_entries)
+ lines.append("")
+
+ if row_level_entries:
+ lines.append(" Row-level validators (SQL returns raw column
values per row, no aggregation):")
+ lines.extend(row_level_entries)
+ lines.append("")
+
+ lines += [
+ " IMPORTANT: validator_args must contain ONLY the keyword
argument names shown",
+ " above — no extra keys, no positional args. The argument values
must match the",
+ " expected Python types (float, int, etc.).",
+ "",
+ ]
+ return "\n".join(lines)
+
+ def validate_suggestion(
+ self,
+ check_name: str,
+ validator_name: str,
+ validator_args: dict[str, Any],
+ ) -> tuple[bool, str]:
+ """
+ Verify that *validator_name* exists and can be instantiated with
*validator_args*.
+
+ :param check_name: The check this suggestion belongs to (used in error
messages).
+ :param validator_name: Name of the validator factory as registered.
+ :param validator_args: Keyword arguments to pass to the factory.
+ :returns: ``(True, "")`` on success; ``(False, error_message)`` on
failure.
+ """
+ if not validator_name or validator_name.lower() == "none":
+ return False, (
+ f"Check {check_name!r}: LLM did not suggest a validator
(validator_name is null or 'none')."
+ )
Review Comment:
`validate_suggestion()` currently treats `validator_name` being empty or
"none" as an error. This contradicts the system-prompt text in
`build_system_prompt_section()` (which instructs the LLM to set "none" when no
validator fits) and the operator/docs behavior that aggregate checks may have
no validator. Consider treating "none"/empty as a valid "no validator" case
(and only enforcing non-empty validator selection for row-level checks).
```suggestion
Empty values and ``"none"`` are treated as a valid "no validator
selected"
case and do not produce an error.
:param check_name: The check this suggestion belongs to (used in
error messages).
:param validator_name: Name of the validator factory as registered,
or ``"none"``.
:param validator_args: Keyword arguments to pass to the factory.
:returns: ``(True, "")`` on success; ``(False, error_message)`` on
failure.
"""
if not validator_name or validator_name.lower() == "none":
return True, ""
```
##########
providers/common/ai/src/airflow/providers/common/ai/operators/llm_data_quality.py:
##########
@@ -0,0 +1,786 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Operator for generating and executing data-quality checks from natural
language using LLMs."""
+
+from __future__ import annotations
+
+import hashlib
+import json
+from collections import Counter
+from collections.abc import Callable, Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.ai.utils.db_schema import get_db_hook
+from airflow.providers.common.ai.utils.dq_models import (
+ DQCheck,
+ DQCheckFailedError,
+ DQCheckGroup,
+ DQCheckInput,
+ DQCheckResult,
+ DQPlan,
+ DQReport,
+ RowLevelResult,
+ UnexpectedResult,
+)
+from airflow.providers.common.ai.utils.dq_validation import
DQValidationToolset, default_registry
+from airflow.providers.common.compat.sdk import Variable
+
+try:
+ from airflow.providers.common.ai.utils.dq_planner import SQLDQPlanner
+except ImportError as e:
+ from airflow.providers.common.compat.sdk import
AirflowOptionalProviderFeatureException
+
+ raise AirflowOptionalProviderFeatureException(e)
+
+if TYPE_CHECKING:
+ from airflow.providers.common.sql.config import DataSourceConfig
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
+ from airflow.sdk import Context
+
+_PLAN_VARIABLE_PREFIX = "dq_plan_"
+_PLAN_VARIABLE_KEY_MAX_LEN = 200 # stay well under Airflow Variable key
length limit
+
+
+def _describe_validator(validator: Callable[[Any], bool]) -> str:
+ """Return a human-readable validator label for failure messages."""
+ display = getattr(validator, "_validator_display", None)
+ if isinstance(display, str) and display:
+ return display
+ validator_name = getattr(validator, "_validator_name", None)
+ if isinstance(validator_name, str) and validator_name:
+ return validator_name
+ validator_name = getattr(validator, "__name__", None)
+ if isinstance(validator_name, str) and validator_name:
+ return validator_name
+ return repr(validator)
+
+
+class LLMDataQualityOperator(LLMOperator):
+ """
+ Generate and execute data-quality checks from natural language
descriptions.
+
+ Each entry in ``checks`` describes **one** data-quality expectation. The
LLM
+ groups related checks into optimised SQL queries, selects the most
appropriate
+ validator for each check from the registered catalog, executes the SQL
against
+ the target database, and applies the validators. The task fails if any
check
+ does not pass, gating downstream tasks on data quality.
+
+ Optionally, supply a fixed ``validator`` on a
:class:`~airflow.providers.common.ai.utils.dq_models.DQCheckInput`
+ to bypass LLM validator selection for that specific check.
+
+ Generated SQL plans (including LLM-chosen validators) are cached in Airflow
+ :class:`~airflow.models.variable.Variable` to avoid repeat LLM calls.
+ Set ``dry_run=True`` to preview the plan without executing it.
+ Set ``require_approval=True`` to gate execution on human review via the
+ HITL interface.
+
+ :param checks: List of
:class:`~airflow.providers.common.ai.utils.dq_models.DQCheckInput`
+ objects (or plain dicts with ``name``, ``description``, and optional
``validator`` keys).
+ Each entry describes one data-quality expectation. Names must be
unique.
+ Example::
+
+ from airflow.providers.common.ai.utils.dq_models import
DQCheckInput
+ from airflow.providers.common.ai.utils.dq_validation import
null_pct_check
+
+ checks = [
+ DQCheckInput(name="email_nulls", description="Check for null
email addresses"),
+ DQCheckInput(
+ name="row_count",
+ description="Ensure at least 1000 rows exist",
+ validator=row_count_check(min_count=1000), # fixed — LLM
skips this one
+ ),
+ ]
+
+ :param llm_conn_id: Connection ID for the LLM provider.
+ :param model_id: Model identifier (e.g. ``"openai:gpt-4o"``).
+ Overrides the model stored in the connection's extra field.
+ :param system_prompt: Additional instructions appended to the planning
prompt.
+ :param agent_params: Additional keyword arguments passed to the pydantic-ai
+ ``Agent`` constructor (e.g. ``retries``, ``model_settings``).
+ :param db_conn_id: Connection ID for the database to run checks against.
+ Must resolve to a
:class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`.
+ :param table_names: Tables to include in the LLM's schema context.
+ :param schema_context: Manual schema description; bypasses DB
introspection.
+ :param dialect: SQL dialect override (``postgres``, ``mysql``, etc.).
+ Auto-detected from *db_conn_id* when not set.
+ :param datasource_config: DataFusion datasource for object-storage schema.
+ :param dry_run: When ``True``, generate and cache the plan but skip
execution.
+ Returns the serialised plan dict instead of a
+ :class:`~airflow.providers.common.ai.utils.dq_models.DQReport`.
+ :param prompt_version: Optional version tag included in the plan cache key.
+ Bump this to invalidate cached plans when checks change semantically
+ without changing their text.
+ :param collect_unexpected: When ``True``, the LLM generates an
+ ``unexpected_query`` for validity / string-format checks.
+ If any of those checks fail, the unexpected query is executed and
+ the resulting sample rows are included in the report.
+ :param unexpected_sample_size: Maximum number of violating rows to return
+ per failed check. Default ``100``.
+ :param row_level_sample_size: Maximum number of rows to fetch per row-level
+ check. ``None`` (default) performs a full table scan.
+ :param require_approval: When ``True``, the operator defers after
generating
+ and caching the DQ plan. The plan SQL is surfaced in the HITL
interface
+ for human review; checks run only after the reviewer approves.
+ ``dry_run=True`` takes precedence.
+
+ When approval is granted Airflow resumes the task by calling
+ :meth:`execute_complete` with the approved plan JSON.
+ """
+
+ template_fields: Sequence[str] = (
+ *LLMOperator.template_fields,
+ "checks",
+ "db_conn_id",
+ "table_names",
+ "schema_context",
+ "prompt_version",
+ "collect_unexpected",
+ "unexpected_sample_size",
+ "row_level_sample_size",
+ )
+
+ def __init__(
+ self,
+ *,
+ checks: list[DQCheckInput | dict[str, Any]],
+ db_conn_id: str | None = None,
+ table_names: list[str] | None = None,
+ schema_context: str | None = None,
+ dialect: str | None = None,
+ datasource_config: DataSourceConfig | None = None,
+ prompt_version: str | None = None,
+ dry_run: bool = False,
+ collect_unexpected: bool = False,
+ unexpected_sample_size: int = 100,
+ row_level_sample_size: int | None = None,
+ **kwargs: Any,
+ ) -> None:
+ kwargs.pop("output_type", None)
+ kwargs.setdefault("prompt", "LLMDataQualityOperator")
+ super().__init__(**kwargs)
+
+ self.checks: list[DQCheckInput] = (
+ checks if not isinstance(checks, list) else
[DQCheckInput.coerce(c) for c in checks]
+ )
+ self.db_conn_id = db_conn_id
+ self.table_names = table_names
+ self.schema_context = schema_context
+ self.dialect = dialect
+ self.datasource_config = datasource_config
+ self.prompt_version = prompt_version
+ self.dq_dry_run = dry_run
+ self.collect_unexpected = collect_unexpected
+ self.unexpected_sample_size = unexpected_sample_size
+ self.row_level_sample_size = row_level_sample_size
+
+ self._validate_checks()
+
+ def execute(self, context: Context) -> dict[str, Any]:
+ """
+ Generate the DQ plan (or load from cache), then execute or defer for
approval.
+
+ The plan is generated with a single LLM call that simultaneously
selects
+ validators from the registry **and** produces the SQL for each check.
+ Checks that have a user-supplied fixed validator bypass LLM selection.
+
+ When ``dry_run=True`` the serialised plan dict is returned immediately
—
+ no SQL is executed and no approval is requested.
+ When ``require_approval=True`` the task defers, presenting the plan to
a
+ human reviewer; data-quality checks run only after the reviewer
approves.
+
+ :returns: Dict with keys ``plan``, ``passed``, and ``results``.
+ :raises DQCheckFailedError: If any data-quality check fails threshold
validation.
+ :raises TaskDeferred: When ``require_approval=True``, defers for human
review.
+ """
+ planner = self._build_planner()
+
+ schema_ctx = planner.build_schema_context(
+ table_names=self.table_names, schema_context=self.schema_context
+ )
+
+ self.log.info("Using schema context:\n%s", schema_ctx)
+
+ plan = self._load_or_generate_plan(planner, schema_ctx)
+
+ if self.dq_dry_run:
+ self.log.info(
+ "dry_run=True — skipping execution. Plan contains %d group(s),
%d check(s).",
+ len(plan.groups),
+ len(plan.check_names),
+ )
+ for group in plan.groups:
+ self.log.info(
+ "Group: %s\nChecks: %s\nSQL Query:\n%s\n",
+ group.group_id,
+ ", ".join(c.check_name for c in group.checks),
+ group.query,
+ )
+ return {"plan": plan.model_dump(), "passed": None, "results": None}
+
+ if self.require_approval:
+ self.defer_for_approval( # type: ignore[misc]
+ context,
+ plan.model_dump_json(),
+ body=self._build_dry_run_markdown(plan),
+ )
+ return {} # type: ignore[return-value] # pragma: no cover
+
+ return self._run_checks_and_report(context, planner, plan)
+
+ def _build_planner(self) -> SQLDQPlanner:
+ """Construct a
:class:`~airflow.providers.common.ai.utils.dq_planner.SQLDQPlanner` from
operator config."""
+ fixed_validators = self._collect_fixed_validators()
+ return SQLDQPlanner(
+ llm_hook=self.llm_hook,
+ db_hook=self.db_hook,
+ dialect=self.dialect,
+ datasource_config=self.datasource_config,
+ system_prompt=self.system_prompt,
+ agent_params=self.agent_params,
+ collect_unexpected=self.collect_unexpected,
+ unexpected_sample_size=self.unexpected_sample_size,
+ validator_contexts=self.validator_contexts,
+ row_validators=self._collect_row_validators(),
+ row_level_sample_size=self.row_level_sample_size,
+ fixed_validators=fixed_validators,
+ )
+
+ @cached_property
+ def validator_contexts(self) -> str:
+ """Return validator-specific LLM context rendered from fixed
validators only."""
+ fixed = self._collect_fixed_validators()
+ return default_registry.build_llm_context(fixed)
+
+ def _run_checks_and_report(
+ self,
+ context: Context,
+ planner: SQLDQPlanner,
+ plan: DQPlan,
+ ) -> dict[str, Any]:
+ """
+ Execute *plan* against the database, apply validators, and return the
serialised report.
+
+ :raises DQCheckFailedError: If any data-quality check fails.
+ """
+ effective_validators =
self._resolve_effective_validators_from_plan(plan)
+ planner.set_row_validators(
+ {name: fn for name, fn in effective_validators.items() if
default_registry.is_row_level(fn)}
+ )
+ results_map = planner.execute_plan(plan)
+ check_results = self._validate_results(results_map, plan,
effective_validators)
+
+ if self.collect_unexpected:
+ failed_names = {r.check_name for r in check_results if not
r.passed}
+ if failed_names:
+ unexpected_map = planner.execute_unexpected_queries(plan,
failed_names)
+ self._attach_unexpected(check_results, unexpected_map)
+
+ report = DQReport.build(check_results)
+
+ output: dict[str, Any] = {
+ "plan": plan.model_dump(),
+ "passed": report.passed,
+ "results": [
+ {
+ "check_name": r.check_name,
+ "metric_key": r.metric_key,
+ "value": r.value.to_dict() if isinstance(r.value,
RowLevelResult) else r.value,
+ "passed": r.passed,
+ "failure_reason": r.failure_reason,
+ **(
+ {
+ "unexpected_records":
r.unexpected.unexpected_records,
+ "unexpected_sample_size": r.unexpected.sample_size,
+ }
+ if r.unexpected
+ else {}
+ ),
+ }
+ for r in report.results
+ ],
+ }
+
+ if not report.passed:
+ # Push results to XCom before failing so downstream tasks
+ # (e.g. with trigger_rule=all_done) can still inspect them.
+ context["ti"].xcom_push(key="return_value", value=output)
+ raise DQCheckFailedError(report.failure_summary)
+
+ self.log.info("All %d data-quality check(s) passed.",
len(report.results))
+ return output
+
+ def _build_dry_run_markdown(self, plan: DQPlan) -> str:
+ """
+ Build a structured markdown summary of the DQ plan for the HITL review
body.
+
+ Aggregate groups and row-level groups are rendered in separate
sections so
+ reviewers can immediately distinguish SQL-aggregate checks from per-row
+ validation logic.
+ """
+ aggregate_groups = [g for g in plan.groups if not any(c.row_level for
c in g.checks)]
+ row_level_groups = [g for g in plan.groups if any(c.row_level for c in
g.checks)]
+
+ total_checks = len(plan.check_names)
+ agg_count = sum(len(g.checks) for g in aggregate_groups)
+ row_count = sum(len(g.checks) for g in row_level_groups)
+
+ lines: list[str] = [
+ "# LLM Data Quality Plan",
+ "",
+ "| | |",
+ "|---|---|",
+ f"| **Plan hash** | `{plan.plan_hash or 'N/A'}` |",
+ f"| **Total checks** | {total_checks} |",
+ f"| **Aggregate checks** | {agg_count} ({len(aggregate_groups)}
group{'s' if len(aggregate_groups) != 1 else ''}) |",
+ f"| **Row-level checks** | {row_count} ({len(row_level_groups)}
group{'s' if len(row_level_groups) != 1 else ''}) |",
+ "",
+ ]
+
+ if aggregate_groups:
+ lines += [
+ "---",
+ "",
+ "## Aggregate Checks",
+ "",
+ "> Each group runs as a **single SQL query**. "
+ "Result columns are matched to check names by metric key.",
+ "",
+ ]
+ for group in aggregate_groups:
+ lines += self._render_aggregate_group(group)
+
+ if row_level_groups:
+ lines += [
+ "---",
+ "",
+ "## Row-Level Checks",
+ "",
+ "> Row-level checks fetch **raw column values** and apply
Python-side "
+ "validation per row. The threshold controls the maximum
allowed fraction "
+ "of invalid rows before the check fails.",
+ "",
+ ]
+ for group in row_level_groups:
+ lines += self._render_row_level_group(group)
+
+ return "\n".join(lines).rstrip()
+
+ def _render_aggregate_group(self, group: DQCheckGroup) -> list[str]:
+ """Render one aggregate SQL group as a markdown subsection."""
+ lines: list[str] = [
+ f"### `{group.group_id}`",
+ "",
+ "| Check name | Metric key | Category | Validator |",
+ "|---|---|---|---|",
+ ]
+ for check in group.checks:
+ category = check.check_category or "—"
+ validator_label = self._describe_validator_for_check(check)
+ lines.append(f"| `{check.check_name}` | `{check.metric_key}` |
{category} | {validator_label} |")
+
+ lines += [
+ "",
+ "```sql",
+ group.query.strip(),
+ "```",
+ "",
+ ]
+
+ # Unexpected queries — only show when present.
+ unexpected = [(c.check_name, c.unexpected_query) for c in group.checks
if c.unexpected_query]
+ if unexpected:
+ lines += ["<details><summary>Unexpected-row queries</summary>", ""]
+ for check_name, uq in unexpected:
+ lines += [
+ f"**`{check_name}`**",
+ "",
+ "```sql",
+ (uq or "").strip(),
+ "```",
+ "",
+ ]
+ lines += ["</details>", ""]
+
+ return lines
+
+ def _render_row_level_group(self, group: DQCheckGroup) -> list[str]:
+ """Render one row-level group as a markdown subsection with threshold
info."""
+ all_validators = self._resolve_effective_validators()
+ lines: list[str] = [
+ f"### `{group.group_id}`",
+ "",
+ "| Check name | Metric key | Max invalid % | Validator |",
+ "|---|---|---|---|",
+ ]
+ for check in group.checks:
+ validator = all_validators.get(check.check_name)
+ max_pct = (
+ self._resolve_row_level_max_invalid_pct(
+ check.check_name,
+ validator,
+ default_when_missing=None,
+ warn_on_missing=False,
+ )
+ if validator is not None
+ else None
+ )
+ threshold_str = f"{max_pct:.2%}" if max_pct is not None else "—"
+ validator_label = self._describe_validator_for_check(check)
+ lines.append(
Review Comment:
In the HITL markdown for row-level groups, `Max invalid %` is derived only
from user-fixed validator callables (`_resolve_effective_validators()`), so for
LLM-suggested row-level validators the threshold will display as `—` even
though the plan includes `validator_args` (e.g. `max_invalid_pct`). Consider
deriving the displayed threshold from `check.validator_args` when the callable
isn’t available yet, so reviewers can see the effective threshold before
approving.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]