Copilot commented on code in PR #62963: URL: https://github.com/apache/airflow/pull/62963#discussion_r3467353983
########## providers/common/ai/tests/unit/common/ai/operators/test_llm_data_quality.py: ########## @@ -0,0 +1,440 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for LLMDataQualityOperator.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.ai.operators.llm_data_quality import LLMDataQualityOperator +from airflow.providers.common.ai.toolsets.dataquality.base import BaseDQToolset +from airflow.providers.common.ai.utils.dataquality.models import ( + DQCheckFailedError, + DQCheckInput, + DQCheckPlan, + DQCheckResult, + DQPlan, + DQReport, +) + +_CHECKS = [ + DQCheckInput(name="null_emails", description="Check for null email addresses"), + DQCheckInput(name="dup_ids", description="Check for duplicate customer IDs"), +] + + +class _TaskInstanceLike: + def xcom_push(self, key: str, value: Any) -> None: + pass + + +class _UsageLike: + requests = 1 + tool_calls = 0 + input_tokens = 1 + output_tokens = 1 + total_tokens = 2 + + +class _ResponseLike: + model_name = "test-model" + + +class _AgentResultLike: + def __init__(self, output: Any) -> None: + self.output = output + self.response = _ResponseLike() + + def usage(self) -> _UsageLike: + return _UsageLike() + + def all_messages(self) -> list[Any]: + return [] + + +def _make_context() -> Any: + task_instance = MagicMock(spec=_TaskInstanceLike) + return {"task_instance": task_instance, "ti": task_instance} + + +def _make_operator(**overrides: Any) -> LLMDataQualityOperator: + defaults: dict[str, Any] = { + "task_id": "test_dq", + "checks": _CHECKS, + "llm_conn_id": "pydantic_ai_default", + "db_conn_id": "postgres_default", + } + defaults.update(overrides) + op = LLMDataQualityOperator(**defaults) + op.llm_hook = MagicMock(spec=PydanticAIHook) + return op + + +def _passing_report() -> DQReport: + return DQReport( + results=[ + DQCheckResult(check_name="null_emails", passed=True), + DQCheckResult(check_name="dup_ids", passed=True), + ], + passed=True, + ) + + +def _failing_report() -> DQReport: + return DQReport( + results=[ + DQCheckResult(check_name="null_emails", passed=False, failure_reason="100 nulls found"), + DQCheckResult(check_name="dup_ids", passed=True), + ], + passed=False, + failure_summary="null_emails: 100 nulls found", + ) + + +def _mock_agent_result(output: Any) -> _AgentResultLike: + return _AgentResultLike(output) + + +class TestLLMDataQualityOperatorInit: + def test_requires_llm_conn_id(self): + with pytest.raises(TypeError): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + db_conn_id="postgres_default", + ) + + def test_raises_when_no_toolsets_and_no_db_conn_id(self): + with pytest.raises(ValueError, match="Either toolsets or db_conn_id"): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + llm_conn_id="pydantic_ai_default", + ) + + def test_raises_when_empty_toolsets_and_no_db_conn_id(self): + with pytest.raises(ValueError, match="Either toolsets or db_conn_id"): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + llm_conn_id="pydantic_ai_default", + toolsets=[], + ) + + def test_empty_checks_raises_value_error(self): + with pytest.raises(ValueError, match="checks must not be empty"): + _make_operator(checks=[]) + + def test_duplicate_check_names_raises(self): + with pytest.raises(ValueError, match="duplicate"): + _make_operator( + checks=[ + DQCheckInput(name="dup", description="first"), + DQCheckInput(name="dup", description="second"), + ] + ) + + def test_dict_checks_auto_coerced(self): + op = _make_operator( + checks=[ + {"name": "null_emails", "description": "Check nulls"}, + {"name": "dup_ids", "description": "Check dups"}, + ] + ) + assert all(isinstance(c, DQCheckInput) for c in op.checks) + + def test_template_fields_include_required_keys(self): + op = _make_operator() + tf = set(op.template_fields) + assert {"checks", "system_prompt", "agent_params", "db_conn_id", "table_names"} <= tf + + +class TestResolveToolsets: + def test_explicit_toolsets_returned_unchanged(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "execute" + op = _make_operator(toolsets=[mock_dq]) + resolved = op._resolve_toolsets() + assert resolved is op.toolsets Review Comment: This test uses `MagicMock(spec=BaseDQToolset)`, but `LLMDataQualityOperator._resolve_toolsets()` detects DQ toolsets via `isinstance(ts, BaseDQToolset)`, so a MagicMock will not be recognized. Also `_resolve_toolsets()` returns a new list (`list(self.toolsets)`), so asserting identity (`is`) will fail even with a real toolset instance. ########## providers/common/ai/src/airflow/providers/common/ai/toolsets/dataquality/sql.py: ########## @@ -0,0 +1,399 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""SQL data-quality toolset for LLMDataQualityOperator.""" + +from __future__ import annotations + +import inspect +import json +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Literal + +try: + from airflow.providers.common.ai.utils.dataquality.validation import ValidatorRegistry, default_registry +except ImportError as e: + from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) + +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.abstract import ToolsetTool + +from airflow.providers.common.ai.toolsets.dataquality.base import _PASSTHROUGH_VALIDATOR, BaseDQToolset +from airflow.providers.common.ai.utils.dataquality.models import RowLevelResult + +if TYPE_CHECKING: + from pydantic_ai._run_context import RunContext + +_log = logging.getLogger(__name__) + +# JSON Schemas for the three SQL-DQ-specific tools. +_LIST_VALIDATORS_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": {}, +} + +_APPLY_VALIDATOR_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "check_name": {"type": "string", "description": "Name of the DQ check being evaluated."}, + "value": {"description": "Metric value returned by the SQL query for this check."}, + "validator_name": { + "type": "string", + "description": "Registered validator name (from list_validators). Pass 'none' to skip.", + }, + "validator_args": { + "type": "object", + "description": 'Keyword arguments for the validator factory (e.g. {"max_pct": 0.05}).', + }, + }, + "required": ["check_name", "value", "validator_name", "validator_args"], +} + +_ROW_LEVEL_SAMPLE_LIMIT = 20 + + +class SQLDQToolset(BaseDQToolset): + """ + Data-quality toolset for SQL-based checks. + + Provides three tools on top of + :class:`~airflow.providers.common.ai.toolsets.dataquality.base.BaseDQToolset`: + + * ``list_validators`` — exposes the registered validator catalog so the LLM + can choose an appropriate validator for each check. + * ``apply_validator`` — instantiates a validator and applies it to a metric + value, returning ``{"passed": bool, "reason": ...}``. + + Use this toolset alongside :class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset` + or :class:`~airflow.providers.common.ai.toolsets.datafusion.DataFusionToolset` in + :class:`~airflow.providers.common.ai.operators.llm_data_quality.LLMDataQualityOperator`. + The data-source toolset handles schema discovery and query execution; this + toolset handles the DQ-specific logic:: + + LLMDataQualityOperator( + checks=[...], + llm_conn_id="pydanticai_default", + toolsets=[ + SQLToolset(db_conn_id="postgres_default"), + SQLDQToolset(), + ], + ) + + :param validator_registry: Validator registry to expose to the LLM. + Defaults to :data:`~airflow.providers.common.ai.utils.dataquality.validation.default_registry`. + """ + + def __init__(self, *, validator_registry: ValidatorRegistry | None = None) -> None: + super().__init__() + self._registry = validator_registry if validator_registry is not None else default_registry + + @property + def id(self) -> str: + return "dq-sql" + + @property + def output_mode(self) -> Literal["execute", "generate"]: + return "execute" + + # ------------------------------------------------------------------ + # AbstractToolset interface + # ------------------------------------------------------------------ + + async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: + tools = await super().get_tools(ctx) + + for name, description, schema in ( + ( + "list_validators", + "List available validator names, parameters, and descriptions.", + _LIST_VALIDATORS_SCHEMA, + ), + ( + "apply_validator", + "Apply a registered validator to a metric value and return pass/fail.", + _APPLY_VALIDATOR_SCHEMA, + ), + ): + # In planning mode (two-phase approval flow) the agent must NOT call + # apply_validator — it only selects a validator by name/args and records + # that choice in the DQPlan output. apply_validator runs in Phase 2 + # (pure Python) after the human reviewer approves. + if name == "apply_validator" and getattr(self, "_planning_mode", False): + continue + tool_def = ToolDefinition( + name=name, + description=description, + parameters_json_schema=schema, + sequential=True, + ) + tools[name] = ToolsetTool( + toolset=self, + tool_def=tool_def, + max_retries=1, + args_validator=_PASSTHROUGH_VALIDATOR, + ) + return tools + + async def call_tool( + self, + name: str, + tool_args: dict[str, Any], + ctx: RunContext[Any], + tool: ToolsetTool[Any], + ) -> Any: + if name == "list_validators": + result = self._list_validators() + _log.info("list_validators result: %s", result) + return result + if name == "apply_validator": + check_name = tool_args["check_name"] + validator_name = tool_args["validator_name"] + validator_args = tool_args.get("validator_args") or {} + value = tool_args["value"] + + # Resolve the display name: for fixed validators, show the actual + # validator's name/signature followed by "(fixed)" so logs are + # meaningful instead of just showing "fixed". + if validator_name.lower() == "fixed" or self._has_fixed_validator(check_name): + fixed = self._get_fixed_validator(check_name) + display_name = getattr(fixed, "_validator_display", None) or getattr( + fixed, "_validator_name", "fixed" + ) + log_validator = f"{display_name}(fixed)" + else: + log_validator = validator_name + + _log.info( + "apply_validator: check=%s, validator=%s, args=%s, value=%s", + check_name, + log_validator, + validator_args, + value, + ) + result = self._apply_validator(check_name, value, validator_name, validator_args) + parsed = json.loads(result) + _log.info( + "apply_validator result: check=%s, passed=%s, reason=%s", + parsed.get("check_name"), + parsed.get("passed"), + parsed.get("reason"), + ) + return result + return await super().call_tool(name, tool_args, ctx, tool) + + # ------------------------------------------------------------------ + # Tool implementations + # ------------------------------------------------------------------ + + def _list_validators(self) -> str: + entries = [] + for name in self._registry.list_validators(): + entry = self._registry.get(name) + entries.append( + { + "name": name, + "category": entry.check_category or None, + "row_level": entry.row_level, + "parameters": self._format_signature(entry.factory), + "description": entry.llm_context or None, + } + ) + return json.dumps(entries) + + def _apply_validator( + self, + check_name: str, + value: Any, + validator_name: str, + validator_args: dict[str, Any], + ) -> str: + if validator_name.lower() == "fixed" or self._has_fixed_validator(check_name): + fixed = self._get_fixed_validator(check_name) + if fixed is None: + return json.dumps( + { + "check_name": check_name, + "passed": False, + "reason": f"No fixed validator configured for check {check_name!r}.", + } + ) + is_row_level = bool(getattr(fixed, "_row_level", False)) + if is_row_level: + return self._apply_row_level_validator(check_name, value, fixed, validator_name="fixed") + + _log.info("apply_validator scalar: check=%s, value=%s", check_name, value) + try: + passed = bool(fixed(value)) + except Exception as exc: + return json.dumps({"check_name": check_name, "passed": False, "reason": str(exc)}) + reason = None if passed else f"Fixed validator returned False for value {value!r}" + return json.dumps({"check_name": check_name, "passed": passed, "reason": reason}) + + if not validator_name or validator_name.lower() == "none": + return json.dumps({"check_name": check_name, "passed": True, "reason": "no validator"}) + + ok, err = self._validate_suggestion(validator_name, validator_args) + if not ok: + return json.dumps({"check_name": check_name, "passed": False, "reason": err}) + + try: + entry = self._registry.get(validator_name) + validator = entry.factory(**validator_args) + except Exception as exc: + return json.dumps({"check_name": check_name, "passed": False, "reason": str(exc)}) + + if entry.row_level: + return self._apply_row_level_validator( + check_name, + value, + validator, + validator_name=validator_name, + ) + + try: + passed = bool(validator(value)) + except Exception as exc: + return json.dumps({"check_name": check_name, "passed": False, "reason": str(exc)}) + + reason = None if passed else f"{validator_name} returned False for value {value!r}" + return json.dumps({"check_name": check_name, "passed": passed, "reason": reason}) + + def _apply_row_level_validator( + self, + check_name: str, + value: Any, + validator: Callable[[Any], bool], + *, + validator_name: str, + ) -> str: + """Apply a row-level validator to each row value and return a RowLevelResult payload.""" + values = self._coerce_row_values(check_name, value) + invalid_values: list[Any] = [] + + for row_value in values: + try: + is_valid = bool(validator(row_value)) + except Exception as exc: + return json.dumps( + { + "check_name": check_name, + "passed": False, + "reason": ( + f"{validator_name} raised {exc!s} while evaluating row-level value {row_value!r}." + ), + } + ) + if not is_valid: + invalid_values.append(row_value) + + summary = RowLevelResult.build(invalid_values, len(values), sample_limit=_ROW_LEVEL_SAMPLE_LIMIT) + max_invalid_pct = self._get_max_invalid_pct(validator) Review Comment: Row-level validation accumulates every invalid row value into `invalid_values` before sampling, which can become a large in-memory list when `values` is big (e.g. if `max_rows` is raised or a non-SQL toolset returns many rows). It’s enough to track an invalid count plus a bounded sample list, and build `RowLevelResult` from those fields directly. ########## providers/common/ai/tests/unit/common/ai/operators/test_llm_data_quality.py: ########## @@ -0,0 +1,440 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for LLMDataQualityOperator.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.ai.operators.llm_data_quality import LLMDataQualityOperator +from airflow.providers.common.ai.toolsets.dataquality.base import BaseDQToolset +from airflow.providers.common.ai.utils.dataquality.models import ( + DQCheckFailedError, + DQCheckInput, + DQCheckPlan, + DQCheckResult, + DQPlan, + DQReport, +) + +_CHECKS = [ + DQCheckInput(name="null_emails", description="Check for null email addresses"), + DQCheckInput(name="dup_ids", description="Check for duplicate customer IDs"), +] + + +class _TaskInstanceLike: + def xcom_push(self, key: str, value: Any) -> None: + pass + + +class _UsageLike: + requests = 1 + tool_calls = 0 + input_tokens = 1 + output_tokens = 1 + total_tokens = 2 + + +class _ResponseLike: + model_name = "test-model" + + +class _AgentResultLike: + def __init__(self, output: Any) -> None: + self.output = output + self.response = _ResponseLike() + + def usage(self) -> _UsageLike: + return _UsageLike() + + def all_messages(self) -> list[Any]: + return [] + + +def _make_context() -> Any: + task_instance = MagicMock(spec=_TaskInstanceLike) + return {"task_instance": task_instance, "ti": task_instance} + + +def _make_operator(**overrides: Any) -> LLMDataQualityOperator: + defaults: dict[str, Any] = { + "task_id": "test_dq", + "checks": _CHECKS, + "llm_conn_id": "pydantic_ai_default", + "db_conn_id": "postgres_default", + } + defaults.update(overrides) + op = LLMDataQualityOperator(**defaults) + op.llm_hook = MagicMock(spec=PydanticAIHook) + return op + + +def _passing_report() -> DQReport: + return DQReport( + results=[ + DQCheckResult(check_name="null_emails", passed=True), + DQCheckResult(check_name="dup_ids", passed=True), + ], + passed=True, + ) + + +def _failing_report() -> DQReport: + return DQReport( + results=[ + DQCheckResult(check_name="null_emails", passed=False, failure_reason="100 nulls found"), + DQCheckResult(check_name="dup_ids", passed=True), + ], + passed=False, + failure_summary="null_emails: 100 nulls found", + ) + + +def _mock_agent_result(output: Any) -> _AgentResultLike: + return _AgentResultLike(output) + + +class TestLLMDataQualityOperatorInit: + def test_requires_llm_conn_id(self): + with pytest.raises(TypeError): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + db_conn_id="postgres_default", + ) + + def test_raises_when_no_toolsets_and_no_db_conn_id(self): + with pytest.raises(ValueError, match="Either toolsets or db_conn_id"): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + llm_conn_id="pydantic_ai_default", + ) + + def test_raises_when_empty_toolsets_and_no_db_conn_id(self): + with pytest.raises(ValueError, match="Either toolsets or db_conn_id"): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + llm_conn_id="pydantic_ai_default", + toolsets=[], + ) + + def test_empty_checks_raises_value_error(self): + with pytest.raises(ValueError, match="checks must not be empty"): + _make_operator(checks=[]) + + def test_duplicate_check_names_raises(self): + with pytest.raises(ValueError, match="duplicate"): + _make_operator( + checks=[ + DQCheckInput(name="dup", description="first"), + DQCheckInput(name="dup", description="second"), + ] + ) + + def test_dict_checks_auto_coerced(self): + op = _make_operator( + checks=[ + {"name": "null_emails", "description": "Check nulls"}, + {"name": "dup_ids", "description": "Check dups"}, + ] + ) + assert all(isinstance(c, DQCheckInput) for c in op.checks) + + def test_template_fields_include_required_keys(self): + op = _make_operator() + tf = set(op.template_fields) + assert {"checks", "system_prompt", "agent_params", "db_conn_id", "table_names"} <= tf + + +class TestResolveToolsets: + def test_explicit_toolsets_returned_unchanged(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "execute" + op = _make_operator(toolsets=[mock_dq]) + resolved = op._resolve_toolsets() + assert resolved is op.toolsets + + def test_auto_creates_sql_toolsets_from_db_conn_id(self): + op = _make_operator() + with ( + patch("airflow.providers.common.ai.toolsets.sql.SQLToolset") as mock_sql, + patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset") as mock_dq, + ): + op._resolve_toolsets() + mock_sql.assert_called_once_with(db_conn_id="postgres_default", allowed_tables=None) + mock_dq.assert_called_once_with() + + def test_auto_creates_sql_toolsets_with_table_names(self): + op = _make_operator(table_names=["customers", "orders"]) + with ( + patch("airflow.providers.common.ai.toolsets.sql.SQLToolset") as mock_sql, + patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset"), + ): + op._resolve_toolsets() + mock_sql.assert_called_once_with( + db_conn_id="postgres_default", allowed_tables=["customers", "orders"] + ) + + def test_auto_appends_sql_dq_toolset_when_missing(self): + mock_other = object() + op = _make_operator(toolsets=[mock_other]) + with patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset") as mock_cls: + mock_cls.return_value = MagicMock(spec=BaseDQToolset) + resolved = op._resolve_toolsets() + mock_cls.assert_called_once_with() + assert len(resolved) == 2 + + +class TestFindDqToolset: + def test_returns_first_base_dq_toolset(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_other = object() + result = LLMDataQualityOperator._find_dq_toolset([mock_other, mock_dq]) + assert result is mock_dq + + def test_raises_when_no_dq_toolset(self): + mock_other = object() + with pytest.raises(ValueError, match="No BaseDQToolset found"): + LLMDataQualityOperator._find_dq_toolset([mock_other]) + + +class TestExecuteMode: + def _run(self, report: DQReport, **overrides: Any) -> Any: + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "execute" + op = _make_operator(toolsets=[mock_dq], **overrides) + op.llm_hook.create_agent.return_value.run_sync.return_value = _mock_agent_result(report) # type: ignore[attr-defined] + return op.execute(context=_make_context()) + + def test_set_checks_called_on_dq_toolset(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "execute" + op = _make_operator(toolsets=[mock_dq]) + op.llm_hook.create_agent.return_value.run_sync.return_value = _mock_agent_result(_passing_report()) + op.execute(context=_make_context()) + mock_dq.set_checks.assert_called_once_with(op.checks) + + def test_passing_report_returns_model_dump(self): + result = self._run(_passing_report()) + assert result["passed"] is True + assert isinstance(result["results"], list) + + def test_failing_report_raises_dq_check_failed_error(self): + with pytest.raises(DQCheckFailedError, match="null_emails"): + self._run(_failing_report()) + + def test_output_type_is_dq_report(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "execute" + op = _make_operator(toolsets=[mock_dq]) + op.llm_hook.create_agent.return_value.run_sync.return_value = _mock_agent_result(_passing_report()) + op.execute(context=_make_context()) + create_agent_kwargs = op.llm_hook.create_agent.call_args.kwargs + assert create_agent_kwargs["output_type"] is DQReport + + +class TestGenerateMode: + def _run_generate(self, config_str: str = "checks:\n - name: foo") -> Any: + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "generate" + op = _make_operator(toolsets=[mock_dq]) + op.llm_hook.create_agent.return_value.run_sync.return_value = _mock_agent_result(config_str) # type: ignore[attr-defined] + return op.execute(context=_make_context()) Review Comment: Like the execute-mode tests, using `MagicMock(spec=BaseDQToolset)` here won’t satisfy `isinstance(..., BaseDQToolset)` checks in the operator, so the generate-mode path may not be exercised as intended. Use a minimal concrete `BaseDQToolset` implementation with `output_mode="generate"`. ########## providers/common/ai/tests/unit/common/ai/operators/test_llm_data_quality.py: ########## @@ -0,0 +1,440 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for LLMDataQualityOperator.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.ai.operators.llm_data_quality import LLMDataQualityOperator +from airflow.providers.common.ai.toolsets.dataquality.base import BaseDQToolset +from airflow.providers.common.ai.utils.dataquality.models import ( + DQCheckFailedError, + DQCheckInput, + DQCheckPlan, + DQCheckResult, + DQPlan, + DQReport, +) + +_CHECKS = [ + DQCheckInput(name="null_emails", description="Check for null email addresses"), + DQCheckInput(name="dup_ids", description="Check for duplicate customer IDs"), +] + + +class _TaskInstanceLike: + def xcom_push(self, key: str, value: Any) -> None: + pass + + +class _UsageLike: + requests = 1 + tool_calls = 0 + input_tokens = 1 + output_tokens = 1 + total_tokens = 2 + + +class _ResponseLike: + model_name = "test-model" + + +class _AgentResultLike: + def __init__(self, output: Any) -> None: + self.output = output + self.response = _ResponseLike() + + def usage(self) -> _UsageLike: + return _UsageLike() + + def all_messages(self) -> list[Any]: + return [] + + +def _make_context() -> Any: + task_instance = MagicMock(spec=_TaskInstanceLike) + return {"task_instance": task_instance, "ti": task_instance} + + +def _make_operator(**overrides: Any) -> LLMDataQualityOperator: + defaults: dict[str, Any] = { + "task_id": "test_dq", + "checks": _CHECKS, + "llm_conn_id": "pydantic_ai_default", + "db_conn_id": "postgres_default", + } + defaults.update(overrides) + op = LLMDataQualityOperator(**defaults) + op.llm_hook = MagicMock(spec=PydanticAIHook) + return op + + +def _passing_report() -> DQReport: + return DQReport( + results=[ + DQCheckResult(check_name="null_emails", passed=True), + DQCheckResult(check_name="dup_ids", passed=True), + ], + passed=True, + ) + + +def _failing_report() -> DQReport: + return DQReport( + results=[ + DQCheckResult(check_name="null_emails", passed=False, failure_reason="100 nulls found"), + DQCheckResult(check_name="dup_ids", passed=True), + ], + passed=False, + failure_summary="null_emails: 100 nulls found", + ) + + +def _mock_agent_result(output: Any) -> _AgentResultLike: + return _AgentResultLike(output) + + +class TestLLMDataQualityOperatorInit: + def test_requires_llm_conn_id(self): + with pytest.raises(TypeError): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + db_conn_id="postgres_default", + ) + + def test_raises_when_no_toolsets_and_no_db_conn_id(self): + with pytest.raises(ValueError, match="Either toolsets or db_conn_id"): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + llm_conn_id="pydantic_ai_default", + ) + + def test_raises_when_empty_toolsets_and_no_db_conn_id(self): + with pytest.raises(ValueError, match="Either toolsets or db_conn_id"): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + llm_conn_id="pydantic_ai_default", + toolsets=[], + ) + + def test_empty_checks_raises_value_error(self): + with pytest.raises(ValueError, match="checks must not be empty"): + _make_operator(checks=[]) + + def test_duplicate_check_names_raises(self): + with pytest.raises(ValueError, match="duplicate"): + _make_operator( + checks=[ + DQCheckInput(name="dup", description="first"), + DQCheckInput(name="dup", description="second"), + ] + ) + + def test_dict_checks_auto_coerced(self): + op = _make_operator( + checks=[ + {"name": "null_emails", "description": "Check nulls"}, + {"name": "dup_ids", "description": "Check dups"}, + ] + ) + assert all(isinstance(c, DQCheckInput) for c in op.checks) + + def test_template_fields_include_required_keys(self): + op = _make_operator() + tf = set(op.template_fields) + assert {"checks", "system_prompt", "agent_params", "db_conn_id", "table_names"} <= tf + + +class TestResolveToolsets: + def test_explicit_toolsets_returned_unchanged(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "execute" + op = _make_operator(toolsets=[mock_dq]) + resolved = op._resolve_toolsets() + assert resolved is op.toolsets + + def test_auto_creates_sql_toolsets_from_db_conn_id(self): + op = _make_operator() + with ( + patch("airflow.providers.common.ai.toolsets.sql.SQLToolset") as mock_sql, + patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset") as mock_dq, + ): + op._resolve_toolsets() + mock_sql.assert_called_once_with(db_conn_id="postgres_default", allowed_tables=None) + mock_dq.assert_called_once_with() + + def test_auto_creates_sql_toolsets_with_table_names(self): + op = _make_operator(table_names=["customers", "orders"]) + with ( + patch("airflow.providers.common.ai.toolsets.sql.SQLToolset") as mock_sql, + patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset"), + ): + op._resolve_toolsets() + mock_sql.assert_called_once_with( + db_conn_id="postgres_default", allowed_tables=["customers", "orders"] + ) + + def test_auto_appends_sql_dq_toolset_when_missing(self): + mock_other = object() + op = _make_operator(toolsets=[mock_other]) + with patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset") as mock_cls: + mock_cls.return_value = MagicMock(spec=BaseDQToolset) + resolved = op._resolve_toolsets() + mock_cls.assert_called_once_with() + assert len(resolved) == 2 + + +class TestFindDqToolset: + def test_returns_first_base_dq_toolset(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_other = object() + result = LLMDataQualityOperator._find_dq_toolset([mock_other, mock_dq]) + assert result is mock_dq + + def test_raises_when_no_dq_toolset(self): + mock_other = object() + with pytest.raises(ValueError, match="No BaseDQToolset found"): + LLMDataQualityOperator._find_dq_toolset([mock_other]) + + +class TestExecuteMode: + def _run(self, report: DQReport, **overrides: Any) -> Any: + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "execute" + op = _make_operator(toolsets=[mock_dq], **overrides) + op.llm_hook.create_agent.return_value.run_sync.return_value = _mock_agent_result(report) # type: ignore[attr-defined] + return op.execute(context=_make_context()) Review Comment: Here and in several other tests in this file, `MagicMock(spec=BaseDQToolset)` is passed as a toolset, but production code relies on `isinstance(..., BaseDQToolset)` for discovery. That means the operator will treat the toolsets list as missing a DQ toolset and may append a real `SQLDQToolset` or raise, so these tests won’t exercise the intended path. ########## providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py: ########## @@ -213,6 +229,20 @@ def _is_ref_allowed(self, catalog: str | None, schema: str | None, table: str) - def id(self) -> str: return f"sql-{self._db_conn_id}" + @property + def sqlglot_dialect(self) -> str | None: + """Sqlglot dialect name for this connection, used for SQL validation and LLM prompting.""" + if self._explicit_dialect is not None: Review Comment: `sqlglot_dialect` docstring says it’s used for “SQL validation and LLM prompting”, but the toolset’s validation path currently uses `_dialect_for_validation()` (hook.dialect_name) and doesn’t consult this property. Either wire this into the validation path, or adjust the docstring to avoid misleading API consumers. ########## providers/common/ai/tests/unit/common/ai/operators/test_llm_data_quality.py: ########## @@ -0,0 +1,440 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for LLMDataQualityOperator.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.ai.operators.llm_data_quality import LLMDataQualityOperator +from airflow.providers.common.ai.toolsets.dataquality.base import BaseDQToolset +from airflow.providers.common.ai.utils.dataquality.models import ( + DQCheckFailedError, + DQCheckInput, + DQCheckPlan, + DQCheckResult, + DQPlan, + DQReport, +) + +_CHECKS = [ + DQCheckInput(name="null_emails", description="Check for null email addresses"), + DQCheckInput(name="dup_ids", description="Check for duplicate customer IDs"), +] + + +class _TaskInstanceLike: + def xcom_push(self, key: str, value: Any) -> None: + pass + + +class _UsageLike: + requests = 1 + tool_calls = 0 + input_tokens = 1 + output_tokens = 1 + total_tokens = 2 + + +class _ResponseLike: + model_name = "test-model" + + +class _AgentResultLike: + def __init__(self, output: Any) -> None: + self.output = output + self.response = _ResponseLike() + + def usage(self) -> _UsageLike: + return _UsageLike() + + def all_messages(self) -> list[Any]: + return [] + + +def _make_context() -> Any: + task_instance = MagicMock(spec=_TaskInstanceLike) + return {"task_instance": task_instance, "ti": task_instance} + + +def _make_operator(**overrides: Any) -> LLMDataQualityOperator: + defaults: dict[str, Any] = { + "task_id": "test_dq", + "checks": _CHECKS, + "llm_conn_id": "pydantic_ai_default", + "db_conn_id": "postgres_default", + } + defaults.update(overrides) + op = LLMDataQualityOperator(**defaults) + op.llm_hook = MagicMock(spec=PydanticAIHook) + return op + + +def _passing_report() -> DQReport: + return DQReport( + results=[ + DQCheckResult(check_name="null_emails", passed=True), + DQCheckResult(check_name="dup_ids", passed=True), + ], + passed=True, + ) + + +def _failing_report() -> DQReport: + return DQReport( + results=[ + DQCheckResult(check_name="null_emails", passed=False, failure_reason="100 nulls found"), + DQCheckResult(check_name="dup_ids", passed=True), + ], + passed=False, + failure_summary="null_emails: 100 nulls found", + ) + + +def _mock_agent_result(output: Any) -> _AgentResultLike: + return _AgentResultLike(output) + + +class TestLLMDataQualityOperatorInit: + def test_requires_llm_conn_id(self): + with pytest.raises(TypeError): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + db_conn_id="postgres_default", + ) + + def test_raises_when_no_toolsets_and_no_db_conn_id(self): + with pytest.raises(ValueError, match="Either toolsets or db_conn_id"): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + llm_conn_id="pydantic_ai_default", + ) + + def test_raises_when_empty_toolsets_and_no_db_conn_id(self): + with pytest.raises(ValueError, match="Either toolsets or db_conn_id"): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + llm_conn_id="pydantic_ai_default", + toolsets=[], + ) + + def test_empty_checks_raises_value_error(self): + with pytest.raises(ValueError, match="checks must not be empty"): + _make_operator(checks=[]) + + def test_duplicate_check_names_raises(self): + with pytest.raises(ValueError, match="duplicate"): + _make_operator( + checks=[ + DQCheckInput(name="dup", description="first"), + DQCheckInput(name="dup", description="second"), + ] + ) + + def test_dict_checks_auto_coerced(self): + op = _make_operator( + checks=[ + {"name": "null_emails", "description": "Check nulls"}, + {"name": "dup_ids", "description": "Check dups"}, + ] + ) + assert all(isinstance(c, DQCheckInput) for c in op.checks) + + def test_template_fields_include_required_keys(self): + op = _make_operator() + tf = set(op.template_fields) + assert {"checks", "system_prompt", "agent_params", "db_conn_id", "table_names"} <= tf + + +class TestResolveToolsets: + def test_explicit_toolsets_returned_unchanged(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "execute" + op = _make_operator(toolsets=[mock_dq]) + resolved = op._resolve_toolsets() + assert resolved is op.toolsets + + def test_auto_creates_sql_toolsets_from_db_conn_id(self): + op = _make_operator() + with ( + patch("airflow.providers.common.ai.toolsets.sql.SQLToolset") as mock_sql, + patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset") as mock_dq, + ): + op._resolve_toolsets() + mock_sql.assert_called_once_with(db_conn_id="postgres_default", allowed_tables=None) + mock_dq.assert_called_once_with() + + def test_auto_creates_sql_toolsets_with_table_names(self): + op = _make_operator(table_names=["customers", "orders"]) + with ( + patch("airflow.providers.common.ai.toolsets.sql.SQLToolset") as mock_sql, + patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset"), + ): + op._resolve_toolsets() + mock_sql.assert_called_once_with( + db_conn_id="postgres_default", allowed_tables=["customers", "orders"] + ) + + def test_auto_appends_sql_dq_toolset_when_missing(self): + mock_other = object() + op = _make_operator(toolsets=[mock_other]) + with patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset") as mock_cls: + mock_cls.return_value = MagicMock(spec=BaseDQToolset) + resolved = op._resolve_toolsets() + mock_cls.assert_called_once_with() + assert len(resolved) == 2 + + +class TestFindDqToolset: + def test_returns_first_base_dq_toolset(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_other = object() + result = LLMDataQualityOperator._find_dq_toolset([mock_other, mock_dq]) + assert result is mock_dq + + def test_raises_when_no_dq_toolset(self): + mock_other = object() + with pytest.raises(ValueError, match="No BaseDQToolset found"): + LLMDataQualityOperator._find_dq_toolset([mock_other]) + + +class TestExecuteMode: + def _run(self, report: DQReport, **overrides: Any) -> Any: + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "execute" + op = _make_operator(toolsets=[mock_dq], **overrides) + op.llm_hook.create_agent.return_value.run_sync.return_value = _mock_agent_result(report) # type: ignore[attr-defined] + return op.execute(context=_make_context()) + + def test_set_checks_called_on_dq_toolset(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "execute" + op = _make_operator(toolsets=[mock_dq]) + op.llm_hook.create_agent.return_value.run_sync.return_value = _mock_agent_result(_passing_report()) + op.execute(context=_make_context()) + mock_dq.set_checks.assert_called_once_with(op.checks) + Review Comment: `set_checks` call assertion relies on a `MagicMock(spec=BaseDQToolset)` toolset, but the operator discovers DQ toolsets via `isinstance`, so this mock won’t be used as the DQ toolset. Use a real `BaseDQToolset` subclass instance and wrap `set_checks` with a spy mock if you need call assertions. ########## providers/common/ai/tests/unit/common/ai/operators/test_llm_data_quality.py: ########## @@ -0,0 +1,440 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for LLMDataQualityOperator.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.ai.operators.llm_data_quality import LLMDataQualityOperator +from airflow.providers.common.ai.toolsets.dataquality.base import BaseDQToolset +from airflow.providers.common.ai.utils.dataquality.models import ( + DQCheckFailedError, + DQCheckInput, + DQCheckPlan, + DQCheckResult, + DQPlan, + DQReport, +) + +_CHECKS = [ + DQCheckInput(name="null_emails", description="Check for null email addresses"), + DQCheckInput(name="dup_ids", description="Check for duplicate customer IDs"), +] + + +class _TaskInstanceLike: + def xcom_push(self, key: str, value: Any) -> None: + pass + + +class _UsageLike: + requests = 1 + tool_calls = 0 + input_tokens = 1 + output_tokens = 1 + total_tokens = 2 + + +class _ResponseLike: + model_name = "test-model" + + +class _AgentResultLike: + def __init__(self, output: Any) -> None: + self.output = output + self.response = _ResponseLike() + + def usage(self) -> _UsageLike: + return _UsageLike() + + def all_messages(self) -> list[Any]: + return [] + + +def _make_context() -> Any: + task_instance = MagicMock(spec=_TaskInstanceLike) + return {"task_instance": task_instance, "ti": task_instance} + + +def _make_operator(**overrides: Any) -> LLMDataQualityOperator: + defaults: dict[str, Any] = { + "task_id": "test_dq", + "checks": _CHECKS, + "llm_conn_id": "pydantic_ai_default", + "db_conn_id": "postgres_default", + } + defaults.update(overrides) + op = LLMDataQualityOperator(**defaults) + op.llm_hook = MagicMock(spec=PydanticAIHook) + return op + + +def _passing_report() -> DQReport: + return DQReport( + results=[ + DQCheckResult(check_name="null_emails", passed=True), + DQCheckResult(check_name="dup_ids", passed=True), + ], + passed=True, + ) + + +def _failing_report() -> DQReport: + return DQReport( + results=[ + DQCheckResult(check_name="null_emails", passed=False, failure_reason="100 nulls found"), + DQCheckResult(check_name="dup_ids", passed=True), + ], + passed=False, + failure_summary="null_emails: 100 nulls found", + ) + + +def _mock_agent_result(output: Any) -> _AgentResultLike: + return _AgentResultLike(output) + + +class TestLLMDataQualityOperatorInit: + def test_requires_llm_conn_id(self): + with pytest.raises(TypeError): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + db_conn_id="postgres_default", + ) + + def test_raises_when_no_toolsets_and_no_db_conn_id(self): + with pytest.raises(ValueError, match="Either toolsets or db_conn_id"): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + llm_conn_id="pydantic_ai_default", + ) + + def test_raises_when_empty_toolsets_and_no_db_conn_id(self): + with pytest.raises(ValueError, match="Either toolsets or db_conn_id"): + LLMDataQualityOperator( + task_id="test_dq", + checks=_CHECKS, + llm_conn_id="pydantic_ai_default", + toolsets=[], + ) + + def test_empty_checks_raises_value_error(self): + with pytest.raises(ValueError, match="checks must not be empty"): + _make_operator(checks=[]) + + def test_duplicate_check_names_raises(self): + with pytest.raises(ValueError, match="duplicate"): + _make_operator( + checks=[ + DQCheckInput(name="dup", description="first"), + DQCheckInput(name="dup", description="second"), + ] + ) + + def test_dict_checks_auto_coerced(self): + op = _make_operator( + checks=[ + {"name": "null_emails", "description": "Check nulls"}, + {"name": "dup_ids", "description": "Check dups"}, + ] + ) + assert all(isinstance(c, DQCheckInput) for c in op.checks) + + def test_template_fields_include_required_keys(self): + op = _make_operator() + tf = set(op.template_fields) + assert {"checks", "system_prompt", "agent_params", "db_conn_id", "table_names"} <= tf + + +class TestResolveToolsets: + def test_explicit_toolsets_returned_unchanged(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_dq.output_mode = "execute" + op = _make_operator(toolsets=[mock_dq]) + resolved = op._resolve_toolsets() + assert resolved is op.toolsets + + def test_auto_creates_sql_toolsets_from_db_conn_id(self): + op = _make_operator() + with ( + patch("airflow.providers.common.ai.toolsets.sql.SQLToolset") as mock_sql, + patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset") as mock_dq, + ): + op._resolve_toolsets() + mock_sql.assert_called_once_with(db_conn_id="postgres_default", allowed_tables=None) + mock_dq.assert_called_once_with() + + def test_auto_creates_sql_toolsets_with_table_names(self): + op = _make_operator(table_names=["customers", "orders"]) + with ( + patch("airflow.providers.common.ai.toolsets.sql.SQLToolset") as mock_sql, + patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset"), + ): + op._resolve_toolsets() + mock_sql.assert_called_once_with( + db_conn_id="postgres_default", allowed_tables=["customers", "orders"] + ) + + def test_auto_appends_sql_dq_toolset_when_missing(self): + mock_other = object() + op = _make_operator(toolsets=[mock_other]) + with patch("airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset") as mock_cls: + mock_cls.return_value = MagicMock(spec=BaseDQToolset) + resolved = op._resolve_toolsets() + mock_cls.assert_called_once_with() + assert len(resolved) == 2 + + +class TestFindDqToolset: + def test_returns_first_base_dq_toolset(self): + mock_dq = MagicMock(spec=BaseDQToolset) + mock_other = object() + result = LLMDataQualityOperator._find_dq_toolset([mock_other, mock_dq]) + assert result is mock_dq Review Comment: `_find_dq_toolset()` uses `isinstance(toolset, BaseDQToolset)`, so `MagicMock(spec=BaseDQToolset)` won’t match and this test will fail. Use a minimal concrete `BaseDQToolset` subclass instance instead of a mock for the DQ toolset slot. ########## providers/common/ai/tests/unit/common/ai/decorators/test_llm_data_quality.py: ########## @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from airflow.providers.common.ai.decorators.llm_data_quality import ( + _LLMDQDecoratedOperator, +) +from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook +from airflow.providers.common.ai.toolsets.dataquality.base import BaseDQToolset +from airflow.providers.common.ai.utils.dataquality.models import ( + DQCheckInput, + DQCheckResult, + DQReport, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_CHECKS = [ + DQCheckInput(name="null_emails", description="Check for null email addresses"), + DQCheckInput(name="dup_ids", description="Check for duplicate customer IDs"), +] + + +def _passing_report() -> DQReport: + return DQReport( + results=[ + DQCheckResult(check_name="null_emails", passed=True), + DQCheckResult(check_name="dup_ids", passed=True), + ], + passed=True, + ) + + +def _make_op(callable_fn=None, **kwargs) -> _LLMDQDecoratedOperator: + if callable_fn is None: + + def callable_fn(): + return _CHECKS + + return _LLMDQDecoratedOperator( + task_id="test_dq", + python_callable=callable_fn, + llm_conn_id="pydantic_ai_default", + db_conn_id="postgres_default", + **kwargs, + ) + + +def _mock_agent_result(output) -> MagicMock: + result = MagicMock() + result.output = output + result.usage.return_value = MagicMock( + requests=1, tool_calls=0, input_tokens=0, output_tokens=0, total_tokens=0 + ) + result.response = MagicMock(model_name="test-model") + result.all_messages.return_value = [] + return result Review Comment: This helper constructs several `MagicMock()` objects without a `spec`/`autospec`, which can hide interface drift (e.g. `usage` vs `usage()` changes) and make refactors harder to catch. Prefer `MagicMock(spec=[...])` for the agent result and response objects, consistent with other tests in this provider. ########## providers/common/ai/docs/operators/llm_data_quality.rst: ########## @@ -0,0 +1,502 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:llm_data_quality: + +``LLMDataQualityOperator`` +========================== + +Use :class:`~airflow.providers.common.ai.operators.llm_data_quality.LLMDataQualityOperator` +to run natural-language data-quality checks against a live data source using an LLM agent. + +The agent: + +1. Calls ``list_checks`` to read the user's quality expectations. +2. Calls ``list_validators`` to discover available validators. +3. Uses schema-discovery tools (``list_tables``, ``get_schema``) to explore the data source. +4. Writes and executes SQL queries, applies validators, and produces a + :class:`~airflow.providers.common.ai.utils.dataquality.models.DQReport`. +5. Fails the task when any check does not pass. + +.. seealso:: + :ref:`Connection configuration <howto/connection:pydanticai>` + +Architecture — Toolsets +----------------------- + +The operator is composed of two toolsets that are passed together in ``toolsets``: + +- **Data-source toolset** — gives the agent access to the data (SQL queries, schema discovery). + Examples: :class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset` for relational + databases, :class:`~airflow.providers.common.ai.toolsets.datafusion.DataFusionToolset` + for object-storage formats (Parquet, CSV, Avro on S3, GCS, etc.). + +- **DQ toolset** (:class:`~airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset`) — + adds the ``list_validators`` and ``apply_validator`` tools so the agent can evaluate metrics + against registered validator thresholds. + +When ``toolsets`` is omitted, the operator auto-creates +:class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset` and +:class:`~airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset` +from ``db_conn_id`` and ``table_names``. + +PostgreSQL / Relational Databases +---------------------------------- + +Pass ``SQLToolset`` (schema-discovery + SQL execution) alongside ``SQLDQToolset``: + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py + :language: python + :start-after: [START howto_operator_llm_dq_postgres_basic] + :end-before: [END howto_operator_llm_dq_postgres_basic] + +Object Storage (S3 / Parquet) +------------------------------ + +Use :class:`~airflow.providers.common.ai.toolsets.datafusion.DataFusionToolset` to register +S3 (or other object-store) Parquet/CSV data as a queryable DataFusion table, then pair it with +``SQLDQToolset``: + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py + :language: python + :start-after: [START howto_operator_llm_dq_object_storage] + :end-before: [END howto_operator_llm_dq_object_storage] + +.. note:: + The ``region`` key in the ``aws_default`` connection **Extra** field must match the + bucket's AWS region to avoid redirect errors, e.g. ``{"region": "eu-central-1"}``. + +Explicit Toolsets +----------------- + +Passing ``toolsets`` directly gives full control over which tables the agent can access. + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py + :language: python + :start-after: [START howto_operator_llm_dq_explicit_toolsets] + :end-before: [END howto_operator_llm_dq_explicit_toolsets] + +Checks +------ + +Each entry in ``checks`` is a +:class:`~airflow.providers.common.ai.utils.dataquality.models.DQCheckInput` +(or a plain ``dict`` with ``name``, ``description``, and optional ``validator`` keys). + +.. code-block:: python + + from airflow.providers.common.ai.utils.dataquality import DQCheckInput, null_pct_check + + checks = [ + DQCheckInput( + name="null_email", + description="Check the percentage of rows where email is NULL", + validator=null_pct_check(max_pct=0.05), + ), + DQCheckInput( + name="row_count", + description="Ensure the table has at least 1000 rows", + ), + ] + +Check names must be unique within a single operator call. + +Validators +---------- + +Built-in validator factories are exported from +:mod:`~airflow.providers.common.ai.utils.dataquality`: + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + * - Factory + - Description + * - ``null_pct_check(max_pct=...)`` + - Passes when the fraction of NULL values is at or below ``max_pct``. + * - ``row_count_check(min_count=...)`` + - Passes when the row count is at least ``min_count``. + * - ``duplicate_pct_check(max_pct=...)`` + - Passes when the fraction of duplicate values is at or below ``max_pct``. + * - ``between_check(min_val=..., max_val=...)`` + - Passes when the metric value falls within ``[min_val, max_val]``. + * - ``exact_check(expected=...)`` + - Passes when the metric value equals ``expected`` exactly. + +Checks without a validator are measured and included in the report. +If no validator is applied, the check is treated as passed. + +Custom Validators +----------------- + +Use :func:`~airflow.providers.common.ai.utils.dataquality.register_validator` +to register validator factories that the LLM can discover via ``list_validators`` +and select dynamically: + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py + :language: python + :start-after: [START howto_operator_llm_dq_email_format_validator] + :end-before: [END howto_operator_llm_dq_email_format_validator] + +Row-Level Checks +---------------- + +Row-level validators receive a list of values (one per row) instead of a single +aggregate metric. The agent issues a plain ``SELECT <column> FROM <table>`` query +(no aggregation) and passes the column values to ``apply_validator``. + +The validator returns a +:class:`~airflow.providers.common.ai.utils.dataquality.models.RowLevelResult` with +these fields: + +.. list-table:: + :header-rows: 1 + :widths: 25 75 + + * - Field + - Description + * - ``total`` + - Number of evaluated rows. + * - ``invalid`` + - Number of rows that failed the validator predicate. + * - ``invalid_pct`` + - ``invalid / total`` (``0.0`` when ``total`` is zero). + * - ``sample_violations`` + - Sample of failing values (string form). + * - ``sample_size`` + - Length of ``sample_violations``. + +Register a row-level validator with ``row_level=True``: + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py + :language: python + :start-after: [START howto_operator_llm_dq_email_format_validator] + :end-before: [END howto_operator_llm_dq_email_format_validator] + +Schema Context +-------------- + +Pass ``schema_context`` to inject a manual schema description into the system prompt. +Useful when the data source cannot be introspected at runtime, or to restrict +the agent to a specific set of tables: + +.. code-block:: python + + LLMDataQualityOperator( + task_id="validate_with_schema", + llm_conn_id="pydanticai_default", + db_conn_id="postgres_default", + schema_context=( + "Table: orders\n" + "Columns: id INT, customer_id INT, amount DECIMAL, created_at TIMESTAMP\n\n" + "Table: customers\n" + "Columns: id INT, email TEXT, country TEXT" + ), + checks=[ + DQCheckInput( + name="null_amount", + description="Check the percentage of orders with a NULL amount", + validator=null_pct_check(max_pct=0.0), + ), + ], + ) + +Human-in-the-Loop Approval +--------------------------- + +Set ``require_approval=True`` to gate execution on a human reviewer. +The task runs in two phases: + +1. **Phase 1** — the LLM discovers schema, writes SQL queries, and selects validators + but does *not* execute any SQL or call ``apply_validator``. The resulting plan + (SQL statements + validator choices) is surfaced in the Airflow UI. +2. **Phase 2** (after approval) — the approved SQL queries run, validators are applied + in pure Python, and the final :class:`~...DQReport` is produced. No LLM calls. + +Rejecting the plan raises +:class:`~airflow.providers.standard.exceptions.HITLRejectException`. + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py + :language: python + :start-after: [START howto_operator_llm_dq_require_approval] + :end-before: [END howto_operator_llm_dq_require_approval] + +See :class:`~airflow.providers.common.ai.operators.llm.LLMOperator` for shared HITL parameters +(``approval_timeout``, ``allow_modifications``). + +Durable Execution +----------------- + +Set ``durable=True`` to enable cross-run caching of LLM and tool-call results. +When a task is retried or restarted, the operator replays cached model responses and +tool calls instead of re-running them, saving tokens and wall-clock time. + +.. code-block:: python + + LLMDataQualityOperator( + task_id="validate_products", + llm_conn_id="pydanticai_default", + toolsets=[ + SQLToolset(db_conn_id="postgres_default", allowed_tables=["products"]), + SQLDQToolset(), + ], + durable=True, + checks=[...], + ) + +TaskFlow Decorator +------------------ + +Use ``@task.llm_dq`` when checks are produced dynamically by a Python callable. +The function body returns the checks list; the decorator handles the LLM agent run, +SQL execution, and validator application automatically. + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py + :language: python + :start-after: [START howto_operator_llm_dq_task_decorator] + :end-before: [END howto_operator_llm_dq_task_decorator] + +Output and XCom +--------------- + +The operator returns a ``dict`` representation of +:class:`~airflow.providers.common.ai.utils.dataquality.models.DQReport` as its XCom value +(``output_mode="execute"``). + +Additionally, the full report is pushed to XCom under the key ``dq_report`` on the +task instance: + +.. code-block:: python + + report = context["ti"].xcom_pull(task_ids="validate_orders", key="dq_report") + +For config-generation backends (``output_mode="generate"``), the operator returns +the generated configuration string instead of a report. + +Logging +------- + +After each LLM call, the operator logs a summary with model name, token usage, +and request count at INFO level: + +.. code-block:: text + + [INFO] Agent run summary — model: gpt-4o, requests: 6, tool_calls: 14, + input_tokens: 3201, output_tokens: 412, total_tokens: 3613 + +Parameters Reference +-------------------- + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Parameter + - Type + - Description + * - ``checks`` + - ``list`` + - List of :class:`~...DQCheckInput` objects or plain dicts. Names must be unique. + * - ``toolsets`` + - ``list | None`` + - Pydantic-AI toolsets. Must include a data-source toolset and a + :class:`~...BaseDQToolset`. Auto-created from ``db_conn_id`` when ``None``. + * - ``db_conn_id`` + - ``str | None`` + - Connection ID for auto-creating ``SQLToolset``. Ignored when ``toolsets`` is set. + * - ``table_names`` + - ``list[str] | None`` + - Tables exposed to the auto-created ``SQLToolset``. Ignored when ``toolsets`` is set. + * - ``schema_context`` + - ``str | None`` + - Manual schema description injected into the system prompt. + * - ``durable`` + - ``bool`` + - Enable cross-run LLM and tool-call caching. Default ``False``. + * - ``require_approval`` + - ``bool`` + - Enable two-phase HITL approval before SQL execution. Default ``False``. + * - ``llm_conn_id`` + - ``str`` + - Pydantic-AI connection ID (inherited from + :class:`~airflow.providers.common.ai.operators.llm.LLMOperator`). + + +Each entry in ``checks`` is a +:class:`~airflow.providers.common.ai.utils.dataquality.models.DQCheckInput`. +The operator: + +1. Exposes checks and validators as tools to the model. +2. Lets the model discover schema and generate SQL via your data-source toolset. +3. Applies validators and produces a + :class:`~airflow.providers.common.ai.utils.dataquality.models.DQReport`. +4. Fails the task when any check fails. + +.. seealso:: + :ref:`Connection configuration <howto/connection:pydanticai>` + +Basic Usage +----------- + +If you pass ``db_conn_id`` (without explicit toolsets), the operator auto-creates: + +- :class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset` +- :class:`~airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset` + Review Comment: This doc page appears to contain a duplicated second pass of major sections (e.g. “Basic Usage”, “Explicit Toolsets”, “Validators”, etc.) starting around here, repeating content already covered earlier in the page. It would be better to remove the duplicated block and keep a single coherent flow to avoid conflicting updates later. ########## providers/common/ai/provider.yaml: ########## @@ -457,6 +459,8 @@ task-decorators: name: llm_file_analysis - class-name: airflow.providers.common.ai.decorators.llm_branch.llm_branch_task name: llm_branch + - class-name: airflow.providers.common.ai.decorators.llm_data_quality.llm_dq_task + name: llm_dq Review Comment: PR description mentions an `@task.llm_data_quality` decorator, but the implementation/registration uses `llm_dq` (e.g. task-decorator name `llm_dq` and `custom_operator_name = "@task.llm_dq"`). Please align naming across docs/description/API (either rename the decorator to `llm_data_quality` or update the PR description to match `llm_dq`). ########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_data_quality.py: ########## @@ -0,0 +1,770 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for running data-quality checks from natural language using an LLM agent.""" + +from __future__ import annotations + +import json +import re +from collections import Counter +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.ai.toolsets.dataquality.base import BaseDQToolset +from airflow.providers.common.ai.utils.dataquality.models import ( + DQCheckFailedError, + DQCheckInput, + DQPlan, + DQReport, +) +from airflow.providers.common.ai.utils.logging import log_run_summary, wrap_toolsets_for_logging + +if TYPE_CHECKING: + from pydantic_ai.toolsets.abstract import AbstractToolset + + from airflow.sdk import Context + +_DQ_SYSTEM_PROMPT = """\ +You are a data-quality expert. Evaluate the user's data-quality checks against a \ +live data source using the tools available to you. + +WORKFLOW: +1. Call ``list_checks`` to read the user's quality expectations. +2. Call ``list_validators`` to see available validators (names, parameters, descriptions). +3. Use schema-discovery tools (``list_tables``, ``get_schema``) to explore the data source. +4. For each check: + a. If the check has ``row_level: true`` (from ``list_checks``), follow the ROW-LEVEL CHECKS + procedure below instead of writing an aggregate query. + b. Otherwise, write a SELECT query that computes the relevant metric for this check. + c. Optionally call ``check_query`` to validate SQL syntax before executing. + d. Execute the query using the ``query`` tool. + e. You MUST call ``apply_validator`` for EVERY check — never skip this step, even for + fixed validators (``has_fixed_validator: true``). Use ``validator_name: "fixed"`` for those. +5. Return a ``DQReport`` with one ``DQCheckResult`` per check — every check must appear exactly once. + +SQL GENERATION RULES: +- Generate ONLY SELECT statements. NEVER use INSERT, UPDATE, DELETE, DROP, TRUNCATE, or DDL. +- Begin each query with a SQL comment that names the check it serves: + -- check: <check_name> + Example: -- check: null_order_id +- For conditional counts, use CASE expressions (FILTER WHERE is not universally supported): + CORRECT: COUNT(CASE WHEN col IS NULL THEN 1 END) + INCORRECT: COUNT(*) FILTER (WHERE col IS NULL) +- For null/invalid percentages: + COUNT(CASE WHEN condition THEN 1 END) * 1.0 / NULLIF(COUNT(*), 0) +- For duplicate percentages: + (COUNT(*) - COUNT(DISTINCT col)) * 1.0 / NULLIF(COUNT(*), 0) +- For float division, cast to avoid integer truncation: + CAST(numerator AS DOUBLE) / CAST(denominator AS DOUBLE) +- Give each metric column a descriptive snake_case alias (e.g. ``null_email_pct``). + +VALIDATOR SELECTION: +- ``list_validators`` returns each validator's name, parameter signature, and description. + Read the parameter names and types before calling ``apply_validator``. +- Pass the required ``validator_args`` as a JSON object matching the parameter signature exactly. +- If a parameter has no default, it is REQUIRED — always include it in ``validator_args``. + Example: ``null_pct_check`` requires ``max_pct`` — you must pass ``{"max_pct": <value>}``. + If you are unsure of a threshold, use a safe default (e.g. ``0.05`` for percentage checks). +- NEVER pass empty ``{}`` args for a validator that has required parameters. +- For checks marked ``has_fixed_validator: true`` in ``list_checks``, call ``apply_validator`` + with ``validator_name: "fixed"`` — the pre-assigned validator runs automatically, no args needed. +- If no suitable validator exists, pass ``validator_name: "none"`` — the check passes by default + but its metric value is still recorded. + +DQREPORT OUTPUT RULES: +- ``check_name`` must exactly match the name returned by ``list_checks`` (no abbreviation). +- Each check must appear in the report exactly once. +- Set ``passed: false`` and provide a clear ``failure_reason`` for every failed check. +- Set ``metric_key`` to the SQL column alias you used for this check's metric value. +- Set ``sql_query`` to the exact SQL statement you executed for this check (including the leading comment). +- Set ``validator_info`` to ``{"name": "<validator_name>", "args": {<validator_args>}}`` when a validator + was applied, or ``null`` when no validator was used (``validator_name: "none"``). + +CHECK CATEGORIES (use to guide SQL and validator selection): + null_check — null / missing value counts or percentages + uniqueness — duplicate detection, cardinality checks + validity — regex / format / pattern matching on string columns + numeric_range — range, bounds, or statistical checks on numeric columns + row_count — total row counts or existence checks + string_format — length, encoding, whitespace, or character-set checks + row_level — per-row anomaly checks; validator receives a list, not an aggregate + +ROW-LEVEL CHECKS (applies when ``list_checks`` returns ``row_level: true`` for a check): +- Do NOT write an aggregate query. Write a plain ``SELECT <column> FROM <table>`` that + returns one value per row — no GROUP BY, no COUNT, no CASE expressions. + Example: -- check: customer_email_format\n SELECT email FROM customers +- After executing, extract the column values into a Python list: + value = [row["<column>"] for row in result["rows"]] + where ``result`` is the JSON object returned by the ``query`` tool. +- Call ``apply_validator`` with ``value`` set to that list and ``validator_name: "fixed"``. +- The validator evaluates each item and returns a ``value`` payload containing a + ``RowLevelResult`` summary (total/invalid/invalid_pct/sample_violations/sample_size). +- Set ``metric_key`` to the SQL column name (e.g. ``"email"``). +- Set ``value`` in the DQCheckResult to the ``value`` object returned by ``apply_validator``. +""" + +_DQ_PLAN_SYSTEM_PROMPT = """\ +You are a data-quality expert. Your task is to plan *how* to evaluate each +data-quality check: write the SQL query and choose the validator for each check. +A human reviewer will inspect your plan before any SQL is executed. + +WORKFLOW (PLANNING MODE — no SQL execution, no apply_validator): +1. Call ``list_checks`` to read the user's quality expectations. +2. Call ``list_validators`` to see available validators (names, parameters, descriptions). +3. Use schema-discovery tools (``list_tables``, ``get_schema``) to understand the data model. +4. For each check: + a. If the check has ``row_level: true`` (from ``list_checks``), follow the ROW-LEVEL + procedure below. Otherwise, write an aggregate SELECT query for the metric. + b. Optionally call ``check_query`` to validate your SQL syntax — but do NOT execute it. + c. Select the appropriate validator name and arguments: + - For ``has_fixed_validator: true`` checks: use ``validator_name: "fixed"``, empty args. + - If no validator fits: use ``validator_name: "none"``. +5. Return a ``DQPlan`` with one ``DQCheckPlan`` per check — every check must appear exactly once. + +IMPORTANT: The ``query`` tool is NOT available. Do NOT try to execute SQL. +Do NOT call ``apply_validator`` — it is not available in planning mode. +SQL will be executed after the human reviewer approves the plan. + +SQL GENERATION RULES: +- Generate ONLY SELECT statements. NEVER use INSERT, UPDATE, DELETE, DROP, TRUNCATE, or DDL. +- Begin each query with a SQL comment that names the check it serves: + -- check: <check_name> +- For conditional counts, use CASE expressions: + COUNT(CASE WHEN col IS NULL THEN 1 END) +- For null/invalid percentages: + COUNT(CASE WHEN condition THEN 1 END) * 1.0 / NULLIF(COUNT(*), 0) +- Give each metric column a descriptive snake_case alias (e.g. ``null_email_pct``). + +ROW-LEVEL CHECKS (applies when ``list_checks`` returns ``row_level: true``): +- Do NOT write an aggregate query. Write a plain ``SELECT <column> FROM <table>``. +- Set ``row_level: true`` in the DQCheckPlan output. +- Set ``metric_key`` to the column name (e.g. ``"email"``). + +DQCheckPlan FIELDS: +- ``check_name``: exact name from ``list_checks`` (no abbreviation). +- ``sql_query``: the SQL statement to execute later (with leading ``-- check: <name>`` comment). +- ``metric_key``: the SQL column alias / name to read as the primary metric value. +- ``row_level``: ``true`` for row-level checks (SELECT returns one row per record), ``false`` for aggregate. +- ``validator_name``: ``"fixed"``, a registered validator name, or ``"none"``. +- ``validator_args``: kwargs for the validator factory, or ``{}`` for ``"fixed"``/``"none"``. + +VALIDATOR ARGS RULES (critical — missing args cause hard failures in Phase 2): +- ``list_validators`` returns a ``parameters`` field for each validator showing its exact + parameter names, types, and defaults. Read this carefully before setting ``validator_args``. +- If a parameter has no default (required), you MUST include it in ``validator_args``. + Example: ``null_pct_check`` requires ``max_pct`` — always set ``{"max_pct": <value>}``. +- If you are unsure of the right threshold, pick a reasonable default (e.g. 0.05 for pct checks). +- NEVER leave ``validator_args`` as ``{}`` for a validator that has required parameters. +- Only use ``{}`` when ``validator_name`` is ``"fixed"`` or ``"none"``. +""" + +_DIALECT_SQL_NOTES: dict[str, str] = { + "postgres": ( + " - Regex match: `col ~ 'pattern'`; not match: `col !~ 'pattern'` (case-sensitive).\n" + " Case-insensitive variants: `~*` and `!~*`.\n" + " - Float division: `numerator * 1.0 / NULLIF(denominator, 0)` (no CAST needed).\n" + " - String functions: LENGTH(), LOWER(), TRIM(), SUBSTRING().\n" + ), + "mysql": ( + " - Regex match: `col REGEXP 'pattern'`; not match: `col NOT REGEXP 'pattern'`.\n" + " - Float division: `numerator * 1.0 / NULLIF(denominator, 0)`.\n" + " - String length (multibyte-safe): CHAR_LENGTH(col); byte length: LENGTH(col).\n" + ), + "sqlite": ( + " - No native regex operator; use LIKE or GLOB:\n" + " `col NOT LIKE '%@%'` or `col GLOB '*@*.*'`.\n" + " - Float division: `CAST(numerator AS REAL) / NULLIF(denominator, 0)`.\n" + ), + "tsql": ( + " - No regex operator; use LIKE or PATINDEX:\n" + " `PATINDEX('%@%.%', col) = 0` means no match.\n" + " - Float division: `CAST(numerator AS FLOAT) / NULLIF(denominator, 0)`.\n" + " - Use TOP n instead of LIMIT n; no OFFSET without ORDER BY.\n" + " - String length: LEN(col) (not LENGTH).\n" + ), + "bigquery": ( + " - Regex: `REGEXP_CONTAINS(col, r'pattern')` — returns TRUE if match.\n" + " Negate with `NOT REGEXP_CONTAINS(...)`.\n" + " - Safe division: `SAFE_DIVIDE(numerator, denominator)` (NULL on zero denominator).\n" + " - Quote reserved names with backticks.\n" + ), + "snowflake": ( + " - Regex: `REGEXP_LIKE(col, 'pattern')` (full-string match) or `col RLIKE 'pattern'`.\n" + " - Float division: `numerator * 1.0 / NULLIF(denominator, 0)`.\n" + " - Unquoted identifiers are uppercased; use double-quotes when case matters.\n" + ), +} + + +def _extract_schema_table_names(schema_context: str) -> list[str]: + """Extract table names from a schema context string produced by build_schema_context.""" + return re.findall(r"^Table:\s+(\S+)", schema_context, re.MULTILINE) + + +class LLMDataQualityOperator(LLMOperator): + """ + Run data-quality checks described in natural language using an LLM agent. + + The agent discovers the database schema, writes and executes SQL queries, + applies validators, and produces a + :class:`~airflow.providers.common.ai.utils.dataquality.models.DQReport`. The task + fails when any check does not pass, gating downstream tasks on data quality. + + Supply the data-source toolset and DQ toolset together in ``toolsets``:: + + from airflow.providers.common.ai.toolsets.sql import SQLToolset + from airflow.providers.common.ai.toolsets.dataquality.sql import SQLDQToolset + + LLMDataQualityOperator( + task_id="quality_check", + checks=[ + DQCheckInput(name="email_nulls", description="Check for null emails"), + DQCheckInput( + name="row_count", + description="At least 1000 rows", + validator=row_count_check(min_count=1000), + ), + ], + llm_conn_id="pydanticai_default", + toolsets=[ + SQLToolset(db_conn_id="postgres_default", allowed_tables=["customers"]), + SQLDQToolset(), + ], + ) + + When ``toolsets`` is omitted, the operator auto-creates + :class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset` and + :class:`~airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset` from + ``db_conn_id`` and ``table_names``. + + For config-generation backends (e.g. ``SodaDQToolset``), the operator + returns the generated config string as its XCom value instead of a report. + + :param checks: List of :class:`~airflow.providers.common.ai.utils.dataquality.models.DQCheckInput` + objects (or plain dicts with ``name``, ``description``, and optional + ``validator`` keys). Names must be unique. + :param toolsets: Pydantic-AI toolsets for the agent. Must include exactly + one :class:`~airflow.providers.common.ai.toolsets.dataquality.base.BaseDQToolset` + subclass alongside a data-source toolset. When ``None``, the operator + auto-creates toolsets from ``db_conn_id``. + :param db_conn_id: Connection ID for auto-creating an + :class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset`. + Ignored when ``toolsets`` is provided. + :param table_names: Tables passed to the auto-created + :class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset` as + ``allowed_tables``. Ignored when ``toolsets`` is provided. + :param schema_context: Additional schema description injected into the + system prompt. Useful when the data source cannot be introspected + at runtime. + """ + + template_fields: Sequence[str] = ( + *LLMOperator.template_fields, + "checks", + "db_conn_id", + "table_names", + "schema_context", + ) + + def __init__( + self, + *, + checks: list[DQCheckInput | dict[str, Any]], + toolsets: list[AbstractToolset] | None = None, + db_conn_id: str | None = None, + table_names: list[str] | None = None, + schema_context: str | None = None, + **kwargs: Any, + ) -> None: + # Pop operator-specific params that must not reach BaseOperator.__init__. + # Using kwargs.pop (rather than named params) is necessary because Airflow's + # apply_defaults metaclass captures the full **kwargs dict and uses it for + # task-map serialization; named params in the dict are NOT removed from the + # snapshot, so they would be re-injected on deserialization and reach + # BaseOperator as unknown kwargs. + durable: bool = kwargs.pop("durable", False) + + kwargs.pop("output_type", None) + kwargs.setdefault("prompt", "Run the data-quality checks.") + super().__init__(**kwargs) + + self.checks: list[DQCheckInput] = ( + [DQCheckInput.coerce(c) for c in checks] if isinstance(checks, list) else checks # type: ignore[assignment] + ) + self.toolsets = toolsets + self.db_conn_id = db_conn_id + self.table_names = table_names + self.schema_context = schema_context + self.durable = durable + + self._validate_checks() + + if not toolsets and db_conn_id is None: + raise ValueError("Either toolsets or db_conn_id must be provided.") + + def execute(self, context: Context) -> Any: + """ + Run the LLM agent to execute all data-quality checks. + + When ``require_approval=True``, the task runs in two phases: + + 1. **Phase 1** — the LLM discovers schema and selects validators *without + executing any SQL queries*. The resulting plan (check names, + descriptions, and validator assignments) is shown to a human reviewer. + 2. **Phase 2** (in :meth:`execute_complete`) — after approval, SQL queries + are executed and validators are applied in pure Python to produce the + :class:`~...DQReport`. + + :returns: :class:`~airflow.providers.common.ai.utils.dataquality.models.DQReport` + as a dict when the DQ toolset is in ``"execute"`` mode, or a config + string when in ``"generate"`` mode. + :raises DQCheckFailedError: When any check fails in ``"execute"`` mode. + """ + if self.require_approval: + plan, _ = self._run_plan_phase(context) + self.defer_for_approval( # type: ignore[misc] + context, + plan.model_dump_json(), + subject=f"Review DQ plan for task `{self.task_id}`", + body=self._build_plan_approval_body(plan), + ) + return self._run_and_report(context) + + def execute_complete(self, context: Context, generated_output: str, event: dict[str, Any]) -> Any: + """ + Phase 2: execute SQL and apply validators after human approval. + + Called automatically by Airflow when the HITL trigger fires. The base + class validates the approval decision (raises on reject or timeout). + Then each SQL query from the approved :class:`~...DQPlan` is executed in + pure Python, the metric values are extracted, and validators are applied + — no LLM calls. + + :param context: Airflow task context. + :param generated_output: The :class:`~...DQPlan` JSON that was deferred. + :param event: Trigger event payload. + :returns: Same as :meth:`execute`. + :raises HITLRejectException: If the reviewer rejected the plan. + """ + approved_output = super().execute_complete(context, generated_output, event) # type: ignore[misc] + plan = DQPlan.model_validate_json(approved_output) + + toolsets = self._resolve_toolsets() + dq_toolset = self._find_dq_toolset(toolsets) + dq_toolset.set_checks(self.checks) + + report = self._execute_plan_validators(plan, dq_toolset, toolsets) + context["task_instance"].xcom_push(key="dq_report", value=report.model_dump()) + if not report.passed: + raise DQCheckFailedError(report.failure_summary) + return report.model_dump() + + # ------------------------------------------------------------------ + # Core execution helpers + # ------------------------------------------------------------------ + + def _run_plan_phase(self, context: Context) -> tuple[DQPlan, BaseDQToolset]: + """ + Phase 1 of the two-phase approval flow. + + Runs the LLM agent in planning mode: schema-discovery tools + (``list_tables``, ``get_schema``, ``check_query``) are available, but + the ``query`` and ``apply_validator`` tools are hidden. The agent + writes SQL strings and selects validators for each check, then outputs + a :class:`~...DQPlan` — without executing any SQL. + + :returns: ``(plan, dq_toolset)`` — the plan for the approval body and the + toolset (already configured with checks) for Phase 2. + """ + toolsets = self._resolve_toolsets() + dq_toolset = self._find_dq_toolset(toolsets) + dq_toolset._planning_mode = True # omit apply_validator from tool list + dq_toolset.set_checks(self.checks) + + # Hide the SQL 'query' tool so the LLM cannot execute DQ queries. + for ts in toolsets: + if ts is not dq_toolset and hasattr(ts, "_query"): + ts._planning_mode = True # type: ignore[attr-defined] + + instructions = self._build_plan_system_prompt(toolsets) + logged_toolsets = wrap_toolsets_for_logging(toolsets, self.log) + storage = None + + if self.durable: + agent, counter, storage = self._build_durable_agent( + context, + DQPlan, + instructions, + logged_toolsets, + ) + else: + agent = self.llm_hook.create_agent( + output_type=DQPlan, + instructions=instructions, + toolsets=logged_toolsets, + **self.agent_params, + ) + counter = None + + result = agent.run_sync(self.prompt, usage_limits=self.usage_limits) + log_run_summary(self.log, result) + + if counter is not None and (counter.replayed_model > 0 or counter.replayed_tool > 0): + self.log.info( + "Durable cache replay (plan phase): model_steps=%d/%d, tool_steps=%d/%d", + counter.replayed_model, + counter.replayed_model + counter.cached_model, + counter.replayed_tool, + counter.replayed_tool + counter.cached_tool, + ) + + if storage is not None: + storage.cleanup() + + return result.output, dq_toolset + + def _run_and_report(self, context: Context) -> Any: + """ + Resolve toolsets, run the LLM agent, push XCom, and handle failures. + + This is the single place where the agent executes regardless of whether + the task took the direct path or the approval-gated path. + """ + toolsets = self._resolve_toolsets() + dq_toolset = self._find_dq_toolset(toolsets) + dq_toolset.set_checks(self.checks) + + output_type: type = DQReport if dq_toolset.output_mode == "execute" else str + instructions = self._build_system_prompt(dq_toolset, toolsets) + logged_toolsets = wrap_toolsets_for_logging(toolsets, self.log) Review Comment: When `dq_toolset.output_mode == "generate"`, the operator switches `output_type` to `str` but still uses the same execution-oriented system prompt (`_DQ_SYSTEM_PROMPT`) that instructs the agent to run `query`/`apply_validator` and return a `DQReport`. This prompt/output mismatch will likely cause models to return the wrong shape or attempt execution steps that a generate-mode toolset may not support. ########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_data_quality.py: ########## @@ -0,0 +1,770 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for running data-quality checks from natural language using an LLM agent.""" + +from __future__ import annotations + +import json +import re +from collections import Counter +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.ai.toolsets.dataquality.base import BaseDQToolset +from airflow.providers.common.ai.utils.dataquality.models import ( + DQCheckFailedError, + DQCheckInput, + DQPlan, + DQReport, +) +from airflow.providers.common.ai.utils.logging import log_run_summary, wrap_toolsets_for_logging + +if TYPE_CHECKING: + from pydantic_ai.toolsets.abstract import AbstractToolset + + from airflow.sdk import Context + +_DQ_SYSTEM_PROMPT = """\ +You are a data-quality expert. Evaluate the user's data-quality checks against a \ +live data source using the tools available to you. + +WORKFLOW: +1. Call ``list_checks`` to read the user's quality expectations. +2. Call ``list_validators`` to see available validators (names, parameters, descriptions). +3. Use schema-discovery tools (``list_tables``, ``get_schema``) to explore the data source. +4. For each check: + a. If the check has ``row_level: true`` (from ``list_checks``), follow the ROW-LEVEL CHECKS + procedure below instead of writing an aggregate query. + b. Otherwise, write a SELECT query that computes the relevant metric for this check. + c. Optionally call ``check_query`` to validate SQL syntax before executing. + d. Execute the query using the ``query`` tool. + e. You MUST call ``apply_validator`` for EVERY check — never skip this step, even for + fixed validators (``has_fixed_validator: true``). Use ``validator_name: "fixed"`` for those. +5. Return a ``DQReport`` with one ``DQCheckResult`` per check — every check must appear exactly once. + +SQL GENERATION RULES: +- Generate ONLY SELECT statements. NEVER use INSERT, UPDATE, DELETE, DROP, TRUNCATE, or DDL. +- Begin each query with a SQL comment that names the check it serves: + -- check: <check_name> + Example: -- check: null_order_id +- For conditional counts, use CASE expressions (FILTER WHERE is not universally supported): + CORRECT: COUNT(CASE WHEN col IS NULL THEN 1 END) + INCORRECT: COUNT(*) FILTER (WHERE col IS NULL) +- For null/invalid percentages: + COUNT(CASE WHEN condition THEN 1 END) * 1.0 / NULLIF(COUNT(*), 0) +- For duplicate percentages: + (COUNT(*) - COUNT(DISTINCT col)) * 1.0 / NULLIF(COUNT(*), 0) +- For float division, cast to avoid integer truncation: + CAST(numerator AS DOUBLE) / CAST(denominator AS DOUBLE) +- Give each metric column a descriptive snake_case alias (e.g. ``null_email_pct``). + +VALIDATOR SELECTION: +- ``list_validators`` returns each validator's name, parameter signature, and description. + Read the parameter names and types before calling ``apply_validator``. +- Pass the required ``validator_args`` as a JSON object matching the parameter signature exactly. +- If a parameter has no default, it is REQUIRED — always include it in ``validator_args``. + Example: ``null_pct_check`` requires ``max_pct`` — you must pass ``{"max_pct": <value>}``. + If you are unsure of a threshold, use a safe default (e.g. ``0.05`` for percentage checks). +- NEVER pass empty ``{}`` args for a validator that has required parameters. +- For checks marked ``has_fixed_validator: true`` in ``list_checks``, call ``apply_validator`` + with ``validator_name: "fixed"`` — the pre-assigned validator runs automatically, no args needed. +- If no suitable validator exists, pass ``validator_name: "none"`` — the check passes by default + but its metric value is still recorded. + +DQREPORT OUTPUT RULES: +- ``check_name`` must exactly match the name returned by ``list_checks`` (no abbreviation). +- Each check must appear in the report exactly once. +- Set ``passed: false`` and provide a clear ``failure_reason`` for every failed check. +- Set ``metric_key`` to the SQL column alias you used for this check's metric value. +- Set ``sql_query`` to the exact SQL statement you executed for this check (including the leading comment). +- Set ``validator_info`` to ``{"name": "<validator_name>", "args": {<validator_args>}}`` when a validator + was applied, or ``null`` when no validator was used (``validator_name: "none"``). + +CHECK CATEGORIES (use to guide SQL and validator selection): + null_check — null / missing value counts or percentages + uniqueness — duplicate detection, cardinality checks + validity — regex / format / pattern matching on string columns + numeric_range — range, bounds, or statistical checks on numeric columns + row_count — total row counts or existence checks + string_format — length, encoding, whitespace, or character-set checks + row_level — per-row anomaly checks; validator receives a list, not an aggregate + +ROW-LEVEL CHECKS (applies when ``list_checks`` returns ``row_level: true`` for a check): +- Do NOT write an aggregate query. Write a plain ``SELECT <column> FROM <table>`` that + returns one value per row — no GROUP BY, no COUNT, no CASE expressions. + Example: -- check: customer_email_format\n SELECT email FROM customers +- After executing, extract the column values into a Python list: + value = [row["<column>"] for row in result["rows"]] + where ``result`` is the JSON object returned by the ``query`` tool. +- Call ``apply_validator`` with ``value`` set to that list and ``validator_name: "fixed"``. +- The validator evaluates each item and returns a ``value`` payload containing a + ``RowLevelResult`` summary (total/invalid/invalid_pct/sample_violations/sample_size). +- Set ``metric_key`` to the SQL column name (e.g. ``"email"``). +- Set ``value`` in the DQCheckResult to the ``value`` object returned by ``apply_validator``. +""" + +_DQ_PLAN_SYSTEM_PROMPT = """\ +You are a data-quality expert. Your task is to plan *how* to evaluate each +data-quality check: write the SQL query and choose the validator for each check. +A human reviewer will inspect your plan before any SQL is executed. + +WORKFLOW (PLANNING MODE — no SQL execution, no apply_validator): +1. Call ``list_checks`` to read the user's quality expectations. +2. Call ``list_validators`` to see available validators (names, parameters, descriptions). +3. Use schema-discovery tools (``list_tables``, ``get_schema``) to understand the data model. +4. For each check: + a. If the check has ``row_level: true`` (from ``list_checks``), follow the ROW-LEVEL + procedure below. Otherwise, write an aggregate SELECT query for the metric. + b. Optionally call ``check_query`` to validate your SQL syntax — but do NOT execute it. + c. Select the appropriate validator name and arguments: + - For ``has_fixed_validator: true`` checks: use ``validator_name: "fixed"``, empty args. + - If no validator fits: use ``validator_name: "none"``. +5. Return a ``DQPlan`` with one ``DQCheckPlan`` per check — every check must appear exactly once. + +IMPORTANT: The ``query`` tool is NOT available. Do NOT try to execute SQL. +Do NOT call ``apply_validator`` — it is not available in planning mode. +SQL will be executed after the human reviewer approves the plan. + +SQL GENERATION RULES: +- Generate ONLY SELECT statements. NEVER use INSERT, UPDATE, DELETE, DROP, TRUNCATE, or DDL. +- Begin each query with a SQL comment that names the check it serves: + -- check: <check_name> +- For conditional counts, use CASE expressions: + COUNT(CASE WHEN col IS NULL THEN 1 END) +- For null/invalid percentages: + COUNT(CASE WHEN condition THEN 1 END) * 1.0 / NULLIF(COUNT(*), 0) +- Give each metric column a descriptive snake_case alias (e.g. ``null_email_pct``). + +ROW-LEVEL CHECKS (applies when ``list_checks`` returns ``row_level: true``): +- Do NOT write an aggregate query. Write a plain ``SELECT <column> FROM <table>``. +- Set ``row_level: true`` in the DQCheckPlan output. +- Set ``metric_key`` to the column name (e.g. ``"email"``). + +DQCheckPlan FIELDS: +- ``check_name``: exact name from ``list_checks`` (no abbreviation). +- ``sql_query``: the SQL statement to execute later (with leading ``-- check: <name>`` comment). +- ``metric_key``: the SQL column alias / name to read as the primary metric value. +- ``row_level``: ``true`` for row-level checks (SELECT returns one row per record), ``false`` for aggregate. +- ``validator_name``: ``"fixed"``, a registered validator name, or ``"none"``. +- ``validator_args``: kwargs for the validator factory, or ``{}`` for ``"fixed"``/``"none"``. + +VALIDATOR ARGS RULES (critical — missing args cause hard failures in Phase 2): +- ``list_validators`` returns a ``parameters`` field for each validator showing its exact + parameter names, types, and defaults. Read this carefully before setting ``validator_args``. +- If a parameter has no default (required), you MUST include it in ``validator_args``. + Example: ``null_pct_check`` requires ``max_pct`` — always set ``{"max_pct": <value>}``. +- If you are unsure of the right threshold, pick a reasonable default (e.g. 0.05 for pct checks). +- NEVER leave ``validator_args`` as ``{}`` for a validator that has required parameters. +- Only use ``{}`` when ``validator_name`` is ``"fixed"`` or ``"none"``. +""" + +_DIALECT_SQL_NOTES: dict[str, str] = { + "postgres": ( + " - Regex match: `col ~ 'pattern'`; not match: `col !~ 'pattern'` (case-sensitive).\n" + " Case-insensitive variants: `~*` and `!~*`.\n" + " - Float division: `numerator * 1.0 / NULLIF(denominator, 0)` (no CAST needed).\n" + " - String functions: LENGTH(), LOWER(), TRIM(), SUBSTRING().\n" + ), + "mysql": ( + " - Regex match: `col REGEXP 'pattern'`; not match: `col NOT REGEXP 'pattern'`.\n" + " - Float division: `numerator * 1.0 / NULLIF(denominator, 0)`.\n" + " - String length (multibyte-safe): CHAR_LENGTH(col); byte length: LENGTH(col).\n" + ), + "sqlite": ( + " - No native regex operator; use LIKE or GLOB:\n" + " `col NOT LIKE '%@%'` or `col GLOB '*@*.*'`.\n" + " - Float division: `CAST(numerator AS REAL) / NULLIF(denominator, 0)`.\n" + ), + "tsql": ( + " - No regex operator; use LIKE or PATINDEX:\n" + " `PATINDEX('%@%.%', col) = 0` means no match.\n" + " - Float division: `CAST(numerator AS FLOAT) / NULLIF(denominator, 0)`.\n" + " - Use TOP n instead of LIMIT n; no OFFSET without ORDER BY.\n" + " - String length: LEN(col) (not LENGTH).\n" + ), + "bigquery": ( + " - Regex: `REGEXP_CONTAINS(col, r'pattern')` — returns TRUE if match.\n" + " Negate with `NOT REGEXP_CONTAINS(...)`.\n" + " - Safe division: `SAFE_DIVIDE(numerator, denominator)` (NULL on zero denominator).\n" + " - Quote reserved names with backticks.\n" + ), + "snowflake": ( + " - Regex: `REGEXP_LIKE(col, 'pattern')` (full-string match) or `col RLIKE 'pattern'`.\n" + " - Float division: `numerator * 1.0 / NULLIF(denominator, 0)`.\n" + " - Unquoted identifiers are uppercased; use double-quotes when case matters.\n" + ), +} + + +def _extract_schema_table_names(schema_context: str) -> list[str]: + """Extract table names from a schema context string produced by build_schema_context.""" + return re.findall(r"^Table:\s+(\S+)", schema_context, re.MULTILINE) + + +class LLMDataQualityOperator(LLMOperator): + """ + Run data-quality checks described in natural language using an LLM agent. + + The agent discovers the database schema, writes and executes SQL queries, + applies validators, and produces a + :class:`~airflow.providers.common.ai.utils.dataquality.models.DQReport`. The task + fails when any check does not pass, gating downstream tasks on data quality. + + Supply the data-source toolset and DQ toolset together in ``toolsets``:: + + from airflow.providers.common.ai.toolsets.sql import SQLToolset + from airflow.providers.common.ai.toolsets.dataquality.sql import SQLDQToolset + + LLMDataQualityOperator( + task_id="quality_check", + checks=[ + DQCheckInput(name="email_nulls", description="Check for null emails"), + DQCheckInput( + name="row_count", + description="At least 1000 rows", + validator=row_count_check(min_count=1000), + ), + ], + llm_conn_id="pydanticai_default", + toolsets=[ + SQLToolset(db_conn_id="postgres_default", allowed_tables=["customers"]), + SQLDQToolset(), + ], + ) + + When ``toolsets`` is omitted, the operator auto-creates + :class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset` and + :class:`~airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset` from + ``db_conn_id`` and ``table_names``. + + For config-generation backends (e.g. ``SodaDQToolset``), the operator + returns the generated config string as its XCom value instead of a report. + + :param checks: List of :class:`~airflow.providers.common.ai.utils.dataquality.models.DQCheckInput` + objects (or plain dicts with ``name``, ``description``, and optional + ``validator`` keys). Names must be unique. + :param toolsets: Pydantic-AI toolsets for the agent. Must include exactly + one :class:`~airflow.providers.common.ai.toolsets.dataquality.base.BaseDQToolset` + subclass alongside a data-source toolset. When ``None``, the operator + auto-creates toolsets from ``db_conn_id``. + :param db_conn_id: Connection ID for auto-creating an + :class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset`. + Ignored when ``toolsets`` is provided. + :param table_names: Tables passed to the auto-created + :class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset` as + ``allowed_tables``. Ignored when ``toolsets`` is provided. + :param schema_context: Additional schema description injected into the + system prompt. Useful when the data source cannot be introspected + at runtime. + """ + + template_fields: Sequence[str] = ( + *LLMOperator.template_fields, + "checks", + "db_conn_id", + "table_names", + "schema_context", + ) + + def __init__( + self, + *, + checks: list[DQCheckInput | dict[str, Any]], + toolsets: list[AbstractToolset] | None = None, + db_conn_id: str | None = None, + table_names: list[str] | None = None, + schema_context: str | None = None, + **kwargs: Any, + ) -> None: + # Pop operator-specific params that must not reach BaseOperator.__init__. + # Using kwargs.pop (rather than named params) is necessary because Airflow's + # apply_defaults metaclass captures the full **kwargs dict and uses it for + # task-map serialization; named params in the dict are NOT removed from the + # snapshot, so they would be re-injected on deserialization and reach + # BaseOperator as unknown kwargs. + durable: bool = kwargs.pop("durable", False) + + kwargs.pop("output_type", None) + kwargs.setdefault("prompt", "Run the data-quality checks.") + super().__init__(**kwargs) + + self.checks: list[DQCheckInput] = ( + [DQCheckInput.coerce(c) for c in checks] if isinstance(checks, list) else checks # type: ignore[assignment] + ) + self.toolsets = toolsets + self.db_conn_id = db_conn_id + self.table_names = table_names + self.schema_context = schema_context + self.durable = durable + + self._validate_checks() + + if not toolsets and db_conn_id is None: + raise ValueError("Either toolsets or db_conn_id must be provided.") + + def execute(self, context: Context) -> Any: + """ + Run the LLM agent to execute all data-quality checks. + + When ``require_approval=True``, the task runs in two phases: + + 1. **Phase 1** — the LLM discovers schema and selects validators *without + executing any SQL queries*. The resulting plan (check names, + descriptions, and validator assignments) is shown to a human reviewer. + 2. **Phase 2** (in :meth:`execute_complete`) — after approval, SQL queries + are executed and validators are applied in pure Python to produce the + :class:`~...DQReport`. + + :returns: :class:`~airflow.providers.common.ai.utils.dataquality.models.DQReport` + as a dict when the DQ toolset is in ``"execute"`` mode, or a config + string when in ``"generate"`` mode. + :raises DQCheckFailedError: When any check fails in ``"execute"`` mode. + """ + if self.require_approval: + plan, _ = self._run_plan_phase(context) + self.defer_for_approval( # type: ignore[misc] + context, + plan.model_dump_json(), + subject=f"Review DQ plan for task `{self.task_id}`", + body=self._build_plan_approval_body(plan), + ) + return self._run_and_report(context) + + def execute_complete(self, context: Context, generated_output: str, event: dict[str, Any]) -> Any: + """ + Phase 2: execute SQL and apply validators after human approval. + + Called automatically by Airflow when the HITL trigger fires. The base + class validates the approval decision (raises on reject or timeout). + Then each SQL query from the approved :class:`~...DQPlan` is executed in + pure Python, the metric values are extracted, and validators are applied + — no LLM calls. + + :param context: Airflow task context. + :param generated_output: The :class:`~...DQPlan` JSON that was deferred. + :param event: Trigger event payload. + :returns: Same as :meth:`execute`. + :raises HITLRejectException: If the reviewer rejected the plan. + """ + approved_output = super().execute_complete(context, generated_output, event) # type: ignore[misc] + plan = DQPlan.model_validate_json(approved_output) + + toolsets = self._resolve_toolsets() + dq_toolset = self._find_dq_toolset(toolsets) + dq_toolset.set_checks(self.checks) + + report = self._execute_plan_validators(plan, dq_toolset, toolsets) + context["task_instance"].xcom_push(key="dq_report", value=report.model_dump()) + if not report.passed: + raise DQCheckFailedError(report.failure_summary) + return report.model_dump() + + # ------------------------------------------------------------------ + # Core execution helpers + # ------------------------------------------------------------------ + + def _run_plan_phase(self, context: Context) -> tuple[DQPlan, BaseDQToolset]: + """ + Phase 1 of the two-phase approval flow. + + Runs the LLM agent in planning mode: schema-discovery tools + (``list_tables``, ``get_schema``, ``check_query``) are available, but + the ``query`` and ``apply_validator`` tools are hidden. The agent + writes SQL strings and selects validators for each check, then outputs + a :class:`~...DQPlan` — without executing any SQL. + + :returns: ``(plan, dq_toolset)`` — the plan for the approval body and the + toolset (already configured with checks) for Phase 2. + """ + toolsets = self._resolve_toolsets() + dq_toolset = self._find_dq_toolset(toolsets) + dq_toolset._planning_mode = True # omit apply_validator from tool list + dq_toolset.set_checks(self.checks) + + # Hide the SQL 'query' tool so the LLM cannot execute DQ queries. + for ts in toolsets: + if ts is not dq_toolset and hasattr(ts, "_query"): + ts._planning_mode = True # type: ignore[attr-defined] + + instructions = self._build_plan_system_prompt(toolsets) + logged_toolsets = wrap_toolsets_for_logging(toolsets, self.log) + storage = None + + if self.durable: + agent, counter, storage = self._build_durable_agent( + context, + DQPlan, + instructions, + logged_toolsets, + ) + else: + agent = self.llm_hook.create_agent( + output_type=DQPlan, + instructions=instructions, + toolsets=logged_toolsets, + **self.agent_params, + ) + counter = None + + result = agent.run_sync(self.prompt, usage_limits=self.usage_limits) + log_run_summary(self.log, result) + + if counter is not None and (counter.replayed_model > 0 or counter.replayed_tool > 0): + self.log.info( + "Durable cache replay (plan phase): model_steps=%d/%d, tool_steps=%d/%d", + counter.replayed_model, + counter.replayed_model + counter.cached_model, + counter.replayed_tool, + counter.replayed_tool + counter.cached_tool, + ) + + if storage is not None: + storage.cleanup() + + return result.output, dq_toolset + + def _run_and_report(self, context: Context) -> Any: + """ + Resolve toolsets, run the LLM agent, push XCom, and handle failures. + + This is the single place where the agent executes regardless of whether + the task took the direct path or the approval-gated path. + """ + toolsets = self._resolve_toolsets() + dq_toolset = self._find_dq_toolset(toolsets) + dq_toolset.set_checks(self.checks) + + output_type: type = DQReport if dq_toolset.output_mode == "execute" else str + instructions = self._build_system_prompt(dq_toolset, toolsets) + logged_toolsets = wrap_toolsets_for_logging(toolsets, self.log) + storage = None + + if self.durable: + agent, counter, storage = self._build_durable_agent( + context, + output_type, + instructions, + logged_toolsets, + ) + else: + agent = self.llm_hook.create_agent( + output_type=output_type, + instructions=instructions, + toolsets=logged_toolsets, + **self.agent_params, + ) + counter = None + + result = agent.run_sync(self.prompt, usage_limits=self.usage_limits) + log_run_summary(self.log, result) + + if counter is not None and (counter.replayed_model > 0 or counter.replayed_tool > 0): + self.log.info( + "Durable cache replay: model_steps=%d/%d, tool_steps=%d/%d", + counter.replayed_model, + counter.replayed_model + counter.cached_model, + counter.replayed_tool, + counter.replayed_tool + counter.cached_tool, + ) + + if storage is not None: + storage.cleanup() + + output = result.output + + if dq_toolset.output_mode == "execute": + report: DQReport = output + context["task_instance"].xcom_push(key="dq_report", value=report.model_dump()) + if not report.passed: + raise DQCheckFailedError(report.failure_summary) + return report.model_dump() + + return output + + def _build_durable_agent( + self, + context: Context, + output_type: type, + instructions: str, + toolsets: list[AbstractToolset], + ) -> tuple[Any, Any, Any]: + """Build a pydantic-ai Agent with CachingModel and CachingToolset wrappers.""" + from pydantic_ai import Agent + + from airflow.providers.common.ai.durable.caching_model import CachingModel + from airflow.providers.common.ai.durable.caching_toolset import CachingToolset + from airflow.providers.common.ai.durable.step_counter import DurableStepCounter + from airflow.providers.common.ai.durable.storage import DurableStorage + + ti = context["task_instance"] + storage = DurableStorage( + dag_id=ti.dag_id, + task_id=ti.task_id, + run_id=ti.run_id, + map_index=ti.map_index if ti.map_index is not None else -1, + ) + counter = DurableStepCounter() + + wrapped_model = CachingModel(self.llm_hook.get_conn(), storage=storage, counter=counter) + cached_toolsets = [CachingToolset(ts, storage=storage, counter=counter) for ts in toolsets] + + agent = Agent( + wrapped_model, + output_type=output_type, + instructions=instructions, + toolsets=cached_toolsets, + **self.agent_params, + ) + return agent, counter, storage + + def _build_plan_approval_body(self, plan: DQPlan) -> str: + """ + Build a Markdown body for the HITL approval form. + + Shows the LLM-generated SQL queries and validator selections so the + reviewer can inspect the plan before any SQL is executed or any + validator is applied. + """ + lines = [ + "## Data Quality Check Plan \u2014 Awaiting Approval", + "", + f"**Total checks:** {len(plan.checks)}", + "", + "Review the SQL queries and validator selections below. ", + "**Approve** to execute the SQL, apply validators, and produce the quality report. ", + "**Reject** to cancel without executing anything.", + "", + "| # | Check | Metric Key | Row Level | Validator | Args |", + "|---|-------|------------|-----------|-----------|------|", + ] + for i, cp in enumerate(plan.checks, 1): + args_display = json.dumps(cp.validator_args) if cp.validator_args else "{}" + lines.append( + f"| {i} | `{cp.check_name}` | `{cp.metric_key}` " + f"| {'Yes' if cp.row_level else 'No'} " + f"| `{cp.validator_name}` | `{args_display}` |" + ) + + lines += ["", "---", "", "### SQL Queries", ""] + for cp in plan.checks: + lines += [f"**{cp.check_name}**", "", "```sql", cp.sql_query.strip(), "```", ""] + + return "\n".join(lines) + + def _execute_plan_validators( + self, + plan: DQPlan, + dq_toolset: BaseDQToolset, + toolsets: list[AbstractToolset], + ) -> DQReport: + """ + Phase 2 of the two-phase approval flow. + + For each check in the approved :class:`~...DQPlan`: + + 1. Executes the SQL query via the data-source toolset. + 2. Extracts the metric value from the result. + 3. Applies the chosen validator in pure Python (no LLM calls). + 4. Builds the final :class:`~...DQReport`. + + :param plan: The approved :class:`~...DQPlan` from Phase 1. + :param dq_toolset: The DQ toolset (must implement ``_apply_validator``). + :param toolsets: All toolsets (used to find the SQL toolset for execution). + :raises ValueError: If no SQL-capable toolset or no validator executor found. + """ + from airflow.providers.common.ai.utils.dataquality.models import DQCheckResult + + sql_toolset = next((ts for ts in toolsets if hasattr(ts, "_query")), None) + if sql_toolset is None: + raise ValueError( + "require_approval Phase 2 requires a toolset with a _query() method " + "(e.g. SQLToolset). Add a data-source toolset to the toolsets list." + ) + + apply_fn = getattr(dq_toolset, "_apply_validator", None) + if apply_fn is None: + raise ValueError( + "require_approval two-phase execution requires a DQ toolset that " + "implements _apply_validator (e.g. SQLDQToolset)." + ) + + results: list[DQCheckResult] = [] + for cp in plan.checks: + try: + raw_result = sql_toolset._query(cp.sql_query) # type: ignore[union-attr] + result_data = json.loads(raw_result) + rows: list[dict[str, Any]] = result_data.get("rows", []) + except Exception as exc: + results.append( + DQCheckResult( + check_name=cp.check_name, + passed=False, + failure_reason=f"SQL execution failed: {exc}", + sql_query=cp.sql_query, + validator_info={"name": cp.validator_name, "args": cp.validator_args}, + ) + ) + continue + + if cp.row_level: + metric_value: Any = [row.get(cp.metric_key) for row in rows] + else: + metric_value = rows[0].get(cp.metric_key) if rows else None + + raw = apply_fn(cp.check_name, metric_value, cp.validator_name, cp.validator_args) + parsed = json.loads(raw) + passed = bool(parsed.get("passed", False)) + result_value = parsed.get("value", metric_value) + results.append( + DQCheckResult( + check_name=cp.check_name, + passed=passed, + value=result_value, + failure_reason=parsed.get("reason") if not passed else None, + metric_key=cp.metric_key, + sql_query=cp.sql_query, + validator_info={"name": cp.validator_name, "args": cp.validator_args}, + ) + ) + + return DQReport.build(results) + + # ------------------------------------------------------------------ + # Toolset / prompt helpers + # ------------------------------------------------------------------ + + def _resolve_toolsets(self) -> list[AbstractToolset]: + """ + Return explicit toolsets or auto-create from ``db_conn_id``. + + When ``toolsets`` is provided but contains no + :class:`BaseDQToolset`, a default + :class:`~airflow.providers.common.ai.toolsets.dataquality.sql.SQLDQToolset` + is appended automatically. + """ + if self.toolsets: + has_dq = any(isinstance(ts, BaseDQToolset) for ts in self.toolsets) + if not has_dq: + from airflow.providers.common.ai.toolsets.dataquality.sql import SQLDQToolset + + return [*self.toolsets, SQLDQToolset()] + return list(self.toolsets) + + from airflow.providers.common.ai.toolsets.dataquality.sql import SQLDQToolset + from airflow.providers.common.ai.toolsets.sql import SQLToolset + + return [ + SQLToolset(db_conn_id=self.db_conn_id, allowed_tables=self.table_names), # type: ignore[arg-type] + SQLDQToolset(), + ] + + @staticmethod + def _find_dq_toolset(toolsets: list[AbstractToolset]) -> BaseDQToolset: + """Return the first :class:`BaseDQToolset` found in *toolsets*.""" + for toolset in toolsets: + if isinstance(toolset, BaseDQToolset): + return toolset + raise ValueError( + "No BaseDQToolset found in toolsets. " + "Add SQLDQToolset (or another BaseDQToolset subclass) to the toolsets list." + ) + + def _build_system_prompt(self, dq_toolset: BaseDQToolset, toolsets: list[AbstractToolset]) -> str: + """Return the full DQ system prompt for the normal (non-planning) execution path.""" + prompt = self._make_prompt(_DQ_SYSTEM_PROMPT, toolsets) + + fixed_checks = [c.name for c in self.checks if c.validator is not None] + if fixed_checks: + names = ", ".join(f'"{n}"' for n in fixed_checks) + prompt += ( + "\nFIXED VALIDATORS:\n" + f" Checks {names} have pre-assigned validators.\n" + ' For these, call apply_validator with validator_name="fixed" — ' + "the system uses the pre-configured validator automatically.\n" + ) + + if self.system_prompt: + prompt += f"\nAdditional instructions:\n{self.system_prompt}\n" + + return prompt + + def _build_plan_system_prompt(self, toolsets: list[AbstractToolset]) -> str: + """Return the full DQ system prompt for Phase 1 (planning mode, no apply_validator).""" + prompt = self._make_prompt(_DQ_PLAN_SYSTEM_PROMPT, toolsets) + + # In planning mode the FIXED VALIDATORS section uses different wording: + # the LLM records "fixed" in the DQPlan instead of calling apply_validator. + fixed_checks = [c.name for c in self.checks if c.validator is not None] + if fixed_checks: + names = ", ".join(f'"{n}"' for n in fixed_checks) + prompt += ( + "\nFIXED VALIDATORS:\n" + f" Checks {names} have pre-assigned validators.\n" + ' For these, set validator_name: "fixed" and validator_args: {} ' + "in your DQCheckPlan output.\n" + ) + + if self.system_prompt: + prompt += f"\nAdditional instructions:\n{self.system_prompt}\n" + + return prompt + + def _make_prompt(self, base: str, toolsets: list[AbstractToolset]) -> str: + """Inject dialect, schema context, and fixed-validator sections into *base*.""" + prompt = base + + dialect = self._detect_sql_dialect(toolsets) + if dialect: + notes = _DIALECT_SQL_NOTES.get(dialect, "") + dialect_section = f"\nSQL DIALECT: {dialect.upper()}\n" + if notes: + dialect_section += " Adapt your SQL to the following dialect-specific rules:\n" + notes + prompt += dialect_section + + if self.schema_context: + table_names = _extract_schema_table_names(self.schema_context) + if table_names: + prompt += ( + "\nTABLE NAME CONSTRAINT:\n" + f" The ONLY tables you may reference in FROM clauses are: {', '.join(table_names)}.\n" + " Use these exact names — do not rename, abbreviate, or invent new table names.\n" + ) + prompt += f"\nSchema context:\n{self.schema_context}\n" + + return prompt + + @staticmethod + def _detect_sql_dialect(toolsets: list[AbstractToolset]) -> str | None: + """Return the sqlglot dialect of the first dialect-aware toolset found.""" + for toolset in toolsets: + dialect = getattr(toolset, "sqlglot_dialect", None) + if dialect: + return dialect + return None Review Comment: `_detect_sql_dialect()` directly evaluates `toolset.sqlglot_dialect`. If that property triggers connection resolution (e.g. `BaseHook.get_connection`) and raises, prompt construction will fail even though dialect hints are optional. Treat dialect detection as best-effort by catching exceptions and continuing. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
